-
Notifications
You must be signed in to change notification settings - Fork 6k
refactor DPMSolverMultistepScheduler using sigmas #4986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8f78025
3b886af
a05a13a
670c782
c95b545
9eeb5e9
515c105
a85a18c
f238e0d
7198998
67ef0e3
86601bc
0445412
f687c7d
48a9b1e
b3bf644
c5a6cdb
3932182
2f785fb
f39cc96
b28320e
8431b9a
3e0826b
3d31065
dd8ec06
b7f910c
709095a
af5fcd6
9128336
7ff1d84
160216e
25c0432
911ca6a
65fe7c3
d53056c
13729a9
b7223a8
02d07d4
466cb53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,6 +54,7 @@ def check_over_configs(self, time_step=0, **config): | |
|
||
output, new_output = sample, sample | ||
for t in range(time_step, time_step + scheduler.config.solver_order + 1): | ||
t = scheduler.timesteps[t] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We used |
||
output = scheduler.step(residual, t, output, **kwargs).prev_sample | ||
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample | ||
|
||
|
@@ -222,7 +223,7 @@ def test_full_loop_with_karras_and_v_prediction(self): | |
sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) | ||
result_mean = torch.mean(torch.abs(sample)) | ||
|
||
assert abs(result_mean.item() - 1.7833) < 1e-3 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would this be ok? |
||
assert abs(result_mean.item() - 1.7833) < 2e-3 | ||
|
||
def test_switch(self): | ||
# make sure that iterating over schedulers with same config names gives same results | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,6 +52,7 @@ def check_over_configs(self, time_step=0, **config): | |
|
||
output, new_output = sample, sample | ||
for t in range(time_step, time_step + scheduler.config.solver_order + 1): | ||
t = scheduler.timesteps[t] | ||
output = scheduler.step(residual, t, output, **kwargs).prev_sample | ||
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample | ||
|
||
|
@@ -241,11 +242,3 @@ def test_fp16_support(self): | |
sample = scheduler.step(residual, t, sample).prev_sample | ||
|
||
assert sample.dtype == torch.float16 | ||
|
||
def test_unique_timesteps(self, **config): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't need this test anymore because we allow duplicated timesteps now |
||
for scheduler_class in self.scheduler_classes: | ||
scheduler_config = self.get_scheduler_config(**config) | ||
scheduler = scheduler_class(**scheduler_config) | ||
|
||
scheduler.set_timesteps(scheduler.config.num_train_timesteps) | ||
assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!