Skip to content

Commit 9c99d2a

Browse files
nipunjindalnjindal
and
njindal
authored
[2905]: Add Karras pattern to discrete euler (huggingface#2956)
* [2905]: Add Karras pattern to discrete euler * [2905]: Add Karras pattern to discrete euler * Review comments * Review comments * Review comments * Review comments --------- Co-authored-by: njindal <[email protected]>
1 parent 43efe0c commit 9c99d2a

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

schedulers/scheduling_euler_discrete.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
103103
interpolation_type (`str`, default `"linear"`, optional):
104104
interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of
105105
[`"linear"`, `"log_linear"`].
106+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
107+
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
108+
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
109+
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
106110
"""
107111

108112
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -118,6 +122,7 @@ def __init__(
118122
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
119123
prediction_type: str = "epsilon",
120124
interpolation_type: str = "linear",
125+
use_karras_sigmas: Optional[bool] = False,
121126
):
122127
if trained_betas is not None:
123128
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -149,6 +154,7 @@ def __init__(
149154
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
150155
self.timesteps = torch.from_numpy(timesteps)
151156
self.is_scale_input_called = False
157+
self.use_karras_sigmas = use_karras_sigmas
152158

153159
def scale_model_input(
154160
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
@@ -187,6 +193,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
187193

188194
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
189195
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
196+
log_sigmas = np.log(sigmas)
190197

191198
if self.config.interpolation_type == "linear":
192199
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
@@ -198,6 +205,10 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
198205
" 'linear' or 'log_linear'"
199206
)
200207

208+
if self.use_karras_sigmas:
209+
sigmas = self._convert_to_karras(in_sigmas=sigmas)
210+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
211+
201212
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
202213
self.sigmas = torch.from_numpy(sigmas).to(device=device)
203214
if str(device).startswith("mps"):
@@ -206,6 +217,43 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
206217
else:
207218
self.timesteps = torch.from_numpy(timesteps).to(device=device)
208219

220+
def _sigma_to_t(self, sigma, log_sigmas):
221+
# get log sigma
222+
log_sigma = np.log(sigma)
223+
224+
# get distribution
225+
dists = log_sigma - log_sigmas[:, np.newaxis]
226+
227+
# get sigmas range
228+
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
229+
high_idx = low_idx + 1
230+
231+
low = log_sigmas[low_idx]
232+
high = log_sigmas[high_idx]
233+
234+
# interpolate sigmas
235+
w = (low - log_sigma) / (low - high)
236+
w = np.clip(w, 0, 1)
237+
238+
# transform interpolation to time range
239+
t = (1 - w) * low_idx + w * high_idx
240+
t = t.reshape(sigma.shape)
241+
return t
242+
243+
# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
244+
def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor:
245+
"""Constructs the noise schedule of Karras et al. (2022)."""
246+
247+
sigma_min: float = in_sigmas[-1].item()
248+
sigma_max: float = in_sigmas[0].item()
249+
250+
rho = 7.0 # 7.0 is the value used in the paper
251+
ramp = np.linspace(0, 1, self.num_inference_steps)
252+
min_inv_rho = sigma_min ** (1 / rho)
253+
max_inv_rho = sigma_max ** (1 / rho)
254+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
255+
return sigmas
256+
209257
def step(
210258
self,
211259
model_output: torch.FloatTensor,

0 commit comments

Comments
 (0)