@@ -1737,6 +1737,9 @@ struct llama_kv_cache {
1737
1737
ggml_type type_k = GGML_TYPE_F16;
1738
1738
ggml_type type_v = GGML_TYPE_F16;
1739
1739
1740
+ // if non-negative, compress data on next update
1741
+ llama_pos compress_delta = -1;
1742
+
1740
1743
std::vector<llama_kv_cell> cells;
1741
1744
1742
1745
std::vector<struct ggml_tensor *> k_l; // per layer
@@ -2272,6 +2275,10 @@ static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama
2272
2275
return result;
2273
2276
}
2274
2277
2278
+ static void llama_kv_cache_compress(struct llama_kv_cache & cache, llama_pos delta) {
2279
+ cache.compress_delta = delta;
2280
+ }
2281
+
2275
2282
static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
2276
2283
cache.do_defrag = true;
2277
2284
}
@@ -8091,6 +8098,240 @@ static int llama_decode_internal(
8091
8098
return 0;
8092
8099
}
8093
8100
8101
+ // summary:
8102
+ //
8103
+ // - determine which KV cell pairs (i0, i1) to merge:
8104
+ //
8105
+ // abs(cell[i0].pos - cell[i1].pos) <= compress_delta
8106
+ //
8107
+ // - move the KV cache to the host memory for easier manipulation
8108
+ // - processing is done layer-by-layer
8109
+ // - convert the KV data to F32
8110
+ // - merge the KV data (different ways to merge)
8111
+ // - convert the KV data back to the original type
8112
+ // - move the KV cache back to the device memory
8113
+ // - update the KV cache metadata
8114
+ //
8115
+ // as a side effect, the new KV cache is defragmented
8116
+ //
8117
+ static void llama_kv_cache_compress_internal(struct llama_context & lctx) {
8118
+ auto & kv_self = lctx.kv_self;
8119
+
8120
+ const auto & hparams = lctx.model.hparams;
8121
+
8122
+ const uint32_t n_layer = hparams.n_layer;
8123
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
8124
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
8125
+ const uint32_t n_embd_head_k = hparams.n_embd_head_k; GGML_UNUSED(n_embd_head_k);
8126
+ const uint32_t n_embd_head_v = hparams.n_embd_head_v; GGML_UNUSED(n_embd_head_v);
8127
+ const uint32_t n_head_kv = hparams.n_head_kv; GGML_UNUSED(n_head_kv);
8128
+ const uint32_t kv_size = kv_self.size;
8129
+
8130
+ const int64_t t_start = ggml_time_us();
8131
+
8132
+ std::vector<uint8_t> buf_q;
8133
+
8134
+ std::vector<float> buf_src_f32;
8135
+ std::vector<float> buf_dst_f32;
8136
+
8137
+ struct c_pair { uint32_t i0, i1; };
8138
+ struct c_info { bool merged; uint32_t id, cnt, r; };
8139
+
8140
+ std::vector<c_info> infos(kv_size, { false, 0, 0, 0 });
8141
+
8142
+ // the destination cell in the new KV cache
8143
+ uint32_t id = 0;
8144
+
8145
+ // number of pairs merged
8146
+ uint32_t n_merges = 0;
8147
+
8148
+ // determine which KV cells to merge
8149
+ for (uint32_t i0 = 0; i0 < kv_size; ++i0) {
8150
+ const auto & cell0 = kv_self.cells[i0];
8151
+
8152
+ if (!cell0.is_empty() && !infos[i0].merged) {
8153
+ infos[i0] = { true, id, 0, 0 };
8154
+ infos[id].cnt = 1;
8155
+
8156
+ const llama_pos p0 = cell0.pos;
8157
+
8158
+ for (uint32_t i1 = i0 + 1; i1 < kv_size; ++i1) {
8159
+ const auto & cell1 = kv_self.cells[i1];
8160
+
8161
+ if (i0 != i1 && cell0.is_same_seq(cell1)) {
8162
+ const llama_pos p1 = cell1.pos;
8163
+
8164
+ if (std::abs(p0 - p1) <= kv_self.compress_delta) {
8165
+ infos[i1] = { true, id, 0, 0 };
8166
+ infos[id].cnt++;
8167
+ n_merges++;
8168
+ }
8169
+ }
8170
+ }
8171
+
8172
+ if (i0 != id) {
8173
+ kv_self.cells[id] = cell0;
8174
+ }
8175
+
8176
+ id++;
8177
+ }
8178
+ }
8179
+
8180
+ kv_self.head = id;
8181
+ kv_self.used = id;
8182
+
8183
+ for (uint32_t i = id; i < kv_size; ++i) {
8184
+ kv_self.cells[i] = llama_kv_cell();
8185
+ }
8186
+
8187
+ LLAMA_LOG_INFO("(tmp log) KV compress pairs: %u\n", n_merges);
8188
+
8189
+ ggml_type_traits_t tt_k;
8190
+ ggml_type_traits_t tt_v;
8191
+
8192
+ tt_k = ggml_internal_get_type_traits(kv_self.type_k);
8193
+ tt_v = ggml_internal_get_type_traits(kv_self.type_v);
8194
+
8195
+ for (uint32_t il = 0; il < n_layer; ++il) {
8196
+ for (uint32_t i = 0; i < kv_size; ++i) {
8197
+ infos[i].r = 0;
8198
+ }
8199
+
8200
+ // update keys
8201
+ {
8202
+ const int64_t ne = n_embd_k_gqa*kv_size;
8203
+
8204
+ const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, ne);
8205
+
8206
+ buf_q.resize(k_size);
8207
+
8208
+ buf_src_f32.resize(ne);
8209
+ buf_dst_f32.resize(ne);
8210
+
8211
+ ggml_backend_tensor_get(kv_self.k_l[il], buf_q.data(), 0, buf_q.size());
8212
+
8213
+ tt_k.to_float(buf_q.data(), buf_src_f32.data(), ne);
8214
+
8215
+ std::fill(buf_dst_f32.begin(), buf_dst_f32.end(), 0);
8216
+
8217
+ for (uint32_t i = 0; i < kv_size; ++i) {
8218
+ if (!infos[i].merged) {
8219
+ continue;
8220
+ }
8221
+
8222
+ const uint32_t id = infos[i].id;
8223
+
8224
+ #if 1
8225
+ // merge using averaging
8226
+ {
8227
+ const float scale = 1.0f/float(infos[id].cnt);
8228
+
8229
+ const int64_t os = i*n_embd_k_gqa;
8230
+ const int64_t od = id*n_embd_k_gqa;
8231
+
8232
+ for (uint32_t j = 0; j < n_embd_k_gqa; ++j) {
8233
+ buf_dst_f32[od + j] += buf_src_f32[os + j]*scale;
8234
+ }
8235
+ }
8236
+ #else
8237
+ // merge separate heads
8238
+ {
8239
+ for (uint32_t h = 0; h < n_head_kv; ++h) {
8240
+ if ((h + il) % infos[id].cnt != infos[id].r) {
8241
+ continue;
8242
+ }
8243
+
8244
+ const int64_t os = i*n_embd_k_gqa + h*n_embd_head_k;
8245
+ const int64_t od = id*n_embd_k_gqa + h*n_embd_head_k;
8246
+
8247
+ for (uint32_t j = 0; j < n_embd_head_k; ++j) {
8248
+ buf_dst_f32[od + j] = buf_src_f32[os + j];
8249
+ }
8250
+ }
8251
+ }
8252
+
8253
+ infos[id].r++;
8254
+ #endif
8255
+ }
8256
+
8257
+ tt_k.from_float(buf_dst_f32.data(), buf_q.data(), ne);
8258
+
8259
+ ggml_backend_tensor_set(kv_self.k_l[il], buf_q.data(), 0, buf_q.size());
8260
+ }
8261
+
8262
+ for (uint32_t i = 0; i < kv_size; ++i) {
8263
+ infos[i].r = 0;
8264
+ }
8265
+
8266
+ // update values (note: they are transposed)
8267
+ {
8268
+ const int64_t ne = n_embd_v_gqa*kv_size;
8269
+
8270
+ const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, ne);
8271
+
8272
+ buf_q.resize(v_size);
8273
+
8274
+ buf_src_f32.resize(ne);
8275
+ buf_dst_f32.resize(ne);
8276
+
8277
+ ggml_backend_tensor_get(kv_self.v_l[il], buf_q.data(), 0, buf_q.size());
8278
+
8279
+ tt_v.to_float(buf_q.data(), buf_src_f32.data(), ne);
8280
+
8281
+ std::fill(buf_dst_f32.begin(), buf_dst_f32.end(), 0);
8282
+
8283
+ for (uint32_t i = 0; i < kv_size; ++i) {
8284
+ if (!infos[i].merged) {
8285
+ continue;
8286
+ }
8287
+
8288
+ const uint32_t id = infos[i].id;
8289
+
8290
+ #if 1
8291
+ // merge using averaging
8292
+ {
8293
+ const float scale = 1.0f/float(infos[id].cnt);
8294
+ //printf("i: %d -> id: %d, scale: %f\n", i, id, scale);
8295
+
8296
+ const int64_t os = i;
8297
+ const int64_t od = id;
8298
+
8299
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
8300
+ buf_dst_f32[od + j*kv_size] += buf_src_f32[os + j*kv_size]*scale;
8301
+ }
8302
+ }
8303
+ #else
8304
+ // merge separate heads
8305
+ {
8306
+ for (uint32_t h = 0; h < n_head_kv; ++h) {
8307
+ if ((h + il) % infos[id].cnt != infos[id].r) {
8308
+ continue;
8309
+ }
8310
+
8311
+ const int64_t os = i;
8312
+ const int64_t od = id;
8313
+
8314
+ for (uint32_t j = h*n_embd_head_v; j < (h + 1)*n_embd_head_v; ++j) {
8315
+ buf_dst_f32[od + j*kv_size] = buf_src_f32[os + j*kv_size];
8316
+ }
8317
+ }
8318
+ }
8319
+
8320
+ infos[id].r++;
8321
+ #endif
8322
+ }
8323
+
8324
+ tt_v.from_float(buf_dst_f32.data(), buf_q.data(), ne);
8325
+
8326
+ ggml_backend_tensor_set(kv_self.v_l[il], buf_q.data(), 0, buf_q.size());
8327
+ }
8328
+ }
8329
+
8330
+ const int64_t t_end = ggml_time_us();
8331
+
8332
+ LLAMA_LOG_INFO("(tmp log) KV compress time: %.3f ms\n", (t_end - t_start)/1000.0);
8333
+ }
8334
+
8094
8335
// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
8095
8336
static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
8096
8337
auto & kv_self = lctx.kv_self;
@@ -8298,6 +8539,14 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
8298
8539
}
8299
8540
}
8300
8541
8542
+ // compress the KV cache data if needed
8543
+ if (lctx.kv_self.compress_delta >= 0) {
8544
+ llama_kv_cache_compress_internal(lctx);
8545
+
8546
+ lctx.kv_self.compress_delta = -1;
8547
+ lctx.kv_self.do_defrag = false;
8548
+ }
8549
+
8301
8550
// defragment the KV cache if needed
8302
8551
if (lctx.kv_self.do_defrag) {
8303
8552
llama_kv_cache_defrag_internal(lctx);
@@ -12374,6 +12623,10 @@ llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id se
12374
12623
return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
12375
12624
}
12376
12625
12626
+ void llama_kv_cache_compress(struct llama_context * ctx, llama_pos delta) {
12627
+ llama_kv_cache_compress(ctx->kv_self, delta);
12628
+ }
12629
+
12377
12630
void llama_kv_cache_defrag(struct llama_context * ctx) {
12378
12631
llama_kv_cache_defrag(ctx->kv_self);
12379
12632
}
0 commit comments