@@ -94,6 +94,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
94
94
`linear` or `scaled_linear`.
95
95
trained_betas (`np.ndarray`, optional):
96
96
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
97
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
98
+ This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
99
+ noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
100
+ of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
97
101
prediction_type (`str`, default `epsilon`, optional):
98
102
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
99
103
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
@@ -111,6 +115,7 @@ def __init__(
111
115
beta_end : float = 0.02 ,
112
116
beta_schedule : str = "linear" ,
113
117
trained_betas : Optional [Union [np .ndarray , List [float ]]] = None ,
118
+ use_karras_sigmas : Optional [bool ] = False ,
114
119
prediction_type : str = "epsilon" ,
115
120
):
116
121
if trained_betas is not None :
@@ -140,8 +145,8 @@ def __init__(
140
145
141
146
# setable values
142
147
self .num_inference_steps = None
143
- timesteps = np . linspace ( 0 , num_train_timesteps - 1 , num_train_timesteps , dtype = float )[:: - 1 ]. copy ()
144
- self .timesteps = torch . from_numpy ( timesteps )
148
+ self . use_karras_sigmas = use_karras_sigmas
149
+ self .set_timesteps ( num_train_timesteps , None )
145
150
self .derivatives = []
146
151
self .is_scale_input_called = False
147
152
@@ -201,8 +206,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
201
206
self .num_inference_steps = num_inference_steps
202
207
203
208
timesteps = np .linspace (0 , self .config .num_train_timesteps - 1 , num_inference_steps , dtype = float )[::- 1 ].copy ()
209
+
204
210
sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
211
+ log_sigmas = np .log (sigmas )
205
212
sigmas = np .interp (timesteps , np .arange (0 , len (sigmas )), sigmas )
213
+
214
+ if self .use_karras_sigmas :
215
+ sigmas = self ._convert_to_karras (in_sigmas = sigmas )
216
+ timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ])
217
+
206
218
sigmas = np .concatenate ([sigmas , [0.0 ]]).astype (np .float32 )
207
219
208
220
self .sigmas = torch .from_numpy (sigmas ).to (device = device )
@@ -214,6 +226,44 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
214
226
215
227
self .derivatives = []
216
228
229
+ # copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t
230
+ def _sigma_to_t (self , sigma , log_sigmas ):
231
+ # get log sigma
232
+ log_sigma = np .log (sigma )
233
+
234
+ # get distribution
235
+ dists = log_sigma - log_sigmas [:, np .newaxis ]
236
+
237
+ # get sigmas range
238
+ low_idx = np .cumsum ((dists >= 0 ), axis = 0 ).argmax (axis = 0 ).clip (max = log_sigmas .shape [0 ] - 2 )
239
+ high_idx = low_idx + 1
240
+
241
+ low = log_sigmas [low_idx ]
242
+ high = log_sigmas [high_idx ]
243
+
244
+ # interpolate sigmas
245
+ w = (low - log_sigma ) / (low - high )
246
+ w = np .clip (w , 0 , 1 )
247
+
248
+ # transform interpolation to time range
249
+ t = (1 - w ) * low_idx + w * high_idx
250
+ t = t .reshape (sigma .shape )
251
+ return t
252
+
253
+ # copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras
254
+ def _convert_to_karras (self , in_sigmas : torch .FloatTensor ) -> torch .FloatTensor :
255
+ """Constructs the noise schedule of Karras et al. (2022)."""
256
+
257
+ sigma_min : float = in_sigmas [- 1 ].item ()
258
+ sigma_max : float = in_sigmas [0 ].item ()
259
+
260
+ rho = 7.0 # 7.0 is the value used in the paper
261
+ ramp = np .linspace (0 , 1 , self .num_inference_steps )
262
+ min_inv_rho = sigma_min ** (1 / rho )
263
+ max_inv_rho = sigma_max ** (1 / rho )
264
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho )) ** rho
265
+ return sigmas
266
+
217
267
def step (
218
268
self ,
219
269
model_output : torch .FloatTensor ,
0 commit comments