Skip to content

Fix multistep dpmsolver for cosine schedule (suitable for deepfloyd-if) #3314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 3, 2023
27 changes: 25 additions & 2 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
lambda_min_clipped (`float`, default `-inf`):
the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for
cosine (squaredcos_cap_v2) noise schedule.
variance_type (`str`, *optional*):
Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's
guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's
guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
diffusion ODEs.
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
Expand All @@ -140,6 +151,8 @@ def __init__(
solver_type: str = "midpoint",
lower_order_final: bool = True,
use_karras_sigmas: Optional[bool] = False,
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
Expand Down Expand Up @@ -187,7 +200,7 @@ def __init__(
self.lower_order_nums = 0
self.use_karras_sigmas = use_karras_sigmas

def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.

Expand All @@ -197,8 +210,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
# Clipping the minimum of all lambda(t) for numerical stability.
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped)
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
Expand Down Expand Up @@ -320,9 +336,13 @@ def convert_model_output(
Returns:
`torch.FloatTensor`: the converted model output.
"""

# DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type == "dpmsolver++":
if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]:
model_output = model_output[:, :3]
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample":
Expand All @@ -343,6 +363,9 @@ def convert_model_output(
# DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver":
if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]:
model_output = model_output[:, :3]
return model_output
elif self.config.prediction_type == "sample":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
Expand Down
24 changes: 23 additions & 1 deletion src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,17 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
lower_order_final (`bool`, default `True`):
whether to use lower-order solvers in the final steps. For singlestep schedulers, we recommend to enable
this to use up all the function evaluations.
lambda_min_clipped (`float`, default `-inf`):
the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for
cosine (squaredcos_cap_v2) noise schedule.
variance_type (`str`, *optional*):
Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's
guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's
guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
diffusion ODEs.

"""

Expand All @@ -135,6 +146,8 @@ def __init__(
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
Expand Down Expand Up @@ -226,8 +239,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
# Clipping the minimum of all lambda(t) for numerical stability.
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped)
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
Expand Down Expand Up @@ -297,6 +313,9 @@ def convert_model_output(
# DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type == "dpmsolver++":
if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]:
model_output = model_output[:, :3]
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample":
Expand All @@ -317,6 +336,9 @@ def convert_model_output(
# DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver":
if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]:
model_output = model_output[:, :3]
return model_output
elif self.config.prediction_type == "sample":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
Expand Down
10 changes: 10 additions & 0 deletions tests/schedulers/test_scheduler_dpm_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def get_scheduler_config(self, **kwargs):
"algorithm_type": "dpmsolver++",
"solver_type": "midpoint",
"lower_order_final": False,
"lambda_min_clipped": -float("inf"),
"variance_type": None,
}

config.update(**kwargs)
Expand Down Expand Up @@ -187,6 +189,14 @@ def test_lower_order_final(self):
self.check_over_configs(lower_order_final=True)
self.check_over_configs(lower_order_final=False)

def test_lambda_min_clipped(self):
self.check_over_configs(lambda_min_clipped=-float("inf"))
self.check_over_configs(lambda_min_clipped=-5.1)

def test_variance_type(self):
self.check_over_configs(variance_type=None)
self.check_over_configs(variance_type="learned_range")

def test_inference_steps(self):
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)
Expand Down
10 changes: 10 additions & 0 deletions tests/schedulers/test_scheduler_dpm_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def get_scheduler_config(self, **kwargs):
"sample_max_value": 1.0,
"algorithm_type": "dpmsolver++",
"solver_type": "midpoint",
"lambda_min_clipped": -float("inf"),
"variance_type": None,
}

config.update(**kwargs)
Expand Down Expand Up @@ -179,6 +181,14 @@ def test_lower_order_final(self):
self.check_over_configs(lower_order_final=True)
self.check_over_configs(lower_order_final=False)

def test_lambda_min_clipped(self):
self.check_over_configs(lambda_min_clipped=-float("inf"))
self.check_over_configs(lambda_min_clipped=-5.1)

def test_variance_type(self):
self.check_over_configs(variance_type=None)
self.check_over_configs(variance_type="learned_range")

def test_inference_steps(self):
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)
Expand Down