-
Notifications
You must be signed in to change notification settings - Fork 12k
ggml : add Flash Attention #5021
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+2,917
−453
Merged
Changes from all commits
Commits
Show all changes
145 commits
Select commit
Hold shift + click to select a range
a1c004e
ggml : add ggml_flash_attn_ext API
ggerganov fa7ebcc
ggml : fix GQA support in ggml_flash_attn_ext
ggerganov c3cdfff
Merge branch 'master' into gg/flash-attn
ggerganov a9681fe
ggml : online attention (CPU)
ggerganov 1173f49
metal : initial implementation
ggerganov 528da75
metal : f16 precision
ggerganov 52ae085
metal : reduce branches
ggerganov b973258
metal : specialize for head size
ggerganov 8cde449
wip : 8 rows per simd group
ggerganov f31955f
wip : 4 rows per simd group
ggerganov a4b6341
wip : template for rows per warp
ggerganov 77d08f3
metal : parallelize across KV size
ggerganov 17720fa
metal : parallel reduce across heads
ggerganov 1446a12
metal : efficient flash_attn_f16 implementation
ggerganov d917746
metal : avoid redundant loads of the attention
ggerganov 432ad04
metal : scale and mask in matrix form
ggerganov 40ea8cd
metal : fix comment
ggerganov f9ca5dc
llama : avoid ggml_cast, use F32 query
ggerganov 6fea843
metal : add parallel reduce version (disabled)
ggerganov b3dd7d9
Merge branch 'master' into gg/flash-attn
ggerganov 77f6976
metal : move output into local memory + optimize
ggerganov ecc466a
metal : add tests, fix scaling, support C > 32
ggerganov 3a428a1
metal : improve precision
ggerganov 8612864
ggml : fix f16 mad
ggerganov 0ad44ba
Merge branch 'master' into gg/flash-attn
ggerganov 134c81c
metal : minor
ggerganov 1db22d7
metal : support Q > 8
ggerganov 4794821
tests : add ATTN tests
ggerganov abeaf0d
metal : disable buffer allocation logs
ggerganov c6c1132
tests : more
ggerganov 5fcb9c1
metal : faster inner loop for C == 32
ggerganov d073e4f
metal : fix array initialization
ggerganov 78df552
tests : ifdef
ggerganov 3d03bcb
Merge branch 'master' into gg/flash-attn
ggerganov 2ddc9bb
Merge branch 'master' into gg/flash-attn
ggerganov 8ad92dc
ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
ggerganov 910b15b
ggml : fix ggml_soft_max mask requirement
ggerganov 2e46013
cuda : fix soft_max to use correct mask size
ggerganov 5a19a9f
cuda : add flash_attn kernel (wip)
ggerganov 41d136b
Merge branch 'master' into gg/flash-attn
ggerganov 56e45a2
metal : optimize softmax for C > 32
ggerganov cda5a60
metal : optimize softmax
ggerganov c6769b9
tests : minor fix
ggerganov db1f3c4
cuda : avoid zeroing fragments
ggerganov 12eaa22
tests : update dims
ggerganov b68a112
cuda : fix __hisinf() result check
ggerganov b150abe
cuda : avoid warp_reduce for smax
ggerganov 7c34655
cuda : use int instead of int64_t
ggerganov 1f8a592
cuda : make loops use the same loop values
ggerganov 92472ea
cuda : unroll some of the loops
ggerganov c51f27c
cuda : avoid __hisinf branches
ggerganov b958151
cuda : use half2 in softmax
ggerganov a7b4715
cuda : switch to 1 warp for bs > 16
ggerganov 3b1c4e7
cuda : speed-up reduce part of the kernel
ggerganov 5b263dd
cuda : unroll Q*K^T loop
ggerganov e04ff39
cuda : fix -INF block check
ggerganov cfd9732
cuda : simplify softmax
ggerganov ef68fac
cuda : fix matrix names
ggerganov 1846e92
cuda : minor
ggerganov 6875997
Merge branch 'master' into gg/flash-attn
ggerganov 31109ca
Merge branch 'master' into gg/flash-attn
ggerganov f249c99
llama : adapt to F16 KQ_pos
ggerganov 02a645e
Merge branch 'master' into gg/flash-attn
ggerganov 6aefd11
llama : adapt new models to F16 KQ_mask
ggerganov e307882
Merge branch 'master' into gg/flash-attn
ggerganov 58c7f61
ggml : fix F16 store (ARM NEON)
ggerganov 9495d39
Merge branch 'master' into gg/flash-attn
ggerganov 3a468e6
llama : fix type of KQ_mask and KQ_pos
ggerganov 0953212
ggml : fix CPU soft_max
ggerganov e425810
tests : add hs=256
ggerganov 013721d
Merge branch 'master' into gg/flash-attn
ggerganov 6be02b5
cuda : fix build
ggerganov 57c03b7
metal : improve perf via smaller int registers
ggerganov 3e318e7
Merge branch 'master' into gg/flash-attn
ggerganov 08e69c5
cuda : adapt soft_max to F16 mask and pos
ggerganov 75aa7b4
CUDA: faster FlashAttention, kernel for bs == 1
JohannesGaessler d59ac67
16 cols for Phi-2
JohannesGaessler 81da919
no vec for hs, no hs==256 ncols==32 for Volta
JohannesGaessler 269374e
adjust kernel selection logic
JohannesGaessler cca6d02
4 warps, 256 stride for all D
JohannesGaessler 68d793b
no ncols == 64
JohannesGaessler 3f777ac
Multiple parallel blocks for batch size 1
JohannesGaessler e1ecd3b
fix compile warnings
JohannesGaessler bb0d51a
fix excessive KQ_b loads
JohannesGaessler c63dfdf
fix cmake build
JohannesGaessler ee19a4a
fix KV cache padding, NaN from INFINITY (#6438)
JohannesGaessler 89961de
Merge branch 'master' into gg/flash-attn
ggerganov 2c41180
Merge branch 'master' into gg/flash-attn
ggerganov 599ce84
llama : flash_attn cparam + fix defrag
ggerganov 4053857
server: support flash_attn param
phymbert 5668c79
server: bench: enable flash_attn param
phymbert 34f93bb
CUDA: refactor host code, dyn. par. blocks
JohannesGaessler 6a3b842
fix flash_attn_vec_f16 race condition
JohannesGaessler ef9e159
flush softmax exp below threshold to 0
JohannesGaessler a5b0e2d
store temp KQ in registers
JohannesGaessler 0bc67dd
Calculate KQ as FP32 if KQV has GGML_PREC_F32
JohannesGaessler 2f538b9
Add __hgt2_mask implementation for CUDA 11
JohannesGaessler 87968de
fix KQ FP32 precision fpr parallel_blocks > 1
JohannesGaessler 260cdb2
llama-bench : add -fa,--flash-attn arg
ggerganov 105332c
metal : add BS=1 kernel for flash attention (#6508)
ggerganov fa9e8c6
Merge branch 'master' into gg/flash-attn
ggerganov c16a7c2
metal : use F32 attention accumulators
ggerganov 9ca8698
batched-bench : add fattn arg
ggerganov 74d57f9
llama : simplify llama_build_kv_store
ggerganov 1db66c1
Merge branch 'master' into gg/flash-attn
ggerganov e32b281
llama : adapt build_olmo to changes
ggerganov 703c6e6
ggml : fix arm fp16 store on windows
ggerganov 97eaece
metal : clean-up
ggerganov 1a88565
metal : clean-up kernel code
ggerganov bc34616
metal : minor
ggerganov 29f6ad8
Merge branch 'master' into gg/flash-attn
ggerganov 5294542
tests : remove benchmarks
ggerganov 3badef1
ggml : fix avx512 const correctness
ggerganov 871fcb6
ggml : fix soft_max with bias on CPU
ggerganov a39217d
common : print --flash-attn in help
ggerganov cb76d74
ggml : fix num dimensions in ggml_flash_attn_ext
ggerganov c11d05f
llama : force disable flash attention for incompatible models
ggerganov f725ca9
ggml : ggml_soft_max support F16/F32 mask/pos
ggerganov 5408d55
cuda : uint -> uint32_t
ggerganov c70bfd7
cuda : "constexpr dim3" -> "const dim3"
ggerganov c129369
cuda : try to fix __hgt2_mask
ggerganov 3864eea
ggml : add TODO's for F16/F32 mask/pos support in other backends
ggerganov 78d363b
llama : replace bool need_kq_pos with use_alibi
ggerganov 19e8982
llama : prep ALiBi support for BERT models
ggerganov 56657e5
llama : fix n_batch requirements
ggerganov d228bf8
cont
ggerganov 751591d
server : add help for --flash-attn arg
ggerganov 8937ec5
Merge branch 'master' into gg/flash-attn
ggerganov ce281b9
llama : disable FA for AMD
ggerganov 1f77f49
Merge branch 'master' into gg/flash-attn
ggerganov ff2c64a
tests : remove TMP_ATTN_BENCH
ggerganov cb3547a
Merge branch 'master' into gg/flash-attn
ggerganov 1fd5bc3
llama : support save/load state with FA enabled
ggerganov 09d0381
Merge branch 'master' into gg/flash-attn
ggerganov ac1c6d9
ci : add CUDA save-load-state tests
ggerganov c225609
llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggerganov bab346b
llama : fix copy-paste errors, add TODO
ggerganov 0fc5c5e
llama : disallow incompatible states
ggerganov 1e590ac
llama : update llama_state_get_size after v_trans field
ggerganov 4f4c024
metal : remove tmp log
ggerganov 9e38760
llama : add static reminder for llama_state_get_size
ggerganov a1616e9
Merge branch 'master' into gg/flash-attn
ggerganov ca0275c
Merge branch 'master' into gg/flash-attn
ggerganov e180fcd
metal : fix max nsg
ggerganov c240ae2
ci : fix arg order
ggerganov File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.