|
21 | 21 | import torch
|
22 | 22 |
|
23 | 23 | from ..configuration_utils import ConfigMixin, register_to_config
|
| 24 | +from ..utils import logging |
24 | 25 | from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
25 | 26 |
|
26 | 27 |
|
| 28 | +logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
| 29 | + |
| 30 | + |
27 | 31 | # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
28 | 32 | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
29 | 33 | """
|
@@ -251,7 +255,14 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
|
251 | 255 | self.timesteps = torch.from_numpy(timesteps).to(device)
|
252 | 256 | self.model_outputs = [None] * self.config.solver_order
|
253 | 257 | self.sample = None
|
254 |
| - self.orders = self.get_order_list(num_inference_steps) |
| 258 | + |
| 259 | + if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0: |
| 260 | + logger.warn( |
| 261 | + "Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=True`." |
| 262 | + ) |
| 263 | + self.register_to_config(lower_order_final=True) |
| 264 | + |
| 265 | + self.order_list = self.get_order_list(num_inference_steps) |
255 | 266 |
|
256 | 267 | # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
257 | 268 | def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
@@ -597,6 +608,12 @@ def step(
|
597 | 608 | self.model_outputs[-1] = model_output
|
598 | 609 |
|
599 | 610 | order = self.order_list[step_index]
|
| 611 | + |
| 612 | + # For img2img denoising might start with order>1 which is not possible |
| 613 | + # In this case make sure that the first two steps are both order=1 |
| 614 | + while self.model_outputs[-order] is None: |
| 615 | + order -= 1 |
| 616 | + |
600 | 617 | # For single-step solvers, we use the initial value at each time with order = 1.
|
601 | 618 | if order == 1:
|
602 | 619 | self.sample = sample
|
|
0 commit comments