@@ -92,6 +92,43 @@ def alpha_bar_fn(t):
92
92
return torch .tensor (betas , dtype = torch .float32 )
93
93
94
94
95
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
96
+ def rescale_zero_terminal_snr (betas ):
97
+ """
98
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
99
+
100
+
101
+ Args:
102
+ betas (`torch.FloatTensor`):
103
+ the betas that the scheduler is being initialized with.
104
+
105
+ Returns:
106
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
107
+ """
108
+ # Convert betas to alphas_bar_sqrt
109
+ alphas = 1.0 - betas
110
+ alphas_cumprod = torch .cumprod (alphas , dim = 0 )
111
+ alphas_bar_sqrt = alphas_cumprod .sqrt ()
112
+
113
+ # Store old values.
114
+ alphas_bar_sqrt_0 = alphas_bar_sqrt [0 ].clone ()
115
+ alphas_bar_sqrt_T = alphas_bar_sqrt [- 1 ].clone ()
116
+
117
+ # Shift so the last timestep is zero.
118
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
119
+
120
+ # Scale so the first timestep is back to the old value.
121
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T )
122
+
123
+ # Convert alphas_bar_sqrt to betas
124
+ alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
125
+ alphas = alphas_bar [1 :] / alphas_bar [:- 1 ] # Revert cumprod
126
+ alphas = torch .cat ([alphas_bar [0 :1 ], alphas ])
127
+ betas = 1 - alphas
128
+
129
+ return betas
130
+
131
+
95
132
class EulerAncestralDiscreteScheduler (SchedulerMixin , ConfigMixin ):
96
133
"""
97
134
Ancestral sampling with Euler method steps.
@@ -122,6 +159,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
122
159
An offset added to the inference steps. You can use a combination of `offset=1` and
123
160
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
124
161
Diffusion.
162
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
163
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
164
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
165
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
125
166
"""
126
167
127
168
_compatibles = [e .name for e in KarrasDiffusionSchedulers ]
@@ -138,6 +179,7 @@ def __init__(
138
179
prediction_type : str = "epsilon" ,
139
180
timestep_spacing : str = "linspace" ,
140
181
steps_offset : int = 0 ,
182
+ rescale_betas_zero_snr : bool = False ,
141
183
):
142
184
if trained_betas is not None :
143
185
self .betas = torch .tensor (trained_betas , dtype = torch .float32 )
@@ -152,9 +194,17 @@ def __init__(
152
194
else :
153
195
raise NotImplementedError (f"{ beta_schedule } does is not implemented for { self .__class__ } " )
154
196
197
+ if rescale_betas_zero_snr :
198
+ self .betas = rescale_zero_terminal_snr (self .betas )
199
+
155
200
self .alphas = 1.0 - self .betas
156
201
self .alphas_cumprod = torch .cumprod (self .alphas , dim = 0 )
157
202
203
+ if rescale_betas_zero_snr :
204
+ # Close to 0 without being 0 so first sigma is not inf
205
+ # FP16 smallest positive subnormal works well here
206
+ self .alphas_cumprod [- 1 ] = 2 ** - 24
207
+
158
208
sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
159
209
sigmas = np .concatenate ([sigmas [::- 1 ], [0.0 ]]).astype (np .float32 )
160
210
self .sigmas = torch .from_numpy (sigmas )
@@ -327,6 +377,9 @@ def step(
327
377
328
378
sigma = self .sigmas [self .step_index ]
329
379
380
+ # Upcast to avoid precision issues when computing prev_sample
381
+ sample = sample .to (torch .float32 )
382
+
330
383
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
331
384
if self .config .prediction_type == "epsilon" :
332
385
pred_original_sample = sample - sigma * model_output
@@ -357,6 +410,9 @@ def step(
357
410
358
411
prev_sample = prev_sample + noise * sigma_up
359
412
413
+ # Cast sample back to model compatible dtype
414
+ prev_sample = prev_sample .to (model_output .dtype )
415
+
360
416
# upon completion increase step index by one
361
417
self ._step_index += 1
362
418
0 commit comments