Skip to content

fixed a dtype bfloat16 bug in torch_utils.py #10125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 6, 2024
Merged

fixed a dtype bfloat16 bug in torch_utils.py #10125

merged 6 commits into from
Dec 6, 2024

Conversation

zhangp365
Copy link
Contributor

when generating 1024*1024 image with bfloat16 dtype, there is an exception:
File "/opt/conda/lib/python3.10/site-packages/diffusers/utils/torch_utils.py", line 107, in fourier_filter
x_freq = fftn(x, dim=(-2, -1))
RuntimeError: Unsupported dtype BFloat16

What does this PR do?

fix a bug.

@hlky

zhangp365 and others added 2 commits December 5, 2024 11:08
when generating 1024*1024 image with bfloat16 dtype, there is an exception:
  File "/opt/conda/lib/python3.10/site-packages/diffusers/utils/torch_utils.py", line 107, in fourier_filter
    x_freq = fftn(x, dim=(-2, -1))
RuntimeError: Unsupported dtype BFloat16
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@hlky hlky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhangp365 Thanks! this is when using freeu? Can we keep the check for non-power of 2 images and add another for bfloat16? I think that makes it clearer why we're casting to float32.

@zhangp365
Copy link
Contributor Author

zhangp365 commented Dec 5, 2024

@hlky Yes, this uses FreeU and sets the pipeline dtype to bfloat16. In this case, when the image is not a non-power-of-2 size, the pipeline runs successfully. However, the standard size image fails to run. Therefore, I believe casting the type to float32 is a safe operation, making the code more robust.

@hlky
Copy link
Contributor

hlky commented Dec 5, 2024

    # Non-power of 2 images must be float32
    if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
        x = x.to(dtype=torch.float32)
    # fftn does not support bfloat16
    elif x.dtype == torch.bfloat16:
        x = x.to(dtype=torch.float32)

If we always cast someone looking at the function in the future may wonder why. cc @sayakpaul @DN6 WDYT?

@sayakpaul
Copy link
Member

Makes sense to me!

@hlky
Copy link
Contributor

hlky commented Dec 5, 2024

@zhangp365 Can you run make style?

@zhangp365
Copy link
Contributor Author

@zhangp365 Can you run make style?

I tried make style, there are a lot of error:

examples/community/adaptive_mask_inpainting.py:1412:12: E721 Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks
     |
1411 |     def vis_seg_on_img(self, image, mask):
1412 |         if type(mask) == np.ndarray:
     |            ^^^^^^^^^^^^^^^^^^^^^^^^ E721
1413 |             mask = torch.tensor(mask)
1414 |         v = Visualizer(image, self.coco_metadata, scale=0.5, instance_mode=ColorMode.IMAGE_BW)
     |

examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb:cell 53:59:7: F821 Undefined name `pickle`
   |
58 |   with open(save_path, 'wb') as f:
59 |       pickle.dump(results, f)
   |       ^^^^^^ F821
   |

examples/research_projects/gligen/demo.ipynb:cell 5:13:1: E402 Module level import not at top of cell
   |
11 | gen_boxes = [('a steam boat', [232, 225, 257, 149]), ('a jumping pink dolphin', [21, 249, 189, 123])]
12 |
13 | import numpy as np
   | ^^^^^^^^^^^^^^^^^^ E402
   |

src/diffusers/configuration_utils.py:691:16: E721 Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks
    |
689 |             if field.name in self._flax_internal_args:
690 |                 continue
691 |             if type(field.default) == dataclasses._MISSING_TYPE:
    |                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E721
692 |                 default_kwargs[field.name] = None
693 |             else:
    |

tests/models/test_modeling_common.py:398:20: E721 Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks
    |
397 |         model.set_default_attn_processor()
398 |         assert all(type(proc) == AttnProcessorNPU for proc in model.attn_processors.values())
    |                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E721
399 |         with torch.no_grad():
400 |             if self.forward_requires_fresh_args:
    |

tests/models/test_modeling_common.py:406:20: E721 Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks
    |
405 |         model.enable_npu_flash_attention()
406 |         assert all(type(proc) == AttnProcessorNPU for proc in model.attn_processors.values())
    |                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E721
407 |         with torch.no_grad():
408 |             if self.forward_requires_fresh_args:
    |

tests/models/test_modeling_common.py:414:20: E721 Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks
    |
413 |         model.set_attn_processor(AttnProcessorNPU())
414 |         assert all(type(proc) == AttnProcessorNPU for proc in model.attn_processors.values())
    |                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E721
415 |         with torch.no_grad():
416 |             if self.forward_requires_fresh_args:
    |

tests/models/test_modeling_common.py:449:20: E721 Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks
    |
448 |         model.set_default_attn_processor()
449 |         assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
    |                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^ E721
450 |         with torch.no_grad():
451 |             if self.forward_requires_fresh_args:
    |

