Skip to content

Commit b62d9a1

Browse files
authored
[Text-to-video] Add torch.compile() compatibility (#3949)
* use sample directly instead of the dataclass. * more usage of directly samples instead of dataclasses * more usage of directly samples instead of dataclasses * use direct sample in the pipeline. * direct usage of sample in the img2img case.
1 parent 46af982 commit b62d9a1

File tree

4 files changed

+21
-13
lines changed

4 files changed

+21
-13
lines changed

src/diffusers/models/unet_3d_blocks.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -250,10 +250,11 @@ def forward(
250250
hidden_states,
251251
encoder_hidden_states=encoder_hidden_states,
252252
cross_attention_kwargs=cross_attention_kwargs,
253-
).sample
253+
return_dict=False,
254+
)[0]
254255
hidden_states = temp_attn(
255-
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
256-
).sample
256+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
257+
)[0]
257258
hidden_states = resnet(hidden_states, temb)
258259
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
259260

@@ -377,10 +378,11 @@ def forward(
377378
hidden_states,
378379
encoder_hidden_states=encoder_hidden_states,
379380
cross_attention_kwargs=cross_attention_kwargs,
380-
).sample
381+
return_dict=False,
382+
)[0]
381383
hidden_states = temp_attn(
382-
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
383-
).sample
384+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
385+
)[0]
384386

385387
output_states += (hidden_states,)
386388

@@ -590,10 +592,11 @@ def forward(
590592
hidden_states,
591593
encoder_hidden_states=encoder_hidden_states,
592594
cross_attention_kwargs=cross_attention_kwargs,
593-
).sample
595+
return_dict=False,
596+
)[0]
594597
hidden_states = temp_attn(
595-
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
596-
).sample
598+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
599+
)[0]
597600

598601
if self.upsamplers is not None:
599602
for upsampler in self.upsamplers:

src/diffusers/models/unet_3d_condition.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -526,8 +526,11 @@ def forward(
526526
sample = self.conv_in(sample)
527527

528528
sample = self.transformer_in(
529-
sample, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
530-
).sample
529+
sample,
530+
num_frames=num_frames,
531+
cross_attention_kwargs=cross_attention_kwargs,
532+
return_dict=False,
533+
)[0]
531534

532535
# 3. down
533536
down_block_res_samples = (sample,)

src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,8 @@ def __call__(
648648
t,
649649
encoder_hidden_states=prompt_embeds,
650650
cross_attention_kwargs=cross_attention_kwargs,
651-
).sample
651+
return_dict=False,
652+
)[0]
652653

653654
# perform guidance
654655
if do_classifier_free_guidance:

src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,8 @@ def __call__(
723723
t,
724724
encoder_hidden_states=prompt_embeds,
725725
cross_attention_kwargs=cross_attention_kwargs,
726-
).sample
726+
return_dict=False,
727+
)[0]
727728

728729
# perform guidance
729730
if do_classifier_free_guidance:

0 commit comments

Comments
 (0)