-
Notifications
You must be signed in to change notification settings - Fork 6k
fix DPM Scheduler with use_karras_sigmas
option
#6477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 16 commits
9db1b45
72e58ba
9a49479
6fbdd12
1b2a428
c6ec05f
aa9d57c
15d4fb3
1827a4c
3afde1f
4271d2f
8d54850
b23cade
a2e0903
c8e9f8b
d084322
49b8e80
179acd7
8b16ab6
41e25e9
085498d
4666722
b68106d
91c02fd
643cdc1
64e35fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -106,9 +106,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): | |||
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and | ||||
`algorithm_type="dpmsolver++"`. | ||||
algorithm_type (`str`, defaults to `dpmsolver++`): | ||||
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The | ||||
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) | ||||
paper, and the `dpmsolver++` type implements the algorithms in the | ||||
Algorithm type for the solver; can be `dpmsolver++` or `sde-dpmsolver++`. It implements the algorithms in the | ||||
[DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or | ||||
`sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. | ||||
solver_type (`str`, defaults to `midpoint`): | ||||
|
@@ -164,6 +162,7 @@ def __init__( | |||
lower_order_final: bool = True, | ||||
euler_at_final: bool = False, | ||||
use_karras_sigmas: Optional[bool] = False, | ||||
final_sigmas_type: Optional[str] = "default", # "denoise_to_zero", "default" | ||||
use_lu_lambdas: Optional[bool] = False, | ||||
lambda_min_clipped: float = -float("inf"), | ||||
variance_type: Optional[str] = None, | ||||
|
@@ -195,9 +194,13 @@ def __init__( | |||
self.init_noise_sigma = 1.0 | ||||
|
||||
# settings for DPM-Solver | ||||
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: | ||||
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"]: | ||||
if algorithm_type == "deis": | ||||
self.register_to_config(algorithm_type="dpmsolver++") | ||||
elif algorithm_type in ["dpmsolver", "sde-dpmsolver"]: | ||||
raise ValueError( | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we just throw a deprecation warning directly here instead of a ValueError? |
||||
f"`algorithm_type` {algorithm_type} is no longer supported in {self.__class__}. Please use `DPMSolverMultistepSchedulerLegacy` instead." | ||||
) | ||||
else: | ||||
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") | ||||
|
||||
|
@@ -267,16 +270,38 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc | |||
sigmas = np.flip(sigmas).copy() | ||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) | ||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() | ||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Was there a reason to set the sigmas this way? or is it just a mistake that we didn't catch in the PR? By repeating the last sigma twice, we make the very last step a dummy step with zero step size i.e. this code here is
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice catch! Was this the culprit? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes this was it! |
||||
if self.config.final_sigmas_type == "default": | ||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) | ||||
elif self.config.final_sigmas_type == "denoise_to_zero": | ||||
sigmas = np.concatenate([sigmas, np.array([0])]).astype(np.float32) | ||||
else: | ||||
raise ValueError( | ||||
f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" | ||||
) | ||||
elif self.config.use_lu_lambdas: | ||||
lambdas = np.flip(log_sigmas.copy()) | ||||
lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps) | ||||
sigmas = np.exp(lambdas) | ||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() | ||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) | ||||
if self.config.final_sigmas_type == "default": | ||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) | ||||
elif self.config.final_sigmas_type == "denoise_to_zero": | ||||
sigmas = np.concatenate([sigmas, np.array([0])]).astype(np.float32) | ||||
else: | ||||
raise ValueError( | ||||
f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" | ||||
) | ||||
else: | ||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) | ||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 | ||||
if self.config.final_sigmas_type == "default": | ||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 | ||||
elif self.config.final_sigmas_type == "denoise_to_zero": | ||||
sigma_last = 0 | ||||
else: | ||||
raise ValueError( | ||||
f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}" | ||||
) | ||||
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) | ||||
|
||||
self.sigmas = torch.from_numpy(sigmas) | ||||
|
@@ -404,14 +429,11 @@ def convert_model_output( | |||
**kwargs, | ||||
) -> torch.FloatTensor: | ||||
""" | ||||
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is | ||||
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an | ||||
integral of the data prediction model. | ||||
Convert the model output to predict data. DPM-Solver++ is designed to discretize an integral of the data prediction model. | ||||
|
||||
<Tip> | ||||
|
||||
The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise | ||||
prediction and data prediction models. | ||||
The algorithm and model type are decoupled. You can use DPMSolver++ for either noise prediction and data prediction models. | ||||
|
||||
</Tip> | ||||
|
||||
|
@@ -441,7 +463,7 @@ def convert_model_output( | |||
# DPM-Solver++ needs to solve an integral of the data prediction model. | ||||
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: | ||||
if self.config.prediction_type == "epsilon": | ||||
# DPM-Solver and DPM-Solver++ only need the "mean" output. | ||||
# DPM-Solver++ only need the "mean" output. | ||||
if self.config.variance_type in ["learned", "learned_range"]: | ||||
model_output = model_output[:, :3] | ||||
sigma = self.sigmas[self.step_index] | ||||
|
@@ -464,37 +486,6 @@ def convert_model_output( | |||
|
||||
return x0_pred | ||||
|
||||
# DPM-Solver needs to solve an integral of the noise prediction model. | ||||
elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: | ||||
if self.config.prediction_type == "epsilon": | ||||
# DPM-Solver and DPM-Solver++ only need the "mean" output. | ||||
if self.config.variance_type in ["learned", "learned_range"]: | ||||
epsilon = model_output[:, :3] | ||||
else: | ||||
epsilon = model_output | ||||
elif self.config.prediction_type == "sample": | ||||
sigma = self.sigmas[self.step_index] | ||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) | ||||
epsilon = (sample - alpha_t * model_output) / sigma_t | ||||
elif self.config.prediction_type == "v_prediction": | ||||
sigma = self.sigmas[self.step_index] | ||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) | ||||
epsilon = alpha_t * model_output + sigma_t * sample | ||||
else: | ||||
raise ValueError( | ||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" | ||||
" `v_prediction` for the DPMSolverMultistepScheduler." | ||||
) | ||||
|
||||
if self.config.thresholding: | ||||
sigma = self.sigmas[self.step_index] | ||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) | ||||
x0_pred = (sample - sigma_t * epsilon) / alpha_t | ||||
x0_pred = self._threshold_sample(x0_pred) | ||||
epsilon = (sample - alpha_t * x0_pred) / sigma_t | ||||
|
||||
return epsilon | ||||
|
||||
def dpm_solver_first_order_update( | ||||
self, | ||||
model_output: torch.FloatTensor, | ||||
|
@@ -546,22 +537,13 @@ def dpm_solver_first_order_update( | |||
h = lambda_t - lambda_s | ||||
if self.config.algorithm_type == "dpmsolver++": | ||||
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output | ||||
elif self.config.algorithm_type == "dpmsolver": | ||||
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output | ||||
elif self.config.algorithm_type == "sde-dpmsolver++": | ||||
assert noise is not None | ||||
x_t = ( | ||||
(sigma_t / sigma_s * torch.exp(-h)) * sample | ||||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output | ||||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise | ||||
) | ||||
elif self.config.algorithm_type == "sde-dpmsolver": | ||||
assert noise is not None | ||||
x_t = ( | ||||
(alpha_t / alpha_s) * sample | ||||
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output | ||||
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise | ||||
) | ||||
return x_t | ||||
|
||||
def multistep_dpm_solver_second_order_update( | ||||
|
@@ -639,20 +621,6 @@ def multistep_dpm_solver_second_order_update( | |||
- (alpha_t * (torch.exp(-h) - 1.0)) * D0 | ||||
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 | ||||
) | ||||
elif self.config.algorithm_type == "dpmsolver": | ||||
# See https://arxiv.org/abs/2206.00927 for detailed derivations | ||||
if self.config.solver_type == "midpoint": | ||||
x_t = ( | ||||
(alpha_t / alpha_s0) * sample | ||||
- (sigma_t * (torch.exp(h) - 1.0)) * D0 | ||||
- 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 | ||||
) | ||||
elif self.config.solver_type == "heun": | ||||
x_t = ( | ||||
(alpha_t / alpha_s0) * sample | ||||
- (sigma_t * (torch.exp(h) - 1.0)) * D0 | ||||
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 | ||||
) | ||||
elif self.config.algorithm_type == "sde-dpmsolver++": | ||||
assert noise is not None | ||||
if self.config.solver_type == "midpoint": | ||||
|
@@ -669,22 +637,6 @@ def multistep_dpm_solver_second_order_update( | |||
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 | ||||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise | ||||
) | ||||
elif self.config.algorithm_type == "sde-dpmsolver": | ||||
assert noise is not None | ||||
if self.config.solver_type == "midpoint": | ||||
x_t = ( | ||||
(alpha_t / alpha_s0) * sample | ||||
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 | ||||
- (sigma_t * (torch.exp(h) - 1.0)) * D1 | ||||
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise | ||||
) | ||||
elif self.config.solver_type == "heun": | ||||
x_t = ( | ||||
(alpha_t / alpha_s0) * sample | ||||
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 | ||||
- 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 | ||||
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise | ||||
) | ||||
return x_t | ||||
|
||||
def multistep_dpm_solver_third_order_update( | ||||
|
@@ -762,14 +714,6 @@ def multistep_dpm_solver_third_order_update( | |||
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 | ||||
- (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 | ||||
) | ||||
elif self.config.algorithm_type == "dpmsolver": | ||||
# See https://arxiv.org/abs/2206.00927 for detailed derivations | ||||
x_t = ( | ||||
(alpha_t / alpha_s0) * sample | ||||
- (sigma_t * (torch.exp(h) - 1.0)) * D0 | ||||
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 | ||||
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 | ||||
) | ||||
return x_t | ||||
|
||||
def _init_step_index(self, timestep): | ||||
|
@@ -831,7 +775,9 @@ def step( | |||
|
||||
# Improve numerical stability for small number of steps | ||||
lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( | ||||
self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15) | ||||
self.config.euler_at_final | ||||
or (self.config.lower_order_final and len(self.timesteps) < 15) | ||||
or self.config.final_sigmas_type == "denoise_to_zero" | ||||
) | ||||
lower_order_second = ( | ||||
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 | ||||
|
@@ -842,7 +788,7 @@ def step( | |||
self.model_outputs[i] = self.model_outputs[i + 1] | ||||
self.model_outputs[-1] = model_output | ||||
|
||||
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: | ||||
if self.config.algorithm_type in ["sde-dpmsolver++"]: | ||||
noise = randn_tensor( | ||||
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype | ||||
) | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are both options necessary? Each has an advantage in certain scenarios?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
denoise_to_zero
matches the last step in k-diffusion. We can apply this with or without the karras sigmas. see some of my comparison results here fix DPM Scheduler withuse_karras_sigmas
option #6477 (comment)default
is what we currently have:use_karras_sigmas = True
anduse_karras_sigmas = False
: with karras_sigmas, we are skipping the last denoising step. I'm not entirely sure why it is done this way, so I'm not comfortable just updating it. I'm hoping to address that in a new PR so we don't delay this one getting merged incurrent code for use_karras_sigmas
I think it should be
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we should change the default to
"denoise_to_zero"
here though if you think it always gives better results. Also cc @sayakpaulThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can do that! I will deprecate the
dpmsolver
anddpmsolver-sde
since they won't work with thedenoise_to_zero
and I think they are not really used at allThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok for me!