@@ -118,6 +118,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
118
118
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
119
119
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
120
120
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
121
+ lambda_min_clipped (`float`, default `-inf`):
122
+ the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for
123
+ cosine (squaredcos_cap_v2) noise schedule.
124
+ variance_type (`str`, *optional*):
125
+ Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's
126
+ guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
127
+ Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
128
+ diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's
129
+ guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
130
+ Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
131
+ diffusion ODEs.
121
132
"""
122
133
123
134
_compatibles = [e .name for e in KarrasDiffusionSchedulers ]
@@ -140,6 +151,8 @@ def __init__(
140
151
solver_type : str = "midpoint" ,
141
152
lower_order_final : bool = True ,
142
153
use_karras_sigmas : Optional [bool ] = False ,
154
+ lambda_min_clipped : float = - float ("inf" ),
155
+ variance_type : Optional [str ] = None ,
143
156
):
144
157
if trained_betas is not None :
145
158
self .betas = torch .tensor (trained_betas , dtype = torch .float32 )
@@ -187,7 +200,7 @@ def __init__(
187
200
self .lower_order_nums = 0
188
201
self .use_karras_sigmas = use_karras_sigmas
189
202
190
- def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None ):
203
+ def set_timesteps (self , num_inference_steps : int = None , device : Union [str , torch .device ] = None ):
191
204
"""
192
205
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
193
206
@@ -197,8 +210,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
197
210
device (`str` or `torch.device`, optional):
198
211
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
199
212
"""
213
+ # Clipping the minimum of all lambda(t) for numerical stability.
214
+ # This is critical for cosine (squaredcos_cap_v2) noise schedule.
215
+ clipped_idx = torch .searchsorted (torch .flip (self .lambda_t , [0 ]), self .lambda_min_clipped )
200
216
timesteps = (
201
- np .linspace (0 , self .config .num_train_timesteps - 1 , num_inference_steps + 1 )
217
+ np .linspace (0 , self .config .num_train_timesteps - 1 - clipped_idx , num_inference_steps + 1 )
202
218
.round ()[::- 1 ][:- 1 ]
203
219
.copy ()
204
220
.astype (np .int64 )
@@ -320,9 +336,13 @@ def convert_model_output(
320
336
Returns:
321
337
`torch.FloatTensor`: the converted model output.
322
338
"""
339
+
323
340
# DPM-Solver++ needs to solve an integral of the data prediction model.
324
341
if self .config .algorithm_type == "dpmsolver++" :
325
342
if self .config .prediction_type == "epsilon" :
343
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
344
+ if self .config .variance_type in ["learned_range" ]:
345
+ model_output = model_output [:, :3 ]
326
346
alpha_t , sigma_t = self .alpha_t [timestep ], self .sigma_t [timestep ]
327
347
x0_pred = (sample - sigma_t * model_output ) / alpha_t
328
348
elif self .config .prediction_type == "sample" :
@@ -343,6 +363,9 @@ def convert_model_output(
343
363
# DPM-Solver needs to solve an integral of the noise prediction model.
344
364
elif self .config .algorithm_type == "dpmsolver" :
345
365
if self .config .prediction_type == "epsilon" :
366
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
367
+ if self .config .variance_type in ["learned_range" ]:
368
+ model_output = model_output [:, :3 ]
346
369
return model_output
347
370
elif self .config .prediction_type == "sample" :
348
371
alpha_t , sigma_t = self .alpha_t [timestep ], self .sigma_t [timestep ]
0 commit comments