Skip to content

Commit 0e82fb1

Browse files
Torch compile graph fix (#3286)
* fix more * Fix more * fix more * Apply suggestions from code review * fix * make style * make fix-copies * fix * make sure torch compile * Clean * fix test
1 parent 709cf55 commit 0e82fb1

36 files changed

+109
-60
lines changed

src/diffusers/models/attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch.nn.functional as F
1919
from torch import nn
2020

21+
from ..utils import maybe_allow_in_graph
2122
from ..utils.import_utils import is_xformers_available
2223
from .attention_processor import Attention
2324
from .embeddings import CombinedTimestepLabelEmbeddings
@@ -193,6 +194,7 @@ def forward(self, hidden_states):
193194
return hidden_states
194195

195196

197+
@maybe_allow_in_graph
196198
class BasicTransformerBlock(nn.Module):
197199
r"""
198200
A basic Transformer block.

src/diffusers/models/attention_processor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch.nn.functional as F
1818
from torch import nn
1919

20-
from ..utils import deprecate, logging
20+
from ..utils import deprecate, logging, maybe_allow_in_graph
2121
from ..utils.import_utils import is_xformers_available
2222

2323

@@ -31,6 +31,7 @@
3131
xformers = None
3232

3333

34+
@maybe_allow_in_graph
3435
class Attention(nn.Module):
3536
r"""
3637
A cross attention layer.

src/diffusers/models/modeling_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,14 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
7777

7878
def get_parameter_dtype(parameter: torch.nn.Module):
7979
try:
80-
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
81-
return next(parameters_and_buffers).dtype
80+
params = tuple(parameter.parameters())
81+
if len(params) > 0:
82+
return params[0].dtype
83+
84+
buffers = tuple(parameter.buffers())
85+
if len(buffers) > 0:
86+
return buffers[0].dtype
87+
8288
except StopIteration:
8389
# For torch.nn.DataParallel compatibility in PyTorch 1.5
8490

src/diffusers/models/unet_2d_blocks.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,8 @@ def forward(
560560
hidden_states,
561561
encoder_hidden_states=encoder_hidden_states,
562562
cross_attention_kwargs=cross_attention_kwargs,
563-
).sample
563+
return_dict=False,
564+
)[0]
564565
hidden_states = resnet(hidden_states, temb)
565566

566567
return hidden_states
@@ -868,15 +869,16 @@ def custom_forward(*inputs):
868869
hidden_states,
869870
encoder_hidden_states=encoder_hidden_states,
870871
cross_attention_kwargs=cross_attention_kwargs,
871-
).sample
872+
return_dict=False,
873+
)[0]
872874

873-
output_states += (hidden_states,)
875+
output_states = output_states + (hidden_states,)
874876

875877
if self.downsamplers is not None:
876878
for downsampler in self.downsamplers:
877879
hidden_states = downsampler(hidden_states)
878880

879-
output_states += (hidden_states,)
881+
output_states = output_states + (hidden_states,)
880882

881883
return hidden_states, output_states
882884

@@ -949,13 +951,13 @@ def custom_forward(*inputs):
949951
else:
950952
hidden_states = resnet(hidden_states, temb)
951953

952-
output_states += (hidden_states,)
954+
output_states = output_states + (hidden_states,)
953955

954956
if self.downsamplers is not None:
955957
for downsampler in self.downsamplers:
956958
hidden_states = downsampler(hidden_states)
957959

958-
output_states += (hidden_states,)
960+
output_states = output_states + (hidden_states,)
959961

960962
return hidden_states, output_states
961963

@@ -1342,13 +1344,13 @@ def custom_forward(*inputs):
13421344
else:
13431345
hidden_states = resnet(hidden_states, temb)
13441346

1345-
output_states += (hidden_states,)
1347+
output_states = output_states + (hidden_states,)
13461348

13471349
if self.downsamplers is not None:
13481350
for downsampler in self.downsamplers:
13491351
hidden_states = downsampler(hidden_states, temb)
13501352

1351-
output_states += (hidden_states,)
1353+
output_states = output_states + (hidden_states,)
13521354

13531355
return hidden_states, output_states
13541356

@@ -1466,13 +1468,13 @@ def forward(
14661468
**cross_attention_kwargs,
14671469
)
14681470

1469-
output_states += (hidden_states,)
1471+
output_states = output_states + (hidden_states,)
14701472

14711473
if self.downsamplers is not None:
14721474
for downsampler in self.downsamplers:
14731475
hidden_states = downsampler(hidden_states, temb)
14741476

1475-
output_states += (hidden_states,)
1477+
output_states = output_states + (hidden_states,)
14761478

14771479
return hidden_states, output_states
14781480

@@ -1859,7 +1861,8 @@ def custom_forward(*inputs):
18591861
hidden_states,
18601862
encoder_hidden_states=encoder_hidden_states,
18611863
cross_attention_kwargs=cross_attention_kwargs,
1862-
).sample
1864+
return_dict=False,
1865+
)[0]
18631866

