Skip to content

Commit d7b8be1

Browse files
committed
llama : add llama_kv_cache_compress (EXPERIMENTAL)
1 parent bf08e00 commit d7b8be1

File tree

3 files changed

+262
-0
lines changed

3 files changed

+262
-0
lines changed

examples/passkey/passkey.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ int main(int argc, char ** argv) {
148148

149149
llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
150150
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
151+
llama_kv_cache_compress(ctx, 0);
151152
llama_kv_cache_update (ctx);
152153

153154
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;

llama.cpp

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1737,6 +1737,9 @@ struct llama_kv_cache {
17371737
ggml_type type_k = GGML_TYPE_F16;
17381738
ggml_type type_v = GGML_TYPE_F16;
17391739

1740+
// if non-negative, compress data on next update
1741+
llama_pos compress_delta = -1;
1742+
17401743
std::vector<llama_kv_cell> cells;
17411744

17421745
std::vector<struct ggml_tensor *> k_l; // per layer
@@ -2272,6 +2275,10 @@ static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama
22722275
return result;
22732276
}
22742277

2278+
static void llama_kv_cache_compress(struct llama_kv_cache & cache, llama_pos delta) {
2279+
cache.compress_delta = delta;
2280+
}
2281+
22752282
static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
22762283
cache.do_defrag = true;
22772284
}
@@ -8091,6 +8098,240 @@ static int llama_decode_internal(
80918098
return 0;
80928099
}
80938100

