20
20
import torch
21
21
22
22
from ..configuration_utils import ConfigMixin , register_to_config
23
- from ..utils import BaseOutput , logging
23
+ from ..utils import BaseOutput , is_scipy_available , logging
24
24
from .scheduling_utils import SchedulerMixin
25
25
26
26
27
+ if is_scipy_available ():
28
+ import scipy .stats
29
+
27
30
logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
28
31
29
32
@@ -72,7 +75,16 @@ def __init__(
72
75
base_image_seq_len : Optional [int ] = 256 ,
73
76
max_image_seq_len : Optional [int ] = 4096 ,
74
77
invert_sigmas : bool = False ,
78
+ use_karras_sigmas : Optional [bool ] = False ,
79
+ use_exponential_sigmas : Optional [bool ] = False ,
80
+ use_beta_sigmas : Optional [bool ] = False ,
75
81
):
82
+ if self .config .use_beta_sigmas and not is_scipy_available ():
83
+ raise ImportError ("Make sure to install scipy if you want to use beta sigmas." )
84
+ if sum ([self .config .use_beta_sigmas , self .config .use_exponential_sigmas , self .config .use_karras_sigmas ]) > 1 :
85
+ raise ValueError (
86
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
87
+ )
76
88
timesteps = np .linspace (1 , num_train_timesteps , num_train_timesteps , dtype = np .float32 )[::- 1 ].copy ()
77
89
timesteps = torch .from_numpy (timesteps ).to (dtype = torch .float32 )
78
90
@@ -185,12 +197,14 @@ def set_timesteps(
185
197
device (`str` or `torch.device`, *optional*):
186
198
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
187
199
"""
200
+ if num_inference_steps is None :
201
+ num_inference_steps = len (sigmas ) - 1
202
+ self .num_inference_steps = num_inference_steps
188
203
189
204
if self .config .use_dynamic_shifting and mu is None :
190
205
raise ValueError (" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" )
191
206
192
207
if sigmas is None :
193
- self .num_inference_steps = num_inference_steps
194
208
timesteps = np .linspace (
195
209
self ._sigma_to_t (self .sigma_max ), self ._sigma_to_t (self .sigma_min ), num_inference_steps
196
210
)
@@ -202,6 +216,15 @@ def set_timesteps(
202
216
else :
203
217
sigmas = self .config .shift * sigmas / (1 + (self .config .shift - 1 ) * sigmas )
204
218
219
+ if self .config .use_karras_sigmas :
220
+ sigmas = self ._convert_to_karras (in_sigmas = sigmas , num_inference_steps = len (sigmas ))
221
+
222
+ elif self .config .use_exponential_sigmas :
223
+ sigmas = self ._convert_to_exponential (in_sigmas = sigmas , num_inference_steps = len (sigmas ))
224
+
225
+ elif self .config .use_beta_sigmas :
226
+ sigmas = self ._convert_to_beta (in_sigmas = sigmas , num_inference_steps = len (sigmas ))
227
+
205
228
sigmas = torch .from_numpy (sigmas ).to (dtype = torch .float32 , device = device )
206
229
timesteps = sigmas * self .config .num_train_timesteps
207
230
@@ -314,5 +337,85 @@ def step(
314
337
315
338
return FlowMatchEulerDiscreteSchedulerOutput (prev_sample = prev_sample )
316
339
340
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
341
+ def _convert_to_karras (self , in_sigmas : torch .Tensor , num_inference_steps ) -> torch .Tensor :
342
+ """Constructs the noise schedule of Karras et al. (2022)."""
343
+
344
+ # Hack to make sure that other schedulers which copy this function don't break
345
+ # TODO: Add this logic to the other schedulers
346
+ if hasattr (self .config , "sigma_min" ):
347
+ sigma_min = self .config .sigma_min
348
+ else :
349
+ sigma_min = None
350
+
351
+ if hasattr (self .config , "sigma_max" ):
352
+ sigma_max = self .config .sigma_max
353
+ else :
354
+ sigma_max = None
355
+
356
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas [- 1 ].item ()
357
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas [0 ].item ()
358
+
359
+ rho = 7.0 # 7.0 is the value used in the paper
360
+ ramp = np .linspace (0 , 1 , num_inference_steps )
361
+ min_inv_rho = sigma_min ** (1 / rho )
362
+ max_inv_rho = sigma_max ** (1 / rho )
363
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho )) ** rho
364
+ return sigmas
365
+
366
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
367
+ def _convert_to_exponential (self , in_sigmas : torch .Tensor , num_inference_steps : int ) -> torch .Tensor :
368
+ """Constructs an exponential noise schedule."""
369
+
370
+ # Hack to make sure that other schedulers which copy this function don't break
371
+ # TODO: Add this logic to the other schedulers
372
+ if hasattr (self .config , "sigma_min" ):
373
+ sigma_min = self .config .sigma_min
374
+ else :
375
+ sigma_min = None
376
+
377
+ if hasattr (self .config , "sigma_max" ):
378
+ sigma_max = self .config .sigma_max
379
+ else :
380
+ sigma_max = None
381
+
382
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas [- 1 ].item ()
383
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas [0 ].item ()
384
+
385
+ sigmas = np .exp (np .linspace (math .log (sigma_max ), math .log (sigma_min ), num_inference_steps ))
386
+ return sigmas
387
+
388
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
389
+ def _convert_to_beta (
390
+ self , in_sigmas : torch .Tensor , num_inference_steps : int , alpha : float = 0.6 , beta : float = 0.6
391
+ ) -> torch .Tensor :
392
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
393
+
394
+ # Hack to make sure that other schedulers which copy this function don't break
395
+ # TODO: Add this logic to the other schedulers
396
+ if hasattr (self .config , "sigma_min" ):
397
+ sigma_min = self .config .sigma_min
398
+ else :
399
+ sigma_min = None
400
+
401
+ if hasattr (self .config , "sigma_max" ):
402
+ sigma_max = self .config .sigma_max
403
+ else :
404
+ sigma_max = None
405
+
406
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas [- 1 ].item ()
407
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas [0 ].item ()
408
+
409
+ sigmas = np .array (
410
+ [
411
+ sigma_min + (ppf * (sigma_max - sigma_min ))
412
+ for ppf in [
413
+ scipy .stats .beta .ppf (timestep , alpha , beta )
414
+ for timestep in 1 - np .linspace (0 , 1 , num_inference_steps )
415
+ ]
416
+ ]
417
+ )
418
+ return sigmas
419
+
317
420
def __len__ (self ):
318
421
return self .config .num_train_timesteps
0 commit comments