-
Notifications
You must be signed in to change notification settings - Fork 6k
[2905]: Add Karras pattern to discrete euler #2956
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -103,6 +103,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): | |
interpolation_type (`str`, default `"linear"`, optional): | ||
interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of | ||
[`"linear"`, `"log_linear"`]. | ||
use_karras_sigmas (`bool`, *optional*, defaults to `False`): | ||
Use karras sigmas. For example, specifying `sample_dpmpp_2m` to `set_scheduler` will be equivalent to | ||
`DPM++2M` in stable-diffusion-webui. On top of that, setting this option to True will make it `DPM++2M | ||
Karras`. | ||
""" | ||
|
||
_compatibles = [e.name for e in KarrasDiffusionSchedulers] | ||
|
@@ -118,6 +122,7 @@ def __init__( | |
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, | ||
prediction_type: str = "epsilon", | ||
interpolation_type: str = "linear", | ||
use_karras_sigmas: Optional[bool] = False, | ||
): | ||
if trained_betas is not None: | ||
self.betas = torch.tensor(trained_betas, dtype=torch.float32) | ||
|
@@ -149,6 +154,7 @@ def __init__( | |
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() | ||
self.timesteps = torch.from_numpy(timesteps) | ||
self.is_scale_input_called = False | ||
self.use_karras_sigmas = use_karras_sigmas | ||
|
||
def scale_model_input( | ||
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] | ||
|
@@ -187,6 +193,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic | |
|
||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() | ||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) | ||
log_sigmas = np.log(sigmas) | ||
|
||
if self.config.interpolation_type == "linear": | ||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) | ||
|
@@ -198,6 +205,10 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic | |
" 'linear' or 'log_linear'" | ||
) | ||
|
||
if self.use_karras_sigmas: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very cool! This works for me :-) |
||
sigmas = self._convert_to_karras(in_sigmas=sigmas) | ||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) | ||
|
||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) | ||
self.sigmas = torch.from_numpy(sigmas).to(device=device) | ||
if str(device).startswith("mps"): | ||
|
@@ -206,6 +217,45 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic | |
else: | ||
self.timesteps = torch.from_numpy(timesteps).to(device=device) | ||
|
||
def _sigma_to_t(self, sigma, log_sigmas): | ||
# get log sigma | ||
log_sigma = np.log(sigma) | ||
|
||
# get distribution | ||
dists = log_sigma - log_sigmas[:, np.newaxis] | ||
|
||
# get sigmas range | ||
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) | ||
high_idx = low_idx + 1 | ||
|
||
low = log_sigmas[low_idx] | ||
high = log_sigmas[high_idx] | ||
|
||
# interpolate sigmas | ||
w = (low - log_sigma) / (low - high) | ||
w = np.clip(w, 0, 1) | ||
|
||
# transform interpolation to time range | ||
t = (1 - w) * low_idx + w * high_idx | ||
t = t.reshape(sigma.shape) | ||
return t | ||
|
||
# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 | ||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perfect this function works for me! Happy to just go with this function as it was before :-) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. Fixed. |
||
"""Constructs the noise schedule of Karras et al. (2022).""" | ||
|
||
sigma_min: float = in_sigmas[-1].item() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should a config param for sigma min/max be also introduced which can be utilized here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think since |
||
sigma_max: float = in_sigmas[0].item() | ||
print(sigma_min, sigma_max) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needs to go. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed the section. |
||
|
||
rho = 7.0 | ||
# ramp = torch.linspace(0, 1, self.num_inference_steps) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably needs to go? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed the section. |
||
ramp = np.linspace(0, 1, self.num_inference_steps) | ||
min_inv_rho = sigma_min ** (1 / rho) | ||
max_inv_rho = sigma_max ** (1 / rho) | ||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho | ||
return sigmas | ||
|
||
def step( | ||
self, | ||
model_output: torch.FloatTensor, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -117,3 +117,30 @@ def test_full_loop_device(self): | |
|
||
assert abs(result_sum.item() - 10.0807) < 1e-2 | ||
assert abs(result_mean.item() - 0.0131) < 1e-3 | ||
|
||
def test_full_loop_device_karras_sigmas(self): | ||
scheduler_class = self.scheduler_classes[0] | ||
scheduler_config = self.get_scheduler_config() | ||
scheduler = scheduler_class(**scheduler_config, use_karras_sigmas=True) | ||
|
||
scheduler.set_timesteps(self.num_inference_steps, device=torch_device) | ||
|
||
generator = torch.manual_seed(0) | ||
|
||
model = self.dummy_model() | ||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma | ||
sample = sample.to(torch_device) | ||
|
||
for t in scheduler.timesteps: | ||
sample = scheduler.scale_model_input(sample, t) | ||
|
||
model_output = model(sample, t) | ||
|
||
output = scheduler.step(model_output, t, sample, generator=generator) | ||
sample = output.prev_sample | ||
|
||
result_sum = torch.sum(torch.abs(sample)) | ||
result_mean = torch.mean(torch.abs(sample)) | ||
print(result_sum.item(), result_mean.item()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needs to go? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed the section. |
||
assert abs(result_sum.item() - 124.52299499511719) < 1e-2 | ||
assert abs(result_mean.item() - 0.16213932633399963) < 1e-3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we also include the reference paper that introduced it?
#2874 (comment)
Feel free to also include it in
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.