Skip to content

Commit 4271d2f

Browse files
author
yiyixuxu
committed
reverse changes to inverse scheduler
1 parent 3afde1f commit 4271d2f

File tree

1 file changed

+6
-29
lines changed

1 file changed

+6
-29
lines changed

src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,6 @@ def __init__(
160160
lower_order_final: bool = True,
161161
euler_at_final: bool = False,
162162
use_karras_sigmas: Optional[bool] = False,
163-
final_sigmas_type: Optional[str] = "default", # "denoise_to_zero", "default"
164163
lambda_min_clipped: float = -float("inf"),
165164
variance_type: Optional[str] = None,
166165
timestep_spacing: str = "linspace",
@@ -203,11 +202,6 @@ def __init__(
203202
else:
204203
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
205204

206-
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "denoise_to_zero":
207-
raise ValueError(
208-
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}."
209-
)
210-
211205
# setable values
212206
self.num_inference_steps = None
213207
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32).copy()
@@ -270,27 +264,13 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
270264
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
271265
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
272266
timesteps = timesteps.copy().astype(np.int64)
273-
if self.config.final_sigmas_type == "default":
274-
sigmas = np.concatenate([sigmas[0], sigmas]).astype(np.float32)
275-
elif self.config.final_sigmas_type == "denoise_to_zero":
276-
sigmas = np.concatenate([np.array([0]), sigmas]).astype(np.float32)
277-
else:
278-
raise ValueError(
279-
f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}"
280-
)
267+
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
281268
else:
282269
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
283-
if self.config.final_sigmas_type == "default":
284-
sigma_last = (
285-
(1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep]
286-
) ** 0.5
287-
elif self.config.final_sigmas_type == "denoise_to_zero":
288-
sigma_last = 0
289-
else:
290-
raise ValueError(
291-
f"`final_sigmas_type` must be one of 'default', or 'denoise_to_zero', but got {self.config.final_sigmas_type}"
292-
)
293-
sigmas = np.concatenate([[sigma_last], sigmas]).astype(np.float32)
270+
sigma_max = (
271+
(1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep]
272+
) ** 0.5
273+
sigmas = np.concatenate([sigmas, [sigma_max]]).astype(np.float32)
294274

295275
self.sigmas = torch.from_numpy(sigmas)
296276
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
@@ -797,7 +777,6 @@ def _init_step_index(self, timestep):
797777

798778
self._step_index = step_index
799779

800-
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step
801780
def step(
802781
self,
803782
model_output: torch.FloatTensor,
@@ -838,9 +817,7 @@ def step(
838817

839818
# Improve numerical stability for small number of steps
840819
lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
841-
self.config.euler_at_final
842-
or (self.config.lower_order_final and len(self.timesteps) < 15)
843-
or self.config.final_sigmas_type == "denoise_to_zero"
820+
self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15)
844821
)
845822
lower_order_second = (
846823
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15

0 commit comments

Comments
 (0)