Skip to content

Commit ad8f985

Browse files
Allow low precision vae sd xl (#4083)
* Allow low precision sd xl * finish * finish * make style
1 parent ee2f277 commit ad8f985

File tree

4 files changed

+79
-62
lines changed

4 files changed

+79
-62
lines changed

src/diffusers/models/autoencoder_kl.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
6464
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
6565
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
6666
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
67+
force_upcast (`bool`, *optional*, default to `True`):
68+
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
69+
can be fine-tuned / trained to a lower range without loosing too much precision in which case
70+
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
6771
"""
6872

6973
_supports_gradient_checkpointing = True
@@ -82,6 +86,7 @@ def __init__(
8286
norm_num_groups: int = 32,
8387
sample_size: int = 32,
8488
scaling_factor: float = 0.18215,
89+
force_upcast: float = True,
8590
):
8691
super().__init__()
8792

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,25 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
501501
latents = latents * self.scheduler.init_noise_sigma
502502
return latents
503503

504+
def upcast_vae(self):
505+
dtype = self.vae.dtype
506+
self.vae.to(dtype=torch.float32)
507+
use_torch_2_0_or_xformers = isinstance(
508+
self.vae.decoder.mid_block.attentions[0].processor,
509+
(
510+
AttnProcessor2_0,
511+
XFormersAttnProcessor,
512+
LoRAXFormersAttnProcessor,
513+
LoRAAttnProcessor2_0,
514+
),
515+
)
516+
# if xformers or torch_2_0 is used attention block does not need
517+
# to be in float32 which can save lots of memory
518+
if use_torch_2_0_or_xformers:
519+
self.vae.post_quant_conv.to(dtype)
520+
self.vae.decoder.conv_in.to(dtype)
521+
self.vae.decoder.mid_block.to(dtype)
522+
504523
@torch.no_grad()
505524
def __call__(
506525
self,
@@ -746,26 +765,9 @@ def __call__(
746765

747766
# 10. Post-processing
748767
# make sure the VAE is in float32 mode, as it overflows in float16
749-
self.vae.to(dtype=torch.float32)
750-
751-
use_torch_2_0_or_xformers = isinstance(
752-
self.vae.decoder.mid_block.attentions[0].processor,
753-
(
754-
AttnProcessor2_0,
755-
XFormersAttnProcessor,
756-
LoRAXFormersAttnProcessor,
757-
LoRAAttnProcessor2_0,
758-
),
759-
)
760-
761-
# if xformers or torch_2_0 is used attention block does not need
762-
# to be in float32 which can save lots of memory
763-
if use_torch_2_0_or_xformers:
764-
self.vae.post_quant_conv.to(latents.dtype)
765-
self.vae.decoder.conv_in.to(latents.dtype)
766-
self.vae.decoder.mid_block.to(latents.dtype)
767-
else:
768-
latents = latents.float()
768+
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
769+
self.upcast_vae()
770+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
769771

770772
# post-processing
771773
if not output_type == "latent":

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,26 @@ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, d
537537
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
538538
return add_time_ids
539539

540+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
541+
def upcast_vae(self):
542+
dtype = self.vae.dtype
543+
self.vae.to(dtype=torch.float32)
544+
use_torch_2_0_or_xformers = isinstance(
545+
self.vae.decoder.mid_block.attentions[0].processor,
546+
(
547+
AttnProcessor2_0,
548+
XFormersAttnProcessor,
549+
LoRAXFormersAttnProcessor,
550+
LoRAAttnProcessor2_0,
551+
),
552+
)
553+
# if xformers or torch_2_0 is used attention block does not need
554+
# to be in float32 which can save lots of memory
555+
if use_torch_2_0_or_xformers:
556+
self.vae.post_quant_conv.to(dtype)
557+
self.vae.decoder.conv_in.to(dtype)
558+
self.vae.decoder.mid_block.to(dtype)
559+
540560
@torch.no_grad()
541561
@replace_example_docstring(EXAMPLE_DOC_STRING)
542562
def __call__(
@@ -799,25 +819,9 @@ def __call__(
799819
callback(i, t, latents)
800820

801821
# make sure the VAE is in float32 mode, as it overflows in float16
802-
self.vae.to(dtype=torch.float32)
803-
804-
use_torch_2_0_or_xformers = isinstance(
805-
self.vae.decoder.mid_block.attentions[0].processor,
806-
(
807-
AttnProcessor2_0,
808-
XFormersAttnProcessor,
809-
LoRAXFormersAttnProcessor,
810-
LoRAAttnProcessor2_0,
811-
),
812-
)
813-
# if xformers or torch_2_0 is used attention block does not need
814-
# to be in float32 which can save lots of memory
815-
if use_torch_2_0_or_xformers:
816-
self.vae.post_quant_conv.to(latents.dtype)
817-
self.vae.decoder.conv_in.to(latents.dtype)
818-
self.vae.decoder.mid_block.to(latents.dtype)
819-
else:
820-
latents = latents.float()
822+
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
823+
self.upcast_vae()
824+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
821825

822826
if not output_type == "latent":
823827
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -542,8 +542,9 @@ def prepare_latents(
542542

543543
else:
544544
# make sure the VAE is in float32 mode, as it overflows in float16
545-
image = image.float()
546-
self.vae.to(dtype=torch.float32)
545+
if self.vae.config.force_upcast:
546+
image = image.float()
547+
self.vae.to(dtype=torch.float32)
547548

548549
if isinstance(generator, list) and len(generator) != batch_size:
549550
raise ValueError(
@@ -559,9 +560,10 @@ def prepare_latents(
559560
else:
560561
init_latents = self.vae.encode(image).latent_dist.sample(generator)
561562

562-
self.vae.to(dtype)
563-
init_latents = init_latents.to(dtype)
563+
if self.vae.config.force_upcast:
564+
self.vae.to(dtype)
564565

566+
init_latents = init_latents.to(dtype)
565567
init_latents = self.vae.config.scaling_factor * init_latents
566568

567569
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
@@ -624,6 +626,26 @@ def _get_add_time_ids(
624626

625627
return add_time_ids, add_neg_time_ids
626628

629+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
630+
def upcast_vae(self):
631+
dtype = self.vae.dtype
632+
self.vae.to(dtype=torch.float32)
633+
use_torch_2_0_or_xformers = isinstance(
634+
self.vae.decoder.mid_block.attentions[0].processor,
635+
(
636+
AttnProcessor2_0,
637+
XFormersAttnProcessor,
638+
LoRAXFormersAttnProcessor,
639+
LoRAAttnProcessor2_0,
640+
),
641+
)
642+
# if xformers or torch_2_0 is used attention block does not need
643+
# to be in float32 which can save lots of memory
644+
if use_torch_2_0_or_xformers:
645+
self.vae.post_quant_conv.to(dtype)
646+
self.vae.decoder.conv_in.to(dtype)
647+
self.vae.decoder.mid_block.to(dtype)
648+
627649
@torch.no_grad()
628650
@replace_example_docstring(EXAMPLE_DOC_STRING)
629651
def __call__(
@@ -932,25 +954,9 @@ def __call__(
932954
callback(i, t, latents)
933955

934956
# make sure the VAE is in float32 mode, as it overflows in float16
935-
self.vae.to(dtype=torch.float32)
936-
937-
use_torch_2_0_or_xformers = isinstance(
938-
self.vae.decoder.mid_block.attentions[0].processor,
939-
(
940-
AttnProcessor2_0,
941-
XFormersAttnProcessor,
942-
LoRAXFormersAttnProcessor,
943-
LoRAAttnProcessor2_0,
944-
),
945-
)
946-
# if xformers or torch_2_0 is used attention block does not need
947-
# to be in float32 which can save lots of memory
948-
if use_torch_2_0_or_xformers:
949-
self.vae.post_quant_conv.to(latents.dtype)
950-
self.vae.decoder.conv_in.to(latents.dtype)
951-
self.vae.decoder.mid_block.to(latents.dtype)
952-
else:
953-
latents = latents.float()
957+
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
958+
self.upcast_vae()
959+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
954960

955961
if not output_type == "latent":
956962
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]

0 commit comments

Comments
 (0)