Skip to content

Commit 85f28b3

Browse files
tolgacangozyiyixuxu
authored andcommitted
Fix image upcasting (huggingface#7858)
Fix image's upcasting before `vae.encode()` when using `fp16` Co-authored-by: YiYi Xu <[email protected]>
1 parent ae4a515 commit 85f28b3

File tree

2 files changed

+1
-2
lines changed

2 files changed

+1
-2
lines changed

src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1419,7 +1419,6 @@ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="
14191419
if needs_upcasting:
14201420
image = image.float()
14211421
self.upcast_vae()
1422-
image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
14231422

14241423
x0 = self.vae.encode(image).latent_dist.mode()
14251424
x0 = x0.to(dtype)

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,8 +525,8 @@ def prepare_image_latents(
525525
# make sure the VAE is in float32 mode, as it overflows in float16
526526
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
527527
if needs_upcasting:
528+
image = image.float()
528529
self.upcast_vae()
529-
image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
530530

531531
image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
532532

0 commit comments

Comments
 (0)