tests/models/test_modeling_common.py:457:20: E721 Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks
    |
456 |         model.enable_xformers_memory_efficient_attention()
457 |         assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
    |                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E721
458 |         with torch.no_grad():
459 |             if self.forward_requires_fresh_args:
    |

tests/models/test_modeling_common.py:465:20: E721 Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks
    |
464 |         model.set_attn_processor(XFormersAttnProcessor())
465 |         assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
    |                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E721
466 |         with torch.no_grad():
467 |             if self.forward_requires_fresh_args:
    |

tests/models/test_modeling_common.py:496:20: E721 Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks
    |
494 |             return
495 |
496 |         assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
    |                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E721
497 |         with torch.no_grad():
498 |             if self.forward_requires_fresh_args:
    |

tests/models/test_modeling_common.py:504:20: E721 Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks
    |
503 |         model.set_default_attn_processor()
504 |         assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
    |                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^ E721
505 |         with torch.no_grad():
506 |             if self.forward_requires_fresh_args:
    |

tests/models/test_modeling_common.py:512:20: E721 Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks
    |
511 |         model.set_attn_processor(AttnProcessor2_0())
512 |         assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
    |                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E721
513 |         with torch.no_grad():
514 |             if self.forward_requires_fresh_args:
    |

tests/models/test_modeling_common.py:520:20: E721 Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks
    |
519 |         model.set_attn_processor(AttnProcessor())
520 |         assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
    |                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^ E721
521 |         with torch.no_grad():
522 |             if self.forward_requires_fresh_args:
    |

tests/pipelines/test_pipelines_common.py:772:21: E721 Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks
    |
770 |             if hasattr(component, "attn_processors"):
771 |                 assert all(
772 |                     type(proc) == AttnProcessor for proc in component.attn_processors.values()
    |                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^ E721
773 |                 ), "`from_pipe` changed the attention processor in original pipeline."
    |

tests/schedulers/test_schedulers.py:827:16: E721 Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks
    |
825 |         scheduler_loaded = DDIMScheduler.from_pretrained(f"{USER}/{self.repo_id}")
826 |
827 |         assert type(scheduler) == type(scheduler_loaded)
    |                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E721
828 |
829 |         # Reset repo
    |

tests/schedulers/test_schedulers.py:838:16: E721 Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks
    |
836 |         scheduler_loaded = DDIMScheduler.from_pretrained(f"{USER}/{self.repo_id}")
837 |
838 |         assert type(scheduler) == type(scheduler_loaded)
    |                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E721
839 |
840 |         # Reset repo
    |

tests/schedulers/test_schedulers.py:854:16: E721 Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks
    |
852 |         scheduler_loaded = DDIMScheduler.from_pretrained(self.org_repo_id)
853 |
854 |         assert type(scheduler) == type(scheduler_loaded)
    |                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E721
855 |
856 |         # Reset repo
    |

tests/schedulers/test_schedulers.py:865:16: E721 Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks
    |
863 |         scheduler_loaded = DDIMScheduler.from_pretrained(self.org_repo_id)
864 |
865 |         assert type(scheduler) == type(scheduler_loaded)
    |                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E721
866 |
867 |         # Reset repo

but the errors are not from this pr. I think this pr will not affect the make process.

@hlky
Copy link
Contributor

hlky commented Dec 6, 2024

Here's the error from the last run, link. The extra errors you're seeing are because of ruff version, we use 0.1.5, pip install ruff==0.1.5 or pip install -e ".[dev]" in Diffusers. Apologies for the inconvenience, your PR is from your main branch which seems to prevent us from pushing to it.

@zhangp365
Copy link
Contributor Author

Here's the error from the last run, link. The extra errors you're seeing are because of ruff version, we use 0.1.5, pip install ruff==0.1.5 or pip install -e ".[dev]" in Diffusers. Apologies for the inconvenience, your PR is from your main branch which seems to prevent us from pushing to it.

Yes, after running pip install -e ".[dev]" in Diffusers, 'make style' is successful now. The error link was correct; at that time, there was a whitespace in an empty line, but I have since removed the empty line, so the issue no longer exists.

@hlky
Copy link
Contributor

hlky commented Dec 6, 2024

Thanks @zhangp365!

@yiyixuxu yiyixuxu merged commit 188bca3 into huggingface:main Dec 6, 2024
15 checks passed
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* fixed a dtype bfloat16 bug in torch_utils.py

when generating 1024*1024 image with bfloat16 dtype, there is an exception:
  File "/opt/conda/lib/python3.10/site-packages/diffusers/utils/torch_utils.py", line 107, in fourier_filter
    x_freq = fftn(x, dim=(-2, -1))
RuntimeError: Unsupported dtype BFloat16

* remove whitespace in torch_utils.py

* Update src/diffusers/utils/torch_utils.py

* Update torch_utils.py

---------

Co-authored-by: hlky <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants