Skip to content

refactor DPMSolverMultistepScheduler using sigmas #4986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
307 changes: 234 additions & 73 deletions src/diffusers/schedulers/scheduling_deis_multistep.py

Large diffs are not rendered by default.

303 changes: 213 additions & 90 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Large diffs are not rendered by default.

356 changes: 265 additions & 91 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py

Large diffs are not rendered by default.

310 changes: 236 additions & 74 deletions src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Large diffs are not rendered by default.

246 changes: 178 additions & 68 deletions src/diffusers/schedulers/scheduling_unipc_multistep.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions tests/schedulers/test_scheduler_deis.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def check_over_configs(self, time_step=0, **config):

output, new_output = sample, sample
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
t = scheduler.timesteps[t]
output = scheduler.step(residual, t, output, **kwargs).prev_sample
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample

Expand Down
6 changes: 4 additions & 2 deletions tests/schedulers/test_scheduler_dpm_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def check_over_configs(self, time_step=0, **config):

output, new_output = sample, sample
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
t = new_scheduler.timesteps[t]
output = scheduler.step(residual, t, output, **kwargs).prev_sample
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample

Expand Down Expand Up @@ -91,6 +92,7 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
# copy over dummy past residual (must be after setting timesteps)
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]

time_step = new_scheduler.timesteps[time_step]
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample

Expand Down Expand Up @@ -264,10 +266,10 @@ def test_fp16_support(self):

assert sample.dtype == torch.float16

def test_unique_timesteps(self, **config):
def test_duplicated_timesteps(self, **config):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)

scheduler.set_timesteps(scheduler.config.num_train_timesteps)
assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps
assert len(scheduler.timesteps) == scheduler.num_inference_steps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

3 changes: 2 additions & 1 deletion tests/schedulers/test_scheduler_dpm_multi_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def check_over_configs(self, time_step=0, **config):

output, new_output = sample, sample
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
t = scheduler.timesteps[t]
Copy link
Collaborator Author

@yiyixuxu yiyixuxu Sep 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We used 0,1,2 directly as timestep index here, but they are not even in the self.timestepes. In the previous implementation, it will default to the last timesteps when it is outside of the timesteps range, but I don't think it is intended. I changed it here, I think it makes more sense this way. Let me know if it's not the case

output = scheduler.step(residual, t, output, **kwargs).prev_sample
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample

Expand Down Expand Up @@ -222,7 +223,7 @@ def test_full_loop_with_karras_and_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True)
result_mean = torch.mean(torch.abs(sample))

assert abs(result_mean.item() - 1.7833) < 1e-3
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this be ok?

assert abs(result_mean.item() - 1.7833) < 2e-3

def test_switch(self):
# make sure that iterating over schedulers with same config names gives same results
Expand Down
31 changes: 31 additions & 0 deletions tests/schedulers/test_scheduler_dpm_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def check_over_configs(self, time_step=0, **config):

output, new_output = sample, sample
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
t = scheduler.timesteps[t]
output = scheduler.step(residual, t, output, **kwargs).prev_sample
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample

Expand Down Expand Up @@ -248,3 +249,33 @@ def test_fp16_support(self):
sample = scheduler.step(residual, t, sample).prev_sample

assert sample.dtype == torch.float16

def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)

num_inference_steps = kwargs.pop("num_inference_steps", None)

for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)

sample = self.dummy_sample
residual = 0.1 * sample

if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps

# copy over dummy past residuals (must be done after set_timesteps)
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]

time_step_0 = scheduler.timesteps[0]
time_step_1 = scheduler.timesteps[1]

output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample
output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample

self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
9 changes: 1 addition & 8 deletions tests/schedulers/test_scheduler_unipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def check_over_configs(self, time_step=0, **config):

output, new_output = sample, sample
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
t = scheduler.timesteps[t]
output = scheduler.step(residual, t, output, **kwargs).prev_sample
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample

Expand Down Expand Up @@ -241,11 +242,3 @@ def test_fp16_support(self):
sample = scheduler.step(residual, t, sample).prev_sample

assert sample.dtype == torch.float16

def test_unique_timesteps(self, **config):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't need this test anymore because we allow duplicated timesteps now

for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)

scheduler.set_timesteps(scheduler.config.num_train_timesteps)
assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps