Skip to content

Commit bab455a

Browse files
takuma104w4ffl35
authored andcommitted
Add support Karras sigmas for StableDiffusionKDiffusionPipeline (huggingface#2874)
* add use_karras_sigmas option thanks @Stax124 * fix sigma_min/max from scheduler.sigmas * add docstring * revert to use k_diffusion_model.sigma, to(device) * add integration test * make style
1 parent 68741c1 commit bab455a

File tree

2 files changed

+50
-8
lines changed

2 files changed

+50
-8
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import torch
1919
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
20+
from k_diffusion.sampling import get_sigmas_karras
2021

2122
from ...loaders import TextualInversionLoaderMixin
2223
from ...pipelines import DiffusionPipeline
@@ -409,6 +410,7 @@ def __call__(
409410
return_dict: bool = True,
410411
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
411412
callback_steps: int = 1,
413+
use_karras_sigmas: Optional[bool] = False,
412414
):
413415
r"""
414416
Function invoked when calling the pipeline for generation.
@@ -465,7 +467,10 @@ def __call__(
465467
callback_steps (`int`, *optional*, defaults to 1):
466468
The frequency at which the `callback` function will be called. If not specified, the callback will be
467469
called at every step.
468-
470+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
471+
Use karras sigmas. For example, specifying `sample_dpmpp_2m` to `set_scheduler` will be equivalent to
472+
`DPM++2M` in stable-diffusion-webui. On top of that, setting this option to True will make it `DPM++2M
473+
Karras`.
469474
Returns:
470475
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
471476
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
@@ -503,10 +508,18 @@ def __call__(
503508

504509
# 4. Prepare timesteps
505510
self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device)
506-
sigmas = self.scheduler.sigmas
511+
512+
# 5. Prepare sigmas
513+
if use_karras_sigmas:
514+
sigma_min: float = self.k_diffusion_model.sigmas[0].item()
515+
sigma_max: float = self.k_diffusion_model.sigmas[-1].item()
516+
sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max)
517+
sigmas = sigmas.to(device)
518+
else:
519+
sigmas = self.scheduler.sigmas
507520
sigmas = sigmas.to(prompt_embeds.dtype)
508521

509-
# 5. Prepare latent variables
522+
# 6. Prepare latent variables
510523
num_channels_latents = self.unet.in_channels
511524
latents = self.prepare_latents(
512525
batch_size * num_images_per_prompt,
@@ -522,7 +535,7 @@ def __call__(
522535
self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
523536
self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device)
524537

525-
# 6. Define model function
538+
# 7. Define model function
526539
def model_fn(x, t):
527540
latent_model_input = torch.cat([x] * 2)
528541
t = torch.cat([t] * 2)
@@ -533,16 +546,16 @@ def model_fn(x, t):
533546
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
534547
return noise_pred
535548

536-
# 7. Run k-diffusion solver
549+
# 8. Run k-diffusion solver
537550
latents = self.sampler(model_fn, latents, sigmas)
538551

539-
# 8. Post-processing
552+
# 9. Post-processing
540553
image = self.decode_latents(latents)
541554

542-
# 9. Run safety checker
555+
# 10. Run safety checker
543556
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
544557

545-
# 10. Convert to PIL
558+
# 11. Convert to PIL
546559
if output_type == "pil":
547560
image = self.numpy_to_pil(image)
548561

tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,32 @@ def test_stable_diffusion_2(self):
7575
expected_slice = np.array([0.1237, 0.1320, 0.1438, 0.1359, 0.1390, 0.1132, 0.1277, 0.1175, 0.1112])
7676

7777
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-1
78+
79+
def test_stable_diffusion_karras_sigmas(self):
80+
sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
81+
sd_pipe = sd_pipe.to(torch_device)
82+
sd_pipe.set_progress_bar_config(disable=None)
83+
84+
sd_pipe.set_scheduler("sample_dpmpp_2m")
85+
86+
prompt = "A painting of a squirrel eating a burger"
87+
generator = torch.manual_seed(0)
88+
output = sd_pipe(
89+
[prompt],
90+
generator=generator,
91+
guidance_scale=7.5,
92+
num_inference_steps=15,
93+
output_type="np",
94+
use_karras_sigmas=True,
95+
)
96+
97+
image = output.images
98+
99+
image_slice = image[0, -3:, -3:, -1]
100+
101+
assert image.shape == (1, 512, 512, 3)
102+
expected_slice = np.array(
103+
[0.11381689, 0.12112921, 0.1389457, 0.12549606, 0.1244964, 0.10831517, 0.11562866, 0.10867816, 0.10499048]
104+
)
105+
106+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

0 commit comments

Comments
 (0)