8101+
// summary:
8102+
//
8103+
// - determine which KV cell pairs (i0, i1) to merge:
8104+
//
8105+
// abs(cell[i0].pos - cell[i1].pos) <= compress_delta
8106+
//
8107+
// - move the KV cache to the host memory for easier manipulation
8108+
// - processing is done layer-by-layer
8109+
// - convert the KV data to F32
8110+
// - merge the KV data (different ways to merge)
8111+
// - convert the KV data back to the original type
8112+
// - move the KV cache back to the device memory
8113+
// - update the KV cache metadata
8114+
//
8115+
// as a side effect, the new KV cache is defragmented
8116+
//
8117+
static void llama_kv_cache_compress_internal(struct llama_context & lctx) {
8118+
auto & kv_self = lctx.kv_self;
8119+
8120+
const auto & hparams = lctx.model.hparams;
8121+
8122+
const uint32_t n_layer = hparams.n_layer;
8123+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
8124+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
8125+
const uint32_t n_embd_head_k = hparams.n_embd_head_k; GGML_UNUSED(n_embd_head_k);
8126+
const uint32_t n_embd_head_v = hparams.n_embd_head_v; GGML_UNUSED(n_embd_head_v);
8127+
const uint32_t n_head_kv = hparams.n_head_kv; GGML_UNUSED(n_head_kv);
8128+
const uint32_t kv_size = kv_self.size;
8129+
8130+
const int64_t t_start = ggml_time_us();
8131+
8132+
std::vector<uint8_t> buf_q;
8133+
8134+
std::vector<float> buf_src_f32;
8135+
std::vector<float> buf_dst_f32;
8136+
8137+
struct c_pair { uint32_t i0, i1; };
8138+
struct c_info { bool merged; uint32_t id, cnt, r; };
8139+
8140+
std::vector<c_info> infos(kv_size, { false, 0, 0, 0 });
8141+
8142+
// the destination cell in the new KV cache
8143+
uint32_t id = 0;
8144+
8145+
// number of pairs merged
8146+
uint32_t n_merges = 0;
8147+
8148+
// determine which KV cells to merge
8149+
for (uint32_t i0 = 0; i0 < kv_size; ++i0) {
8150+
const auto & cell0 = kv_self.cells[i0];
8151+
8152+
if (!cell0.is_empty() && !infos[i0].merged) {
8153+
infos[i0] = { true, id, 0, 0 };
8154+
infos[id].cnt = 1;
8155+
8156+
const llama_pos p0 = cell0.pos;
8157+
8158+
for (uint32_t i1 = i0 + 1; i1 < kv_size; ++i1) {
8159+
const auto & cell1 = kv_self.cells[i1];
8160+
8161+
if (i0 != i1 && cell0.is_same_seq(cell1)) {
8162+
const llama_pos p1 = cell1.pos;
8163+
8164+
if (std::abs(p0 - p1) <= kv_self.compress_delta) {
8165+
infos[i1] = { true, id, 0, 0 };
8166+
infos[id].cnt++;
8167+
n_merges++;
8168+
}
8169+
}
8170+
}
8171+
8172+
if (i0 != id) {
8173+
kv_self.cells[id] = cell0;
8174+
}
8175+
8176+
id++;
8177+
}
8178+
}
8179+
8180+
kv_self.head = id;
8181+
kv_self.used = id;
8182+
8183+
for (uint32_t i = id; i < kv_size; ++i) {
8184+
kv_self.cells[i] = llama_kv_cell();
8185+
}
8186+
8187+
LLAMA_LOG_INFO("(tmp log) KV compress pairs: %u\n", n_merges);
8188+
8189+
ggml_type_traits_t tt_k;
8190+
ggml_type_traits_t tt_v;
8191+
8192+
tt_k = ggml_internal_get_type_traits(kv_self.type_k);
8193+
tt_v = ggml_internal_get_type_traits(kv_self.type_v);
8194+
8195+
for (uint32_t il = 0; il < n_layer; ++il) {
8196+
for (uint32_t i = 0; i < kv_size; ++i) {
8197+
infos[i].r = 0;
8198+
}
8199+
8200+
// update keys
8201+
{
8202+
const int64_t ne = n_embd_k_gqa*kv_size;
8203+
8204+
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, ne);
8205+
8206+
buf_q.resize(k_size);
8207+
8208+
buf_src_f32.resize(ne);
8209+
buf_dst_f32.resize(ne);
8210+
8211+
ggml_backend_tensor_get(kv_self.k_l[il], buf_q.data(), 0, buf_q.size());
8212+
8213+
tt_k.to_float(buf_q.data(), buf_src_f32.data(), ne);
8214+
8215+
std::fill(buf_dst_f32.begin(), buf_dst_f32.end(), 0);
8216+
8217+
for (uint32_t i = 0; i < kv_size; ++i) {
8218+
if (!infos[i].merged) {
8219+
continue;
8220+
}
8221+
8222+
const uint32_t id = infos[i].id;
8223+
8224+
#if 1
8225+
// merge using averaging
8226+
{
8227+
const float scale = 1.0f/float(infos[id].cnt);
8228+
8229+
const int64_t os = i*n_embd_k_gqa;
8230+
const int64_t od = id*n_embd_k_gqa;
8231+
8232+
for (uint32_t j = 0; j < n_embd_k_gqa; ++j) {
8233+
buf_dst_f32[od + j] += buf_src_f32[os + j]*scale;
8234+
}
8235+
}
8236+
#else
8237+
// merge separate heads
8238+
{
8239+
for (uint32_t h = 0; h < n_head_kv; ++h) {
8240+
if ((h + il) % infos[id].cnt != infos[id].r) {
8241+
continue;
8242+
}
8243+
8244+
const int64_t os = i*n_embd_k_gqa + h*n_embd_head_k;
8245+
const int64_t od = id*n_embd_k_gqa + h*n_embd_head_k;
8246+
8247+
for (uint32_t j = 0; j < n_embd_head_k; ++j) {
8248+
buf_dst_f32[od + j] = buf_src_f32[os + j];
8249+
}
8250+
}
8251+
}
8252+
8253+
infos[id].r++;
8254+
#endif
8255+
}
8256+
8257+
tt_k.from_float(buf_dst_f32.data(), buf_q.data(), ne);
8258+
8259+
ggml_backend_tensor_set(kv_self.k_l[il], buf_q.data(), 0, buf_q.size());
8260+
}
8261+
8262+
for (uint32_t i = 0; i < kv_size; ++i) {
8263+
infos[i].r = 0;
8264+
}
8265+
8266+
// update values (note: they are transposed)
8267+
{
8268+
const int64_t ne = n_embd_v_gqa*kv_size;
8269+
8270+
const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, ne);
8271+
8272+
buf_q.resize(v_size);
8273+
8274+
buf_src_f32.resize(ne);
8275+
buf_dst_f32.resize(ne);
8276+
8277+
ggml_backend_tensor_get(kv_self.v_l[il], buf_q.data(), 0, buf_q.size());
8278+
8279+
tt_v.to_float(buf_q.data(), buf_src_f32.data(), ne);
8280+
8281+
std::fill(buf_dst_f32.begin(), buf_dst_f32.end(), 0);
8282+
8283+
for (uint32_t i = 0; i < kv_size; ++i) {
8284+
if (!infos[i].merged) {
8285+
continue;
8286+
}
8287+
8288+
const uint32_t id = infos[i].id;
8289+
8290+
#if 1
8291+
// merge using averaging
8292+
{
8293+
const float scale = 1.0f/float(infos[id].cnt);
8294+
//printf("i: %d -> id: %d, scale: %f\n", i, id, scale);
8295+
8296+
const int64_t os = i;
8297+
const int64_t od = id;
8298+
8299+
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
8300+
buf_dst_f32[od + j*kv_size] += buf_src_f32[os + j*kv_size]*scale;
8301+
}
8302+
}
8303+
#else
8304+
// merge separate heads
8305+
{
8306+
for (uint32_t h = 0; h < n_head_kv; ++h) {
8307+
if ((h + il) % infos[id].cnt != infos[id].r) {
8308+
continue;
8309+
}
8310+
8311+
const int64_t os = i;
8312+
const int64_t od = id;
8313+
8314+
for (uint32_t j = h*n_embd_head_v; j < (h + 1)*n_embd_head_v; ++j) {
8315+
buf_dst_f32[od + j*kv_size] = buf_src_f32[os + j*kv_size];
8316+
}
8317+
}
8318+
}
8319+
8320+
infos[id].r++;
8321+
#endif
8322+
}
8323+
8324+
tt_v.from_float(buf_dst_f32.data(), buf_q.data(), ne);
8325+
8326+
ggml_backend_tensor_set(kv_self.v_l[il], buf_q.data(), 0, buf_q.size());
8327+
}
8328+
}
8329+
8330+
const int64_t t_end = ggml_time_us();
8331+
8332+
LLAMA_LOG_INFO("(tmp log) KV compress time: %.3f ms\n", (t_end - t_start)/1000.0);
8333+
}
8334+
80948335
// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
80958336
static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
80968337
auto & kv_self = lctx.kv_self;
@@ -8298,6 +8539,14 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
82988539
}
82998540
}
83008541

