Skip to content

Add support for Deepseek-R1 flash attention #11557

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

siddartha-RE
Copy link

This PR adds support for Flash Attention in Deepseek V3 models.

Since Deepseek v3 has n_embd_head_v != n_embd_head_k padding is needed for the flash attention operation. This change pads and slices before the flash attention operation. This is done in lieu of adding support in the flash attention kernel for a different V head dimension.

@siddartha-RE
Copy link
Author

This is not yet working, probably because the padding and slicing are not quite working correctly. I am investigating but if you have a different suggestion than padding I can try that.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jan 31, 2025
@davidsyoung
Copy link

This would really make a massive difference to running R1 locally.

I don't have the ability to work on this, but is there anything that the community can do to help push this forward?

@dmatora
Copy link

dmatora commented Mar 7, 2025

DeepSeek support for flash attention was just merged into ik_llama.cpp - ikawrakow/ik_llama.cpp#241
@siddartha-RE can you have a look and maybe adopt it for llama.cpp?

@fredlas
Copy link
Contributor

fredlas commented Mar 17, 2025

I took a stab at integrating ikawrakow's work with this PR (also updating to current master): https://github.com/fredlas/llama.cpp/tree/ik_r1_fa

There were some additional changes it needed to get it to compile, and not fail dimension sanity check asserts. Unfortunately it's generating gibberish, so I guess I'm plugging in something wrong. Likely culprits are ggml_compute_forward_flash_attn_ext_f16() and ggml_compute_forward_flash_attn_back_f32() in ggml-cpu.c, where I used AI to split the D into Dq Dk Dv and gave it just the lightest of reviews, because I have no idea what's going on in there.

I am not going to be able to take it any further, but maybe someone who knows the transformer guts could skim it and see something obviously (to them) wrong.

@pl752
Copy link
Contributor

pl752 commented Mar 25, 2025

@fredlas, I have added a comment with a suggestion about pointer offsets in the code
Also I have attempted similar experiment recently, which worked pretty well, details are available in commit pl752@ba2ef97
Also ggml_compute_forward_flash_attn_back_f32 function doesn't seem to be necessary for the inference.

VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*Dv); // (temporary) FP16 VKQ accumulator
Q_q   = (ggml_fp16_t *) (VKQ32 + 2*Dv) // 2*Dq) // THERE is an offset after the buffer, probably of size Dv, not Dq

P. S. you might also want to build my llama.cpp fork and test cpu only inference of the model like DeepSeek-V2-Lite (V2, V3 and R1 have similar attention mehanism), -fa -ctk q8_0 -ctv q8_0 seem to work
2025-03-26_050828593

@pl752
Copy link
Contributor

pl752 commented Mar 26, 2025

UPD: Definitely padding is somehow troublesome, since in cuda it turns output into complete gibberish.
Turning padding off results in somewhat coherent text from cuda, but output quality is extremely low

example:

 To achieve the desired functionality, you can be able todrawBall,draw using buffercalculateget the ball size,updatecalculate.



the ball.



phins
phins

bufferData

To use the buffer objectupdate.update(GL_ARRAY_BUFFER, GL_ARRAY_BUFFER,
GL_TRIANGLES, GL_TRIANGLES, GL_TRIANGLE


To create aupdate thebuffer object



To use buffer to create a buffer object for the ball in OpenGL, you can use the following code:




#include <getBall.
#include(GLFW/glfw.h)
#include <GL/GL.h>
#include <GLFW/glfw32.


#include_updateBuffer

#include
#include.



#include("B
#
#include

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants