17
17
18
18
import torch
19
19
from k_diffusion .external import CompVisDenoiser , CompVisVDenoiser
20
+ from k_diffusion .sampling import get_sigmas_karras
20
21
21
22
from ...loaders import TextualInversionLoaderMixin
22
23
from ...pipelines import DiffusionPipeline
@@ -409,6 +410,7 @@ def __call__(
409
410
return_dict : bool = True ,
410
411
callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
411
412
callback_steps : int = 1 ,
413
+ use_karras_sigmas : Optional [bool ] = False ,
412
414
):
413
415
r"""
414
416
Function invoked when calling the pipeline for generation.
@@ -465,7 +467,10 @@ def __call__(
465
467
callback_steps (`int`, *optional*, defaults to 1):
466
468
The frequency at which the `callback` function will be called. If not specified, the callback will be
467
469
called at every step.
468
-
470
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
471
+ Use karras sigmas. For example, specifying `sample_dpmpp_2m` to `set_scheduler` will be equivalent to
472
+ `DPM++2M` in stable-diffusion-webui. On top of that, setting this option to True will make it `DPM++2M
473
+ Karras`.
469
474
Returns:
470
475
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
471
476
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
@@ -503,10 +508,18 @@ def __call__(
503
508
504
509
# 4. Prepare timesteps
505
510
self .scheduler .set_timesteps (num_inference_steps , device = prompt_embeds .device )
506
- sigmas = self .scheduler .sigmas
511
+
512
+ # 5. Prepare sigmas
513
+ if use_karras_sigmas :
514
+ sigma_min : float = self .k_diffusion_model .sigmas [0 ].item ()
515
+ sigma_max : float = self .k_diffusion_model .sigmas [- 1 ].item ()
516
+ sigmas = get_sigmas_karras (n = num_inference_steps , sigma_min = sigma_min , sigma_max = sigma_max )
517
+ sigmas = sigmas .to (device )
518
+ else :
519
+ sigmas = self .scheduler .sigmas
507
520
sigmas = sigmas .to (prompt_embeds .dtype )
508
521
509
- # 5 . Prepare latent variables
522
+ # 6 . Prepare latent variables
510
523
num_channels_latents = self .unet .in_channels
511
524
latents = self .prepare_latents (
512
525
batch_size * num_images_per_prompt ,
@@ -522,7 +535,7 @@ def __call__(
522
535
self .k_diffusion_model .sigmas = self .k_diffusion_model .sigmas .to (latents .device )
523
536
self .k_diffusion_model .log_sigmas = self .k_diffusion_model .log_sigmas .to (latents .device )
524
537
525
- # 6 . Define model function
538
+ # 7 . Define model function
526
539
def model_fn (x , t ):
527
540
latent_model_input = torch .cat ([x ] * 2 )
528
541
t = torch .cat ([t ] * 2 )
@@ -533,16 +546,16 @@ def model_fn(x, t):
533
546
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
534
547
return noise_pred
535
548
536
- # 7 . Run k-diffusion solver
549
+ # 8 . Run k-diffusion solver
537
550
latents = self .sampler (model_fn , latents , sigmas )
538
551
539
- # 8 . Post-processing
552
+ # 9 . Post-processing
540
553
image = self .decode_latents (latents )
541
554
542
- # 9 . Run safety checker
555
+ # 10 . Run safety checker
543
556
image , has_nsfw_concept = self .run_safety_checker (image , device , prompt_embeds .dtype )
544
557
545
- # 10 . Convert to PIL
558
+ # 11 . Convert to PIL
546
559
if output_type == "pil" :
547
560
image = self .numpy_to_pil (image )
548
561
0 commit comments