Skip to content

[2905]: Add Karras pattern to discrete euler #2956

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions src/diffusers/schedulers/scheduling_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
interpolation_type (`str`, default `"linear"`, optional):
interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of
[`"linear"`, `"log_linear"`].
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
Expand All @@ -118,6 +122,7 @@ def __init__(
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
interpolation_type: str = "linear",
use_karras_sigmas: Optional[bool] = False,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
Expand Down Expand Up @@ -149,6 +154,7 @@ def __init__(
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps)
self.is_scale_input_called = False
self.use_karras_sigmas = use_karras_sigmas

def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
Expand Down Expand Up @@ -187,6 +193,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic

timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)

if self.config.interpolation_type == "linear":
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
Expand All @@ -198,6 +205,10 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
" 'linear' or 'log_linear'"
)

if self.use_karras_sigmas:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool! This works for me :-)

sigmas = self._convert_to_karras(in_sigmas=sigmas)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])

sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
if str(device).startswith("mps"):
Expand All @@ -206,6 +217,43 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
else:
self.timesteps = torch.from_numpy(timesteps).to(device=device)

def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
log_sigma = np.log(sigma)

# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]

# get sigmas range
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1

low = log_sigmas[low_idx]
high = log_sigmas[high_idx]

# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)

# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
return t

# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect this function works for me! Happy to just go with this function as it was before :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Fixed.

"""Constructs the noise schedule of Karras et al. (2022)."""

sigma_min: float = in_sigmas[-1].item()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should a config param for sigma min/max be also introduced which can be utilized here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think since sigma_min and sigma_max are derived values (being derived from in_sigmas), I think it's fine to not register the, separately.

sigma_max: float = in_sigmas[0].item()

rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, self.num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas

def step(
self,
model_output: torch.FloatTensor,
Expand Down
27 changes: 27 additions & 0 deletions tests/schedulers/test_scheduler_euler.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,30 @@ def test_full_loop_device(self):

assert abs(result_sum.item() - 10.0807) < 1e-2
assert abs(result_mean.item() - 0.0131) < 1e-3

def test_full_loop_device_karras_sigmas(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config, use_karras_sigmas=True)

scheduler.set_timesteps(self.num_inference_steps, device=torch_device)

generator = torch.manual_seed(0)

model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)

for t in scheduler.timesteps:
sample = scheduler.scale_model_input(sample, t)

model_output = model(sample, t)

output = scheduler.step(model_output, t, sample, generator=generator)
sample = output.prev_sample

result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))

assert abs(result_sum.item() - 124.52299499511719) < 1e-2
assert abs(result_mean.item() - 0.16213932633399963) < 1e-3