@@ -6445,7 +6445,7 @@ static __global__ void flash_attn_ext_f16(
6445
6445
const int D16 = D/16 ;
6446
6446
const int Q16 = Q/16 ;
6447
6447
const int NW = WARP_SIZE;
6448
- const int SH = (C + 2 * Q); // shared memory per simdgroup in (half)
6448
+ const int SH = (C + Q); // shared memory per simdgroup in (half)
6449
6449
6450
6450
const int T = D + num_warps*SH; // shared memory size per query in (half)
6451
6451
const int T2 = T/2 ; // shared memory size per query in (half2)
@@ -6455,6 +6455,8 @@ static __global__ void flash_attn_ext_f16(
6455
6455
half * sq = (half *) (__flash_attn_f16_shmem + 0 *D); // holds the query data
6456
6456
half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0 *D); // same as above but in half2
6457
6457
half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1 *D); // scratch buffer for attention and diagonal matrix
6458
+
6459
+ half16x16_acc zr;
6458
6460
half16x16_acc lo[Q16][D16];
6459
6461
6460
6462
// load heads from Q to shared memory
@@ -6470,6 +6472,8 @@ static __global__ void flash_attn_ext_f16(
6470
6472
}
6471
6473
}
6472
6474
6475
+ nvcuda::wmma::fill_fragment (zr, 0.0 );
6476
+
6473
6477
// zero out lo
6474
6478
for (int64_t j = 0 ; j < Q16; ++j) {
6475
6479
for (int64_t i = 0 ; i < D16; ++i) {
@@ -6648,13 +6652,15 @@ static __global__ void flash_attn_ext_f16(
6648
6652
6649
6653
for (int64_t i = 0 ; i < D16; ++i) {
6650
6654
// convert accumulator to matrix_b
6651
- // TODO: try to avoid the extra QxQ matrix in shared memory needed for this conversion
6652
- nvcuda::wmma::store_matrix_sync ( ss + 16 *j*T + C + Q, lo[j][i], T, nvcuda::wmma::mem_row_major);
6653
- nvcuda::wmma::load_matrix_sync (lob, ss + 16 *j*T + C + Q, T);
6655
+ nvcuda::wmma::store_matrix_sync ( ss + 16 *j*T + C + 16 *j, lo[j][i], T, nvcuda::wmma::mem_row_major);
6656
+ nvcuda::wmma::load_matrix_sync (lob, ss + 16 *j*T + C + 16 *j, T);
6654
6657
6655
6658
nvcuda::wmma::fill_fragment (lo[j][i], 0.0 );
6656
6659
nvcuda::wmma::mma_sync (lo[j][i], mm, lob, lo[j][i]);
6657
6660
}
6661
+
6662
+ // restore zeros
6663
+ nvcuda::wmma::store_matrix_sync (ss + 16 *j*T + C + 16 *j, zr, T, nvcuda::wmma::mem_row_major);
6658
6664
}
6659
6665
6660
6666
// O = O + (Q*K^T)*V
@@ -10928,14 +10934,13 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
10928
10934
const int ncpw = 32 ; // cache values per warp (does not work for other values)
10929
10935
10930
10936
const int nwarps_max = 8 ; // TODO: we don't want to launch too much warps. how much is too much?
10937
+ // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
10931
10938
const int nwarps = Q->ne [1 ] <= nqpb ? MAX (4 , MIN (K->ne [1 ]/ncpw, nwarps_max)) : 4 ;
10932
10939
10933
10940
dim3 blocks_num ((Q->ne [1 ] + nqpb - 1 ) / nqpb, Q->ne [2 ], Q->ne [3 ]);
10934
10941
dim3 block_dim (32 , nwarps, 1 );
10935
10942
10936
- // TODO: compare to Metal, here we need extra `nqpb` space in order to do the diag(ms)*O scaling
10937
- // try to avoid this
10938
- const size_t shmem = nqpb*(Q->ne [0 ] + nwarps*(ncpw + 2 *nqpb))*(sizeof (float )/2 );
10943
+ const size_t shmem = nqpb*(Q->ne [0 ] + nwarps*(ncpw + nqpb))*(sizeof (float )/2 );
10939
10944
10940
10945
switch (Q->ne [0 ])
10941
10946
{
0 commit comments