Skip to content

Add support for Deepseek-R1 flash attention #11557

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ggml/src/ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
break;
case 192:
ggml_cuda_flash_attn_ext_wmma_f16_case<192, cols_per_block, float>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
break;
Expand Down
55 changes: 48 additions & 7 deletions ggml/src/ggml-cuda/pad.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,31 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons
}
}

static __global__ void pad_f16(const half * x, half * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
// blockIdx.z: idx of ne2*ne3, aka ne02*ne03
// blockIdx.y: idx of ne1
// blockIDx.x: idx of ne0 / BLOCK_SIZE
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
}

// operation
int offset_dst =
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
int offset_src =
nidx +
blockIdx.y * ne00 +
blockIdx.z * ne00 * ne01;
dst[offset_dst] = x[offset_src];
} else {
dst[offset_dst] = 0.0f;
}
}

static void pad_f32_cuda(const float * x, float * dst,
const int ne00, const int ne01, const int ne02, const int ne03,
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
Expand All @@ -33,17 +58,33 @@ static void pad_f32_cuda(const float * x, float * dst,
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
}

static void pad_f16_cuda(const half * x, half * dst,
const int ne00, const int ne01, const int ne02, const int ne03,
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
dim3 gridDim(num_blocks, ne1, ne2*ne3);
pad_f16<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
}

void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == src0->type);
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors

pad_f32_cuda(src0_d, dst_d,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
if (src0->type == GGML_TYPE_F32) {
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
pad_f32_cuda(src0_d, dst_d,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
} else {
const half * src0_d = (const half *)src0->data;
half * dst_d = (half *)dst->data;
pad_f16_cuda(src0_d, dst_d,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
}
}
24 changes: 21 additions & 3 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,12 +589,30 @@ static struct ggml_tensor * llm_build_kqv(
0);
cb(v, "v", il);

cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
struct ggml_tensor * padded_v = v;
int64_t n_embd_head_v_out = n_embd_head_v;
if (n_embd_head_v < n_embd_head_k) {
padded_v = ggml_pad(ctx, v, 0, k->ne[0] - v->ne[1], 0, 0);
cb(padded_v, "padded_v", il);
n_embd_head_v_out = n_embd_head_k;
}

cur = ggml_flash_attn_ext(ctx, q, k, padded_v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);

ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);

if (n_embd_head_v < n_embd_head_k) {
cur = ggml_reshape_3d(ctx, cur, n_embd_head_v_out, n_head, n_tokens);
cur = ggml_view_3d(ctx, cur, n_embd_head_v, n_head, n_tokens,
ggml_element_size(cur) * n_embd_head_v_out,
ggml_element_size(cur) * n_embd_head_v_out * n_head,
0);
cur = ggml_cont(ctx, cur);
}

cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);

} else {
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
Expand Down Expand Up @@ -9551,8 +9569,8 @@ struct llama_context * llama_init_from_model(
params.flash_attn = false;
}

if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
if (params.flash_attn && model->hparams.n_embd_head_k < model->hparams.n_embd_head_v) {
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k >= n_embd_head_v - forcing off\n", __func__);
params.flash_attn = false;
}

Expand Down