Skip to content

Commit ac61eef

Browse files
yiyixuxuyiyixuxupatrickvonplaten
authored
fix DPM Scheduler with use_karras_sigmas option (#6477)
* fix --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <[email protected]>
1 parent f95615b commit ac61eef

5 files changed

+61
-8
lines changed

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
128128
Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
129129
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
130130
`lambda(t)`.
131+
final_sigmas_type (`str`, defaults to `"zero"`):
132+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma
133+
is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
131134
lambda_min_clipped (`float`, defaults to `-inf`):
132135
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
133136
cosine (`squaredcos_cap_v2`) noise schedule.
@@ -165,11 +168,16 @@ def __init__(
165168
euler_at_final: bool = False,
166169
use_karras_sigmas: Optional[bool] = False,
167170
use_lu_lambdas: Optional[bool] = False,
171+
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
168172
lambda_min_clipped: float = -float("inf"),
169173
variance_type: Optional[str] = None,
170174
timestep_spacing: str = "linspace",
171175
steps_offset: int = 0,
172176
):
177+
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
178+
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
179+
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
180+
173181
if trained_betas is not None:
174182
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
175183
elif beta_schedule == "linear":
@@ -207,6 +215,11 @@ def __init__(
207215
else:
208216
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
209217

218+
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
219+
raise ValueError(
220+
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
221+
)
222+
210223
# setable values
211224
self.num_inference_steps = None
212225
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
@@ -267,17 +280,24 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
267280
sigmas = np.flip(sigmas).copy()
268281
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
269282
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
270-
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
271283
elif self.config.use_lu_lambdas:
272284
lambdas = np.flip(log_sigmas.copy())
273285
lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
274286
sigmas = np.exp(lambdas)
275287
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
276-
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
277288
else:
278289
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
290+
291+
if self.config.final_sigmas_type == "sigma_min":
279292
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
280-
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
293+
elif self.config.final_sigmas_type == "zero":
294+
sigma_last = 0
295+
else:
296+
raise ValueError(
297+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
298+
)
299+
300+
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
281301

282302
self.sigmas = torch.from_numpy(sigmas)
283303
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
@@ -831,7 +851,9 @@ def step(
831851

832852
# Improve numerical stability for small number of steps
833853
lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
834-
self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15)
854+
self.config.euler_at_final
855+
or (self.config.lower_order_final and len(self.timesteps) < 15)
856+
or self.config.final_sigmas_type == "zero"
835857
)
836858
lower_order_second = (
837859
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15

src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,10 @@ def __init__(
165165
timestep_spacing: str = "linspace",
166166
steps_offset: int = 0,
167167
):
168+
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
169+
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
170+
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
171+
168172
if trained_betas is not None:
169173
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
170174
elif beta_schedule == "linear":
@@ -783,7 +787,6 @@ def _init_step_index(self, timestep):
783787

784788
self._step_index = step_index
785789

786-
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step
787790
def step(
788791
self,
789792
model_output: torch.FloatTensor,

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
108108
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
109109
`algorithm_type="dpmsolver++"`.
110110
algorithm_type (`str`, defaults to `dpmsolver++`):
111-
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
111+
Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The
112112
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
113113
paper, and the `dpmsolver++` type implements the algorithms in the
114114
[DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
@@ -122,6 +122,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
122122
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
123123
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
124124
the sigmas are determined according to a sequence of noise levels {σi}.
125+
final_sigmas_type (`str`, *optional*, defaults to `"zero"`):
126+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma
127+
is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
125128
lambda_min_clipped (`float`, defaults to `-inf`):
126129
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
127130
cosine (`squaredcos_cap_v2`) noise schedule.
@@ -150,9 +153,14 @@ def __init__(
150153
solver_type: str = "midpoint",
151154
lower_order_final: bool = True,
152155
use_karras_sigmas: Optional[bool] = False,
156+
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
153157
lambda_min_clipped: float = -float("inf"),
154158
variance_type: Optional[str] = None,
155159
):
160+
if algorithm_type == "dpmsolver":
161+
deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
162+
deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message)
163+
156164
if trained_betas is not None:
157165
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
158166
elif beta_schedule == "linear":
@@ -189,6 +197,11 @@ def __init__(
189197
else:
190198
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
191199

200+
if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero":
201+
raise ValueError(
202+
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead."
203+
)
204+
192205
# setable values
193206
self.num_inference_steps = None
194207
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
@@ -267,11 +280,18 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
267280
sigmas = np.flip(sigmas).copy()
268281
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
269282
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
270-
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
271283
else:
272284
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
285+
286+
if self.config.final_sigmas_type == "sigma_min":
273287
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
274-
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
288+
elif self.config.final_sigmas_type == "zero":
289+
sigma_last = 0
290+
else:
291+
raise ValueError(
292+
f" `final_sigmas_type` must be one of `sigma_min` or `zero`, but got {self.config.final_sigmas_type}"
293+
)
294+
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
275295

276296
self.sigmas = torch.from_numpy(sigmas).to(device=device)
277297

@@ -285,6 +305,12 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
285305
)
286306
self.register_to_config(lower_order_final=True)
287307

308+
if not self.config.lower_order_final and self.config.final_sigmas_type == "zero":
309+
logger.warn(
310+
" `last_sigmas_type='zero'` is not supported for `lower_order_final=False`. Changing scheduler {self.config} to have `lower_order_final` set to True."
311+
)
312+
self.register_to_config(lower_order_final=True)
313+
288314
self.order_list = self.get_order_list(num_inference_steps)
289315

290316
# add an index counter for schedulers that allow duplicated timesteps

tests/schedulers/test_scheduler_dpm_multi.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def get_scheduler_config(self, **kwargs):
3232
"euler_at_final": False,
3333
"lambda_min_clipped": -float("inf"),
3434
"variance_type": None,
35+
"final_sigmas_type": "sigma_min",
3536
}
3637

3738
config.update(**kwargs)

tests/schedulers/test_scheduler_dpm_single.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def get_scheduler_config(self, **kwargs):
3030
"solver_type": "midpoint",
3131
"lambda_min_clipped": -float("inf"),
3232
"variance_type": None,
33+
"final_sigmas_type": "sigma_min",
3334
}
3435

3536
config.update(**kwargs)

0 commit comments

Comments
 (0)