Skip to content

Commit 989b692

Browse files
eqytimocafe
authored andcommitted
[cuDNN][SDPA] Loosen constraints for GQA for cuDNN Attention (pytorch#150337)
cuDNN attention doesn't require key and value tensors to have the same number of heads Pull Request resolved: pytorch#150337 Approved by: https://github.com/drisspg
1 parent 5fca74e commit 989b692

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -553,9 +553,10 @@ bool check_for_nested_inputs(sdp_params const& params, bool debug) {
553553
TORCH_WARN("Experimental cuDNN SDPA nested tensor support is not enabled.");
554554
}
555555
return false;
556-
} else if (params.query.requires_grad() || params.key.requires_grad() || params.value.requires_grad()) {
556+
} else if (has_for_nested_inputs(params) && (params.query.requires_grad() || params.key.requires_grad() || params.value.requires_grad())) {
557557
if (debug) {
558558
TORCH_WARN("Experimental cuDNN SDPA nested tensor support does not support backward.");
559+
return false;
559560
}
560561
}
561562

@@ -645,7 +646,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
645646
constexpr auto dense_constraints =
646647
c10::array_of<bool (*)(sdp_params const&, bool)>(
647648
check_last_dim_stride_equals_1_dense<true /*ignore_singleton_dim=*/>,
648-
check_batch_size_and_num_heads_dense<true /*enable_gqa*/>
649+
check_batch_size_and_num_heads_dense<true /*enable_gqa*/, false /*requires_same_num_heads*/>
649650
);
650651

651652
if (has_only_dense_inputs(params)) {

aten/src/ATen/native/transformers/sdp_utils_cpp.h

+7-5
Original file line numberDiff line numberDiff line change
@@ -333,13 +333,14 @@ inline bool check_safe_kv_broadcast(at::Tensor const& param, bool debug) {
333333
return true;
334334
}
335335

336+
template <bool requires_same_num_heads=true>
336337
inline bool check_grouped_query_attention(sdp_params const& params, bool debug) {
337338
const auto q_num_heads = params.query.sym_size(-3);
338339
const auto k_num_heads = params.key.sym_size(-3);
339340
const auto v_num_heads = params.value.sym_size(-3);
340341
const bool same_kv_heads = k_num_heads == v_num_heads;
341342

342-
if (!(same_kv_heads)){
343+
if (requires_same_num_heads && !(same_kv_heads)){
343344
if (debug) {
344345
TORCH_WARN(
345346
"Both fused kernels require key and value to have the same num_heads and batch_size but got: ",
@@ -355,10 +356,10 @@ inline bool check_grouped_query_attention(sdp_params const& params, bool debug)
355356
}
356357
// Check if grouped query attention is supported and validate the number of
357358
// heads
358-
if (q_num_heads % k_num_heads != 0) {
359+
if (q_num_heads % k_num_heads != 0 || (!requires_same_num_heads && (q_num_heads % v_num_heads != 0))) {
359360
if (debug) {
360361
TORCH_WARN(
361-
"FlashAttentionV2 only supports grouped query attention, where the number of heads in key/value must divide number of heads in query.",
362+
"The number of heads in key/value must divide number of heads in query.",
362363
"Got input Key sizes(): ",
363364
params.key.sym_size(-3),
364365
", Value sizes(): ",
@@ -372,7 +373,7 @@ inline bool check_grouped_query_attention(sdp_params const& params, bool debug)
372373
return true;
373374
}
374375

375-
template <bool supports_gqa>
376+
template <bool supports_gqa, bool requires_same_num_heads=true>
376377
inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool debug) {
377378
// This is expected to be called after check_tensor_shapes ensuring that the
378379
// size() calls won't error since the inputs are all 4 dimensional
@@ -407,9 +408,10 @@ inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool
407408
}
408409

409410
if(params.enable_gqa && supports_gqa){
410-
return check_grouped_query_attention(params, debug);
411+
return check_grouped_query_attention<requires_same_num_heads>(params, debug);
411412
}
412413

414+
// same num heads condition for non-gqa case
413415
if (!same_num_heads){
414416
if (debug) {
415417
TORCH_WARN(

test/test_transformers.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2479,7 +2479,8 @@ def test_cudnn_attention_gqa(self, device):
24792479
# Sample call to SDPA - GQ
24802480
query = torch.rand(batch, 32, seq_len_q, D, device='cuda', dtype=torch.bfloat16)
24812481
key = torch.rand(batch, 8, seq_len_kv, D, device='cuda', dtype=torch.bfloat16)
2482-
value = torch.rand(batch, 8, seq_len_kv, D, device='cuda', dtype=torch.bfloat16)
2482+
# cuDNN supports h_k != h_v
2483+
value = torch.rand(batch, 4, seq_len_kv, D, device='cuda', dtype=torch.bfloat16)
24832484
with sdpa_kernel([SDPBackend.MATH]):
24842485
output_math = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)
24852486

0 commit comments

Comments
 (0)