Skip to content

Commit 6d78118

Browse files
committed
EulerDiscrete upcast samples during step()
Fixes the ZSNR precision issues on fp16/bf16 with no measureable performance loss. Now using the full 2 ** -24, the results are effectively equivalent to DDIM's ZSNR rescaling
1 parent 4f580e4 commit 6d78118

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

src/diffusers/schedulers/scheduling_euler_discrete.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,9 @@ def __init__(
211211
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
212212

213213
if rescale_betas_zero_snr:
214-
# Bandaid so first sigma isn't inf
215-
# Lower values that follow the 'proper' curve have precision issues on fp16/bf16
216-
self.alphas_cumprod[-1] = 2 ** -16
214+
# Close to 0 without being 0 so first sigma is not inf
215+
# FP16 smallest positive subnormal works well here
216+
self.alphas_cumprod[-1] = 2 ** -24
217217

218218
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
219219
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
@@ -468,6 +468,9 @@ def step(
468468
if self.step_index is None:
469469
self._init_step_index(timestep)
470470

471+
# Upcast to avoid precision issues when computing prev_sample
472+
sample = sample.to(torch.float32)
473+
471474
sigma = self.sigmas[self.step_index]
472475

473476
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
@@ -504,6 +507,9 @@ def step(
504507

505508
prev_sample = sample + derivative * dt
506509

510+
# Cast sample back to model compatible dtype
511+
prev_sample = prev_sample.to(model_output.dtype)
512+
507513
# upon completion increase step index by one
508514
self._step_index += 1
509515

0 commit comments

Comments
 (0)