Skip to content

Commit 524535b

Browse files
nipunjindalnjindalpatrickvonplatensayakpaul
authored
[2064]: Add Karras to DPMSolverMultistepScheduler (#3001)
* [2737]: Add Karras DPMSolverMultistepScheduler * [2737]: Add Karras DPMSolverMultistepScheduler * Add test * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * fix: repo consistency. * remove Copied from statement from the set_timestep method. * fix: test * Empty commit. Co-authored-by: njindal <[email protected]> --------- Co-authored-by: njindal <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 7b2407f commit 524535b

File tree

4 files changed

+60
-5
lines changed

4 files changed

+60
-5
lines changed

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,6 @@ def __init__(
171171
self.model_outputs = [None] * solver_order
172172
self.lower_order_nums = 0
173173

174-
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_timesteps
175174
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
176175
"""
177176
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
114114
lower_order_final (`bool`, default `True`):
115115
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
116116
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
117-
117+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
118+
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
119+
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
120+
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
118121
"""
119122

120123
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -136,6 +139,7 @@ def __init__(
136139
algorithm_type: str = "dpmsolver++",
137140
solver_type: str = "midpoint",
138141
lower_order_final: bool = True,
142+
use_karras_sigmas: Optional[bool] = False,
139143
):
140144
if trained_betas is not None:
141145
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -181,6 +185,7 @@ def __init__(
181185
self.timesteps = torch.from_numpy(timesteps)
182186
self.model_outputs = [None] * solver_order
183187
self.lower_order_nums = 0
188+
self.use_karras_sigmas = use_karras_sigmas
184189

185190
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
186191
"""
@@ -199,6 +204,13 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
199204
.astype(np.int64)
200205
)
201206

207+
if self.use_karras_sigmas:
208+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
209+
log_sigmas = np.log(sigmas)
210+
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
211+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
212+
timesteps = np.flip(timesteps).copy().astype(np.int64)
213+
202214
# when num_inference_steps == num_train_timesteps, we can end up with
203215
# duplicates in timesteps.
204216
_, unique_indices = np.unique(timesteps, return_index=True)
@@ -248,6 +260,44 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
248260

249261
return sample
250262

263+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
264+
def _sigma_to_t(self, sigma, log_sigmas):
265+
# get log sigma
266+
log_sigma = np.log(sigma)
267+
268+
# get distribution
269+
dists = log_sigma - log_sigmas[:, np.newaxis]
270+
271+
# get sigmas range
272+
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
273+
high_idx = low_idx + 1
274+
275+
low = log_sigmas[low_idx]
276+
high = log_sigmas[high_idx]
277+
278+
# interpolate sigmas
279+
w = (low - log_sigma) / (low - high)
280+
w = np.clip(w, 0, 1)
281+
282+
# transform interpolation to time range
283+
t = (1 - w) * low_idx + w * high_idx
284+
t = t.reshape(sigma.shape)
285+
return t
286+
287+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
288+
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
289+
"""Constructs the noise schedule of Karras et al. (2022)."""
290+
291+
sigma_min: float = in_sigmas[-1].item()
292+
sigma_max: float = in_sigmas[0].item()
293+
294+
rho = 7.0 # 7.0 is the value used in the paper
295+
ramp = np.linspace(0, 1, num_inference_steps)
296+
min_inv_rho = sigma_min ** (1 / rho)
297+
max_inv_rho = sigma_max ** (1 / rho)
298+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
299+
return sigmas
300+
251301
def convert_model_output(
252302
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
253303
) -> torch.FloatTensor:

src/diffusers/schedulers/scheduling_euler_discrete.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
206206
)
207207

208208
if self.use_karras_sigmas:
209-
sigmas = self._convert_to_karras(in_sigmas=sigmas)
209+
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
210210
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
211211

212212
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
@@ -241,14 +241,14 @@ def _sigma_to_t(self, sigma, log_sigmas):
241241
return t
242242

243243
# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
244-
def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor:
244+
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
245245
"""Constructs the noise schedule of Karras et al. (2022)."""
246246

247247
sigma_min: float = in_sigmas[-1].item()
248248
sigma_max: float = in_sigmas[0].item()
249249

250250
rho = 7.0 # 7.0 is the value used in the paper
251-
ramp = np.linspace(0, 1, self.num_inference_steps)
251+
ramp = np.linspace(0, 1, num_inference_steps)
252252
min_inv_rho = sigma_min ** (1 / rho)
253253
max_inv_rho = sigma_max ** (1 / rho)
254254
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho

tests/schedulers/test_scheduler_dpm_multi.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,12 @@ def test_full_loop_with_v_prediction(self):
209209

210210
assert abs(result_mean.item() - 0.2251) < 1e-3
211211

212+
def test_full_loop_with_karras_and_v_prediction(self):
213+
sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True)
214+
result_mean = torch.mean(torch.abs(sample))
215+
216+
assert abs(result_mean.item() - 0.2096) < 1e-3
217+
212218
def test_switch(self):
213219
# make sure that iterating over schedulers with same config names gives same results
214220
# for defaults

0 commit comments

Comments
 (0)