Closed
Description
Describe the bug
I tried to run the controlnet example from this blog post and it turned out that the BasicTransformerBlock
is causing a large number of graph breaks (>100) on a single controlnet pipeline. Ideally the whole BasicTransformerBlock.forward
should be include in one single frame for speedups. The exact reason for the graph breaks is:
call_function UserDefinedObjectVariable(AttnProcessor2_0) [NNModuleVariable(), TensorVariable()] {'encoder_hidden_states': TensorVariable(), 'attention_mask': ConstantVariable(NoneType)}
for both self attention
and cross attention
. Is there a way to reduce the graph breaks to make StableDiffusionControlNetPipeline
working better with torch.compile
?
Reproduction
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from diffusers.utils import load_image
import cv2
from PIL import Image
import torch
import numpy as np
image = load_image(
"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
)
image = np.array(image)
low_threshold = 100
high_threshold = 200
image = cv2.Canny(image, low_threshold, high_threshold)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
import torch
import torch._dynamo as dynamo
@dynamo.optimize("inductor")
def generate(prompt):
generator = [torch.Generator(device="cuda").manual_seed(2) for i in range(len(prompt))]
return pipe(
prompt,
canny_image,
negative_prompt=["monochrome, lowres, bad anatomy, worst quality, low quality"] * len(prompt),
num_inference_steps=10,
generator=generator,
)
prompt = ", best quality, extremely detailed"
prompt = [t + prompt for t in ["Sandra Oh", "Kim Kardashian", "rihanna", "taylor swift"]]
ex = dynamo.explain(generate, prompt)[-1]
print(ex)
### Logs
```shell
graph #169 break reason: call_function UserDefinedObjectVariable(AttnProcessor2_0) [NNModuleVariable(), TensorVariable()] {'encoder_hidden_states': ConstantVariable(NoneType), 'attention_mask': ConstantVariable(NoneType)} after 3
stack: File "/home/yj/diffusers/src/diffusers/models/attention.py", line 313, in forward
attn_output = self.attn1(
File "/home/yj/pytorch/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/yj/diffusers/src/diffusers/models/attention_processor.py", line 267, in forward
return self.processor(
graph #171 break reason: call_function UserDefinedObjectVariable(AttnProcessor2_0) [NNModuleVariable(), TensorVariable()] {'encoder_hidden_states': TensorVariable(), 'attention_mask': ConstantVariable(NoneType)} after 1
stack: File "/home/yj/diffusers/src/diffusers/models/attention.py", line 331, in <resume in forward>
attn_output = self.attn2(
File "/home/yj/pytorch/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/yj/diffusers/src/diffusers/models/attention_processor.py", line 267, in forward
return self.processor(
System Info
Ubuntu 20.04 with cuda 11.8
diffusers 0.16.0.dev0 /home/yj/diffusers
torch 2.1.0a0+git0bbf8a9 /home/yj/pytorch