Skip to content

fix DPM Scheduler with use_karras_sigmas option #6477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@
"DEISMultistepScheduler",
"DPMSolverMultistepInverseScheduler",
"DPMSolverMultistepScheduler",
"DPMSolverMultistepSchedulerLegacy",
"DPMSolverSinglestepScheduler",
"DPMSolverSinglestepSchedulerLegacy",
"EulerAncestralDiscreteScheduler",
"EulerDiscreteScheduler",
"HeunDiscreteScheduler",
Expand Down Expand Up @@ -519,7 +521,9 @@
DEISMultistepScheduler,
DPMSolverMultistepInverseScheduler,
DPMSolverMultistepScheduler,
DPMSolverMultistepSchedulerLegacy,
DPMSolverSinglestepScheduler,
DPMSolverSinglestepSchedulerLegacy,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
Expand Down
14 changes: 12 additions & 2 deletions src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@
_dummy_modules.update(get_objects_from_module(dummy_pt_objects))

else:
_import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"]
_import_structure["deprecated"] = [
"DPMSolverMultistepSchedulerLegacy",
"DPMSolverSinglestepSchedulerLegacy",
"KarrasVeScheduler",
"ScoreSdeVpScheduler",
]
_import_structure["scheduling_amused"] = ["AmusedScheduler"]
_import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"]
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
Expand Down Expand Up @@ -129,7 +134,12 @@
except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler
from .deprecated import (
DPMSolverMultistepSchedulerLegacy,
DPMSolverSinglestepSchedulerLegacy,
KarrasVeScheduler,
ScoreSdeVpScheduler,
)
from .scheduling_amused import AmusedScheduler
from .scheduling_consistency_decoder import ConsistencyDecoderScheduler
from .scheduling_consistency_models import CMStochasticIterativeScheduler
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/schedulers/deprecated/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
else:
_import_structure["scheduling_dpmsolver_multistep_legacy"] = ["DPMSolverMultistepSchedulerLegacy"]
_import_structure["scheduling_dpmsolver_singlestep_legacy"] = ["DPMSolverSinglestepSchedulerLegacy"]
_import_structure["scheduling_karras_ve"] = ["KarrasVeScheduler"]
_import_structure["scheduling_sde_vp"] = ["ScoreSdeVpScheduler"]

Expand All @@ -32,6 +34,8 @@
except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .scheduling_dpmsolver_multistep_legacy import DPMSolverMultistepSchedulerLegacy
from .scheduling_dpmsolver_singlestep_legacy import DPMSolverSinglestepSchedulerLegacy
from .scheduling_karras_ve import KarrasVeScheduler
from .scheduling_sde_vp import ScoreSdeVpScheduler

Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

132 changes: 39 additions & 93 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
`algorithm_type="dpmsolver++"`.
algorithm_type (`str`, defaults to `dpmsolver++`):
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
paper, and the `dpmsolver++` type implements the algorithms in the
Algorithm type for the solver; can be `dpmsolver++` or `sde-dpmsolver++`. It implements the algorithms in the
[DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
`sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
solver_type (`str`, defaults to `midpoint`):
Expand Down Expand Up @@ -164,6 +162,7 @@ def __init__(
lower_order_final: bool = True,
euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
final_sigmas_type: Optional[str] = "default", # "denoise_to_zero", "default"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are both options necessary? Each has an advantage in certain scenarios?

Copy link
Collaborator Author

@yiyixuxu yiyixuxu Jan 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • denoise_to_zero matches the last step in k-diffusion. We can apply this with or without the karras sigmas. see some of my comparison results here fix DPM Scheduler with use_karras_sigmas option  #6477 (comment)
  • default is what we currently have:
    • in my experiment, it was worse than the other options we just added. But obviously, I did not test for a lot of different settings, and I believe it will do better with higher step numbers
    • I want to point out that there is some inconsistency between the current behavior of use_karras_sigmas = True and use_karras_sigmas = False: with karras_sigmas, we are skipping the last denoising step. I'm not entirely sure why it is done this way, so I'm not comfortable just updating it. I'm hoping to address that in a new PR so we don't delay this one getting merged in

current code for use_karras_sigmas

            sigmas = np.flip(sigmas).copy()
            sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
            sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)

I think it should be

            sigmas = np.flip(sigmas).copy()
            sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=(num_inference_steps+1))
            timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we should change the default to "denoise_to_zero" here though if you think it always gives better results. Also cc @sayakpaul

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can do that! I will deprecate the dpmsolver and dpmsolver-sde since they won't work with the denoise_to_zero and I think they are not really used at all

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for me!

use_lu_lambdas: Optional[bool] = False,
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
Expand Down Expand Up @@ -195,9 +194,13 @@ def __init__(
self.init_noise_sigma = 1.0

# settings for DPM-Solver
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"]:
if algorithm_type == "deis":
self.register_to_config(algorithm_type="dpmsolver++")
elif algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
raise ValueError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just throw a deprecation warning directly here instead of a ValueError?

f"`algorithm_type` {algorithm_type} is no longer supported in {self.__class__}. Please use `DPMSolverMultistepSchedulerLegacy` instead."
)
else:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")

Expand Down Expand Up @@ -267,16 +270,38 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten

Was there a reason to set the sigmas this way? or is it just a mistake that we didn't catch in the PR?

By repeating the last sigma twice, we make the very last step a dummy step with zero step size

i.e. this code here is x_t= sample because we have sigma_t = sigma_s0 and h=0

x_t = (
                    (sigma_t / sigma_s0) * sample
                    - (alpha_t * (torch.exp(-h) - 1.0)) * D0
                    - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
                )

if self.config.solver_type == "midpoint":

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! Was this the culprit?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes this was it!

if self.config.final_sigmas_type == "default":
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
elif self.config.final_sigmas_type == "denoise_to_zero":
sigmas = np.concatenate([sigmas, np.array([0])]).astype(np.float32)
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}"
)
elif self.config.use_lu_lambdas:
lambdas = np.flip(log_sigmas.copy())
lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
sigmas = np.exp(lambdas)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
if self.config.final_sigmas_type == "default":
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
elif self.config.final_sigmas_type == "denoise_to_zero":
sigmas = np.concatenate([sigmas, np.array([0])]).astype(np.float32)
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}"
)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
if self.config.final_sigmas_type == "default":
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
elif self.config.final_sigmas_type == "denoise_to_zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}"
)

sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)

self.sigmas = torch.from_numpy(sigmas)
Expand Down Expand Up @@ -404,14 +429,11 @@ def convert_model_output(
**kwargs,
) -> torch.FloatTensor:
"""
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
integral of the data prediction model.
Convert the model output to predict data. DPM-Solver++ is designed to discretize an integral of the data prediction model.

<Tip>

The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
prediction and data prediction models.
The algorithm and model type are decoupled. You can use DPMSolver++ for either noise prediction and data prediction models.

</Tip>

Expand Down Expand Up @@ -441,7 +463,7 @@ def convert_model_output(
# DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output.
# DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned", "learned_range"]:
model_output = model_output[:, :3]
sigma = self.sigmas[self.step_index]
Expand All @@ -464,37 +486,6 @@ def convert_model_output(

return x0_pred

# DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned", "learned_range"]:
epsilon = model_output[:, :3]
else:
epsilon = model_output
elif self.config.prediction_type == "sample":
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
epsilon = (sample - alpha_t * model_output) / sigma_t
elif self.config.prediction_type == "v_prediction":
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
epsilon = alpha_t * model_output + sigma_t * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction` for the DPMSolverMultistepScheduler."
)

