Skip to content

Commit 22bfb08

Browse files
LuChengTHUpatrickvonplaten
authored andcommitted
Fix multistep dpmsolver for cosine schedule (suitable for deepfloyd-if) (huggingface#3314)
* fix multistep dpmsolver for cosine schedule (deepfloy-if) * fix a typo * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * update all dpmsolver (singlestep, multistep, dpm, dpm++) for cosine noise schedule * add test, fix style --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 7be5b0a commit 22bfb08

File tree

4 files changed

+68
-3
lines changed

4 files changed

+68
-3
lines changed

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
118118
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
119119
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
120120
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
121+
lambda_min_clipped (`float`, default `-inf`):
122+
the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for
123+
cosine (squaredcos_cap_v2) noise schedule.
124+
variance_type (`str`, *optional*):
125+
Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's
126+
guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
127+
Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
128+
diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's
129+
guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
130+
Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
131+
diffusion ODEs.
121132
"""
122133

123134
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -140,6 +151,8 @@ def __init__(
140151
solver_type: str = "midpoint",
141152
lower_order_final: bool = True,
142153
use_karras_sigmas: Optional[bool] = False,
154+
lambda_min_clipped: float = -float("inf"),
155+
variance_type: Optional[str] = None,
143156
):
144157
if trained_betas is not None:
145158
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -187,7 +200,7 @@ def __init__(
187200
self.lower_order_nums = 0
188201
self.use_karras_sigmas = use_karras_sigmas
189202

190-
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
203+
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
191204
"""
192205
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
193206
@@ -197,8 +210,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
197210
device (`str` or `torch.device`, optional):
198211
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
199212
"""
213+
# Clipping the minimum of all lambda(t) for numerical stability.
214+
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
215+
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped)
200216
timesteps = (
201-
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
217+
np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
202218
.round()[::-1][:-1]
203219
.copy()
204220
.astype(np.int64)
@@ -320,9 +336,13 @@ def convert_model_output(
320336
Returns:
321337
`torch.FloatTensor`: the converted model output.
322338
"""
339+
323340
# DPM-Solver++ needs to solve an integral of the data prediction model.
324341
if self.config.algorithm_type == "dpmsolver++":
325342
if self.config.prediction_type == "epsilon":
343+
# DPM-Solver and DPM-Solver++ only need the "mean" output.
344+
if self.config.variance_type in ["learned_range"]:
345+
model_output = model_output[:, :3]
326346
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
327347
x0_pred = (sample - sigma_t * model_output) / alpha_t
328348
elif self.config.prediction_type == "sample":
@@ -343,6 +363,9 @@ def convert_model_output(
343363
# DPM-Solver needs to solve an integral of the noise prediction model.
344364
elif self.config.algorithm_type == "dpmsolver":
345365
if self.config.prediction_type == "epsilon":
366+
# DPM-Solver and DPM-Solver++ only need the "mean" output.
367+
if self.config.variance_type in ["learned_range"]:
368+
model_output = model_output[:, :3]
346369
return model_output
347370
elif self.config.prediction_type == "sample":
348371
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,17 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
113113
lower_order_final (`bool`, default `True`):
114114
whether to use lower-order solvers in the final steps. For singlestep schedulers, we recommend to enable
115115
this to use up all the function evaluations.
116+
lambda_min_clipped (`float`, default `-inf`):
117+
the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for
118+
cosine (squaredcos_cap_v2) noise schedule.
119+
variance_type (`str`, *optional*):
120+
Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's
121+
guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
122+
Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
123+
diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's
124+
guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
125+
Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
126+
diffusion ODEs.
116127
117128
"""
118129

@@ -135,6 +146,8 @@ def __init__(
135146
algorithm_type: str = "dpmsolver++",
136147
solver_type: str = "midpoint",
137148
lower_order_final: bool = True,
149+
lambda_min_clipped: float = -float("inf"),
150+
variance_type: Optional[str] = None,
138151
):
139152
if trained_betas is not None:
140153
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -226,8 +239,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
226239
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
227240
"""
228241
self.num_inference_steps = num_inference_steps
242+
# Clipping the minimum of all lambda(t) for numerical stability.
243+
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
244+
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped)
229245
timesteps = (
230-
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
246+
np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
231247
.round()[::-1][:-1]
232248
.copy()
233249
.astype(np.int64)
@@ -297,6 +313,9 @@ def convert_model_output(
297313
# DPM-Solver++ needs to solve an integral of the data prediction model.
298314
if self.config.algorithm_type == "dpmsolver++":
299315
if self.config.prediction_type == "epsilon":
316+
# DPM-Solver and DPM-Solver++ only need the "mean" output.
317+
if self.config.variance_type in ["learned_range"]:
318+
model_output = model_output[:, :3]
300319
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
301320
x0_pred = (sample - sigma_t * model_output) / alpha_t
302321
elif self.config.prediction_type == "sample":
@@ -317,6 +336,9 @@ def convert_model_output(
317336
# DPM-Solver needs to solve an integral of the noise prediction model.
318337
elif self.config.algorithm_type == "dpmsolver":
319338
if self.config.prediction_type == "epsilon":
339+
# DPM-Solver and DPM-Solver++ only need the "mean" output.
340+
if self.config.variance_type in ["learned_range"]:
341+
model_output = model_output[:, :3]
320342
return model_output
321343
elif self.config.prediction_type == "sample":
322344
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]

tests/schedulers/test_scheduler_dpm_multi.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ def get_scheduler_config(self, **kwargs):
2929
"algorithm_type": "dpmsolver++",
3030
"solver_type": "midpoint",
3131
"lower_order_final": False,
32+
"lambda_min_clipped": -float("inf"),
33+
"variance_type": None,
3234
}
3335

3436
config.update(**kwargs)
@@ -187,6 +189,14 @@ def test_lower_order_final(self):
187189
self.check_over_configs(lower_order_final=True)
188190
self.check_over_configs(lower_order_final=False)
189191

192+
def test_lambda_min_clipped(self):
193+
self.check_over_configs(lambda_min_clipped=-float("inf"))
194+
self.check_over_configs(lambda_min_clipped=-5.1)
195+
196+
def test_variance_type(self):
197+
self.check_over_configs(variance_type=None)
198+
self.check_over_configs(variance_type="learned_range")
199+
190200
def test_inference_steps(self):
191201
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
192202
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)

tests/schedulers/test_scheduler_dpm_single.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def get_scheduler_config(self, **kwargs):
2828
"sample_max_value": 1.0,
2929
"algorithm_type": "dpmsolver++",
3030
"solver_type": "midpoint",
31+
"lambda_min_clipped": -float("inf"),
32+
"variance_type": None,
3133
}
3234

3335
config.update(**kwargs)
@@ -179,6 +181,14 @@ def test_lower_order_final(self):
179181
self.check_over_configs(lower_order_final=True)
180182
self.check_over_configs(lower_order_final=False)
181183

184+
def test_lambda_min_clipped(self):
185+
self.check_over_configs(lambda_min_clipped=-float("inf"))
186+
self.check_over_configs(lambda_min_clipped=-5.1)
187+
188+
def test_variance_type(self):
189+
self.check_over_configs(variance_type=None)
190+
self.check_over_configs(variance_type="learned_range")
191+
182192
def test_inference_steps(self):
183193
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
184194
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)

0 commit comments

Comments
 (0)