8542+
// compress the KV cache data if needed
8543+
if (lctx.kv_self.compress_delta >= 0) {
8544+
llama_kv_cache_compress_internal(lctx);
8545+
8546+
lctx.kv_self.compress_delta = -1;
8547+
lctx.kv_self.do_defrag = false;
8548+
}
8549+
83018550
// defragment the KV cache if needed
83028551
if (lctx.kv_self.do_defrag) {
83038552
llama_kv_cache_defrag_internal(lctx);
@@ -12374,6 +12623,10 @@ llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id se
1237412623
return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
1237512624
}
1237612625

12626+
void llama_kv_cache_compress(struct llama_context * ctx, llama_pos delta) {
12627+
llama_kv_cache_compress(ctx->kv_self, delta);
12628+
}
12629+
1237712630
void llama_kv_cache_defrag(struct llama_context * ctx) {
1237812631
llama_kv_cache_defrag(ctx->kv_self);
1237912632
}

llama.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,14 @@ extern "C" {
554554
struct llama_context * ctx,
555555
llama_seq_id seq_id);
556556

557+
// [EXPERIMENTAL] Compress the data in the KV cache
558+
// This will be applied:
559+
// - lazily on next llama_decode()
560+
// - explicitly with llama_kv_cache_update()
561+
LLAMA_API void llama_kv_cache_compress(
562+
struct llama_context * ctx,
563+
llama_pos delta);
564+
557565
// Defragment the KV cache
558566
// This will be applied:
559567
// - lazily on next llama_decode()

0 commit comments

Comments
 (0)