Skip to content

Commit 7337b87

Browse files
authored
Add support Karras sigmas for StableDiffusionKDiffusionPipeline (huggingface#2874)
* add use_karras_sigmas option thanks @Stax124 * fix sigma_min/max from scheduler.sigmas * add docstring * revert to use k_diffusion_model.sigma, to(device) * add integration test * make style
1 parent ce93063 commit 7337b87

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import torch
1919
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
20+
from k_diffusion.sampling import get_sigmas_karras
2021

2122
from ...loaders import TextualInversionLoaderMixin
2223
from ...pipelines import DiffusionPipeline
@@ -409,6 +410,7 @@ def __call__(
409410
return_dict: bool = True,
410411
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
411412
callback_steps: int = 1,
413+
use_karras_sigmas: Optional[bool] = False,
412414
):
413415
r"""
414416
Function invoked when calling the pipeline for generation.
@@ -465,7 +467,10 @@ def __call__(
465467
callback_steps (`int`, *optional*, defaults to 1):
466468
The frequency at which the `callback` function will be called. If not specified, the callback will be
467469
called at every step.
468-
470+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
471+
Use karras sigmas. For example, specifying `sample_dpmpp_2m` to `set_scheduler` will be equivalent to
472+
`DPM++2M` in stable-diffusion-webui. On top of that, setting this option to True will make it `DPM++2M
473+
Karras`.
469474
Returns:
470475
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
471476
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
@@ -503,10 +508,18 @@ def __call__(
503508

504509
# 4. Prepare timesteps
505510
self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device)
506-
sigmas = self.scheduler.sigmas
511+
512+
# 5. Prepare sigmas
513+
if use_karras_sigmas:
514+
sigma_min: float = self.k_diffusion_model.sigmas[0].item()
515+
sigma_max: float = self.k_diffusion_model.sigmas[-1].item()
516+
sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max)
517+
sigmas = sigmas.to(device)
518+
else:
519+
sigmas = self.scheduler.sigmas
507520
sigmas = sigmas.to(prompt_embeds.dtype)
508521

509-
# 5. Prepare latent variables
522+
# 6. Prepare latent variables
510523
num_channels_latents = self.unet.in_channels
511524
latents = self.prepare_latents(
512525
batch_size * num_images_per_prompt,
@@ -522,7 +535,7 @@ def __call__(
522535
self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
523536
self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device)
524537

525-
# 6. Define model function
538+
# 7. Define model function
526539
def model_fn(x, t):
527540
latent_model_input = torch.cat([x] * 2)
528541
t = torch.cat([t] * 2)
@@ -533,16 +546,16 @@ def model_fn(x, t):
533546
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
534547
return noise_pred
535548

536-
# 7. Run k-diffusion solver
549+
# 8. Run k-diffusion solver
537550
latents = self.sampler(model_fn, latents, sigmas)
538551

539-
# 8. Post-processing
552+
# 9. Post-processing
540553
image = self.decode_latents(latents)
541554

542-
# 9. Run safety checker
555+
# 10. Run safety checker
543556
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
544557

545-
# 10. Convert to PIL
558+
# 11. Convert to PIL
546559
if output_type == "pil":
547560
image = self.numpy_to_pil(image)
548561

0 commit comments

Comments
 (0)