Closed
Description
Describe the bug
Trying to use torch.compile
on a text-to-video model doesn't work
If I try to follow the docs and do a pipe.unet.to(memory_format=torch.channels_last)
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16)
pipe.unet.to(memory_format=torch.channels_last)
I get a
RuntimeError: required rank 4 tensor to use channels_last format
If I try to not use the torch.channels_last
format and go directly
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
I get a
RuntimeError: "LayerNormKernelImpl" not implemented for 'Half'
Keyboard interruption in main thread... closing server.
Reproduction
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16)
pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
Logs
Full traceback for pipe.unet.to(memory_format=torch.channels_last)
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <cell line: 1>:1 │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1145 in to │
│ │
│ 1142 │ │ │ │ │ │ │ non_blocking, memory_format=convert_to_format) │
│ 1143 │ │ │ return t.to(device, dtype if t.is_floating_point() or t.is_complex() else No │
│ 1144 │ │ │
│ ❱ 1145 │ │ return self._apply(convert) │
│ 1146 │ │
│ 1147 │ def register_full_backward_pre_hook( │
│ 1148 │ │ self, │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:797 in _apply │
│ │
│ 794 │ │
│ 795 │ def _apply(self, fn): │
│ 796 │ │ for module in self.children(): │
│ ❱ 797 │ │ │ module._apply(fn) │
│ 798 │ │ │
│ 799 │ │ def compute_should_use_set_data(tensor, tensor_applied): │
│ 800 │ │ │ if torch._has_compatible_shallow_copy_type(tensor, tensor_applied): │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:797 in _apply │
│ │
│ 794 │ │
│ 795 │ def _apply(self, fn): │
│ 796 │ │ for module in self.children(): │
│ ❱ 797 │ │ │ module._apply(fn) │
│ 798 │ │ │
│ 799 │ │ def compute_should_use_set_data(tensor, tensor_applied): │
│ 800 │ │ │ if torch._has_compatible_shallow_copy_type(tensor, tensor_applied): │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:797 in _apply │
│ │
│ 794 │ │
│ 795 │ def _apply(self, fn): │
│ 796 │ │ for module in self.children(): │
│ ❱ 797 │ │ │ module._apply(fn) │
│ 798 │ │ │
│ 799 │ │ def compute_should_use_set_data(tensor, tensor_applied): │
│ 800 │ │ │ if torch._has_compatible_shallow_copy_type(tensor, tensor_applied): │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:797 in _apply │
│ │
│ 794 │ │
│ 795 │ def _apply(self, fn): │
│ 796 │ │ for module in self.children(): │
│ ❱ 797 │ │ │ module._apply(fn) │
│ 798 │ │ │
│ 799 │ │ def compute_should_use_set_data(tensor, tensor_applied): │
│ 800 │ │ │ if torch._has_compatible_shallow_copy_type(tensor, tensor_applied): │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:797 in _apply │
│ │
│ 794 │ │
│ 795 │ def _apply(self, fn): │
│ 796 │ │ for module in self.children(): │
│ ❱ 797 │ │ │ module._apply(fn) │
│ 798 │ │ │
│ 799 │ │ def compute_should_use_set_data(tensor, tensor_applied): │
│ 800 │ │ │ if torch._has_compatible_shallow_copy_type(tensor, tensor_applied): │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:797 in _apply │
│ │
│ 794 │ │
│ 795 │ def _apply(self, fn): │
│ 796 │ │ for module in self.children(): │
│ ❱ 797 │ │ │ module._apply(fn) │
│ 798 │ │ │
│ 799 │ │ def compute_should_use_set_data(tensor, tensor_applied): │
│ 800 │ │ │ if torch._has_compatible_shallow_copy_type(tensor, tensor_applied): │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:797 in _apply │
│ │
│ 794 │ │
│ 795 │ def _apply(self, fn): │
│ 796 │ │ for module in self.children(): │
│ ❱ 797 │ │ │ module._apply(fn) │
│ 798 │ │ │
│ 799 │ │ def compute_should_use_set_data(tensor, tensor_applied): │
│ 800 │ │ │ if torch._has_compatible_shallow_copy_type(tensor, tensor_applied): │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:820 in _apply │
│ │
│ 817 │ │ │ # track autograd history of `param_applied`, so we have to use │
│ 818 │ │ │ # `with torch.no_grad():` │
│ 819 │ │ │ with torch.no_grad(): │
│ ❱ 820 │ │ │ │ param_applied = fn(param) │
│ 821 │ │ │ should_use_set_data = compute_should_use_set_data(param, param_applied) │
│ 822 │ │ │ if should_use_set_data: │
│ 823 │ │ │ │ param.data = param_applied │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1141 in convert │
│ │
│ 1138 │ │ │
│ 1139 │ │ def convert(t): │
│ 1140 │ │ │ if convert_to_format is not None and t.dim() in (4, 5): │
│ ❱ 1141 │ │ │ │ return t.to(device, dtype if t.is_floating_point() or t.is_complex() els │
│ 1142 │ │ │ │ │ │ │ non_blocking, memory_format=convert_to_format) │
│ 1143 │ │ │ return t.to(device, dtype if t.is_floating_point() or t.is_complex() else No │
│ 1144 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: required rank 4 tensor to use channels_last format
Full traceback for pipe.unet.to(memory_format=torch.channels_last)
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/gradio/routes.py", line 437, in run_predict
output = await app.get_blocks().process_api(
File "/usr/local/lib/python3.10/dist-packages/gradio/blocks.py", line 1352, in process_api
result = await self.call_function(
File "/usr/local/lib/python3.10/dist-packages/gradio/blocks.py", line 1077, in call_function
prediction = await anyio.to_thread.run_sync(
File "/usr/local/lib/python3.10/dist-packages/anyio/to_thread.py", line 33, in run_sync
return await get_asynclib().run_sync_in_worker_thread(
File "/usr/local/lib/python3.10/dist-packages/anyio/_backends/_asyncio.py", line 877, in run_sync_in_worker_thread
return await future
File "/usr/local/lib/python3.10/dist-packages/anyio/_backends/_asyncio.py", line 807, in run
result = context.run(func, *args)
File "<ipython-input-13-947ecc021452>", line 13, in infer
video_frames = pipe(prompt, num_inference_steps=40, height=320, width=576, num_frames=24).frames
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py", line 605, in __call__
prompt_embeds = self._encode_prompt(
File "/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py", line 298, in _encode_prompt
prompt_embeds = self.text_encoder(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/clip/modeling_clip.py", line 822, in forward
return self.text_model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/clip/modeling_clip.py", line 740, in forward
encoder_outputs = self.encoder(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/clip/modeling_clip.py", line 654, in forward
layer_outputs = encoder_layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/clip/modeling_clip.py", line 382, in forward
hidden_states = self.layer_norm1(hidden_states)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/normalization.py", line 190, in forward
return F.layer_norm(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py", line 2515, in layer_norm
return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: "LayerNormKernelImpl" not implemented for 'Half'
Keyboard interruption in main thread... closing server.
System Info
diffusers==0.17.1