Skip to content

Commit 7372b62

Browse files
committed
ggml : ggml_get_rows support 2D indexing [n_tokens, n_experts] (cpu only)
1 parent 8b185b7 commit 7372b62

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

ggml.c

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4735,7 +4735,8 @@ struct ggml_tensor * ggml_get_rows(
47354735
struct ggml_context * ctx,
47364736
struct ggml_tensor * a,
47374737
struct ggml_tensor * b) {
4738-
GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32);
4738+
GGML_ASSERT(a->ne[2] == b->ne[1]);
4739+
GGML_ASSERT(ggml_is_matrix(b) && b->type == GGML_TYPE_I32);
47394740

47404741
bool is_node = false;
47414742

@@ -4745,7 +4746,7 @@ struct ggml_tensor * ggml_get_rows(
47454746

47464747
// TODO: implement non F32 return
47474748
//struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
4748-
struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0]);
4749+
struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1]);
47494750

47504751
result->op = GGML_OP_GET_ROWS;
47514752
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -10348,8 +10349,8 @@ static void ggml_compute_forward_get_rows_q(
1034810349
const enum ggml_type type = src0->type;
1034910350
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
1035010351

10351-
assert( dst->ne[0] == nc);
10352-
assert( dst->ne[1] == nr);
10352+
assert( dst->ne[0] == nc);
10353+
assert(ggml_nrows(dst) == nr);
1035310354
assert(src0->nb[0] == ggml_type_size(type));
1035410355

1035510356
for (int i = 0; i < nr; ++i) {
@@ -10375,8 +10376,8 @@ static void ggml_compute_forward_get_rows_f16(
1037510376
const int nc = src0->ne[0];
1037610377
const int nr = ggml_nelements(src1);
1037710378

10378-
assert( dst->ne[0] == nc);
10379-
assert( dst->ne[1] == nr);
10379+
assert( dst->ne[0] == nc);
10380+
assert(ggml_nrows(dst) == nr);
1038010381
assert(src0->nb[0] == sizeof(ggml_fp16_t));
1038110382

1038210383
for (int i = 0; i < nr; ++i) {
@@ -10403,8 +10404,8 @@ static void ggml_compute_forward_get_rows_f32(
1040310404
const int nc = src0->ne[0];
1040410405
const int nr = ggml_nelements(src1);
1040510406

10406-
assert( dst->ne[0] == nc);
10407-
assert( dst->ne[1] == nr);
10407+
assert( dst->ne[0] == nc);
10408+
assert(ggml_nrows(dst) == nr);
1040810409
assert(src0->nb[0] == sizeof(float));
1040910410

1041010411
for (int i = 0; i < nr; ++i) {

ggml.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,6 +1263,7 @@ extern "C" {
12631263
struct ggml_context * ctx,
12641264
struct ggml_tensor * a);
12651265

1266+
// supports 3D: a->ne[2] == b->ne[1]
12661267
GGML_API struct ggml_tensor * ggml_get_rows(
12671268
struct ggml_context * ctx,
12681269
struct ggml_tensor * a,

0 commit comments

Comments
 (0)