Skip to content

Commit 2c04bee

Browse files
committed
cuda : avoid extra QxQ matrix in shared memory
1 parent 71b69aa commit 2c04bee

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

ggml-cuda.cu

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6445,7 +6445,7 @@ static __global__ void flash_attn_ext_f16(
64456445
const int D16 = D/16;
64466446
const int Q16 = Q/16;
64476447
const int NW = WARP_SIZE;
6448-
const int SH = (C + 2*Q); // shared memory per simdgroup in (half)
6448+
const int SH = (C + Q); // shared memory per simdgroup in (half)
64496449

64506450
const int T = D + num_warps*SH; // shared memory size per query in (half)
64516451
const int T2 = T/2; // shared memory size per query in (half2)
@@ -6455,6 +6455,8 @@ static __global__ void flash_attn_ext_f16(
64556455
half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data
64566456
half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2
64576457
half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix
6458+
6459+
half16x16_acc zr;
64586460
half16x16_acc lo[Q16][D16];
64596461

64606462
// load heads from Q to shared memory
@@ -6470,6 +6472,8 @@ static __global__ void flash_attn_ext_f16(
64706472
}
64716473
}
64726474

6475+
nvcuda::wmma::fill_fragment(zr, 0.0);
6476+
64736477
// zero out lo
64746478
for (int64_t j = 0; j < Q16; ++j) {
64756479
for (int64_t i = 0; i < D16; ++i) {
@@ -6648,13 +6652,15 @@ static __global__ void flash_attn_ext_f16(
66486652

66496653
for (int64_t i = 0; i < D16; ++i) {
66506654
// convert accumulator to matrix_b
6651-
// TODO: try to avoid the extra QxQ matrix in shared memory needed for this conversion
6652-
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + Q, lo[j][i], T, nvcuda::wmma::mem_row_major);
6653-
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + Q, T);
6655+
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
6656+
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T);
66546657

66556658
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
66566659
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]);
66576660
}
6661+
6662+
// restore zeros
6663+
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major);
66586664
}
66596665

66606666
// O = O + (Q*K^T)*V
@@ -10928,14 +10934,13 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1092810934
const int ncpw = 32; // cache values per warp (does not work for other values)
1092910935

1093010936
const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
10937+
// TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
1093110938
const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, nwarps_max)) : 4;
1093210939

1093310940
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
1093410941
dim3 block_dim(32, nwarps, 1);
1093510942

10936-
// TODO: compare to Metal, here we need extra `nqpb` space in order to do the diag(ms)*O scaling
10937-
// try to avoid this
10938-
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + 2*nqpb))*(sizeof(float)/2);
10943+
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
1093910944

1094010945
switch (Q->ne[0])
1094110946
{

0 commit comments

Comments
 (0)