Skip to content

Commit 8263cf0

Browse files
yiyixuxuyiyixuxupatrickvonplaten
authored
refactor DPMSolverMultistepScheduler using sigmas (#4986)
--------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 74e43a4 commit 8263cf0

10 files changed

+1165
-407
lines changed

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 234 additions & 73 deletions
Large diffs are not rendered by default.

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 213 additions & 90 deletions
Large diffs are not rendered by default.

src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py

Lines changed: 265 additions & 91 deletions
Large diffs are not rendered by default.

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 236 additions & 74 deletions
Large diffs are not rendered by default.

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 178 additions & 68 deletions
Large diffs are not rendered by default.

tests/schedulers/test_scheduler_deis.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def check_over_configs(self, time_step=0, **config):
5151

5252
output, new_output = sample, sample
5353
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
54+
t = scheduler.timesteps[t]
5455
output = scheduler.step(residual, t, output, **kwargs).prev_sample
5556
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
5657

tests/schedulers/test_scheduler_dpm_multi.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def check_over_configs(self, time_step=0, **config):
5959

6060
output, new_output = sample, sample
6161
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
62+
t = new_scheduler.timesteps[t]
6263
output = scheduler.step(residual, t, output, **kwargs).prev_sample
6364
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
6465

@@ -91,6 +92,7 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
9192
# copy over dummy past residual (must be after setting timesteps)
9293
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
9394

95+
time_step = new_scheduler.timesteps[time_step]
9496
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
9597
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
9698

@@ -264,10 +266,10 @@ def test_fp16_support(self):
264266

265267
assert sample.dtype == torch.float16
266268

267-
def test_unique_timesteps(self, **config):
269+
def test_duplicated_timesteps(self, **config):
268270
for scheduler_class in self.scheduler_classes:
269271
scheduler_config = self.get_scheduler_config(**config)
270272
scheduler = scheduler_class(**scheduler_config)
271273

272274
scheduler.set_timesteps(scheduler.config.num_train_timesteps)
273-
assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps
275+
assert len(scheduler.timesteps) == scheduler.num_inference_steps

tests/schedulers/test_scheduler_dpm_multi_inverse.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def check_over_configs(self, time_step=0, **config):
5454

5555
output, new_output = sample, sample
5656
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
57+
t = scheduler.timesteps[t]
5758
output = scheduler.step(residual, t, output, **kwargs).prev_sample
5859
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
5960

@@ -222,7 +223,7 @@ def test_full_loop_with_karras_and_v_prediction(self):
222223
sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True)
223224
result_mean = torch.mean(torch.abs(sample))
224225

225-
assert abs(result_mean.item() - 1.7833) < 1e-3
226+
assert abs(result_mean.item() - 1.7833) < 2e-3
226227

227228
def test_switch(self):
228229
# make sure that iterating over schedulers with same config names gives same results

tests/schedulers/test_scheduler_dpm_single.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def check_over_configs(self, time_step=0, **config):
5858

5959
output, new_output = sample, sample
6060
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
61+
t = scheduler.timesteps[t]
6162
output = scheduler.step(residual, t, output, **kwargs).prev_sample
6263
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
6364

@@ -248,3 +249,33 @@ def test_fp16_support(self):
248249
sample = scheduler.step(residual, t, sample).prev_sample
249250

250251
assert sample.dtype == torch.float16
252+
253+
def test_step_shape(self):
254+
kwargs = dict(self.forward_default_kwargs)
255+
256+
num_inference_steps = kwargs.pop("num_inference_steps", None)
257+
258+
for scheduler_class in self.scheduler_classes:
259+
scheduler_config = self.get_scheduler_config()
260+
scheduler = scheduler_class(**scheduler_config)
261+
262+
sample = self.dummy_sample
263+
residual = 0.1 * sample
264+
265+
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
266+
scheduler.set_timesteps(num_inference_steps)
267+
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
268+
kwargs["num_inference_steps"] = num_inference_steps
269+
270+
# copy over dummy past residuals (must be done after set_timesteps)
271+
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
272+
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
273+
274+
time_step_0 = scheduler.timesteps[0]
275+
time_step_1 = scheduler.timesteps[1]
276+
277+
output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample
278+
output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample
279+
280+
self.assertEqual(output_0.shape, sample.shape)
281+
self.assertEqual(output_0.shape, output_1.shape)

tests/schedulers/test_scheduler_unipc.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def check_over_configs(self, time_step=0, **config):
5252

5353
output, new_output = sample, sample
5454
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
55+
t = scheduler.timesteps[t]
5556
output = scheduler.step(residual, t, output, **kwargs).prev_sample
5657
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
5758

@@ -241,11 +242,3 @@ def test_fp16_support(self):
241242
sample = scheduler.step(residual, t, sample).prev_sample
242243

243244
assert sample.dtype == torch.float16
244-
245-
def test_unique_timesteps(self, **config):
246-
for scheduler_class in self.scheduler_classes:
247-
scheduler_config = self.get_scheduler_config(**config)
248-
scheduler = scheduler_class(**scheduler_config)
249-
250-
scheduler.set_timesteps(scheduler.config.num_train_timesteps)
251-
assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps

0 commit comments

Comments
 (0)