@@ -92,6 +92,43 @@ def alpha_bar_fn(t):
92
92
return torch .tensor (betas , dtype = torch .float32 )
93
93
94
94
95
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
96
+ def rescale_zero_terminal_snr (betas ):
97
+ """
98
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
99
+
100
+
101
+ Args:
102
+ betas (`torch.FloatTensor`):
103
+ the betas that the scheduler is being initialized with.
104
+
105
+ Returns:
106
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
107
+ """
108
+ # Convert betas to alphas_bar_sqrt
109
+ alphas = 1.0 - betas
110
+ alphas_cumprod = torch .cumprod (alphas , dim = 0 )
111
+ alphas_bar_sqrt = alphas_cumprod .sqrt ()
112
+
113
+ # Store old values.
114
+ alphas_bar_sqrt_0 = alphas_bar_sqrt [0 ].clone ()
115
+ alphas_bar_sqrt_T = alphas_bar_sqrt [- 1 ].clone ()
116
+
117
+ # Shift so the last timestep is zero.
118
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
119
+
120
+ # Scale so the first timestep is back to the old value.
121
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T )
122
+
123
+ # Convert alphas_bar_sqrt to betas
124
+ alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
125
+ alphas = alphas_bar [1 :] / alphas_bar [:- 1 ] # Revert cumprod
126
+ alphas = torch .cat ([alphas_bar [0 :1 ], alphas ])
127
+ betas = 1 - alphas
128
+
129
+ return betas
130
+
131
+
95
132
class EulerDiscreteScheduler (SchedulerMixin , ConfigMixin ):
96
133
"""
97
134
Euler scheduler.
@@ -128,6 +165,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
128
165
An offset added to the inference steps. You can use a combination of `offset=1` and
129
166
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
130
167
Diffusion.
168
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
169
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
170
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
171
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
131
172
"""
132
173
133
174
_compatibles = [e .name for e in KarrasDiffusionSchedulers ]
@@ -149,6 +190,7 @@ def __init__(
149
190
timestep_spacing : str = "linspace" ,
150
191
timestep_type : str = "discrete" , # can be "discrete" or "continuous"
151
192
steps_offset : int = 0 ,
193
+ rescale_betas_zero_snr : bool = False ,
152
194
):
153
195
if trained_betas is not None :
154
196
self .betas = torch .tensor (trained_betas , dtype = torch .float32 )
@@ -163,9 +205,17 @@ def __init__(
163
205
else :
164
206
raise NotImplementedError (f"{ beta_schedule } does is not implemented for { self .__class__ } " )
165
207
208
+ if rescale_betas_zero_snr :
209
+ self .betas = rescale_zero_terminal_snr (self .betas )
210
+
166
211
self .alphas = 1.0 - self .betas
167
212
self .alphas_cumprod = torch .cumprod (self .alphas , dim = 0 )
168
213
214
+ if rescale_betas_zero_snr :
215
+ # Close to 0 without being 0 so first sigma is not inf
216
+ # FP16 smallest positive subnormal works well here
217
+ self .alphas_cumprod [- 1 ] = 2 ** - 24
218
+
169
219
sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
170
220
timesteps = np .linspace (0 , num_train_timesteps - 1 , num_train_timesteps , dtype = float )[::- 1 ].copy ()
171
221
@@ -420,6 +470,9 @@ def step(
420
470
if self .step_index is None :
421
471
self ._init_step_index (timestep )
422
472
473
+ # Upcast to avoid precision issues when computing prev_sample
474
+ sample = sample .to (torch .float32 )
475
+
423
476
sigma = self .sigmas [self .step_index ]
424
477
425
478
gamma = min (s_churn / (len (self .sigmas ) - 1 ), 2 ** 0.5 - 1 ) if s_tmin <= sigma <= s_tmax else 0.0
@@ -456,6 +509,9 @@ def step(
456
509
457
510
prev_sample = sample + derivative * dt
458
511
512
+ # Cast sample back to model compatible dtype
513
+ prev_sample = prev_sample .to (model_output .dtype )
514
+
459
515
# upon completion increase step index by one
460
516
self ._step_index += 1
461
517
0 commit comments