@@ -333,13 +333,14 @@ inline bool check_safe_kv_broadcast(at::Tensor const& param, bool debug) {
333
333
return true ;
334
334
}
335
335
336
+ template <bool requires_same_num_heads=true >
336
337
inline bool check_grouped_query_attention (sdp_params const & params, bool debug) {
337
338
const auto q_num_heads = params.query .sym_size (-3 );
338
339
const auto k_num_heads = params.key .sym_size (-3 );
339
340
const auto v_num_heads = params.value .sym_size (-3 );
340
341
const bool same_kv_heads = k_num_heads == v_num_heads;
341
342
342
- if (!(same_kv_heads)){
343
+ if (requires_same_num_heads && !(same_kv_heads)){
343
344
if (debug) {
344
345
TORCH_WARN (
345
346
" Both fused kernels require key and value to have the same num_heads and batch_size but got: " ,
@@ -355,10 +356,10 @@ inline bool check_grouped_query_attention(sdp_params const& params, bool debug)
355
356
}
356
357
// Check if grouped query attention is supported and validate the number of
357
358
// heads
358
- if (q_num_heads % k_num_heads != 0 ) {
359
+ if (q_num_heads % k_num_heads != 0 || (!requires_same_num_heads && (q_num_heads % v_num_heads != 0 )) ) {
359
360
if (debug) {
360
361
TORCH_WARN (
361
- " FlashAttentionV2 only supports grouped query attention, where the number of heads in key/value must divide number of heads in query." ,
362
+ " The number of heads in key/value must divide number of heads in query." ,
362
363
" Got input Key sizes(): " ,
363
364
params.key .sym_size (-3 ),
364
365
" , Value sizes(): " ,
@@ -372,7 +373,7 @@ inline bool check_grouped_query_attention(sdp_params const& params, bool debug)
372
373
return true ;
373
374
}
374
375
375
- template <bool supports_gqa>
376
+ template <bool supports_gqa, bool requires_same_num_heads= true >
376
377
inline bool check_batch_size_and_num_heads_dense (sdp_params const & params, bool debug) {
377
378
// This is expected to be called after check_tensor_shapes ensuring that the
378
379
// size() calls won't error since the inputs are all 4 dimensional
@@ -407,9 +408,10 @@ inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool
407
408
}
408
409
409
410
if (params.enable_gqa && supports_gqa){
410
- return check_grouped_query_attention (params, debug);
411
+ return check_grouped_query_attention<requires_same_num_heads> (params, debug);
411
412
}
412
413
414
+ // same num heads condition for non-gqa case
413
415
if (!same_num_heads){
414
416
if (debug) {
415
417
TORCH_WARN (
0 commit comments