if self.config.thresholding:
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
x0_pred = (sample - sigma_t * epsilon) / alpha_t
x0_pred = self._threshold_sample(x0_pred)
epsilon = (sample - alpha_t * x0_pred) / sigma_t

return epsilon

def dpm_solver_first_order_update(
self,
model_output: torch.FloatTensor,
Expand Down Expand Up @@ -546,22 +537,13 @@ def dpm_solver_first_order_update(
h = lambda_t - lambda_s
if self.config.algorithm_type == "dpmsolver++":
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
elif self.config.algorithm_type == "dpmsolver":
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
x_t = (
(sigma_t / sigma_s * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.algorithm_type == "sde-dpmsolver":
assert noise is not None
x_t = (
(alpha_t / alpha_s) * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
return x_t

def multistep_dpm_solver_second_order_update(
Expand Down Expand Up @@ -639,20 +621,6 @@ def multistep_dpm_solver_second_order_update(
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
)
elif self.config.algorithm_type == "dpmsolver":
# See https://arxiv.org/abs/2206.00927 for detailed derivations
if self.config.solver_type == "midpoint":
x_t = (
(alpha_t / alpha_s0) * sample
- (sigma_t * (torch.exp(h) - 1.0)) * D0
- 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
)
elif self.config.solver_type == "heun":
x_t = (
(alpha_t / alpha_s0) * sample
- (sigma_t * (torch.exp(h) - 1.0)) * D0
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
)
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
if self.config.solver_type == "midpoint":
Expand All @@ -669,22 +637,6 @@ def multistep_dpm_solver_second_order_update(
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.algorithm_type == "sde-dpmsolver":
assert noise is not None
if self.config.solver_type == "midpoint":
x_t = (
(alpha_t / alpha_s0) * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
- (sigma_t * (torch.exp(h) - 1.0)) * D1
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
elif self.config.solver_type == "heun":
x_t = (
(alpha_t / alpha_s0) * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
- 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
return x_t

def multistep_dpm_solver_third_order_update(
Expand Down Expand Up @@ -762,14 +714,6 @@ def multistep_dpm_solver_third_order_update(
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
- (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
)
elif self.config.algorithm_type == "dpmsolver":
# See https://arxiv.org/abs/2206.00927 for detailed derivations
x_t = (
(alpha_t / alpha_s0) * sample
- (sigma_t * (torch.exp(h) - 1.0)) * D0
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
)
return x_t

def _init_step_index(self, timestep):
Expand Down Expand Up @@ -831,7 +775,9 @@ def step(

# Improve numerical stability for small number of steps
lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15)
self.config.euler_at_final
or (self.config.lower_order_final and len(self.timesteps) < 15)
or self.config.final_sigmas_type == "denoise_to_zero"
)
lower_order_second = (
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
Expand All @@ -842,7 +788,7 @@ def step(
self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output

if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
if self.config.algorithm_type in ["sde-dpmsolver++"]:
noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,12 +273,6 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
sigmas = np.concatenate([sigmas, [sigma_max]]).astype(np.float32)

self.sigmas = torch.from_numpy(sigmas)

# when num_inference_steps == num_train_timesteps, we can end up with
# duplicates in timesteps.
_, unique_indices = np.unique(timesteps, return_index=True)
timesteps = timesteps[np.sort(unique_indices)]

self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)

self.num_inference_steps = len(timesteps)
Expand Down
Loading