Skip to content

Commit 3dc2362

Browse files
Beinseziisayakpaul
authored andcommitted
EulerDiscreteScheduler add rescale_betas_zero_snr (#6024)
* EulerDiscreteScheduler add `rescale_betas_zero_snr`
1 parent 821726d commit 3dc2362

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

src/diffusers/schedulers/scheduling_euler_discrete.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,43 @@ def alpha_bar_fn(t):
9292
return torch.tensor(betas, dtype=torch.float32)
9393

9494

95+
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
96+
def rescale_zero_terminal_snr(betas):
97+
"""
98+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
99+
100+
101+
Args:
102+
betas (`torch.FloatTensor`):
103+
the betas that the scheduler is being initialized with.
104+
105+
Returns:
106+
`torch.FloatTensor`: rescaled betas with zero terminal SNR
107+
"""
108+
# Convert betas to alphas_bar_sqrt
109+
alphas = 1.0 - betas
110+
alphas_cumprod = torch.cumprod(alphas, dim=0)
111+
alphas_bar_sqrt = alphas_cumprod.sqrt()
112+
113+
# Store old values.
114+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
115+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
116+
117+
# Shift so the last timestep is zero.
118+
alphas_bar_sqrt -= alphas_bar_sqrt_T
119+
120+
# Scale so the first timestep is back to the old value.
121+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
122+
123+
# Convert alphas_bar_sqrt to betas
124+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
125+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
126+
alphas = torch.cat([alphas_bar[0:1], alphas])
127+
betas = 1 - alphas
128+
129+
return betas
130+
131+
95132
class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
96133
"""
97134
Euler scheduler.
@@ -128,6 +165,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
128165
An offset added to the inference steps. You can use a combination of `offset=1` and
129166
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
130167
Diffusion.
168+
rescale_betas_zero_snr (`bool`, defaults to `False`):
169+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
170+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
171+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
131172
"""
132173

133174
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -149,6 +190,7 @@ def __init__(
149190
timestep_spacing: str = "linspace",
150191
timestep_type: str = "discrete", # can be "discrete" or "continuous"
151192
steps_offset: int = 0,
193+
rescale_betas_zero_snr: bool = False,
152194
):
153195
if trained_betas is not None:
154196
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -163,9 +205,17 @@ def __init__(
163205
else:
164206
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
165207

208+
if rescale_betas_zero_snr:
209+
self.betas = rescale_zero_terminal_snr(self.betas)
210+
166211
self.alphas = 1.0 - self.betas
167212
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
168213

214+
if rescale_betas_zero_snr:
215+
# Close to 0 without being 0 so first sigma is not inf
216+
# FP16 smallest positive subnormal works well here
217+
self.alphas_cumprod[-1] = 2**-24
218+
169219
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
170220
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
171221

@@ -420,6 +470,9 @@ def step(
420470
if self.step_index is None:
421471
self._init_step_index(timestep)
422472

473+
# Upcast to avoid precision issues when computing prev_sample
474+
sample = sample.to(torch.float32)
475+
423476
sigma = self.sigmas[self.step_index]
424477

425478
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
@@ -456,6 +509,9 @@ def step(
456509

457510
prev_sample = sample + derivative * dt
458511

512+
# Cast sample back to model compatible dtype
513+
prev_sample = prev_sample.to(model_output.dtype)
514+
459515
# upon completion increase step index by one
460516
self._step_index += 1
461517

tests/schedulers/test_scheduler_euler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ def test_timestep_type(self):
4545
def test_karras_sigmas(self):
4646
self.check_over_configs(use_karras_sigmas=True, sigma_min=0.02, sigma_max=700.0)
4747

48+
def test_rescale_betas_zero_snr(self):
49+
for rescale_betas_zero_snr in [True, False]:
50+
self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr)
51+
4852
def test_full_loop_no_noise(self):
4953
scheduler_class = self.scheduler_classes[0]
5054
scheduler_config = self.get_scheduler_config()

0 commit comments

Comments
 (0)