18641867
if self.upsamplers is not None:
18651868
for upsampler in self.upsamplers:

src/diffusers/models/unet_2d_condition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,7 @@ def forward(
682682
# `Timesteps` does not contain any weights and will always return f32 tensors
683683
# but time_embedding might actually be running in fp16. so we need to cast here.
684684
# there might be better ways to encapsulate this.
685-
t_emb = t_emb.to(dtype=self.dtype)
685+
t_emb = t_emb.to(dtype=sample.dtype)
686686

687687
emb = self.time_embedding(t_emb, timestep_cond)
688688

@@ -697,7 +697,7 @@ def forward(
697697
# there might be better ways to encapsulate this.
698698
class_labels = class_labels.to(dtype=sample.dtype)
699699

700-
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
700+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
701701

702702
if self.config.class_embeddings_concat:
703703
emb = torch.cat([emb, class_emb], dim=-1)

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def run_safety_checker(self, image, device, dtype):
437437

438438
def decode_latents(self, latents):
439439
latents = 1 / self.vae.config.scaling_factor * latents
440-
image = self.vae.decode(latents).sample
440+
image = self.vae.decode(latents, return_dict=False)[0]
441441
image = (image / 2 + 0.5).clamp(0, 1)
442442
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
443443
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
@@ -683,15 +683,16 @@ def __call__(
683683
t,
684684
encoder_hidden_states=prompt_embeds,
685685
cross_attention_kwargs=cross_attention_kwargs,
686-
).sample
686+
return_dict=False,
687+
)[0]
687688

688689
# perform guidance
689690
if do_classifier_free_guidance:
690691
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
691692
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
692693

693694
# compute the previous noisy sample x_t -> x_t-1
694-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
695+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
695696

696697
# call the callback, if provided
697698
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

src/diffusers/pipelines/deepfloyd_if/pipeline_if.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,8 @@ def __call__(
793793
t,
794794
encoder_hidden_states=prompt_embeds,
795795
cross_attention_kwargs=cross_attention_kwargs,
796-
).sample
796+
return_dict=False,
797+
)[0]
797798

798799
# perform guidance
799800
if do_classifier_free_guidance:
@@ -805,8 +806,8 @@ def __call__(
805806

806807
# compute the previous noisy sample x_t -> x_t-1
807808
intermediate_images = self.scheduler.step(
808-
noise_pred, t, intermediate_images, **extra_step_kwargs
809-
).prev_sample
809+
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
810+
)[0]
810811

