Skip to content

Commit 0ffac97

Browse files
authored
Add use_Karras_sigmas to LMSDiscreteScheduler (#3351)
* add karras sigma to lms discrete scheduler * add test for lms_scheduler karras * reformat test lms
1 parent b0966f5 commit 0ffac97

File tree

2 files changed

+77
-2
lines changed

2 files changed

+77
-2
lines changed

src/diffusers/schedulers/scheduling_lms_discrete.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
9494
`linear` or `scaled_linear`.
9595
trained_betas (`np.ndarray`, optional):
9696
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
97+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
98+
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
99+
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
100+
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
97101
prediction_type (`str`, default `epsilon`, optional):
98102
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
99103
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
@@ -111,6 +115,7 @@ def __init__(
111115
beta_end: float = 0.02,
112116
beta_schedule: str = "linear",
113117
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
118+
use_karras_sigmas: Optional[bool] = False,
114119
prediction_type: str = "epsilon",
115120
):
116121
if trained_betas is not None:
@@ -140,8 +145,8 @@ def __init__(
140145

141146
# setable values
142147
self.num_inference_steps = None
143-
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
144-
self.timesteps = torch.from_numpy(timesteps)
148+
self.use_karras_sigmas = use_karras_sigmas
149+
self.set_timesteps(num_train_timesteps, None)
145150
self.derivatives = []
146151
self.is_scale_input_called = False
147152

@@ -201,8 +206,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
201206
self.num_inference_steps = num_inference_steps
202207

203208
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
209+
204210
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
211+
log_sigmas = np.log(sigmas)
205212
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
213+
214+
if self.use_karras_sigmas:
215+
sigmas = self._convert_to_karras(in_sigmas=sigmas)
216+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
217+
206218
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
207219

208220
self.sigmas = torch.from_numpy(sigmas).to(device=device)
@@ -214,6 +226,44 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
214226

215227
self.derivatives = []
216228

229+
# copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t
230+
def _sigma_to_t(self, sigma, log_sigmas):
231+
# get log sigma
232+
log_sigma = np.log(sigma)
233+
234+
# get distribution
235+
dists = log_sigma - log_sigmas[:, np.newaxis]
236+
237+
# get sigmas range
238+
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
239+
high_idx = low_idx + 1
240+
241+
low = log_sigmas[low_idx]
242+
high = log_sigmas[high_idx]
243+
244+
# interpolate sigmas
245+
w = (low - log_sigma) / (low - high)
246+
w = np.clip(w, 0, 1)
247+
248+
# transform interpolation to time range
249+
t = (1 - w) * low_idx + w * high_idx
250+
t = t.reshape(sigma.shape)
251+
return t
252+
253+
# copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras
254+
def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor:
255+
"""Constructs the noise schedule of Karras et al. (2022)."""
256+
257+
sigma_min: float = in_sigmas[-1].item()
258+
sigma_max: float = in_sigmas[0].item()
259+
260+
rho = 7.0 # 7.0 is the value used in the paper
261+
ramp = np.linspace(0, 1, self.num_inference_steps)
262+
min_inv_rho = sigma_min ** (1 / rho)
263+
max_inv_rho = sigma_max ** (1 / rho)
264+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
265+
return sigmas
266+
217267
def step(
218268
self,
219269
model_output: torch.FloatTensor,

tests/schedulers/test_scheduler_lms.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,28 @@ def test_full_loop_device(self):
113113

114114
assert abs(result_sum.item() - 1006.388) < 1e-2
115115
assert abs(result_mean.item() - 1.31) < 1e-3
116+
117+
def test_full_loop_device_karras_sigmas(self):
118+
scheduler_class = self.scheduler_classes[0]
119+
scheduler_config = self.get_scheduler_config()
120+
scheduler = scheduler_class(**scheduler_config, use_karras_sigmas=True)
121+
122+
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
123+
124+
model = self.dummy_model()
125+
sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma
126+
sample = sample.to(torch_device)
127+
128+
for t in scheduler.timesteps:
129+
sample = scheduler.scale_model_input(sample, t)
130+
131+
model_output = model(sample, t)
132+
133+
output = scheduler.step(model_output, t, sample)
134+
sample = output.prev_sample
135+
136+
result_sum = torch.sum(torch.abs(sample))
137+
result_mean = torch.mean(torch.abs(sample))
138+
139+
assert abs(result_sum.item() - 3812.9927) < 1e-2
140+
assert abs(result_mean.item() - 4.9648) < 1e-3

0 commit comments

Comments
 (0)