Skip to content

Improve FluxPipeline checks and logging #9064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

PDillis
Copy link

@PDillis PDillis commented Aug 3, 2024

What does this PR do?

  • Remove unnecessary logging of txt_ids and img_ids shapes that bloat up the console, especially when generating large quantities of images. This is done inside of models/transformers/transformer_flux.py.
  • If the user passes their own latents to the FluxPipeline.__call__, then there are additional checks to ensure the right shape is being used and, if not, that the expected ones are notified to the user. Otherwise, the typical message of PyTorch's tensor mismatch is uninformative at best. As such, the docstring has also been updated.

@sayakpaul

@@ -373,7 +373,6 @@ def forward(
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)

print(f"{txt_ids.shape=}, {img_ids.shape=}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was actually fixing it here #9057. But okay.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood, we can ignore this then.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. I have left some comments. LMK what you think.

"Unpacked latents detected. These will be automatically packed. "
"In the future, please provide packed latents to improve performance."
)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Packing the latents is an inexpensive operation. So, I think this warning is unnecessary.

Copy link
Author

@PDillis PDillis Aug 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, the warning can be removed. Note also that I moved the definitions for height and width in case no latents were provided, as otherwise we would be giving the wrong dimensions to self._prepare_latent_image_ids

Comment on lines 150 to 152
This pipeline expects `latents` to be in a packed format. If you're providing
custom latents, make sure to use the `_pack_latents` method to prepare them.
Packed latents should be a 3D tensor of shape (batch_size, num_patches, channels).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this note is unnecessary as well given that _pack_latents() is an inexpensive operation. I think your updates to check_inputs() should do the trick.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, maybe then I can add this note as part of the latents docstring in the __call__

if not isinstance(latents, torch.Tensor):
raise ValueError(f"`latents` has to be of type `torch.Tensor` but is {type(latents)}")

if not _are_latents_packed(latents):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But later in prepare_latents(), we are packing the latents and throwing a warning, no? So, I think we can remove this check.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, this check can be removed from the check_inputs function.

Comment on lines 708 to 711
if not self._are_latents_packed(latents):
raise ValueError(
"Latents are not in the correct packed format. Please use `_pack_latents` to prepare them."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed too since essentially check_inputs() should catch these.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the check_inputs function should catch these, so this can be removed. Then there is no need for the _are_latents_packed method, so that can also be removed.

@sayakpaul sayakpaul requested a review from DN6 August 3, 2024 03:11
@sayakpaul
Copy link
Member

@PDillis thank you for being forthcoming toward my comments. Let me know once you'd like another set of reviews.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@PDillis
Copy link
Author

PDillis commented Aug 5, 2024

@sayakpaul Thank you, I've done the suggested changes. Let me know what you think.

Comment on lines +487 to +489
if latents.ndim == 4:
# Packing the latents to be of shape (batch_size, num_patches, channels)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this check necessary? check_inputs will raise an error if an ndim == 4 latent tensor is passed in? Also _prepare_latent_image_ids divides the height and width by 2 before creating the image ids? if the height isn't scaled up before passing into this method, you will get an incorrect shape.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps checking the number of dimensions is too harsh. However, there must be a way to clearly tell the user the correct shapes for the latents. This was my issue when trying to use the model with custom latents and coming from other "traditional" models. The only thing to guide me before going into the source code was a size error mismatch. That is, in other models, I'm used to doing the following:

pipe = ...
# ...
latents_shape = (
    batch_size,
    pipe.transformer.config.in_channels,
    height // pipe.vae_scale_factor,
    width // pipe.vae_scale_factor,
)
latents = torch.randn(latents_shape, ...)
# ...
image = pipe(
    prompt=prompt,
    # ...
    latents=latents
).images[0]

For a similar code to work with this model, I had to do it like so:

pipe = ...
# ...
latents_shape = (
    batch_size,
    pipe.transformer.config.in_channels // 4,
    2 * height // pipe.vae_scale_factor,
    2 * width // pipe.vae_scale_factor,
)
latents = torch.randn(latents_shape, ...)
latents = pipe._pack_latents(latents, *latents.shape)  # Note: 2x the original height and width!
# ...
image = pipe(
    prompt=prompt,
    # ...
    latents=latents
).images[0]

Hence, no issues were encountered when we would run _prepare_latent_image_ids. A distinction on the dimensions should be made: height and width generally refer to the image's dimensions, but in latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype), for example, the height and width refer to the upscaled latent dimensions, as you note.

I think either a method to generate the latents of the correct shape is warranted, or as I was doing in this PR of improving the docstrings and perhaps simply raising an error in order to avoid any mixup of dimensions.

if not isinstance(latents, torch.Tensor):
raise ValueError(f"`latents` has to be of type `torch.Tensor` but is {type(latents)}")

batch_size, num_patches, channels = latents.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this throw an error if the shape is ndim==4. Perhaps we check to see if ndim==3 here and raise and error rather than letting it fail.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the latents should be passed in expected shape, so the check is very nice!
but we do not need to pack for them, an error is enough IMO

Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Sep 14, 2024
@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Dec 3, 2024

cc @PDillis are we interested in finishing this PR? I think we can just keep the error checking part :)

@yiyixuxu yiyixuxu added close-to-merge and removed stale Issues that haven't received updates labels Dec 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants