Skip to content

refactor DPMSolverMultistepScheduler using sigmas #4986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 126 additions & 76 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,14 @@ def __init__(
self.timesteps = torch.from_numpy(timesteps)
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None

@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index

def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
"""
Expand Down Expand Up @@ -242,19 +250,20 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
)

sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)

if self.config.use_karras_sigmas:
log_sigmas = np.log(sigmas)
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
timesteps = np.flip(timesteps).copy().astype(np.int64)
sigmas = np.flip(sigmas).copy()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)

self.sigmas = torch.from_numpy(sigmas)

# when num_inference_steps == num_train_timesteps, we can end up with
# duplicates in timesteps.
_, unique_indices = np.unique(timesteps, return_index=True)
timesteps = timesteps[np.sort(unique_indices)]

self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device)

self.num_inference_steps = len(timesteps)
Expand All @@ -264,6 +273,9 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
] * self.config.solver_order
self.lower_order_nums = 0

# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
"""
Expand Down Expand Up @@ -323,6 +335,12 @@ def _sigma_to_t(self, sigma, log_sigmas):
t = t.reshape(sigma.shape)
return t

def _sigma_to_alpha_sigma_t(self, sigma):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _sigma_to_alpha_sigma_t(self, sigma):
def _sigma_to_alpha_sigma_t(self, sigma):

Can we add a copied from here from another sampler

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, originated here and not copied from anywhere 😛

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @yiyixuxu, question from a casual observer here, following along and converting UniPC into jax... it seems the _sigma_to_alpha_sigma_t is used up to 9X in every scheduler step (once in convert_model_output, twice in s0 and t versions in each the C and P multistep update functions, and up to twice again in each according to the order determination of that step, which would most commonly be 2 or 3, so 1X or 2X more again)

is this function used elsewhere in other schedulers in a way that makes it impossible to run the calculation once and set all sigmas and alphas within the set_timesteps function? from what i can tell during my tinkering, alphas, lambdas, sigmas (and scheduler-specific things like R, rks, rhos, order, etc.) are only dependent on timestep and input parameters, not any data/sample values during the step process, so they can be created in advance and then referenced from the state (or self for torch) variable using [step_index] instead of [timestep], which all could maybe save some compute (and lines of code?)

Copy link
Collaborator Author

@yiyixuxu yiyixuxu Jan 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@krahnikblis hey! sorry I missed this comment. a good suggestion! feel free to open an PR, otherwise I will find a time to work on this soon

alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
sigma_t = sigma * alpha_t

return alpha_t, sigma_t

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
Expand All @@ -337,9 +355,7 @@ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas

def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor:
def convert_model_output(self, model_output: torch.FloatTensor, sample: torch.FloatTensor) -> torch.FloatTensor:
"""
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
Expand All @@ -355,8 +371,6 @@ def convert_model_output(
Args:
model_output (`torch.FloatTensor`):
The direct output from the learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.

Expand All @@ -371,12 +385,14 @@ def convert_model_output(
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned", "learned_range"]:
model_output = model_output[:, :3]
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample":
x0_pred = model_output
elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
x0_pred = alpha_t * sample - sigma_t * model_output
else:
raise ValueError(
Expand All @@ -398,10 +414,12 @@ def convert_model_output(
else:
epsilon = model_output
elif self.config.prediction_type == "sample":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
epsilon = (sample - alpha_t * model_output) / sigma_t
elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
epsilon = alpha_t * model_output + sigma_t * sample
else:
raise ValueError(
Expand All @@ -410,7 +428,8 @@ def convert_model_output(
)

if self.config.thresholding:
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
x0_pred = (sample - sigma_t * epsilon) / alpha_t
x0_pred = self._threshold_sample(x0_pred)
epsilon = (sample - alpha_t * x0_pred) / sigma_t
Expand All @@ -420,8 +439,6 @@ def convert_model_output(
def dpm_solver_first_order_update(
self,
model_output: torch.FloatTensor,
timestep: int,
prev_timestep: int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also correctly deprecate those

sample: torch.FloatTensor,
noise: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
Expand All @@ -442,9 +459,13 @@ def dpm_solver_first_order_update(
`torch.FloatTensor`:
The sample tensor at the previous timestep.
"""
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]

sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)

h = lambda_t - lambda_s
if self.config.algorithm_type == "dpmsolver++":
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
Expand All @@ -469,8 +490,6 @@ def dpm_solver_first_order_update(
def multistep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.FloatTensor],
timestep_list: List[int],
prev_timestep: int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add *args and **kwargs to deprecate

sample: torch.FloatTensor,
noise: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
Expand All @@ -491,11 +510,23 @@ def multistep_dpm_solver_second_order_update(
`torch.FloatTensor`:
The sample tensor at the previous timestep.
"""
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]

sigma_t, sigma_s0, sigma_s1 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
)

alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)

lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)

m0, m1 = model_output_list[-1], model_output_list[-2]
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]

h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
r0 = h_0 / h
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
Expand Down Expand Up @@ -564,8 +595,6 @@ def multistep_dpm_solver_second_order_update(
def multistep_dpm_solver_third_order_update(
self,
model_output_list: List[torch.FloatTensor],
timestep_list: List[int],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add *args and **kwargs to deprecate

prev_timestep: int,
sample: torch.FloatTensor,
) -> torch.FloatTensor:
"""
Expand All @@ -585,16 +614,26 @@ def multistep_dpm_solver_third_order_update(
`torch.FloatTensor`:
The sample tensor at the previous timestep.
"""
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
self.lambda_t[t],
self.lambda_t[s0],
self.lambda_t[s1],
self.lambda_t[s2],

sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
self.sigmas[self.step_index - 2],
)
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]

alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)

lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)

m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]

h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
r0, r1 = h_0 / h, h_1 / h
D0 = m0
Expand All @@ -619,6 +658,25 @@ def multistep_dpm_solver_third_order_update(
)
return x_t

def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)

index_candidates = (self.timesteps == timestep).nonzero()

if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()

self._step_index = step_index

def step(
self,
model_output: torch.FloatTensor,
Expand Down Expand Up @@ -654,22 +712,17 @@ def step(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)

if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero()
if len(step_index) == 0:
step_index = len(self.timesteps) - 1
else:
step_index = step_index.item()
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
if self.step_index is None:
self._init_step_index(timestep)

lower_order_final = (
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
(self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
)
lower_order_second = (
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
)

model_output = self.convert_model_output(model_output, timestep, sample)
model_output = self.convert_model_output(model_output, sample)
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output
Expand All @@ -682,23 +735,18 @@ def step(
noise = None

if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
prev_sample = self.dpm_solver_first_order_update(
model_output, timestep, prev_timestep, sample, noise=noise
)
prev_sample = self.dpm_solver_first_order_update(model_output, sample, noise=noise)
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
timestep_list = [self.timesteps[step_index - 1], timestep]
prev_sample = self.multistep_dpm_solver_second_order_update(
self.model_outputs, timestep_list, prev_timestep, sample, noise=noise
)
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample, noise=noise)
else:
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
prev_sample = self.multistep_dpm_solver_third_order_update(
self.model_outputs, timestep_list, prev_timestep, sample
)
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample)

if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1

# upon completion increase step index by one
self._step_index += 1

if not return_dict:
return (prev_sample,)

Expand All @@ -719,28 +767,30 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch
"""
return sample

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)

sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]

sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)

noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
noisy_samples = original_samples + noise * sigma
return noisy_samples

def __len__(self):
Expand Down
Loading