Skip to content

Commit 56bd7e6

Browse files
[Scheduler] introduce sigma schedule. (#7649)
* introduce sigma schedule. Co-authored-by: Suraj Patil <[email protected]> * address yiyi * update docstrings. * implement the schedule for EDMDPMSolverMultistepScheduler --------- Co-authored-by: Suraj Patil <[email protected]>
1 parent 9d16daa commit 56bd7e6

File tree

2 files changed

+59
-9
lines changed

2 files changed

+59
-9
lines changed

src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm
1616

17+
import math
1718
from typing import List, Optional, Tuple, Union
1819

1920
import numpy as np
@@ -44,6 +45,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
4445
range is [0.2, 80.0].
4546
sigma_data (`float`, *optional*, defaults to 0.5):
4647
The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
48+
sigma_schedule (`str`, *optional*, defaults to `karras`):
49+
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
50+
(https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was
51+
incorporated in this model: https://huggingface.co/stabilityai/cosxl.
4752
num_train_timesteps (`int`, defaults to 1000):
4853
The number of diffusion steps to train the model.
4954
solver_order (`int`, defaults to 2):
@@ -89,6 +94,7 @@ def __init__(
8994
sigma_min: float = 0.002,
9095
sigma_max: float = 80.0,
9196
sigma_data: float = 0.5,
97+
sigma_schedule: str = "karras",
9298
num_train_timesteps: int = 1000,
9399
prediction_type: str = "epsilon",
94100
rho: float = 7.0,
@@ -121,7 +127,11 @@ def __init__(
121127
)
122128

123129
ramp = torch.linspace(0, 1, num_train_timesteps)
124-
sigmas = self._compute_sigmas(ramp)
130+
if sigma_schedule == "karras":
131+
sigmas = self._compute_karras_sigmas(ramp)
132+
elif sigma_schedule == "exponential":
133+
sigmas = self._compute_exponential_sigmas(ramp)
134+
125135
self.timesteps = self.precondition_noise(sigmas)
126136

127137
self.sigmas = self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
@@ -236,7 +246,10 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
236246
self.num_inference_steps = num_inference_steps
237247

238248
ramp = np.linspace(0, 1, self.num_inference_steps)
239-
sigmas = self._compute_sigmas(ramp)
249+
if self.config.sigma_schedule == "karras":
250+
sigmas = self._compute_karras_sigmas(ramp)
251+
elif self.config.sigma_schedule == "exponential":
252+
sigmas = self._compute_exponential_sigmas(ramp)
240253

241254
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
242255
self.timesteps = self.precondition_noise(sigmas)
@@ -262,17 +275,28 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
262275
self._begin_index = None
263276
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
264277

265-
# Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
266-
def _compute_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
278+
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
279+
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
267280
"""Constructs the noise schedule of Karras et al. (2022)."""
268-
269281
sigma_min = sigma_min or self.config.sigma_min
270282
sigma_max = sigma_max or self.config.sigma_max
271283

272284
rho = self.config.rho
273285
min_inv_rho = sigma_min ** (1 / rho)
274286
max_inv_rho = sigma_max ** (1 / rho)
275287
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
288+
289+
return sigmas
290+
291+
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
292+
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
293+
"""Implementation closely follows k-diffusion.
294+
295+
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
296+
"""
297+
sigma_min = sigma_min or self.config.sigma_min
298+
sigma_max = sigma_max or self.config.sigma_max
299+
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0)
276300
return sigmas
277301

278302
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample

src/diffusers/schedulers/scheduling_edm_euler.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import math
1516
from dataclasses import dataclass
1617
from typing import Optional, Tuple, Union
1718

@@ -65,6 +66,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
6566
range is [0.2, 80.0].
6667
sigma_data (`float`, *optional*, defaults to 0.5):
6768
The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
69+
sigma_schedule (`str`, *optional*, defaults to `karras`):
70+
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
71+
(https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was
72+
incorporated in this model: https://huggingface.co/stabilityai/cosxl.
6873
num_train_timesteps (`int`, defaults to 1000):
6974
The number of diffusion steps to train the model.
7075
prediction_type (`str`, defaults to `epsilon`, *optional*):
@@ -84,15 +89,23 @@ def __init__(
8489
sigma_min: float = 0.002,
8590
sigma_max: float = 80.0,
8691
sigma_data: float = 0.5,
92+
sigma_schedule: str = "karras",
8793
num_train_timesteps: int = 1000,
8894
prediction_type: str = "epsilon",
8995
rho: float = 7.0,
9096
):
97+
if sigma_schedule not in ["karras", "exponential"]:
98+
raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`")
99+
91100
# setable values
92101
self.num_inference_steps = None
93102

94103
ramp = torch.linspace(0, 1, num_train_timesteps)
95-
sigmas = self._compute_sigmas(ramp)
104+
if sigma_schedule == "karras":
105+
sigmas = self._compute_karras_sigmas(ramp)
106+
elif sigma_schedule == "exponential":
107+
sigmas = self._compute_exponential_sigmas(ramp)
108+
96109
self.timesteps = self.precondition_noise(sigmas)
97110

98111
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
@@ -200,7 +213,10 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
200213
self.num_inference_steps = num_inference_steps
201214

202215
ramp = np.linspace(0, 1, self.num_inference_steps)
203-
sigmas = self._compute_sigmas(ramp)
216+
if self.config.sigma_schedule == "karras":
217+
sigmas = self._compute_karras_sigmas(ramp)
218+
elif self.config.sigma_schedule == "exponential":
219+
sigmas = self._compute_exponential_sigmas(ramp)
204220

205221
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
206222
self.timesteps = self.precondition_noise(sigmas)
@@ -211,16 +227,26 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
211227
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
212228

213229
# Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
214-
def _compute_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
230+
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
215231
"""Constructs the noise schedule of Karras et al. (2022)."""
216-
217232
sigma_min = sigma_min or self.config.sigma_min
218233
sigma_max = sigma_max or self.config.sigma_max
219234

220235
rho = self.config.rho
221236
min_inv_rho = sigma_min ** (1 / rho)
222237
max_inv_rho = sigma_max ** (1 / rho)
223238
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
239+
240+
return sigmas
241+
242+
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
243+
"""Implementation closely follows k-diffusion.
244+
245+
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
246+
"""
247+
sigma_min = sigma_min or self.config.sigma_min
248+
sigma_max = sigma_max or self.config.sigma_max
249+
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0)
224250
return sigmas
225251

226252
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep

0 commit comments

Comments
 (0)