Skip to content

[Text-to-video] Add torch.compile() compatibility #3949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 6, 2023
Merged

Conversation

sayakpaul
Copy link
Member

What does this PR do?

Fixes #3915

Description

torch.compile() for the repeat_interleave() function was added in a nightly build. See: pytorch/pytorch#99929.

So, once I upgraded to Torch 2.1 nightly, the issue went away. However, there were other issues which are fixed in this PR. The PR takes inspiration from #3313.

Even though we're able to successfully compile the model, it takes a hefty amount of time after torch.compile() is called on the UNet:

import torch
from diffusers import DiffusionPipeline
from diffusers.utils import export_to_video
from PIL import Image


pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16)
pipe.to("cuda")
pipe.enable_vae_slicing()

pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)

prompt = "Darth Vader is surfing on waves"
video_frames = pipe(prompt, num_inference_steps=40, height=320, width=576, num_frames=36).frames
video_path = export_to_video(video_frames, output_video_path="video_576_darth_vader_36.mp4")

The first call pipe is really time-consuming which is understandable because that is when the compiled UNet model is also used for the first time. But even in the subsequent calls, the timing doesn't seem to improve much. In my experiments, I actually found the runtime to be performing much better without torch.compile().

Let me know if anything is unclear.

I leaving the outputs of the progress bars:

With torch.compile()

image

Without torch.compile()

image

My explorations can be found in this Colab Notebook.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 5, 2023

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten
Copy link
Contributor

Great job! Looks clean - also ok for me to not add a test at the moment, given the problems we have with memory leaks. Good to merge for me

@sayakpaul
Copy link
Member Author

Do we have a handle on why these leaks happen? Why does torch.compile() perform so poorly when we 3D inputs like videos?

@sayakpaul sayakpaul merged commit b62d9a1 into main Jul 6, 2023
@sayakpaul sayakpaul deleted the fix/unet3d-compile branch July 6, 2023 09:00
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* use sample directly instead of the dataclass.

* more usage of directly samples instead of dataclasses

* more usage of directly samples instead of dataclasses

* use direct sample in the pipeline.

* direct usage of sample in the img2img case.
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* use sample directly instead of the dataclass.

* more usage of directly samples instead of dataclasses

* more usage of directly samples instead of dataclasses

* use direct sample in the pipeline.

* direct usage of sample in the img2img case.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

torch.compile doesn't seem to be working for text-to-video pipelines
3 participants