Skip to content

add index_counter to DPMSolverMultistepScheduler #4187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver

import math
from collections import defaultdict
from typing import List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -274,11 +275,6 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc

self.sigmas = torch.from_numpy(sigmas)

# when num_inference_steps == num_train_timesteps, we can end up with
# duplicates in timesteps.
_, unique_indices = np.unique(timesteps, return_index=True)
timesteps = timesteps[np.sort(unique_indices)]

self.timesteps = torch.from_numpy(timesteps).to(device)

self.num_inference_steps = len(timesteps)
Expand All @@ -288,6 +284,9 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
] * self.config.solver_order
self.lower_order_nums = 0

# add an index counter for schedulers that allow duplicated timesteps
self._index_counter = defaultdict(int)

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
"""
Expand Down Expand Up @@ -660,11 +659,25 @@ def step(

if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero()
if len(step_index) == 0:
indices = (self.timesteps == timestep).nonzero()
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep

if len(indices) == 0:
step_index = len(self.timesteps) - 1
else:
step_index = step_index.item()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(self._index_counter) == 0:
pos = 1 if len(indices) > 1 else 0
else:
pos = self._index_counter[timestep_int]
step_index = indices[pos].item()

# advance index counter by 1
self._index_counter[timestep_int] += 1

prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
lower_order_final = (
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
Expand Down
4 changes: 2 additions & 2 deletions tests/schedulers/test_scheduler_dpm_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,10 @@ def test_fp16_support(self):

assert sample.dtype == torch.float16

def test_unique_timesteps(self, **config):
def test_duplicated_timesteps(self, **config):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)

scheduler.set_timesteps(scheduler.config.num_train_timesteps)
assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps
assert len(scheduler.timesteps) == scheduler.num_inference_steps