Skip to content

excessive graph breaks on attention.py and attention_processor.py for control_net on torch.compile #3218

Closed
@shingjan

Description

@shingjan

Describe the bug

I tried to run the controlnet example from this blog post and it turned out that the BasicTransformerBlock is causing a large number of graph breaks (>100) on a single controlnet pipeline. Ideally the whole BasicTransformerBlock.forward should be include in one single frame for speedups. The exact reason for the graph breaks is:

call_function UserDefinedObjectVariable(AttnProcessor2_0) [NNModuleVariable(), TensorVariable()] {'encoder_hidden_states': TensorVariable(), 'attention_mask': ConstantVariable(NoneType)}

for both self attention and cross attention. Is there a way to reduce the graph breaks to make StableDiffusionControlNetPipeline working better with torch.compile?

Reproduction

from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from diffusers.utils import load_image
import cv2
from PIL import Image
import torch
import numpy as np

image = load_image(
    "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
)

image = np.array(image)

low_threshold = 100
high_threshold = 200

image = cv2.Canny(image, low_threshold, high_threshold)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)

controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()

import torch
import torch._dynamo as dynamo

@dynamo.optimize("inductor")
def generate(prompt):
    generator = [torch.Generator(device="cuda").manual_seed(2) for i in range(len(prompt))]
    return pipe(
        prompt,
        canny_image,
        negative_prompt=["monochrome, lowres, bad anatomy, worst quality, low quality"] * len(prompt),
        num_inference_steps=10,
        generator=generator,
    )

prompt = ", best quality, extremely detailed"
prompt = [t + prompt for t in ["Sandra Oh", "Kim Kardashian", "rihanna", "taylor swift"]]
ex = dynamo.explain(generate, prompt)[-1]
print(ex)

### Logs

```shell
graph #169 break reason: call_function UserDefinedObjectVariable(AttnProcessor2_0) [NNModuleVariable(), TensorVariable()] {'encoder_hidden_states': ConstantVariable(NoneType), 'attention_mask': ConstantVariable(NoneType)} after 3
stack:   File "/home/yj/diffusers/src/diffusers/models/attention.py", line 313, in forward
    attn_output = self.attn1(
  File "/home/yj/pytorch/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yj/diffusers/src/diffusers/models/attention_processor.py", line 267, in forward
    return self.processor(
 
graph #171 break reason: call_function UserDefinedObjectVariable(AttnProcessor2_0) [NNModuleVariable(), TensorVariable()] {'encoder_hidden_states': TensorVariable(), 'attention_mask': ConstantVariable(NoneType)} after 1
stack:   File "/home/yj/diffusers/src/diffusers/models/attention.py", line 331, in <resume in forward>
    attn_output = self.attn2(
  File "/home/yj/pytorch/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yj/diffusers/src/diffusers/models/attention_processor.py", line 267, in forward
    return self.processor(

System Info

Ubuntu 20.04 with cuda 11.8

diffusers 0.16.0.dev0 /home/yj/diffusers
torch 2.1.0a0+git0bbf8a9 /home/yj/pytorch

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions