Skip to content

Commit 861cd67

Browse files
committed
ggml : sync latest ggml_mul_mat_id
1 parent a3eefe9 commit 861cd67

File tree

4 files changed

+110
-71
lines changed

4 files changed

+110
-71
lines changed

ggml-cuda.cu

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
#include <algorithm>
2+
#include <assert.h>
3+
#include <atomic>
4+
#include <cinttypes>
25
#include <cstddef>
36
#include <cstdint>
4-
#include <cinttypes>
57
#include <float.h>
68
#include <limits>
79
#include <stdint.h>
810
#include <stdio.h>
9-
#include <atomic>
10-
#include <assert.h>
11+
#include <vector>
12+
1113

1214
#if defined(GGML_USE_HIPBLAS)
1315
#include <hip/hip_runtime.h>
@@ -8234,36 +8236,51 @@ static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
82348236
}
82358237
#endif
82368238

8237-
static void ggml_cuda_mul_mat_id(const ggml_tensor * _src0, const ggml_tensor * _src1, ggml_tensor * dst) {
8239+
static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
82388240
#if 0
8239-
//#ifdef CUDA_USE_TENSOR_CORES
8240-
// const bool use_tensor_cores = true;
8241-
//#else
8242-
// const bool use_tensor_cores = false;
8243-
//#endif
8244-
82458241
ggml_cuda_mul_mat_id_cublas(dst);
8246-
82478242
// TODO: mmq/mmv support
8248-
#else
8249-
const struct ggml_tensor * ids = dst->src[0];
8250-
const struct ggml_tensor * src1 = dst->src[1];
8251-
const int id = dst->op_params[0];
8243+
#endif
82528244

8253-
int32_t * ids_dev = (int32_t *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
8245+
const struct ggml_tensor * ids = src0;
8246+
const int32_t id = dst->op_params[0];
8247+
const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
82548248

8255-
int32_t a_id;
8256-
CUDA_CHECK(cudaMemcpyAsync(&a_id, ids_dev + id, sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
8249+
std::vector<char> ids_host(ggml_nbytes(ids));
8250+
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
82578251
CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
82588252

8259-
GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
8260-
const struct ggml_tensor * src0 = dst->src[a_id + 2];
8253+
const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra;
8254+
const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra;
82618255

8262-
ggml_cuda_mul_mat(src0, src1, dst);
8263-
#endif
8256+
ggml_tensor_extra_gpu src1_row_extra;
8257+
ggml_tensor_extra_gpu dst_row_extra;
8258+
8259+
ggml_tensor src1_row = *src1;
8260+
ggml_tensor dst_row = *dst;
8261+
8262+
src1_row.ne[1] = 1;
8263+
dst_row.ne[1] = 1;
8264+
8265+
src1_row.extra = &src1_row_extra;
8266+
dst_row.extra = &dst_row_extra;
82648267

8265-
(void) _src0;
8266-
(void) _src1;
8268+
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
8269+
//int32_t row_id;
8270+
//CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
8271+
//CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
8272+
8273+
const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
8274+
8275+
GGML_ASSERT(row_id >= 0 && row_id < ids->ne[0]);
8276+
8277+
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
8278+
8279+
src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1];
8280+
dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1];
8281+
8282+
ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
8283+
}
82678284
}
82688285

82698286
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {

ggml-metal.m

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
177177
ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
178178
} else {
179179
char* buffer2 = malloc(len+1);
180+
va_end(args);
181+
va_start(args, format);
180182
vsnprintf(buffer2, len+1, format, args);
181183
buffer2[len] = 0;
182184
ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
@@ -1193,7 +1195,9 @@ void ggml_metal_graph_compute(
11931195
const float scale = ((float *) dst->op_params)[0];
11941196

11951197
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1196-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1198+
if (id_src1) {
1199+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1200+
}
11971201
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
11981202
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
11991203
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
@@ -1511,9 +1515,7 @@ void ggml_metal_graph_compute(
15111515
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
15121516
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
15131517
}
1514-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1515-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1516-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1518+
const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
15171519
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
15181520
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
15191521
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
@@ -1523,7 +1525,7 @@ void ggml_metal_graph_compute(
15231525
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
15241526
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
15251527
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1526-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1528+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
15271529
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
15281530
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
15291531
[encoder setBytes:&idx length:sizeof(idx) atIndex:15];
@@ -1538,7 +1540,14 @@ void ggml_metal_graph_compute(
15381540
}
15391541

15401542
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
1541-
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1543+
1544+
for (int64_t i01 = 0; i01 < src0->ne[1]; i01++) {
1545+
[encoder setBuffer:id_src0 offset:offs_src0 + i01*nb01 atIndex:0];
1546+
[encoder setBuffer:id_src1 offset:offs_src1 + i01*nb11 atIndex:1];
1547+
[encoder setBuffer:id_dst offset:offs_dst + i01*nb1 atIndex:2];
1548+
1549+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1550+
}
15421551
}
15431552
} break;
15441553
case GGML_OP_GET_ROWS:

ggml.c

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4083,7 +4083,9 @@ struct ggml_tensor * ggml_mul_mat_id(
40834083
int64_t n_as = ids->ne[0];
40844084

40854085
GGML_ASSERT(ids->type == GGML_TYPE_I32);
4086-
GGML_ASSERT(ggml_is_vector(ids));
4086+
GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1);
4087+
GGML_ASSERT(ids->ne[1] == b->ne[1]);
4088+
GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
40874089
GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2);
40884090
GGML_ASSERT(id >= 0 && id < n_as);
40894091

@@ -9519,11 +9521,16 @@ static bool ggml_compute_forward_mul_mat_use_blas(
95199521
}
95209522
#endif
95219523

9524+
// off1 = offset in i11 and i1
9525+
// cne1 = ne11 and ne1
9526+
// in a normal matrix multiplication, off1 = 0 and cne1 = ne1
9527+
// during GGML_TASK_INIT, the full src1 is converted regardless of off1 and cne1
95229528
static void ggml_compute_forward_mul_mat(
95239529
const struct ggml_compute_params * params,
95249530
const struct ggml_tensor * src0,
95259531
const struct ggml_tensor * src1,
9526-
struct ggml_tensor * dst) {
9532+
struct ggml_tensor * dst,
9533+
int64_t off1, int64_t cne1) {
95279534
int64_t t0 = ggml_perf_time_us();
95289535
UNUSED(t0);
95299536

@@ -9591,10 +9598,9 @@ static void ggml_compute_forward_mul_mat(
95919598
const int64_t i03 = i13/r3;
95929599
const int64_t i02 = i12/r2;
95939600

9594-
const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
9595-
const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
9596-
9597-
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
9601+
const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
9602+
const float * y = (float *) ((char *) src1->data + off1*nb11 + i12*nb12 + i13*nb13);
9603+
float * d = (float *) ((char *) dst->data + off1*nb1 + i12*nb2 + i13*nb3);
95989604

95999605
if (type != GGML_TYPE_F32) {
96009606
float * const wdata = params->wdata;
@@ -9611,10 +9617,10 @@ static void ggml_compute_forward_mul_mat(
96119617
}
96129618

96139619
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
9614-
ne11, ne01, ne10,
9615-
1.0f, y, ne10,
9616-
x, ne00,
9617-
0.0f, d, ne01);
9620+
cne1, ne01, ne10,
9621+
1.0f, y, ne10,
9622+
x, ne00,
9623+
0.0f, d, ne01);
96189624
}
96199625
}
96209626

