-
Notifications
You must be signed in to change notification settings - Fork 6k
refactor DPMSolverMultistepScheduler using sigmas #4986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 29 commits
8f78025
3b886af
a05a13a
670c782
c95b545
9eeb5e9
515c105
a85a18c
f238e0d
7198998
67ef0e3
86601bc
0445412
f687c7d
48a9b1e
b3bf644
c5a6cdb
3932182
2f785fb
f39cc96
b28320e
8431b9a
3e0826b
3d31065
dd8ec06
b7f910c
709095a
af5fcd6
9128336
7ff1d84
160216e
25c0432
911ca6a
65fe7c3
d53056c
13729a9
b7223a8
02d07d4
466cb53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -203,6 +203,14 @@ def __init__( | |
self.timesteps = torch.from_numpy(timesteps) | ||
self.model_outputs = [None] * solver_order | ||
self.lower_order_nums = 0 | ||
self._step_index = None | ||
|
||
@property | ||
def step_index(self): | ||
""" | ||
The index counter for current timestep. It will increae 1 after each scheduler step. | ||
""" | ||
return self._step_index | ||
|
||
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): | ||
""" | ||
|
@@ -242,19 +250,20 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc | |
) | ||
|
||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) | ||
log_sigmas = np.log(sigmas) | ||
|
||
if self.config.use_karras_sigmas: | ||
log_sigmas = np.log(sigmas) | ||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) | ||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() | ||
timesteps = np.flip(timesteps).copy().astype(np.int64) | ||
sigmas = np.flip(sigmas).copy() | ||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) | ||
else: | ||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) | ||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 | ||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) | ||
|
||
self.sigmas = torch.from_numpy(sigmas) | ||
|
||
# when num_inference_steps == num_train_timesteps, we can end up with | ||
# duplicates in timesteps. | ||
_, unique_indices = np.unique(timesteps, return_index=True) | ||
timesteps = timesteps[np.sort(unique_indices)] | ||
|
||
self.sigmas = torch.from_numpy(sigmas).to(device=device) | ||
self.timesteps = torch.from_numpy(timesteps).to(device) | ||
|
||
self.num_inference_steps = len(timesteps) | ||
|
@@ -264,6 +273,9 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc | |
] * self.config.solver_order | ||
self.lower_order_nums = 0 | ||
|
||
# add an index counter for schedulers that allow duplicated timesteps | ||
self._step_index = None | ||
|
||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample | ||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: | ||
""" | ||
|
@@ -323,6 +335,12 @@ def _sigma_to_t(self, sigma, log_sigmas): | |
t = t.reshape(sigma.shape) | ||
return t | ||
|
||
def _sigma_to_alpha_sigma_t(self, sigma): | ||
alpha_t = 1 / ((sigma**2 + 1) ** 0.5) | ||
sigma_t = sigma * alpha_t | ||
|
||
return alpha_t, sigma_t | ||
|
||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras | ||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: | ||
"""Constructs the noise schedule of Karras et al. (2022).""" | ||
|
@@ -337,9 +355,7 @@ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) | |
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho | ||
return sigmas | ||
|
||
def convert_model_output( | ||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor | ||
) -> torch.FloatTensor: | ||
def convert_model_output(self, model_output: torch.FloatTensor, sample: torch.FloatTensor) -> torch.FloatTensor: | ||
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is | ||
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an | ||
|
@@ -355,8 +371,6 @@ def convert_model_output( | |
Args: | ||
model_output (`torch.FloatTensor`): | ||
The direct output from the learned diffusion model. | ||
timestep (`int`): | ||
The current discrete timestep in the diffusion chain. | ||
sample (`torch.FloatTensor`): | ||
A current instance of a sample created by the diffusion process. | ||
|
||
|
@@ -371,12 +385,14 @@ def convert_model_output( | |
# DPM-Solver and DPM-Solver++ only need the "mean" output. | ||
if self.config.variance_type in ["learned", "learned_range"]: | ||
model_output = model_output[:, :3] | ||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] | ||
sigma = self.sigmas[self.step_index] | ||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) | ||
x0_pred = (sample - sigma_t * model_output) / alpha_t | ||
elif self.config.prediction_type == "sample": | ||
x0_pred = model_output | ||
elif self.config.prediction_type == "v_prediction": | ||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] | ||
sigma = self.sigmas[self.step_index] | ||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) | ||
x0_pred = alpha_t * sample - sigma_t * model_output | ||
else: | ||
raise ValueError( | ||
|
@@ -398,10 +414,12 @@ def convert_model_output( | |
else: | ||
epsilon = model_output | ||
elif self.config.prediction_type == "sample": | ||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] | ||
sigma = self.sigmas[self.step_index] | ||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) | ||
epsilon = (sample - alpha_t * model_output) / sigma_t | ||
elif self.config.prediction_type == "v_prediction": | ||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] | ||
sigma = self.sigmas[self.step_index] | ||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) | ||
epsilon = alpha_t * model_output + sigma_t * sample | ||
else: | ||
raise ValueError( | ||
|
@@ -410,7 +428,8 @@ def convert_model_output( | |
) | ||
|
||
if self.config.thresholding: | ||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] | ||
sigma = self.sigmas[self.step_index] | ||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) | ||
x0_pred = (sample - sigma_t * epsilon) / alpha_t | ||
x0_pred = self._threshold_sample(x0_pred) | ||
epsilon = (sample - alpha_t * x0_pred) / sigma_t | ||
|
@@ -420,8 +439,6 @@ def convert_model_output( | |
def dpm_solver_first_order_update( | ||
self, | ||
model_output: torch.FloatTensor, | ||
timestep: int, | ||
prev_timestep: int, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's also correctly deprecate those |
||
sample: torch.FloatTensor, | ||
noise: Optional[torch.FloatTensor] = None, | ||
) -> torch.FloatTensor: | ||
|
@@ -442,9 +459,13 @@ def dpm_solver_first_order_update( | |
`torch.FloatTensor`: | ||
The sample tensor at the previous timestep. | ||
""" | ||
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] | ||
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] | ||
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] | ||
|
||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] | ||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) | ||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) | ||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t) | ||
lambda_s = torch.log(alpha_s) - torch.log(sigma_s) | ||
|
||
h = lambda_t - lambda_s | ||
if self.config.algorithm_type == "dpmsolver++": | ||
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output | ||
|
@@ -469,8 +490,6 @@ def dpm_solver_first_order_update( | |
def multistep_dpm_solver_second_order_update( | ||
self, | ||
model_output_list: List[torch.FloatTensor], | ||
timestep_list: List[int], | ||
prev_timestep: int, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add *args and **kwargs to deprecate |
||
sample: torch.FloatTensor, | ||
noise: Optional[torch.FloatTensor] = None, | ||
) -> torch.FloatTensor: | ||
|
@@ -491,11 +510,23 @@ def multistep_dpm_solver_second_order_update( | |
`torch.FloatTensor`: | ||
The sample tensor at the previous timestep. | ||
""" | ||
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] | ||
|
||
sigma_t, sigma_s0, sigma_s1 = ( | ||
self.sigmas[self.step_index + 1], | ||
self.sigmas[self.step_index], | ||
self.sigmas[self.step_index - 1], | ||
) | ||
|
||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) | ||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) | ||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) | ||
|
||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t) | ||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) | ||
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) | ||
|
||
m0, m1 = model_output_list[-1], model_output_list[-2] | ||
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] | ||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] | ||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] | ||
|
||
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 | ||
r0 = h_0 / h | ||
D0, D1 = m0, (1.0 / r0) * (m0 - m1) | ||
|
@@ -564,8 +595,6 @@ def multistep_dpm_solver_second_order_update( | |
def multistep_dpm_solver_third_order_update( | ||
self, | ||
model_output_list: List[torch.FloatTensor], | ||
timestep_list: List[int], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add *args and **kwargs to deprecate |
||
prev_timestep: int, | ||
sample: torch.FloatTensor, | ||
) -> torch.FloatTensor: | ||
""" | ||
|
@@ -585,16 +614,26 @@ def multistep_dpm_solver_third_order_update( | |
`torch.FloatTensor`: | ||
The sample tensor at the previous timestep. | ||
""" | ||
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] | ||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] | ||
lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( | ||
self.lambda_t[t], | ||
self.lambda_t[s0], | ||
self.lambda_t[s1], | ||
self.lambda_t[s2], | ||
|
||
sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( | ||
self.sigmas[self.step_index + 1], | ||
self.sigmas[self.step_index], | ||
self.sigmas[self.step_index - 1], | ||
self.sigmas[self.step_index - 2], | ||
) | ||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] | ||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] | ||
|
||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) | ||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) | ||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) | ||
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) | ||
|
||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t) | ||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) | ||
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) | ||
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) | ||
|
||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] | ||
|
||
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 | ||
r0, r1 = h_0 / h, h_1 / h | ||
D0 = m0 | ||
|
@@ -619,6 +658,25 @@ def multistep_dpm_solver_third_order_update( | |
) | ||
return x_t | ||
|
||
def _init_step_index(self, timestep): | ||
if isinstance(timestep, torch.Tensor): | ||
timestep = timestep.to(self.timesteps.device) | ||
|
||
index_candidates = (self.timesteps == timestep).nonzero() | ||
|
||
if len(index_candidates) == 0: | ||
step_index = len(self.timesteps) - 1 | ||
# The sigma index that is taken for the **very** first `step` | ||
# is always the second index (or the last index if there is only 1) | ||
# This way we can ensure we don't accidentally skip a sigma in | ||
# case we start in the middle of the denoising schedule (e.g. for image-to-image) | ||
elif len(index_candidates) > 1: | ||
step_index = index_candidates[1].item() | ||
else: | ||
step_index = index_candidates[0].item() | ||
|
||
self._step_index = step_index | ||
|
||
def step( | ||
self, | ||
model_output: torch.FloatTensor, | ||
|
@@ -654,22 +712,17 @@ def step( | |
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" | ||
) | ||
|
||
if isinstance(timestep, torch.Tensor): | ||
timestep = timestep.to(self.timesteps.device) | ||
step_index = (self.timesteps == timestep).nonzero() | ||
if len(step_index) == 0: | ||
step_index = len(self.timesteps) - 1 | ||
else: | ||
step_index = step_index.item() | ||
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] | ||
if self.step_index is None: | ||
self._init_step_index(timestep) | ||
|
||
lower_order_final = ( | ||
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 | ||
(self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 | ||
) | ||
lower_order_second = ( | ||
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 | ||
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 | ||
) | ||
|
||
model_output = self.convert_model_output(model_output, timestep, sample) | ||
model_output = self.convert_model_output(model_output, sample) | ||
for i in range(self.config.solver_order - 1): | ||
self.model_outputs[i] = self.model_outputs[i + 1] | ||
self.model_outputs[-1] = model_output | ||
|
@@ -682,23 +735,18 @@ def step( | |
noise = None | ||
|
||
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: | ||
prev_sample = self.dpm_solver_first_order_update( | ||
model_output, timestep, prev_timestep, sample, noise=noise | ||
) | ||
prev_sample = self.dpm_solver_first_order_update(model_output, sample, noise=noise) | ||
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: | ||
timestep_list = [self.timesteps[step_index - 1], timestep] | ||
prev_sample = self.multistep_dpm_solver_second_order_update( | ||
self.model_outputs, timestep_list, prev_timestep, sample, noise=noise | ||
) | ||
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample, noise=noise) | ||
else: | ||
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] | ||
prev_sample = self.multistep_dpm_solver_third_order_update( | ||
self.model_outputs, timestep_list, prev_timestep, sample | ||
) | ||
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample) | ||
|
||
if self.lower_order_nums < self.config.solver_order: | ||
self.lower_order_nums += 1 | ||
|
||
# upon completion increase step index by one | ||
self._step_index += 1 | ||
|
||
if not return_dict: | ||
return (prev_sample,) | ||
|
||
|
@@ -719,28 +767,30 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch | |
""" | ||
return sample | ||
|
||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise | ||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise | ||
def add_noise( | ||
self, | ||
original_samples: torch.FloatTensor, | ||
noise: torch.FloatTensor, | ||
timesteps: torch.IntTensor, | ||
timesteps: torch.FloatTensor, | ||
) -> torch.FloatTensor: | ||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples | ||
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) | ||
timesteps = timesteps.to(original_samples.device) | ||
# Make sure sigmas and timesteps have the same device and dtype as original_samples | ||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) | ||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): | ||
# mps does not support float64 | ||
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) | ||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32) | ||
else: | ||
schedule_timesteps = self.timesteps.to(original_samples.device) | ||
timesteps = timesteps.to(original_samples.device) | ||
|
||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 | ||
sqrt_alpha_prod = sqrt_alpha_prod.flatten() | ||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape): | ||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) | ||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] | ||
|
||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 | ||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() | ||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): | ||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) | ||
sigma = sigmas[step_indices].flatten() | ||
while len(sigma.shape) < len(original_samples.shape): | ||
sigma = sigma.unsqueeze(-1) | ||
|
||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise | ||
noisy_samples = original_samples + noise * sigma | ||
return noisy_samples | ||
|
||
def __len__(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a copied from here from another sampler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, originated here and not copied from anywhere 😛
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi @yiyixuxu, question from a casual observer here, following along and converting UniPC into jax... it seems the
_sigma_to_alpha_sigma_t
is used up to 9X in every scheduler step (once inconvert_model_output
, twice in s0 and t versions in each the C and P multistep update functions, and up to twice again in each according to theorder
determination of that step, which would most commonly be 2 or 3, so 1X or 2X more again)is this function used elsewhere in other schedulers in a way that makes it impossible to run the calculation once and set all
sigmas
andalphas
within theset_timesteps
function? from what i can tell during my tinkering,alphas
,lambdas
,sigmas
(and scheduler-specific things likeR
,rks
,rhos
,order
, etc.) are only dependent ontimestep
and input parameters, not any data/sample values during thestep
process, so they can be created in advance and then referenced from thestate
(orself
for torch) variable using[step_index]
instead of[timestep]
, which all could maybe save some compute (and lines of code?)Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@krahnikblis hey! sorry I missed this comment. a good suggestion! feel free to open an PR, otherwise I will find a time to work on this soon