Skip to content

Attn added kv processor torch 2.0 block #3023

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

williamberman
Copy link
Contributor

@williamberman williamberman commented Apr 8, 2023

rebased on top of #3021 and #3011

200a8c7 and on is the main commit

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 8, 2023

The documentation is not available anymore as the PR was closed or merged.

@williamberman williamberman force-pushed the AttnAddedKVProcessor2_0 branch 2 times, most recently from 26db440 to 40ec593 Compare April 9, 2023 00:04
@williamberman williamberman force-pushed the AttnAddedKVProcessor2_0 branch from 40ec593 to a38867e Compare April 9, 2023 01:57
Comment on lines +324 to +333
if out_dim == 3:
if attention_mask.shape[0] < batch_size * head_size:
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
elif out_dim == 4:
attention_mask = attention_mask.unsqueeze(1)
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten when using the torch built in attention and putting heads as the second dim, we need to make the attention mask also put the heads in the second dim. I'm not sure what the equivalent check for attention_mask.shape[0] < batch_size * head_size is. If we assume the input attention mask is always just the same batch size of the inputs, we don't have to do the check and I think this works. My understanding is that's what the original code was doing anyway since it just repeats by the head size regardless

@williamberman williamberman force-pushed the AttnAddedKVProcessor2_0 branch from a38867e to a480bc0 Compare April 11, 2023 00:52
@williamberman
Copy link
Contributor Author

waiting to merge until #3021 and #3011 are merged

Comment on lines 113 to 116
else:
raise ValueError(
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@williamberman williamberman force-pushed the AttnAddedKVProcessor2_0 branch from a480bc0 to fb485ad Compare April 11, 2023 01:43
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice refactor!

value = encoder_hidden_states_value_proj

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity. Which scale is this one?

Comment on lines 342 to 349
elif isinstance(self.norm_cross, nn.GroupNorm):
# Group norm norms along the channels dimension and expects
# input to be in the shape of (N, C, *). In this case, we want
# to norm along the hidden dimension, so we need to move
# (batch_size, sequence_length, hidden_size) ->
# (batch_size, hidden_size, sequence_length)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👌

Comment on lines 486 to 494
if not attn.only_cross_attention:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean 🌿

Comment on lines 1523 to 1569
only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
processor=processor,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👌

Comment on lines +424 to +429
# Check is relaxed because there is not a torch 2.0 sliced attention added kv processor
expected_max_diff = 1e-2

self._test_attention_slicing_forward_pass(
test_max_difference=test_max_difference, expected_max_diff=expected_max_diff
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be conditioned on the torch version being used?

Copy link
Contributor Author

@williamberman williamberman Apr 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry should have noted, intention is to follow up and add the 2.0 sliced attention added kv processor

@williamberman williamberman force-pushed the AttnAddedKVProcessor2_0 branch from fb485ad to ac3786a Compare April 11, 2023 17:44
@williamberman williamberman force-pushed the AttnAddedKVProcessor2_0 branch from ac3786a to a65dd58 Compare April 11, 2023 22:59
@williamberman williamberman merged commit ea39cd7 into huggingface:main Apr 11, 2023
w4ffl35 pushed a commit to w4ffl35/diffusers that referenced this pull request Apr 14, 2023
dg845 pushed a commit to dg845/diffusers that referenced this pull request May 6, 2023
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants