-
Notifications
You must be signed in to change notification settings - Fork 6k
Fix DPM single for different strength #3413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f65a58e
c228daf
70e58d9
cfb3d2d
f64e139
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -21,9 +21,13 @@ | |||
import torch | ||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config | ||||
from ..utils import logging | ||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput | ||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | ||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar | ||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): | ||||
""" | ||||
|
@@ -251,7 +255,14 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic | |||
self.timesteps = torch.from_numpy(timesteps).to(device) | ||||
self.model_outputs = [None] * self.config.solver_order | ||||
self.sample = None | ||||
self.orders = self.get_order_list(num_inference_steps) | ||||
|
||||
if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0: | ||||
logger.warn( | ||||
"Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=True`." | ||||
) | ||||
self.register_to_config(lower_order_final=True) | ||||
|
||||
self.order_list = self.get_order_list(num_inference_steps) | ||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this still hold? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think so! It's a different function I think ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What I mean is, the
Is that allowed in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
was added to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah! Sorry. |
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: | ||||
|
@@ -597,6 +608,12 @@ def step( | |||
self.model_outputs[-1] = model_output | ||||
|
||||
order = self.order_list[step_index] | ||||
|
||||
# For img2img denoising might start with order>1 which is not possible | ||||
# In this case make sure that the first two steps are both order=1 | ||||
while self.model_outputs[-order] is None: | ||||
order -= 1 | ||||
|
||||
# For single-step solvers, we use the initial value at each time with order = 1. | ||||
if order == 1: | ||||
self.sample = sample | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love this!