Skip to content

Commit 1364bcd

Browse files
committed
mpt : removed ne01 + n_past == ne00 assertion from alibi (cuda/f32) and rope_shift from build_mpt
1 parent 4708012 commit 1364bcd

File tree

3 files changed

+5
-35
lines changed

3 files changed

+5
-35
lines changed

ggml-cuda.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -6295,12 +6295,12 @@ inline void ggml_cuda_op_alibi(
62956295
const int64_t ne02 = src0->ne[2];
62966296
const int64_t nrows = ggml_nrows(src0);
62976297

6298-
const int n_past = ((int32_t *) dst->op_params)[0];
6298+
//const int n_past = ((int32_t *) dst->op_params)[0];
62996299
const int n_head = ((int32_t *) dst->op_params)[1];
63006300
float max_bias;
63016301
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
63026302

6303-
GGML_ASSERT(ne01 + n_past == ne00);
6303+
//GGML_ASSERT(ne01 + n_past == ne00);
63046304
GGML_ASSERT(n_head == ne02);
63056305

63066306
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));

ggml.c

+2-2
Original file line numberDiff line numberDiff line change
@@ -12889,7 +12889,7 @@ static void ggml_compute_forward_alibi_f32(
1288912889
return;
1289012890
}
1289112891

12892-
const int n_past = ((int32_t *) dst->op_params)[0];
12892+
//const int n_past = ((int32_t *) dst->op_params)[0];
1289312893
const int n_head = ((int32_t *) dst->op_params)[1];
1289412894
float max_bias;
1289512895
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
@@ -12910,7 +12910,7 @@ static void ggml_compute_forward_alibi_f32(
1291012910
//const int nb3 = src0->nb[3];
1291112911

1291212912
GGML_ASSERT(nb0 == sizeof(float));
12913-
GGML_ASSERT(ne1 + n_past == ne0);
12913+
//GGML_ASSERT(ne1 + n_past == ne0);
1291412914
GGML_ASSERT(n_head == ne2);
1291512915

1291612916
// add alibi to src0 (KQ_scaled)

llama.cpp

+1-31
Original file line numberDiff line numberDiff line change
@@ -4076,8 +4076,6 @@ static struct ggml_cgraph * llm_build_mpt(
40764076
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n;
40774077
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
40784078

4079-
const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift;
4080-
40814079
//printf("kv_head = %d, n_kv = %d, n_tokens = %d, n_ctx = %d, is_measure = %d, has_shift = %d\n",
40824080
// kv_head, n_kv, n_tokens, n_ctx, ggml_allocr_is_measure(lctx.alloc), kv_self.has_shift);
40834081

@@ -4176,34 +4174,6 @@ static struct ggml_cgraph * llm_build_mpt(
41764174
}
41774175
}
41784176

4179-
// shift the entire K-cache if needed
4180-
// TODO: Do we need to handle it? (MPT uses alibi instead of rope)
4181-
/* if (do_rope_shift) {
4182-
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
4183-
offload_func_kq(K_shift);
4184-
ggml_set_name(K_shift, "K_shift");
4185-
ggml_allocr_alloc(lctx.alloc, K_shift);
4186-
if (!ggml_allocr_is_measure(lctx.alloc)) {
4187-
int * data = (int *) K_shift->data;
4188-
for (int i = 0; i < n_ctx; ++i) {
4189-
data[i] = kv_self.cells[i].delta;
4190-
}
4191-
}
4192-
4193-
for (int il = 0; il < n_layer; ++il) {
4194-
struct ggml_tensor * tmp =
4195-
ggml_rope_custom_inplace(ctx0,
4196-
ggml_view_3d(ctx0, kv_self.k,
4197-
n_embd_head, n_head_kv, n_ctx,
4198-
ggml_element_size(kv_self.k)*n_embd_head,
4199-
ggml_element_size(kv_self.k)*n_embd_gqa,
4200-
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
4201-
K_shift, n_embd_head, 2, 0, freq_base, freq_scale);
4202-
offload_func_kq(tmp);
4203-
ggml_build_forward_expand(gf, tmp);
4204-
}
4205-
}*/
4206-
42074177
for (int il = 0; il < n_layer; ++il) {
42084178
struct ggml_tensor * attn_norm;
42094179

@@ -4306,7 +4276,7 @@ static struct ggml_cgraph * llm_build_mpt(
43064276

43074277
// TODO: replace with ggml_add()
43084278
struct ggml_tensor * KQ_scaled_alibi =
4309-
ggml_alibi(ctx0, KQ_scaled, std::max(kv_head, n_kv - n_tokens), n_head, max_alibi_bias);
4279+
ggml_alibi(ctx0, KQ_scaled, 0, n_head, max_alibi_bias);
43104280
offload_func_kq(KQ_scaled_alibi);
43114281
ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi");
43124282

0 commit comments

Comments
 (0)