Skip to content

[draft] refactor DPMSolverMultistepScheduler using sigmas #4690

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 111 additions & 63 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,6 @@ def __init__(
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)

# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0

# settings for DPM-Solver
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
if algorithm_type == "deis":
Expand All @@ -200,9 +197,26 @@ def __init__(
# setable values
self.num_inference_steps = None
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()

self.timesteps = torch.from_numpy(timesteps)
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None

@property
def init_noise_sigma(self):
# standard deviation of the initial noise distribution
if self.config.timestep_spacing in ["linspace", "trailing"]:
return self.sigmas.max()

return (self.sigmas.max() ** 2 + 1) ** 0.5

@property
def step_index(self):
"""
TODO: Nice docstring
"""
return self._step_index

def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
"""
Expand All @@ -221,39 +235,34 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc

# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64)
)
timesteps = np.linspace(0, last_timestep - 1, num_inference_steps)[::-1].copy().astype(np.float32)
elif self.config.timestep_spacing == "leading":
step_ratio = last_timestep // (num_inference_steps + 1)
step_ratio = last_timestep // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.float32)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)

sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)

if self.config.use_karras_sigmas:
log_sigmas = np.log(sigmas)
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
timesteps = np.flip(timesteps).copy().astype(np.int64)

self.sigmas = torch.from_numpy(sigmas)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])

# when num_inference_steps == num_train_timesteps, we can end up with
# duplicates in timesteps.
_, unique_indices = np.unique(timesteps, return_index=True)
timesteps = timesteps[np.sort(unique_indices)]
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device)

self.timesteps = torch.from_numpy(timesteps).to(device)

Expand All @@ -264,6 +273,9 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
] * self.config.solver_order
self.lower_order_nums = 0

# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
"""
Expand Down Expand Up @@ -371,13 +383,13 @@ def convert_model_output(
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned", "learned_range"]:
model_output = model_output[:, :3]
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t
sigma = self.sigmas[self.step_index]
x0_pred = sample - sigma * model_output
elif self.config.prediction_type == "sample":
x0_pred = model_output
elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = alpha_t * sample - sigma_t * model_output
sigma = self.sigmas[self.step_index]
x0_pred = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
Expand Down Expand Up @@ -442,19 +454,24 @@ def dpm_solver_first_order_update(
`torch.FloatTensor`:
The sample tensor at the previous timestep.
"""
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
h = lambda_t - lambda_s

def t_fn(_sigma):
return -torch.log(_sigma)

# YiYi notes: keep these for now so don't get an error, don't need once fully refactored
#alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]

sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
h = t_fn(sigma_t) - t_fn(sigma_s)
if self.config.algorithm_type == "dpmsolver++":
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
x_t = (sigma_t / sigma_s) * sample - (torch.exp(-h) - 1.0) * model_output
elif self.config.algorithm_type == "dpmsolver":
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
x_t = (
(sigma_t / sigma_s * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
+ (1 - torch.exp(-2.0 * h)) * model_output
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.algorithm_type == "sde-dpmsolver":
Expand Down Expand Up @@ -491,27 +508,34 @@ def multistep_dpm_solver_second_order_update(
`torch.FloatTensor`:
The sample tensor at the previous timestep.
"""
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]

def t_fn(_sigma):
return -torch.log(_sigma)

# YiYi notes: keep these for now so don't get an error, not needed once fully refactored
#t, s0 = prev_timestep, timestep_list[-1]
#alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]

sigma_t, sigma_s0, sigma_s1 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
)
m0, m1 = model_output_list[-1], model_output_list[-2]
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1

h, h_0 = t_fn(sigma_t) - t_fn(sigma_s0), t_fn(sigma_s0) - t_fn(sigma_s1)
r0 = h_0 / h
D0, D1 = m0, (1.0 / r0) * (m0 - m1)

if self.config.algorithm_type == "dpmsolver++":
# See https://arxiv.org/abs/2211.01095 for detailed derivations
if self.config.solver_type == "midpoint":
x_t = (
(sigma_t / sigma_s0) * sample
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
- 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
)
x_t = (sigma_t / sigma_s0) * sample - (torch.exp(-h) - 1.0) * D0 - 0.5 * (torch.exp(-h) - 1.0) * D1
elif self.config.solver_type == "heun":
x_t = (
(sigma_t / sigma_s0) * sample
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
- (torch.exp(-h) - 1.0) * D0
+ ((torch.exp(-h) - 1.0) / h + 1.0) * D1
)
elif self.config.algorithm_type == "dpmsolver":
# See https://arxiv.org/abs/2206.00927 for detailed derivations
Expand All @@ -532,15 +556,15 @@ def multistep_dpm_solver_second_order_update(
if self.config.solver_type == "midpoint":
x_t = (
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
+ (1 - torch.exp(-2.0 * h)) * D0
+ 0.5 * (1 - torch.exp(-2.0 * h)) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.solver_type == "heun":
x_t = (
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ (1 - torch.exp(-2.0 * h)) * D0
+ ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.algorithm_type == "sde-dpmsolver":
Expand Down Expand Up @@ -619,6 +643,23 @@ def multistep_dpm_solver_third_order_update(
)
return x_t

def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)

index_candidates = (self.timesteps == timestep).nonzero()

# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
step_index = index_candidates[0]

self._step_index = step_index.item()

def step(
self,
model_output: torch.FloatTensor,
Expand Down Expand Up @@ -654,19 +695,13 @@ def step(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)

if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero()
if len(step_index) == 0:
step_index = len(self.timesteps) - 1
else:
step_index = step_index.item()
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
lower_order_final = (
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
)
if self.step_index is None:
self._init_step_index(timestep)

prev_timestep = 0 if self.step_index == len(self.timesteps) - 1 else self.timesteps[self.step_index + 1]
lower_order_final = self.step_index == len(self.timesteps) - 1
lower_order_second = (
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
)

model_output = self.convert_model_output(model_output, timestep, sample)
Expand All @@ -686,37 +721,50 @@ def step(
model_output, timestep, prev_timestep, sample, noise=noise
)
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
timestep_list = [self.timesteps[step_index - 1], timestep]
timestep_list = [self.timesteps[self.step_index - 1], timestep]
prev_sample = self.multistep_dpm_solver_second_order_update(
self.model_outputs, timestep_list, prev_timestep, sample, noise=noise
)
else:
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
timestep_list = [self.timesteps[self.step_index - 2], self.timesteps[self.step_index - 1], timestep]
prev_sample = self.multistep_dpm_solver_third_order_update(
self.model_outputs, timestep_list, prev_timestep, sample
)

if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1

# upon completion increase step index by one
self._step_index += 1

if not return_dict:
return (prev_sample,)

return SchedulerOutput(prev_sample=prev_sample)

def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.scale_model_input
def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.

Args:
sample (`torch.FloatTensor`):
The input sample.
sample (`torch.FloatTensor`): input sample
timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain

Returns:
`torch.FloatTensor`:
A scaled input sample.
"""
if self.step_index is None:
self._init_step_index(timestep)

sigma = self.sigmas[self.step_index]

sample = sample / ((sigma**2 + 1) ** 0.5)

self.is_scale_input_called = True
return sample

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
Expand Down
4 changes: 2 additions & 2 deletions tests/schedulers/test_scheduler_dpm_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,10 @@ def test_fp16_support(self):

assert sample.dtype == torch.float16

def test_unique_timesteps(self, **config):
def test_duplicated_timesteps(self, **config):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)

scheduler.set_timesteps(scheduler.config.num_train_timesteps)
assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps
assert len(scheduler.timesteps) == scheduler.num_inference_steps