-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Improve FluxPipeline
checks and logging
#9064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…d in `self.check_inputs` and `self.prepare_latents`)
@@ -373,7 +373,6 @@ def forward( | |||
) | |||
encoder_hidden_states = self.context_embedder(encoder_hidden_states) | |||
|
|||
print(f"{txt_ids.shape=}, {img_ids.shape=}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was actually fixing it here #9057. But okay.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understood, we can ignore this then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. I have left some comments. LMK what you think.
"Unpacked latents detected. These will be automatically packed. " | ||
"In the future, please provide packed latents to improve performance." | ||
) | ||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Packing the latents is an inexpensive operation. So, I think this warning is unnecessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, the warning can be removed. Note also that I moved the definitions for height
and width
in case no latents were provided, as otherwise we would be giving the wrong dimensions to self._prepare_latent_image_ids
This pipeline expects `latents` to be in a packed format. If you're providing | ||
custom latents, make sure to use the `_pack_latents` method to prepare them. | ||
Packed latents should be a 3D tensor of shape (batch_size, num_patches, channels). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this note is unnecessary as well given that _pack_latents()
is an inexpensive operation. I think your updates to check_inputs()
should do the trick.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, maybe then I can add this note as part of the latents
docstring in the __call__
if not isinstance(latents, torch.Tensor): | ||
raise ValueError(f"`latents` has to be of type `torch.Tensor` but is {type(latents)}") | ||
|
||
if not _are_latents_packed(latents): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But later in prepare_latents()
, we are packing the latents and throwing a warning, no? So, I think we can remove this check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same, this check can be removed from the check_inputs
function.
if not self._are_latents_packed(latents): | ||
raise ValueError( | ||
"Latents are not in the correct packed format. Please use `_pack_latents` to prepare them." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be removed too since essentially check_inputs()
should catch these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the check_inputs
function should catch these, so this can be removed. Then there is no need for the _are_latents_packed
method, so that can also be removed.
@PDillis thank you for being forthcoming toward my comments. Let me know once you'd like another set of reviews. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…cstring for shape of latents.
@sayakpaul Thank you, I've done the suggested changes. Let me know what you think. |
if latents.ndim == 4: | ||
# Packing the latents to be of shape (batch_size, num_patches, channels) | ||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this check necessary? check_inputs
will raise an error if an ndim == 4
latent tensor is passed in? Also _prepare_latent_image_ids
divides the height and width by 2 before creating the image ids? if the height isn't scaled up before passing into this method, you will get an incorrect shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps checking the number of dimensions is too harsh. However, there must be a way to clearly tell the user the correct shapes for the latents. This was my issue when trying to use the model with custom latents and coming from other "traditional" models. The only thing to guide me before going into the source code was a size error mismatch. That is, in other models, I'm used to doing the following:
pipe = ...
# ...
latents_shape = (
batch_size,
pipe.transformer.config.in_channels,
height // pipe.vae_scale_factor,
width // pipe.vae_scale_factor,
)
latents = torch.randn(latents_shape, ...)
# ...
image = pipe(
prompt=prompt,
# ...
latents=latents
).images[0]
For a similar code to work with this model, I had to do it like so:
pipe = ...
# ...
latents_shape = (
batch_size,
pipe.transformer.config.in_channels // 4,
2 * height // pipe.vae_scale_factor,
2 * width // pipe.vae_scale_factor,
)
latents = torch.randn(latents_shape, ...)
latents = pipe._pack_latents(latents, *latents.shape) # Note: 2x the original height and width!
# ...
image = pipe(
prompt=prompt,
# ...
latents=latents
).images[0]
Hence, no issues were encountered when we would run _prepare_latent_image_ids
. A distinction on the dimensions should be made: height
and width
generally refer to the image's dimensions, but in latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
, for example, the height
and width
refer to the upscaled latent dimensions, as you note.
I think either a method to generate the latents of the correct shape is warranted, or as I was doing in this PR of improving the docstrings and perhaps simply raising an error in order to avoid any mixup of dimensions.
if not isinstance(latents, torch.Tensor): | ||
raise ValueError(f"`latents` has to be of type `torch.Tensor` but is {type(latents)}") | ||
|
||
batch_size, num_patches, channels = latents.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't this throw an error if the shape is ndim==4
. Perhaps we check to see if ndim==3
here and raise and error rather than letting it fail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the latents
should be passed in expected shape, so the check is very nice!
but we do not need to pack for them, an error is enough IMO
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
cc @PDillis are we interested in finishing this PR? I think we can just keep the error checking part :) |
What does this PR do?
txt_ids
andimg_ids
shapes that bloat up the console, especially when generating large quantities of images. This is done inside ofmodels/transformers/transformer_flux.py
.latent
s to theFluxPipeline.__call__
, then there are additional checks to ensure the right shape is being used and, if not, that the expected ones are notified to the user. Otherwise, the typical message of PyTorch's tensor mismatch is uninformative at best. As such, the docstring has also been updated.@sayakpaul