-
Notifications
You must be signed in to change notification settings - Fork 6k
Attn added kv processor torch 2.0 block #3023
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Attn added kv processor torch 2.0 block #3023
Conversation
The documentation is not available anymore as the PR was closed or merged. |
26db440
to
40ec593
Compare
40ec593
to
a38867e
Compare
if out_dim == 3: | ||
if attention_mask.shape[0] < batch_size * head_size: | ||
attention_mask = attention_mask.repeat_interleave(head_size, dim=0) | ||
elif out_dim == 4: | ||
attention_mask = attention_mask.unsqueeze(1) | ||
attention_mask = attention_mask.repeat_interleave(head_size, dim=1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten when using the torch built in attention and putting heads as the second dim, we need to make the attention mask also put the heads in the second dim. I'm not sure what the equivalent check for attention_mask.shape[0] < batch_size * head_size
is. If we assume the input attention mask is always just the same batch size of the inputs, we don't have to do the check and I think this works. My understanding is that's what the original code was doing anyway since it just repeats by the head size regardless
a38867e
to
a480bc0
Compare
else: | ||
raise ValueError( | ||
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
a480bc0
to
fb485ad
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice refactor!
value = encoder_hidden_states_value_proj | ||
|
||
# the output of sdp = (batch, num_heads, seq_len, head_dim) | ||
# TODO: add support for attn.scale when we move to Torch 2.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity. Which scale is this one?
elif isinstance(self.norm_cross, nn.GroupNorm): | ||
# Group norm norms along the channels dimension and expects | ||
# input to be in the shape of (N, C, *). In this case, we want | ||
# to norm along the hidden dimension, so we need to move | ||
# (batch_size, sequence_length, hidden_size) -> | ||
# (batch_size, hidden_size, sequence_length) | ||
encoder_hidden_states = encoder_hidden_states.transpose(1, 2) | ||
encoder_hidden_states = self.norm_cross(encoder_hidden_states) | ||
encoder_hidden_states = encoder_hidden_states.transpose(1, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👌
if not attn.only_cross_attention: | ||
key = attn.to_k(hidden_states) | ||
value = attn.to_v(hidden_states) | ||
key = attn.head_to_batch_dim(key) | ||
value = attn.head_to_batch_dim(value) | ||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) | ||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) | ||
else: | ||
key = encoder_hidden_states_key_proj | ||
value = encoder_hidden_states_value_proj |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clean 🌿
only_cross_attention=only_cross_attention, | ||
cross_attention_norm=cross_attention_norm, | ||
processor=processor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👌
# Check is relaxed because there is not a torch 2.0 sliced attention added kv processor | ||
expected_max_diff = 1e-2 | ||
|
||
self._test_attention_slicing_forward_pass( | ||
test_max_difference=test_max_difference, expected_max_diff=expected_max_diff | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be conditioned on the torch version being used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry should have noted, intention is to follow up and add the 2.0 sliced attention added kv processor
fb485ad
to
ac3786a
Compare
ac3786a
to
a65dd58
Compare
add AttnAddedKVProcessor2_0 block
add AttnAddedKVProcessor2_0 block
add AttnAddedKVProcessor2_0 block
add AttnAddedKVProcessor2_0 block
rebased on top of #3021 and #3011
200a8c7 and on is the main commit