Skip to content

Commit 188bca3

Browse files
zhangp365hlky
andauthored
fixed a dtype bfloat16 bug in torch_utils.py (#10125)
* fixed a dtype bfloat16 bug in torch_utils.py when generating 1024*1024 image with bfloat16 dtype, there is an exception: File "/opt/conda/lib/python3.10/site-packages/diffusers/utils/torch_utils.py", line 107, in fourier_filter x_freq = fftn(x, dim=(-2, -1)) RuntimeError: Unsupported dtype BFloat16 * remove whitespace in torch_utils.py * Update src/diffusers/utils/torch_utils.py * Update torch_utils.py --------- Co-authored-by: hlky <[email protected]>
1 parent cd89204 commit 188bca3

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

src/diffusers/utils/torch_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.T
102102
# Non-power of 2 images must be float32
103103
if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
104104
x = x.to(dtype=torch.float32)
105+
# fftn does not support bfloat16
106+
elif x.dtype == torch.bfloat16:
107+
x = x.to(dtype=torch.float32)
105108

106109
# FFT
107110
x_freq = fftn(x, dim=(-2, -1))

0 commit comments

Comments
 (0)