Skip to content

Fix DPM single for different strength #3413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
"""
Expand Down Expand Up @@ -251,7 +255,14 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
self.timesteps = torch.from_numpy(timesteps).to(device)
self.model_outputs = [None] * self.config.solver_order
self.sample = None
self.orders = self.get_order_list(num_inference_steps)

if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0:
logger.warn(
"Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=True`."
)
Comment on lines +261 to +262
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love this!

self.register_to_config(lower_order_final=True)

self.order_list = self.get_order_list(num_inference_steps)

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this still hold?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think so! It's a different function I think (_threshold_sample vs. step) :-)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I mean is, the

def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
doesn't seem to have:

order = order if self.model_outputs[-order] is not None else order - 1

Is that allowed in # Copied from ... statements?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

order = order if self.model_outputs[-order] is not None else order - 1

was added to the def step(...) function not the def _threshold_sample function :-) think the git diff review made it hard to read ;-)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! Sorry.

def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
Expand Down Expand Up @@ -597,6 +608,12 @@ def step(
self.model_outputs[-1] = model_output

order = self.order_list[step_index]

# For img2img denoising might start with order>1 which is not possible
# In this case make sure that the first two steps are both order=1
while self.model_outputs[-order] is None:
order -= 1

# For single-step solvers, we use the initial value at each time with order = 1.
if order == 1:
self.sample = sample
Expand Down
16 changes: 16 additions & 0 deletions tests/schedulers/test_scheduler_dpm_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,22 @@ def full_loop(self, scheduler=None, **config):

return sample

def test_full_uneven_loop(self):
scheduler = DPMSolverSinglestepScheduler(**self.get_scheduler_config())
num_inference_steps = 50
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)

# make sure that the first t is uneven
for i, t in enumerate(scheduler.timesteps[3:]):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample

result_mean = torch.mean(torch.abs(sample))

assert abs(result_mean.item() - 0.2574) < 1e-3

def test_timesteps(self):
for timesteps in [25, 50, 100, 999, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
Expand Down