@@ -9630,6 +9636,7 @@ static void ggml_compute_forward_mul_mat(
96309636
const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
96319637

96329638
assert(params->wsize >= ne11*ne12*ne13*row_size);
9639+
assert(src1->type == GGML_TYPE_F32);
96339640

96349641
for (int64_t i13 = 0; i13 < ne13; ++i13) {
96359642
for (int64_t i12 = 0; i12 < ne12; ++i12) {
@@ -9652,7 +9659,7 @@ static void ggml_compute_forward_mul_mat(
96529659
const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
96539660

96549661
const int64_t nr0 = ne01; // src0 rows
9655-
const int64_t nr1 = ne11*ne12*ne13; // src1 rows
9662+
const int64_t nr1 = cne1*ne12*ne13; // src1 rows
96569663

96579664
//printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
96589665

@@ -9694,9 +9701,9 @@ static void ggml_compute_forward_mul_mat(
96949701
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
96959702
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
96969703
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
9697-
const int64_t i13 = (ir1/(ne12*ne11));
9698-
const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
9699-
const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);
9704+
const int64_t i13 = (ir1/(ne12*cne1));
9705+
const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
9706+
const int64_t i11 = (ir1 - i13*ne12*cne1 - i12*cne1) + off1;
97009707

97019708
// broadcast src0 into src1
97029709
const int64_t i03 = i13/r3;
@@ -9736,20 +9743,26 @@ static void ggml_compute_forward_mul_mat(
97369743

97379744
static void ggml_compute_forward_mul_mat_id(
97389745
const struct ggml_compute_params * params,
9746+
const struct ggml_tensor * src0,
9747+
const struct ggml_tensor * src1,
97399748
struct ggml_tensor * dst) {
97409749

9741-
const struct ggml_tensor * ids = dst->src[0];
9742-
const struct ggml_tensor * src1 = dst->src[1];
9750+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9751+
// during GGML_TASK_INIT the entire src1 is converted to vec_dot_type
9752+
ggml_compute_forward_mul_mat(params, dst->src[2], src1, dst, 0, dst->ne[1]);
9753+
return;
9754+
}
97439755

9756+
const struct ggml_tensor * ids = src0;
97449757
const int id = ggml_get_op_params_i32(dst, 0);
97459758

9746-
const int a_id = ((int32_t *)ids->data)[id];
9747-
9748-
GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
9759+
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
9760+
const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
9761+
GGML_ASSERT(row_id >= 0 && row_id < ids->ne[0]);
97499762

9750-
const struct ggml_tensor * src0 = dst->src[a_id + 2];
9751-
9752-
ggml_compute_forward_mul_mat(params, src0, src1, dst);
9763+
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
9764+
ggml_compute_forward_mul_mat(params, src0_row, src1, dst, i01, 1);
9765+
}
97539766
}
97549767

97559768
// ggml_compute_forward_out_prod
@@ -14037,11 +14050,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1403714050
} break;
1403814051
case GGML_OP_MUL_MAT:
1403914052
{
14040-
ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
14053+
ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor, 0, tensor->ne[1]);
1404114054
} break;
1404214055
case GGML_OP_MUL_MAT_ID:
1404314056
{
14044-
ggml_compute_forward_mul_mat_id(params, tensor);
14057+
ggml_compute_forward_mul_mat_id(params, tensor->src[0], tensor->src[1], tensor);
1404514058
} break;
1404614059
case GGML_OP_OUT_PROD:
1404714060
{

tests/test-backend-ops.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -770,19 +770,17 @@ struct test_mul_mat_id : public test_case {
770770
const int64_t m;
771771
const int64_t n;
772772
const int64_t k;
773-
const std::array<int64_t, 2> bs; // dims 3 and 4
774-
const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
775773

776774
std::string vars() override {
777-
return VARS_TO_STR9(type_a, type_b, n_mats, id, m, n, k, bs, nr);
775+
return VARS_TO_STR7(type_a, type_b, n_mats, id, m, n, k);
778776
}
779777

780778
double max_nmse_err() override {
781779
return 5e-4;
782780
}
783781

784782
size_t op_size(ggml_tensor * t) override {
785-
size_t a = ggml_nbytes(t->src[2]) * n * nr[0] * nr[1];
783+
size_t a = ggml_nbytes(t->src[2]) * n;
786784
size_t b = ggml_nbytes(t->src[1]) * m;
787785
size_t c = ggml_nbytes(t);
788786
return a + b + c;
@@ -792,35 +790,37 @@ struct test_mul_mat_id : public test_case {
792790

793791
test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
794792
int n_mats = 2, int id = 0,
795-
int64_t m = 32, int64_t n = 32, int64_t k = 32,
796-
std::array<int64_t, 2> bs = {10, 10},
797-
std::array<int64_t, 2> nr = {2, 2})
793+
int64_t m = 32, int64_t n = 32, int64_t k = 32)
798794
: type_a(type_a), type_b(type_b), n_mats(n_mats), id(id),
799-
m(m), n(n), k(k), bs(bs), nr(nr) {}
795+
m(m), n(n), k(k) {}
800796

801797
ggml_tensor * build_graph(ggml_context * ctx) override {
802798
// C^T = A * B^T: (k, m) * (k, n) => (m, n)
803799
std::vector<ggml_tensor *> mats;
804800
for (int i = 0; i < n_mats; i++) {
805-
ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
801+
ggml_tensor * a = ggml_new_tensor_2d(ctx, type_a, k, m);
806802
mats.push_back(a);
807803
}
808-
ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_mats);
809-
ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
804+
ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);
805+
ggml_tensor * b = ggml_new_tensor_2d(ctx, type_b, k, n);
810806
ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), ids, id, b);
811807
return out;
812808
}
813809

814810
void initialize_tensors(ggml_context * ctx) override {
811+
std::random_device rd;
812+
std::default_random_engine rng(rd());
815813
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
816814
if (t->type == GGML_TYPE_I32) {
817815
// ids
818-
std::vector<int> data(n_mats);
819-
for (int i = 0; i < n_mats; i++) {
820-
data[i] = i;
816+
for (int64_t r = 0; r < ggml_nrows(t); r++) {
817+
std::vector<int32_t> data(t->ne[0]);
818+
for (int i = 0; i < t->ne[0]; i++) {
819+
data[i] = i;
820+
}
821+
std::shuffle(data.begin(), data.end(), rng);
822+
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
821823
}
822-
std::shuffle(data.begin(), data.end(), std::default_random_engine(std::random_device()()));
823-
ggml_backend_tensor_set(t, data.data(), 0, n_mats * sizeof(int));
824824
} else {
825825
init_tensor_uniform(t);
826826
}
@@ -1215,7 +1215,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
12151215
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
12161216
for (int n_mats : {1, 2, 4}) {
12171217
for (int id = 0; id < n_mats; id++) {
1218-
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, {1, 1}, {1, 1}));
1218+
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256));
12191219
}
12201220
}
12211221
}

0 commit comments

Comments
 (0)