12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import math
15
16
from dataclasses import dataclass
16
17
from typing import Optional , Tuple , Union
17
18
@@ -65,6 +66,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
65
66
range is [0.2, 80.0].
66
67
sigma_data (`float`, *optional*, defaults to 0.5):
67
68
The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
69
+ sigma_schedule (`str`, *optional*, defaults to `karras`):
70
+ Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
71
+ (https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was
72
+ incorporated in this model: https://huggingface.co/stabilityai/cosxl.
68
73
num_train_timesteps (`int`, defaults to 1000):
69
74
The number of diffusion steps to train the model.
70
75
prediction_type (`str`, defaults to `epsilon`, *optional*):
@@ -84,15 +89,23 @@ def __init__(
84
89
sigma_min : float = 0.002 ,
85
90
sigma_max : float = 80.0 ,
86
91
sigma_data : float = 0.5 ,
92
+ sigma_schedule : str = "karras" ,
87
93
num_train_timesteps : int = 1000 ,
88
94
prediction_type : str = "epsilon" ,
89
95
rho : float = 7.0 ,
90
96
):
97
+ if sigma_schedule not in ["karras" , "exponential" ]:
98
+ raise ValueError (f"Wrong value for provided for `{ sigma_schedule = } `.`" )
99
+
91
100
# setable values
92
101
self .num_inference_steps = None
93
102
94
103
ramp = torch .linspace (0 , 1 , num_train_timesteps )
95
- sigmas = self ._compute_sigmas (ramp )
104
+ if sigma_schedule == "karras" :
105
+ sigmas = self ._compute_karras_sigmas (ramp )
106
+ elif sigma_schedule == "exponential" :
107
+ sigmas = self ._compute_exponential_sigmas (ramp )
108
+
96
109
self .timesteps = self .precondition_noise (sigmas )
97
110
98
111
self .sigmas = torch .cat ([sigmas , torch .zeros (1 , device = sigmas .device )])
@@ -200,7 +213,10 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
200
213
self .num_inference_steps = num_inference_steps
201
214
202
215
ramp = np .linspace (0 , 1 , self .num_inference_steps )
203
- sigmas = self ._compute_sigmas (ramp )
216
+ if self .config .sigma_schedule == "karras" :
217
+ sigmas = self ._compute_karras_sigmas (ramp )
218
+ elif self .config .sigma_schedule == "exponential" :
219
+ sigmas = self ._compute_exponential_sigmas (ramp )
204
220
205
221
sigmas = torch .from_numpy (sigmas ).to (dtype = torch .float32 , device = device )
206
222
self .timesteps = self .precondition_noise (sigmas )
@@ -211,16 +227,26 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
211
227
self .sigmas = self .sigmas .to ("cpu" ) # to avoid too much CPU/GPU communication
212
228
213
229
# Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
214
- def _compute_sigmas (self , ramp , sigma_min = None , sigma_max = None ) -> torch .FloatTensor :
230
+ def _compute_karras_sigmas (self , ramp , sigma_min = None , sigma_max = None ) -> torch .FloatTensor :
215
231
"""Constructs the noise schedule of Karras et al. (2022)."""
216
-
217
232
sigma_min = sigma_min or self .config .sigma_min
218
233
sigma_max = sigma_max or self .config .sigma_max
219
234
220
235
rho = self .config .rho
221
236
min_inv_rho = sigma_min ** (1 / rho )
222
237
max_inv_rho = sigma_max ** (1 / rho )
223
238
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho )) ** rho
239
+
240
+ return sigmas
241
+
242
+ def _compute_exponential_sigmas (self , ramp , sigma_min = None , sigma_max = None ) -> torch .FloatTensor :
243
+ """Implementation closely follows k-diffusion.
244
+
245
+ https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
246
+ """
247
+ sigma_min = sigma_min or self .config .sigma_min
248
+ sigma_max = sigma_max or self .config .sigma_max
249
+ sigmas = torch .linspace (math .log (sigma_min ), math .log (sigma_max ), len (ramp )).exp ().flip (0 )
224
250
return sigmas
225
251
226
252
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
0 commit comments