-
Notifications
You must be signed in to change notification settings - Fork 6k
[2905]: Add Karras pattern to discrete euler #2956
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -103,6 +103,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): | |
interpolation_type (`str`, default `"linear"`, optional): | ||
interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of | ||
[`"linear"`, `"log_linear"`]. | ||
use_karras_sigmas (`bool`, *optional*, defaults to `False`): | ||
Use karras sigmas. For example, specifying `sample_dpmpp_2m` to `set_scheduler` will be equivalent to | ||
`DPM++2M` in stable-diffusion-webui. On top of that, setting this option to True will make it `DPM++2M | ||
Karras`. Please see equation (5) https://arxiv.org/pdf/2206.00364.pdf for more details. | ||
""" | ||
|
||
_compatibles = [e.name for e in KarrasDiffusionSchedulers] | ||
|
@@ -118,6 +122,7 @@ def __init__( | |
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, | ||
prediction_type: str = "epsilon", | ||
interpolation_type: str = "linear", | ||
use_karras_sigmas: Optional[bool] = False, | ||
): | ||
if trained_betas is not None: | ||
self.betas = torch.tensor(trained_betas, dtype=torch.float32) | ||
|
@@ -149,6 +154,7 @@ def __init__( | |
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() | ||
self.timesteps = torch.from_numpy(timesteps) | ||
self.is_scale_input_called = False | ||
self.use_karras_sigmas = use_karras_sigmas | ||
|
||
def scale_model_input( | ||
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] | ||
|
@@ -187,6 +193,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic | |
|
||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() | ||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) | ||
log_sigmas = np.log(sigmas) | ||
|
||
if self.config.interpolation_type == "linear": | ||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) | ||
|
@@ -198,6 +205,10 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic | |
" 'linear' or 'log_linear'" | ||
) | ||
|
||
if self.use_karras_sigmas: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very cool! This works for me :-) |
||
sigmas = self._convert_to_karras(in_sigmas=sigmas) | ||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) | ||
|
||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) | ||
self.sigmas = torch.from_numpy(sigmas).to(device=device) | ||
if str(device).startswith("mps"): | ||
|
@@ -206,6 +217,43 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic | |
else: | ||
self.timesteps = torch.from_numpy(timesteps).to(device=device) | ||
|
||
def _sigma_to_t(self, sigma, log_sigmas): | ||
# get log sigma | ||
log_sigma = np.log(sigma) | ||
|
||
# get distribution | ||
dists = log_sigma - log_sigmas[:, np.newaxis] | ||
|
||
# get sigmas range | ||
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) | ||
high_idx = low_idx + 1 | ||
|
||
low = log_sigmas[low_idx] | ||
high = log_sigmas[high_idx] | ||
|
||
# interpolate sigmas | ||
w = (low - log_sigma) / (low - high) | ||
w = np.clip(w, 0, 1) | ||
|
||
# transform interpolation to time range | ||
t = (1 - w) * low_idx + w * high_idx | ||
t = t.reshape(sigma.shape) | ||
return t | ||
|
||
# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 | ||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perfect this function works for me! Happy to just go with this function as it was before :-) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. Fixed. |
||
"""Constructs the noise schedule of Karras et al. (2022).""" | ||
|
||
sigma_min: float = in_sigmas[-1].item() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should a config param for sigma min/max be also introduced which can be utilized here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think since |
||
sigma_max: float = in_sigmas[0].item() | ||
|
||
rho = 7.0 # 7.0 is the value used in the paper | ||
ramp = np.linspace(0, 1, self.num_inference_steps) | ||
min_inv_rho = sigma_min ** (1 / rho) | ||
max_inv_rho = sigma_max ** (1 / rho) | ||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho | ||
return sigmas | ||
|
||
def step( | ||
self, | ||
model_output: torch.FloatTensor, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might be wrong, but I think this comment is only appropriate for the custom k-diffusion pipeline, whereas these sigmas can be used in other pipelines if I understand it correctly.
Also, not sure about why we are mentioning
sample_dpmpp_2m
in the Euler scheduler?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, these sigmas can be used generically in other scheduler pipelines as well. Updated the READme.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating 🙏 I love the new text!