-
Notifications
You must be signed in to change notification settings - Fork 6k
fixed a dtype bfloat16 bug in torch_utils.py #10125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
when generating 1024*1024 image with bfloat16 dtype, there is an exception: File "/opt/conda/lib/python3.10/site-packages/diffusers/utils/torch_utils.py", line 107, in fourier_filter x_freq = fftn(x, dim=(-2, -1)) RuntimeError: Unsupported dtype BFloat16
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhangp365 Thanks! this is when using freeu
? Can we keep the check for non-power of 2 images and add another for bfloat16? I think that makes it clearer why we're casting to float32.
@hlky Yes, this uses FreeU and sets the pipeline dtype to bfloat16. In this case, when the image is not a non-power-of-2 size, the pipeline runs successfully. However, the standard size image fails to run. Therefore, I believe casting the type to float32 is a safe operation, making the code more robust. |
# Non-power of 2 images must be float32
if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
x = x.to(dtype=torch.float32)
# fftn does not support bfloat16
elif x.dtype == torch.bfloat16:
x = x.to(dtype=torch.float32) If we always cast someone looking at the function in the future may wonder why. cc @sayakpaul @DN6 WDYT? |
Makes sense to me! |
@zhangp365 Can you run |
I tried
but the errors are not from this pr. I think this pr will not affect the make process. |
Here's the error from the last run, link. The extra errors you're seeing are because of |
Yes, after running |
Thanks @zhangp365! |
* fixed a dtype bfloat16 bug in torch_utils.py when generating 1024*1024 image with bfloat16 dtype, there is an exception: File "/opt/conda/lib/python3.10/site-packages/diffusers/utils/torch_utils.py", line 107, in fourier_filter x_freq = fftn(x, dim=(-2, -1)) RuntimeError: Unsupported dtype BFloat16 * remove whitespace in torch_utils.py * Update src/diffusers/utils/torch_utils.py * Update torch_utils.py --------- Co-authored-by: hlky <[email protected]>
when generating 1024*1024 image with bfloat16 dtype, there is an exception:
File "/opt/conda/lib/python3.10/site-packages/diffusers/utils/torch_utils.py", line 107, in fourier_filter
x_freq = fftn(x, dim=(-2, -1))
RuntimeError: Unsupported dtype BFloat16
What does this PR do?
fix a bug.
@hlky