Skip to content

Commit 6dd3871

Browse files
Fix DPM single (#3413)
* Fix DPM single * add test * fix one more bug * Apply suggestions from code review Co-authored-by: StAlKeR7779 <[email protected]> --------- Co-authored-by: StAlKeR7779 <[email protected]>
1 parent 51843fd commit 6dd3871

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,13 @@
2121
import torch
2222

2323
from ..configuration_utils import ConfigMixin, register_to_config
24+
from ..utils import logging
2425
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
2526

2627

28+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29+
30+
2731
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
2832
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
2933
"""
@@ -251,7 +255,14 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
251255
self.timesteps = torch.from_numpy(timesteps).to(device)
252256
self.model_outputs = [None] * self.config.solver_order
253257
self.sample = None
254-
self.orders = self.get_order_list(num_inference_steps)
258+
259+
if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0:
260+
logger.warn(
261+
"Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=True`."
262+
)
263+
self.register_to_config(lower_order_final=True)
264+
265+
self.order_list = self.get_order_list(num_inference_steps)
255266

256267
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
257268
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
@@ -597,6 +608,12 @@ def step(
597608
self.model_outputs[-1] = model_output
598609

599610
order = self.order_list[step_index]
611+
612+
# For img2img denoising might start with order>1 which is not possible
613+
# In this case make sure that the first two steps are both order=1
614+
while self.model_outputs[-order] is None:
615+
order -= 1
616+
600617
# For single-step solvers, we use the initial value at each time with order = 1.
601618
if order == 1:
602619
self.sample = sample

tests/schedulers/test_scheduler_dpm_single.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,22 @@ def full_loop(self, scheduler=None, **config):
116116

117117
return sample
118118

119+
def test_full_uneven_loop(self):
120+
scheduler = DPMSolverSinglestepScheduler(**self.get_scheduler_config())
121+
num_inference_steps = 50
122+
model = self.dummy_model()
123+
sample = self.dummy_sample_deter
124+
scheduler.set_timesteps(num_inference_steps)
125+
126+
# make sure that the first t is uneven
127+
for i, t in enumerate(scheduler.timesteps[3:]):
128+
residual = model(sample, t)
129+
sample = scheduler.step(residual, t, sample).prev_sample
130+
131+
result_mean = torch.mean(torch.abs(sample))
132+
133+
assert abs(result_mean.item() - 0.2574) < 1e-3
134+
119135
def test_timesteps(self):
120136
for timesteps in [25, 50, 100, 999, 1000]:
121137
self.check_over_configs(num_train_timesteps=timesteps)

0 commit comments

Comments
 (0)