811812
# call the callback, if provided
812813
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -829,7 +830,7 @@ def __call__(
829830

830831
# 11. Apply watermark
831832
if self.watermarker is not None:
832-
self.watermarker.apply_watermark(image, self.unet.config.sample_size)
833+
image = self.watermarker.apply_watermark(image, self.unet.config.sample_size)
833834
elif output_type == "pt":
834835
nsfw_detected = None
835836
watermark_detected = None

src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def prepare_extra_step_kwargs(self, generator, eta):
256256
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
257257
def decode_latents(self, latents):
258258
latents = 1 / self.vae.config.scaling_factor * latents
259-
image = self.vae.decode(latents).sample
259+
image = self.vae.decode(latents, return_dict=False)[0]
260260
image = (image / 2 + 0.5).clamp(0, 1)
261261
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
262262
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def __init__(
134134
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
135135
def decode_latents(self, latents):
136136
latents = 1 / self.vae.config.scaling_factor * latents
137-
image = self.vae.decode(latents).sample
137+
image = self.vae.decode(latents, return_dict=False)[0]
138138
image = (image / 2 + 0.5).clamp(0, 1)
139139
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
140140
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ def run_safety_checker(self, image, device, dtype):
516516
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
517517
def decode_latents(self, latents):
518518
latents = 1 / self.vae.config.scaling_factor * latents
519-
image = self.vae.decode(latents).sample
519+
image = self.vae.decode(latents, return_dict=False)[0]
520520
image = (image / 2 + 0.5).clamp(0, 1)
521521
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
522522
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def run_safety_checker(self, image, device, dtype):
440440

441441
def decode_latents(self, latents):
442442
latents = 1 / self.vae.config.scaling_factor * latents
443-
image = self.vae.decode(latents).sample
443+
image = self.vae.decode(latents, return_dict=False)[0]
444444
image = (image / 2 + 0.5).clamp(0, 1)
445445
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
446446
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
@@ -686,15 +686,16 @@ def __call__(
686686
t,
687687
encoder_hidden_states=prompt_embeds,
688688
cross_attention_kwargs=cross_attention_kwargs,
689-
).sample
689+
return_dict=False,
690+
)[0]
690691

691692
# perform guidance
692693
if do_classifier_free_guidance:
693694
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
694695
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
695696

696697
# compute the previous noisy sample x_t -> x_t-1
697-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
698+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
698699

699700
# call the callback, if provided
700701
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ def run_safety_checker(self, image, device, dtype):
454454
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
455455
def decode_latents(self, latents):
456456
latents = 1 / self.vae.config.scaling_factor * latents
457-
image = self.vae.decode(latents).sample
457+
image = self.vae.decode(latents, return_dict=False)[0]
458458
image = (image / 2 + 0.5).clamp(0, 1)
459459
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
460460
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ def run_safety_checker(self, image, device, dtype):
496496
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
497497
def decode_latents(self, latents):
498498
latents = 1 / self.vae.config.scaling_factor * latents
499-
image = self.vae.decode(latents).sample
499+
image = self.vae.decode(latents, return_dict=False)[0]
500500
image = (image / 2 + 0.5).clamp(0, 1)
501501
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
502502
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def run_safety_checker(self, image, device, dtype):
326326
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
327327
def decode_latents(self, latents):
328328
latents = 1 / self.vae.config.scaling_factor * latents
329-
image = self.vae.decode(latents).sample
329+
image = self.vae.decode(latents, return_dict=False)[0]
330330
image = (image / 2 + 0.5).clamp(0, 1)
331331
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
332332
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ def prepare_extra_step_kwargs(self, generator, eta):
648648
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
649649
def decode_latents(self, latents):
650650
latents = 1 / self.vae.config.scaling_factor * latents
651-
image = self.vae.decode(latents).sample
651+
image = self.vae.decode(latents, return_dict=False)[0]
652652
image = (image / 2 + 0.5).clamp(0, 1)
653653
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
654654
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def run_safety_checker(self, image, device, dtype):
195195
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
196196
def decode_latents(self, latents):
197197
latents = 1 / self.vae.config.scaling_factor * latents
198-
image = self.vae.decode(latents).sample
198+
image = self.vae.decode(latents, return_dict=False)[0]
199199
image = (image / 2 + 0.5).clamp(0, 1)
200200
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
201201
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ def prepare_extra_step_kwargs(self, generator, eta):
525525
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
526526
def decode_latents(self, latents):
527527
latents = 1 / self.vae.config.scaling_factor * latents
528-
image = self.vae.decode(latents).sample
528+
image = self.vae.decode(latents, return_dict=False)[0]
529529
image = (image / 2 + 0.5).clamp(0, 1)
530530
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
531531
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def run_safety_checker(self, image, device, dtype):
446446
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
447447
def decode_latents(self, latents):
448448
latents = 1 / self.vae.config.scaling_factor * latents
449-
image = self.vae.decode(latents).sample
449+
image = self.vae.decode(latents, return_dict=False)[0]
450450
image = (image / 2 + 0.5).clamp(0, 1)
451451
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
452452
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ def prepare_extra_step_kwargs(self, generator, eta):
656656
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
657657
def decode_latents(self, latents):
658658
latents = 1 / self.vae.config.scaling_factor * latents
659-
image = self.vae.decode(latents).sample
659+
image = self.vae.decode(latents, return_dict=False)[0]
660660
image = (image / 2 + 0.5).clamp(0, 1)
661661
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
662662
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def run_safety_checker(self, image, device, dtype):
358358
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
359359
def decode_latents(self, latents):
360360
latents = 1 / self.vae.config.scaling_factor * latents
361-
image = self.vae.decode(latents).sample
361+
image = self.vae.decode(latents, return_dict=False)[0]
362362
image = (image / 2 + 0.5).clamp(0, 1)
363363
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
364364
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def _encode_prompt(self, prompt, device, do_classifier_free_guidance, negative_p
221221
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
222222
def decode_latents(self, latents):
223223
latents = 1 / self.vae.config.scaling_factor * latents
224-
image = self.vae.decode(latents).sample
224+
image = self.vae.decode(latents, return_dict=False)[0]
225225
image = (image / 2 + 0.5).clamp(0, 1)
226226
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
227227
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def run_safety_checker(self, image, device, dtype):
385385
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
386386
def decode_latents(self, latents):
387387
latents = 1 / self.vae.config.scaling_factor * latents
388-
image = self.vae.decode(latents).sample
388+
image = self.vae.decode(latents, return_dict=False)[0]
389389
image = (image / 2 + 0.5).clamp(0, 1)
390390
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
391391
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def run_safety_checker(self, image, device, dtype):
349349
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
350350
def decode_latents(self, latents):
351351
latents = 1 / self.vae.config.scaling_factor * latents
352-
image = self.vae.decode(latents).sample
352+
image = self.vae.decode(latents, return_dict=False)[0]
353353
image = (image / 2 + 0.5).clamp(0, 1)
354354
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
355355
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ def run_safety_checker(self, image, device, dtype):
590590
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
591591
def decode_latents(self, latents):
592592
latents = 1 / self.vae.config.scaling_factor * latents
593-
image = self.vae.decode(latents).sample
593+
image = self.vae.decode(latents, return_dict=False)[0]
594594
image = (image / 2 + 0.5).clamp(0, 1)
595595
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
596596
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

0 commit comments

Comments
 (0)