Closed
Description
Describe the bug
Thanks to the amazing work done in the memory efficient PR, I can now run Stable Diffusion in fp16, on TITAN RTX (24Go VRAM) until a batch size of 31 with no issue.
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
use_auth_token=True,
revision="fp16",
torch_dtype=torch.float16,
).to("cuda")
batch_size = 32
with torch.inference_mode():
image = pipe([prompt] * batch_size, num_inference_steps=5).images[0]
When I try a batch size of 32, I get the following error:
Traceback (most recent call last):
File "/home/nouamane/projects/diffusers/a.py", line 45, in <module>
image = pipe([prompt] * batch_size, num_inference_steps=5).images[0]
File "/home/nouamane/miniconda3/envs/dev/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/nouamane/projects/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 353, in __call__
image = self.vae.decode(latents).sample
File "/home/nouamane/projects/diffusers/src/diffusers/models/vae.py", line 577, in decode
dec = self.decoder(z)
File "/home/nouamane/miniconda3/envs/dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/nouamane/projects/diffusers/src/diffusers/models/vae.py", line 217, in forward
sample = up_block(sample)
File "/home/nouamane/miniconda3/envs/dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/nouamane/projects/diffusers/src/diffusers/models/unet_blocks.py", line 1281, in forward
hidden_states = upsampler(hidden_states)
File "/home/nouamane/miniconda3/envs/dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/nouamane/projects/diffusers/src/diffusers/models/resnet.py", line 54, in forward
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
File "/home/nouamane/miniconda3/envs/dev/lib/python3.9/site-packages/torch/nn/functional.py", line 3910, in interpolate
return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
RuntimeError: upsample_nearest_nhwc only supports output tensors with less than INT_MAX elements
Is there a way to fix this issue?
@patrickvonplaten @patil-suraj
System Info
diffusers
version: 0.7.0.dev0- Platform: Linux-5.3.0-64-generic-x86_64-with-glibc2.30
- Python version: 3.9.13
- PyTorch version (GPU?): 1.12.1 (True)
- Huggingface_hub version: 0.10.0
- Transformers version: 4.24.0.dev0
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no