Skip to content

F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") breaks for large bsz #984

Closed
@NouamaneTazi

Description

@NouamaneTazi

Describe the bug

Thanks to the amazing work done in the memory efficient PR, I can now run Stable Diffusion in fp16, on TITAN RTX (24Go VRAM) until a batch size of 31 with no issue.

pipe = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4", 
    use_auth_token=True,
    revision="fp16",
    torch_dtype=torch.float16,
).to("cuda")

batch_size = 32

with torch.inference_mode():
    image = pipe([prompt] * batch_size, num_inference_steps=5).images[0]

When I try a batch size of 32, I get the following error:

Traceback (most recent call last):
  File "/home/nouamane/projects/diffusers/a.py", line 45, in <module>
    image = pipe([prompt] * batch_size, num_inference_steps=5).images[0]
  File "/home/nouamane/miniconda3/envs/dev/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/nouamane/projects/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 353, in __call__
    image = self.vae.decode(latents).sample
  File "/home/nouamane/projects/diffusers/src/diffusers/models/vae.py", line 577, in decode
    dec = self.decoder(z)
  File "/home/nouamane/miniconda3/envs/dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nouamane/projects/diffusers/src/diffusers/models/vae.py", line 217, in forward
    sample = up_block(sample)
  File "/home/nouamane/miniconda3/envs/dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nouamane/projects/diffusers/src/diffusers/models/unet_blocks.py", line 1281, in forward
    hidden_states = upsampler(hidden_states)
  File "/home/nouamane/miniconda3/envs/dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nouamane/projects/diffusers/src/diffusers/models/resnet.py", line 54, in forward
    hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
  File "/home/nouamane/miniconda3/envs/dev/lib/python3.9/site-packages/torch/nn/functional.py", line 3910, in interpolate
    return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
RuntimeError: upsample_nearest_nhwc only supports output tensors with less than INT_MAX elements

Is there a way to fix this issue?
@patrickvonplaten @patil-suraj

System Info

  • diffusers version: 0.7.0.dev0
  • Platform: Linux-5.3.0-64-generic-x86_64-with-glibc2.30
  • Python version: 3.9.13
  • PyTorch version (GPU?): 1.12.1 (True)
  • Huggingface_hub version: 0.10.0
  • Transformers version: 4.24.0.dev0
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions