Skip to content

Commit 44c58e1

Browse files
Fix DPM single (huggingface#3413)
* Fix DPM single * add test * fix one more bug * Apply suggestions from code review Co-authored-by: StAlKeR7779 <[email protected]> --------- Co-authored-by: StAlKeR7779 <[email protected]>
1 parent 2f29510 commit 44c58e1

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,13 @@
2121
import torch
2222

2323
from ..configuration_utils import ConfigMixin, register_to_config
24+
from ..utils import logging
2425
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
2526

2627

28+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29+
30+
2731
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
2832
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
2933
"""
@@ -251,7 +255,14 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
251255
self.timesteps = torch.from_numpy(timesteps).to(device)
252256
self.model_outputs = [None] * self.config.solver_order
253257
self.sample = None
254-
self.orders = self.get_order_list(num_inference_steps)
258+
259+
if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0:
260+
logger.warn(
261+
"Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=True`."
262+
)
263+
self.register_to_config(lower_order_final=True)
264+
265+
self.order_list = self.get_order_list(num_inference_steps)
255266

256267
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
257268
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
@@ -597,6 +608,12 @@ def step(
597608
self.model_outputs[-1] = model_output
598609

599610
order = self.order_list[step_index]
611+
612+
# For img2img denoising might start with order>1 which is not possible
613+
# In this case make sure that the first two steps are both order=1
614+
while self.model_outputs[-order] is None:
615+
order -= 1
616+
600617
# For single-step solvers, we use the initial value at each time with order = 1.
601618
if order == 1:
602619
self.sample = sample

0 commit comments

Comments
 (0)