@@ -108,7 +108,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
108
108
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
109
109
`algorithm_type="dpmsolver++"`.
110
110
algorithm_type (`str`, defaults to `dpmsolver++`):
111
- Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde- dpmsolver++`. The
111
+ Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The
112
112
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
113
113
paper, and the `dpmsolver++` type implements the algorithms in the
114
114
[DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
@@ -122,6 +122,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
122
122
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
123
123
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
124
124
the sigmas are determined according to a sequence of noise levels {σi}.
125
+ final_sigmas_type (`str`, *optional*, defaults to `"zero"`):
126
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma
127
+ is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
125
128
lambda_min_clipped (`float`, defaults to `-inf`):
126
129
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
127
130
cosine (`squaredcos_cap_v2`) noise schedule.
@@ -150,9 +153,14 @@ def __init__(
150
153
solver_type : str = "midpoint" ,
151
154
lower_order_final : bool = True ,
152
155
use_karras_sigmas : Optional [bool ] = False ,
156
+ final_sigmas_type : Optional [str ] = "zero" , # "zero", "sigma_min"
153
157
lambda_min_clipped : float = - float ("inf" ),
154
158
variance_type : Optional [str ] = None ,
155
159
):
160
+ if algorithm_type == "dpmsolver" :
161
+ deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
162
+ deprecate ("algorithm_types=dpmsolver" , "1.0.0" , deprecation_message )
163
+
156
164
if trained_betas is not None :
157
165
self .betas = torch .tensor (trained_betas , dtype = torch .float32 )
158
166
elif beta_schedule == "linear" :
@@ -189,6 +197,11 @@ def __init__(
189
197
else :
190
198
raise NotImplementedError (f"{ solver_type } does is not implemented for { self .__class__ } " )
191
199
200
+ if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero" :
201
+ raise ValueError (
202
+ f"`final_sigmas_type` { final_sigmas_type } is not supported for `algorithm_type` { algorithm_type } . Please chooose `sigma_min` instead."
203
+ )
204
+
192
205
# setable values
193
206
self .num_inference_steps = None
194
207
timesteps = np .linspace (0 , num_train_timesteps - 1 , num_train_timesteps , dtype = np .float32 )[::- 1 ].copy ()
@@ -267,11 +280,18 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
267
280
sigmas = np .flip (sigmas ).copy ()
268
281
sigmas = self ._convert_to_karras (in_sigmas = sigmas , num_inference_steps = num_inference_steps )
269
282
timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ]).round ()
270
- sigmas = np .concatenate ([sigmas , sigmas [- 1 :]]).astype (np .float32 )
271
283
else :
272
284
sigmas = np .interp (timesteps , np .arange (0 , len (sigmas )), sigmas )
285
+
286
+ if self .config .final_sigmas_type == "sigma_min" :
273
287
sigma_last = ((1 - self .alphas_cumprod [0 ]) / self .alphas_cumprod [0 ]) ** 0.5
274
- sigmas = np .concatenate ([sigmas , [sigma_last ]]).astype (np .float32 )
288
+ elif self .config .final_sigmas_type == "zero" :
289
+ sigma_last = 0
290
+ else :
291
+ raise ValueError (
292
+ f" `final_sigmas_type` must be one of `sigma_min` or `zero`, but got { self .config .final_sigmas_type } "
293
+ )
294
+ sigmas = np .concatenate ([sigmas , [sigma_last ]]).astype (np .float32 )
275
295
276
296
self .sigmas = torch .from_numpy (sigmas ).to (device = device )
277
297
@@ -285,6 +305,12 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
285
305
)
286
306
self .register_to_config (lower_order_final = True )
287
307
308
+ if not self .config .lower_order_final and self .config .final_sigmas_type == "zero" :
309
+ logger .warn (
310
+ " `last_sigmas_type='zero'` is not supported for `lower_order_final=False`. Changing scheduler {self.config} to have `lower_order_final` set to True."
311
+ )
312
+ self .register_to_config (lower_order_final = True )
313
+
288
314
self .order_list = self .get_order_list (num_inference_steps )
289
315
290
316
# add an index counter for schedulers that allow duplicated timesteps
0 commit comments