60
60
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
61
61
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
62
62
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
63
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
63
64
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
64
65
GGML_METAL_KERNEL_TYPE_RMS_NORM,
65
66
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
81
82
GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
82
83
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
83
84
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
85
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
84
86
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
85
87
// GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
86
88
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
98
100
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
99
101
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
100
102
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
103
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
101
104
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
102
105
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
103
106
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
112
115
GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
113
116
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
114
117
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
118
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
115
119
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
116
120
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
117
121
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
126
130
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
127
131
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
128
132
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
133
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
129
134
GGML_METAL_KERNEL_TYPE_ROPE_F32,
130
135
GGML_METAL_KERNEL_TYPE_ROPE_F16,
131
136
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
@@ -422,6 +427,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
422
427
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true );
423
428
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true );
424
429
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true );
430
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true );
425
431
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true );
426
432
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction );
427
433
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction );
@@ -443,6 +449,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
443
449
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction );
444
450
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction );
445
451
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction );
452
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction );
446
453
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction );
447
454
// GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
448
455
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction );
@@ -460,6 +467,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
460
467
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction );
461
468
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction );
462
469
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction );
470
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction );
463
471
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm );
464
472
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm );
465
473
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm );
@@ -474,6 +482,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
474
482
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm );
475
483
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm );
476
484
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm );
485
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm );
477
486
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm );
478
487
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm );
479
488
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm );
@@ -488,6 +497,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
488
497
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm );
489
498
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm );
490
499
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm );
500
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm );
491
501
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true );
492
502
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true );
493
503
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true );
@@ -1260,6 +1270,7 @@ static bool ggml_metal_graph_compute(
1260
1270
case GGML_TYPE_Q6_K: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline ; break ;
1261
1271
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline ; break ;
1262
1272
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline ; break ;
1273
+ case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline ; break ;
1263
1274
default : GGML_ASSERT (false && " MUL MAT-MAT not implemented" );
1264
1275
}
1265
1276
@@ -1388,6 +1399,12 @@ static bool ggml_metal_graph_compute(
1388
1399
nth1 = 16 ;
1389
1400
pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline ;
1390
1401
} break ;
1402
+ case GGML_TYPE_IQ3_XXS:
1403
+ {
1404
+ nth0 = 4 ;
1405
+ nth1 = 16 ;
1406
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline ;
1407
+ } break ;
1391
1408
default :
1392
1409
{
1393
1410
GGML_METAL_LOG_ERROR (" Asserting on type %d \n " , (int )src0t);
@@ -1430,6 +1447,11 @@ static bool ggml_metal_graph_compute(
1430
1447
[encoder setThreadgroupMemoryLength: mem_size atIndex: 0 ];
1431
1448
[encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 7 )/8 , ne11, ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
1432
1449
}
1450
+ else if (src0t == GGML_TYPE_IQ3_XXS) {
1451
+ const int mem_size = 256 *4 +128 ;
1452
+ [encoder setThreadgroupMemoryLength: mem_size atIndex: 0 ];
1453
+ [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 7 )/8 , ne11, ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
1454
+ }
1433
1455
else if (src0t == GGML_TYPE_Q4_K) {
1434
1456
[encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 3 )/4 , ne11, ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
1435
1457
}
@@ -1524,6 +1546,7 @@ static bool ggml_metal_graph_compute(
1524
1546
case GGML_TYPE_Q6_K: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline ; break ;
1525
1547
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline ; break ;
1526
1548
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline ; break ;
1549
+ case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline ; break ;
1527
1550
default : GGML_ASSERT (false && " MUL_MAT_ID not implemented" );
1528
1551
}
1529
1552
@@ -1655,6 +1678,12 @@ static bool ggml_metal_graph_compute(
1655
1678
nth1 = 16 ;
1656
1679
pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline ;
1657
1680
} break ;
1681
+ case GGML_TYPE_IQ3_XXS:
1682
+ {
1683
+ nth0 = 4 ;
1684
+ nth1 = 16 ;
1685
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline ;
1686
+ } break ;
1658
1687
default :
1659
1688
{
1660
1689
GGML_METAL_LOG_ERROR (" Asserting on type %d \n " , (int )src2t);
@@ -1713,6 +1742,11 @@ static bool ggml_metal_graph_compute(
1713
1742
[encoder setThreadgroupMemoryLength: mem_size atIndex: 0 ];
1714
1743
[encoder dispatchThreadgroups: MTLSizeMake ((ne21 + 7 )/8 , _ne1, ne01*ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
1715
1744
}
1745
+ else if (src2t == GGML_TYPE_IQ3_XXS) {
1746
+ const int mem_size = 256 *4 +128 ;
1747
+ [encoder setThreadgroupMemoryLength: mem_size atIndex: 0 ];
1748
+ [encoder dispatchThreadgroups: MTLSizeMake ((ne21 + 7 )/8 , _ne1, ne01*ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
1749
+ }
1716
1750
else if (src2t == GGML_TYPE_Q4_K) {
1717
1751
[encoder dispatchThreadgroups: MTLSizeMake ((ne21 + 3 )/4 , _ne1, ne01*ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
1718
1752
}
@@ -1753,6 +1787,7 @@ static bool ggml_metal_graph_compute(
1753
1787
case GGML_TYPE_Q6_K: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline ; break ;
1754
1788
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline ; break ;
1755
1789
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline ; break ;
1790
+ case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline ; break ;
1756
1791
case GGML_TYPE_I32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline ; break ;
1757
1792
default : GGML_ASSERT (false && " not implemented" );
1758
1793
}
0 commit comments