Skip to content

Commit 61d3764

Browse files
darhsusayakpaulyiyixuxu
authored
Support bfloat16 for Upsample2D (#9480)
* Support bfloat16 for Upsample2D * Add test and use is_torch_version * Resolve comments and add decorator * Simplify require_torch_version_greater_equal decorator * Run make style --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: YiYi Xu <[email protected]>
1 parent 33fafe3 commit 61d3764

File tree

3 files changed

+34
-6
lines changed

3 files changed

+34
-6
lines changed

src/diffusers/models/upsampling.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch.nn.functional as F
2020

2121
from ..utils import deprecate
22+
from ..utils.import_utils import is_torch_version
2223
from .normalization import RMSNorm
2324

2425

@@ -151,11 +152,10 @@ def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None
151152
if self.use_conv_transpose:
152153
return self.conv(hidden_states)
153154

154-
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
155-
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
156-
# https://github.com/pytorch/pytorch/issues/86679
155+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 until PyTorch 2.1
156+
# https://github.com/pytorch/pytorch/issues/86679#issuecomment-1783978767
157157
dtype = hidden_states.dtype
158-
if dtype == torch.bfloat16:
158+
if dtype == torch.bfloat16 and is_torch_version("<", "2.1"):
159159
hidden_states = hidden_states.to(torch.float32)
160160

161161
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
@@ -170,8 +170,8 @@ def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None
170170
else:
171171
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
172172

173-
# If the input is bfloat16, we cast back to bfloat16
174-
if dtype == torch.bfloat16:
173+
# Cast back to original dtype
174+
if dtype == torch.bfloat16 and is_torch_version("<", "2.1"):
175175
hidden_states = hidden_states.to(dtype)
176176

177177
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed

src/diffusers/utils/testing_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,18 @@ def require_torch_2(test_case):
252252
)
253253

254254

255+
def require_torch_version_greater_equal(torch_version):
256+
"""Decorator marking a test that requires torch with a specific version or greater."""
257+
258+
def decorator(test_case):
259+
correct_torch_version = is_torch_available() and is_torch_version(">=", torch_version)
260+
return unittest.skipUnless(
261+
correct_torch_version, f"test requires torch with the version greater than or equal to {torch_version}"
262+
)(test_case)
263+
264+
return decorator
265+
266+
255267
def require_torch_gpu(test_case):
256268
"""Decorator marking a test that requires CUDA and PyTorch."""
257269
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(

tests/models/test_layers_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from diffusers.utils.testing_utils import (
2828
backend_manual_seed,
2929
require_torch_accelerator_with_fp64,
30+
require_torch_version_greater_equal,
3031
torch_device,
3132
)
3233

@@ -120,6 +121,21 @@ def test_upsample_default(self):
120121
expected_slice = torch.tensor([-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254])
121122
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
122123

124+
@require_torch_version_greater_equal("2.1")
125+
def test_upsample_bfloat16(self):
126+
torch.manual_seed(0)
127+
sample = torch.randn(1, 32, 32, 32).to(torch.bfloat16)
128+
upsample = Upsample2D(channels=32, use_conv=False)
129+
with torch.no_grad():
130+
upsampled = upsample(sample)
131+
132+
assert upsampled.shape == (1, 32, 64, 64)
133+
output_slice = upsampled[0, -1, -3:, -3:]
134+
expected_slice = torch.tensor(
135+
[-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254], dtype=torch.bfloat16
136+
)
137+
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
138+
123139
def test_upsample_with_conv(self):
124140
torch.manual_seed(0)
125141
sample = torch.randn(1, 32, 32, 32)

0 commit comments

Comments
 (0)