@@ -160,7 +160,6 @@ def __init__(
160
160
lower_order_final : bool = True ,
161
161
euler_at_final : bool = False ,
162
162
use_karras_sigmas : Optional [bool ] = False ,
163
- final_sigmas_type : Optional [str ] = "default" , # "denoise_to_zero", "default"
164
163
lambda_min_clipped : float = - float ("inf" ),
165
164
variance_type : Optional [str ] = None ,
166
165
timestep_spacing : str = "linspace" ,
@@ -203,11 +202,6 @@ def __init__(
203
202
else :
204
203
raise NotImplementedError (f"{ solver_type } does is not implemented for { self .__class__ } " )
205
204
206
- if algorithm_type not in ["dpmsolver++" , "sde-dpmsolver++" ] and final_sigmas_type == "denoise_to_zero" :
207
- raise ValueError (
208
- f"`final_sigmas_type` { final_sigmas_type } is not supported for `algorithm_type` { algorithm_type } ."
209
- )
210
-
211
205
# setable values
212
206
self .num_inference_steps = None
213
207
timesteps = np .linspace (0 , num_train_timesteps - 1 , num_train_timesteps , dtype = np .float32 ).copy ()
@@ -270,27 +264,13 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
270
264
sigmas = self ._convert_to_karras (in_sigmas = sigmas , num_inference_steps = num_inference_steps )
271
265
timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ]).round ()
272
266
timesteps = timesteps .copy ().astype (np .int64 )
273
- if self .config .final_sigmas_type == "default" :
274
- sigmas = np .concatenate ([sigmas [0 ], sigmas ]).astype (np .float32 )
275
- elif self .config .final_sigmas_type == "denoise_to_zero" :
276
- sigmas = np .concatenate ([np .array ([0 ]), sigmas ]).astype (np .float32 )
277
- else :
278
- raise ValueError (
279
- f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got { self .config .final_sigmas_type } "
280
- )
267
+ sigmas = np .concatenate ([sigmas , sigmas [- 1 :]]).astype (np .float32 )
281
268
else :
282
269
sigmas = np .interp (timesteps , np .arange (0 , len (sigmas )), sigmas )
283
- if self .config .final_sigmas_type == "default" :
284
- sigma_last = (
285
- (1 - self .alphas_cumprod [self .noisiest_timestep ]) / self .alphas_cumprod [self .noisiest_timestep ]
286
- ) ** 0.5
287
- elif self .config .final_sigmas_type == "denoise_to_zero" :
288
- sigma_last = 0
289
- else :
290
- raise ValueError (
291
- f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got { self .config .final_sigmas_type } "
292
- )
293
- sigmas = np .concatenate ([[sigma_last ], sigmas ]).astype (np .float32 )
270
+ sigma_max = (
271
+ (1 - self .alphas_cumprod [self .noisiest_timestep ]) / self .alphas_cumprod [self .noisiest_timestep ]
272
+ ) ** 0.5
273
+ sigmas = np .concatenate ([sigmas , [sigma_max ]]).astype (np .float32 )
294
274
295
275
self .sigmas = torch .from_numpy (sigmas )
296
276
self .timesteps = torch .from_numpy (timesteps ).to (device = device , dtype = torch .int64 )
@@ -797,7 +777,6 @@ def _init_step_index(self, timestep):
797
777
798
778
self ._step_index = step_index
799
779
800
- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step
801
780
def step (
802
781
self ,
803
782
model_output : torch .FloatTensor ,
@@ -838,9 +817,7 @@ def step(
838
817
839
818
# Improve numerical stability for small number of steps
840
819
lower_order_final = (self .step_index == len (self .timesteps ) - 1 ) and (
841
- self .config .euler_at_final
842
- or (self .config .lower_order_final and len (self .timesteps ) < 15 )
843
- or self .config .final_sigmas_type == "denoise_to_zero"
820
+ self .config .euler_at_final or (self .config .lower_order_final and len (self .timesteps ) < 15 )
844
821
)
845
822
lower_order_second = (
846
823
(self .step_index == len (self .timesteps ) - 2 ) and self .config .lower_order_final and len (self .timesteps ) < 15
0 commit comments