@@ -114,7 +114,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
114
114
lower_order_final (`bool`, default `True`):
115
115
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
116
116
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
117
-
117
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
118
+ This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
119
+ noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
120
+ of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
118
121
"""
119
122
120
123
_compatibles = [e .name for e in KarrasDiffusionSchedulers ]
@@ -136,6 +139,7 @@ def __init__(
136
139
algorithm_type : str = "dpmsolver++" ,
137
140
solver_type : str = "midpoint" ,
138
141
lower_order_final : bool = True ,
142
+ use_karras_sigmas : Optional [bool ] = False ,
139
143
):
140
144
if trained_betas is not None :
141
145
self .betas = torch .tensor (trained_betas , dtype = torch .float32 )
@@ -181,6 +185,7 @@ def __init__(
181
185
self .timesteps = torch .from_numpy (timesteps )
182
186
self .model_outputs = [None ] * solver_order
183
187
self .lower_order_nums = 0
188
+ self .use_karras_sigmas = use_karras_sigmas
184
189
185
190
def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None ):
186
191
"""
@@ -199,6 +204,13 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
199
204
.astype (np .int64 )
200
205
)
201
206
207
+ if self .use_karras_sigmas :
208
+ sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
209
+ log_sigmas = np .log (sigmas )
210
+ sigmas = self ._convert_to_karras (in_sigmas = sigmas , num_inference_steps = num_inference_steps )
211
+ timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ]).round ()
212
+ timesteps = np .flip (timesteps ).copy ().astype (np .int64 )
213
+
202
214
# when num_inference_steps == num_train_timesteps, we can end up with
203
215
# duplicates in timesteps.
204
216
_ , unique_indices = np .unique (timesteps , return_index = True )
@@ -248,6 +260,44 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
248
260
249
261
return sample
250
262
263
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
264
+ def _sigma_to_t (self , sigma , log_sigmas ):
265
+ # get log sigma
266
+ log_sigma = np .log (sigma )
267
+
268
+ # get distribution
269
+ dists = log_sigma - log_sigmas [:, np .newaxis ]
270
+
271
+ # get sigmas range
272
+ low_idx = np .cumsum ((dists >= 0 ), axis = 0 ).argmax (axis = 0 ).clip (max = log_sigmas .shape [0 ] - 2 )
273
+ high_idx = low_idx + 1
274
+
275
+ low = log_sigmas [low_idx ]
276
+ high = log_sigmas [high_idx ]
277
+
278
+ # interpolate sigmas
279
+ w = (low - log_sigma ) / (low - high )
280
+ w = np .clip (w , 0 , 1 )
281
+
282
+ # transform interpolation to time range
283
+ t = (1 - w ) * low_idx + w * high_idx
284
+ t = t .reshape (sigma .shape )
285
+ return t
286
+
287
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
288
+ def _convert_to_karras (self , in_sigmas : torch .FloatTensor , num_inference_steps ) -> torch .FloatTensor :
289
+ """Constructs the noise schedule of Karras et al. (2022)."""
290
+
291
+ sigma_min : float = in_sigmas [- 1 ].item ()
292
+ sigma_max : float = in_sigmas [0 ].item ()
293
+
294
+ rho = 7.0 # 7.0 is the value used in the paper
295
+ ramp = np .linspace (0 , 1 , num_inference_steps )
296
+ min_inv_rho = sigma_min ** (1 / rho )
297
+ max_inv_rho = sigma_max ** (1 / rho )
298
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho )) ** rho
299
+ return sigmas
300
+
251
301
def convert_model_output (
252
302
self , model_output : torch .FloatTensor , timestep : int , sample : torch .FloatTensor
253
303
) -> torch .FloatTensor :
0 commit comments