Skip to content

Add support Karras sigmas for StableDiffusionKDiffusionPipeline #2874

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 31, 2023
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import torch
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
from k_diffusion.sampling import get_sigmas_karras

from ...pipelines import DiffusionPipeline
from ...schedulers import LMSDiscreteScheduler
Expand Down Expand Up @@ -400,6 +401,7 @@ def __call__(
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
use_karras_sigmas: Optional[bool] = False,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -456,7 +458,9 @@ def __call__(
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.

use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Use karras sigmas. For example, specifying `sample_dpmpp_2m` to `set_scheduler` will be equivalent to
`DPM++2M` in stable-diffusion-webui. On top of that, setting this option to True will make it `DPM++2M Karras`.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
Expand Down Expand Up @@ -494,10 +498,18 @@ def __call__(

# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device)
sigmas = self.scheduler.sigmas

# 5. Prepare sigmas
if use_karras_sigmas:
sigma_min: float = self.k_diffusion_model.sigmas[0].item()
sigma_max: float = self.k_diffusion_model.sigmas[-1].item()
sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max)
sigmas = sigmas.to(device)
else:
sigmas = self.scheduler.sigmas
sigmas = sigmas.to(prompt_embeds.dtype)

# 5. Prepare latent variables
# 6. Prepare latent variables
num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
Expand All @@ -513,7 +525,7 @@ def __call__(
self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device)

# 6. Define model function
# 7. Define model function
def model_fn(x, t):
latent_model_input = torch.cat([x] * 2)
t = torch.cat([t] * 2)
Expand All @@ -524,16 +536,16 @@ def model_fn(x, t):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
return noise_pred

# 7. Run k-diffusion solver
# 8. Run k-diffusion solver
latents = self.sampler(model_fn, latents, sigmas)

# 8. Post-processing
# 9. Post-processing
image = self.decode_latents(latents)

# 9. Run safety checker
# 10. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)

# 10. Convert to PIL
# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,32 @@ def test_stable_diffusion_2(self):
expected_slice = np.array([0.1237, 0.1320, 0.1438, 0.1359, 0.1390, 0.1132, 0.1277, 0.1175, 0.1112])

assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-1

def test_stable_diffusion_karras_sigmas(self):
sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)

sd_pipe.set_scheduler("sample_dpmpp_2m")

prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=7.5,
num_inference_steps=15,
output_type="np",
use_karras_sigmas=True,
)

image = output.images

image_slice = image[0, -3:, -3:, -1]

assert image.shape == (1, 512, 512, 3)
expected_slice = np.array(
[0.11381689, 0.12112921, 0.1389457, 0.12549606, 0.1244964, 0.10831517, 0.11562866, 0.10867816, 0.10499048]
)

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2