-
Notifications
You must be signed in to change notification settings - Fork 6k
[2905]: Add Karras pattern to discrete euler #2956
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -103,6 +103,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): | |
interpolation_type (`str`, default `"linear"`, optional): | ||
interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of | ||
[`"linear"`, `"log_linear"`]. | ||
use_karras_sigmas (`bool`, *optional*, defaults to `False`): | ||
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the | ||
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence | ||
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. | ||
""" | ||
|
||
_compatibles = [e.name for e in KarrasDiffusionSchedulers] | ||
|
@@ -118,6 +122,7 @@ def __init__( | |
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, | ||
prediction_type: str = "epsilon", | ||
interpolation_type: str = "linear", | ||
use_karras_sigmas: Optional[bool] = False, | ||
): | ||
if trained_betas is not None: | ||
self.betas = torch.tensor(trained_betas, dtype=torch.float32) | ||
|
@@ -149,6 +154,7 @@ def __init__( | |
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() | ||
self.timesteps = torch.from_numpy(timesteps) | ||
self.is_scale_input_called = False | ||
self.use_karras_sigmas = use_karras_sigmas | ||
|
||
def scale_model_input( | ||
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] | ||
|
@@ -187,6 +193,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic | |
|
||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() | ||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) | ||
log_sigmas = np.log(sigmas) | ||
|
||
if self.config.interpolation_type == "linear": | ||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) | ||
|
@@ -198,6 +205,10 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic | |
" 'linear' or 'log_linear'" | ||
) | ||
|
||
if self.use_karras_sigmas: | ||
sigmas = self._convert_to_karras(in_sigmas=sigmas) | ||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) | ||
|
||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) | ||
self.sigmas = torch.from_numpy(sigmas).to(device=device) | ||
if str(device).startswith("mps"): | ||
|
@@ -206,6 +217,43 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic | |
else: | ||
self.timesteps = torch.from_numpy(timesteps).to(device=device) | ||
|
||
def _sigma_to_t(self, sigma, log_sigmas): | ||
# get log sigma | ||
log_sigma = np.log(sigma) | ||
|
||
# get distribution | ||
dists = log_sigma - log_sigmas[:, np.newaxis] | ||
|
||
# get sigmas range | ||
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) | ||
high_idx = low_idx + 1 | ||
|
||
low = log_sigmas[low_idx] | ||
high = log_sigmas[high_idx] | ||
|
||
# interpolate sigmas | ||
w = (low - log_sigma) / (low - high) | ||
w = np.clip(w, 0, 1) | ||
|
||
# transform interpolation to time range | ||
t = (1 - w) * low_idx + w * high_idx | ||
t = t.reshape(sigma.shape) | ||
return t | ||
|
||
# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 | ||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perfect this function works for me! Happy to just go with this function as it was before :-) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. Fixed. |
||
"""Constructs the noise schedule of Karras et al. (2022).""" | ||
|
||
sigma_min: float = in_sigmas[-1].item() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should a config param for sigma min/max be also introduced which can be utilized here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think since |
||
sigma_max: float = in_sigmas[0].item() | ||
|
||
rho = 7.0 # 7.0 is the value used in the paper | ||
ramp = np.linspace(0, 1, self.num_inference_steps) | ||
min_inv_rho = sigma_min ** (1 / rho) | ||
max_inv_rho = sigma_max ** (1 / rho) | ||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho | ||
return sigmas | ||
|
||
def step( | ||
self, | ||
model_output: torch.FloatTensor, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool! This works for me :-)