Skip to content

Commit 457abdf

Browse files
Beinseziisayakpaul
andauthored
EulerAncestral add rescale_betas_zero_snr (#6187)
* EulerAncestral add `rescale_betas_zero_snr` Uses same infinite sigma fix from EulerDiscrete. Interestingly the ancestral version had the opposite problem: too much contrast instead of too little. * UT for EulerAncestral `rescale_betas_zero_snr` * EulerAncestral upcast samples during step() It helps this scheduler too, particularly when the model is using bf16. While the noise dtype is still the model's it's automatically upcasted for the add so all it affects is determinism. --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent ff43dba commit 457abdf

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,43 @@ def alpha_bar_fn(t):
9292
return torch.tensor(betas, dtype=torch.float32)
9393

9494

95+
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
96+
def rescale_zero_terminal_snr(betas):
97+
"""
98+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
99+
100+
101+
Args:
102+
betas (`torch.FloatTensor`):
103+
the betas that the scheduler is being initialized with.
104+
105+
Returns:
106+
`torch.FloatTensor`: rescaled betas with zero terminal SNR
107+
"""
108+
# Convert betas to alphas_bar_sqrt
109+
alphas = 1.0 - betas
110+
alphas_cumprod = torch.cumprod(alphas, dim=0)
111+
alphas_bar_sqrt = alphas_cumprod.sqrt()
112+
113+
# Store old values.
114+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
115+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
116+
117+
# Shift so the last timestep is zero.
118+
alphas_bar_sqrt -= alphas_bar_sqrt_T
119+
120+
# Scale so the first timestep is back to the old value.
121+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
122+
123+
# Convert alphas_bar_sqrt to betas
124+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
125+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
126+
alphas = torch.cat([alphas_bar[0:1], alphas])
127+
betas = 1 - alphas
128+
129+
return betas
130+
131+
95132
class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
96133
"""
97134
Ancestral sampling with Euler method steps.
@@ -122,6 +159,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
122159
An offset added to the inference steps. You can use a combination of `offset=1` and
123160
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
124161
Diffusion.
162+
rescale_betas_zero_snr (`bool`, defaults to `False`):
163+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
164+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
165+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
125166
"""
126167

127168
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -138,6 +179,7 @@ def __init__(
138179
prediction_type: str = "epsilon",
139180
timestep_spacing: str = "linspace",
140181
steps_offset: int = 0,
182+
rescale_betas_zero_snr: bool = False,
141183
):
142184
if trained_betas is not None:
143185
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -152,9 +194,17 @@ def __init__(
152194
else:
153195
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
154196

197+
if rescale_betas_zero_snr:
198+
self.betas = rescale_zero_terminal_snr(self.betas)
199+
155200
self.alphas = 1.0 - self.betas
156201
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
157202

203+
if rescale_betas_zero_snr:
204+
# Close to 0 without being 0 so first sigma is not inf
205+
# FP16 smallest positive subnormal works well here
206+
self.alphas_cumprod[-1] = 2**-24
207+
158208
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
159209
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
160210
self.sigmas = torch.from_numpy(sigmas)
@@ -327,6 +377,9 @@ def step(
327377

328378
sigma = self.sigmas[self.step_index]
329379

380+
# Upcast to avoid precision issues when computing prev_sample
381+
sample = sample.to(torch.float32)
382+
330383
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
331384
if self.config.prediction_type == "epsilon":
332385
pred_original_sample = sample - sigma * model_output
@@ -357,6 +410,9 @@ def step(
357410

358411
prev_sample = prev_sample + noise * sigma_up
359412

413+
# Cast sample back to model compatible dtype
414+
prev_sample = prev_sample.to(model_output.dtype)
415+
360416
# upon completion increase step index by one
361417
self._step_index += 1
362418

tests/schedulers/test_scheduler_euler_ancestral.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ def test_prediction_type(self):
3737
for prediction_type in ["epsilon", "v_prediction"]:
3838
self.check_over_configs(prediction_type=prediction_type)
3939

40+
def test_rescale_betas_zero_snr(self):
41+
for rescale_betas_zero_snr in [True, False]:
42+
self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr)
43+
4044
def test_full_loop_no_noise(self):
4145
scheduler_class = self.scheduler_classes[0]
4246
scheduler_config = self.get_scheduler_config()

0 commit comments

Comments
 (0)