Skip to content

Commit f4d7e54

Browse files
ikawrakowKawrakow
andauthored
SOTA 3-bit quants (#5196)
* iq3_xxs: quantize/dequantize RMSE seems a bit high-ish at about half-way between q2_K and q3_K, so need to check more. * iq3_xxs: CUDA dequantize works * iq2_xxs: tuning quantization * iq3_xxs: starting to look better PPL on wiki.test.raw LLaMA-v1-7B: 6.4218 LLaMA-v2-7B: 6.3560 Mistral-7B : 6.0717 This is better than Q3_K_XS, with a 5% reduction in quantized model size. * iq3_xxs: CUDA dot product We have PP-512: 5891 t/s TG-128: 143.9 t/s * iq3_xxs: scalar and AVX2 dot products * iq3_xxs: ARM_NEON and Metal Metal performance is decent, ARM_NEON is pathetic * iq3_xxs: slightly better grid points * Faster iq3_xxs and iq2_xs dot products on CUDA * iq3_xxs: add some quant mix * iq3_xxs: fix failing quantization test Dot product still fails. Is this real? * iq3_xxs: hopefully fix ROCm * iq3_xxs: failing tests This time the dot product accuracy did find an actual bug in the AVX2 implementation. * Add IQ3_XXS to test-backend-ops --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 2256f36 commit f4d7e54

14 files changed

+1215
-18
lines changed

examples/quantize-stats/quantize-stats.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,8 @@ int main(int argc, char ** argv) {
378378
printf("testing %s ...\n", ggml_type_name(type));
379379
}
380380

381+
ggml_quantize_init(type);
382+
381383
error_stats global_stats {};
382384

383385
for (const auto& kv_tensor : tensors) {

examples/quantize/quantize.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
2525
{ "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS, " 2.31 bpw quantization", },
2626
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
2727
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
28+
{ "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", },
2829
{ "Q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" },
2930
{ "Q3_K_XS",LLAMA_FTYPE_MOSTLY_Q3_K_XS,"3-bit extra small quantization" , },
3031
{ "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S, " 2.75G, +0.5551 ppl @ LLaMA-v1-7B", },

ggml-cuda.cu

Lines changed: 189 additions & 11 deletions
Large diffs are not rendered by default.

ggml-metal.m

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
6161
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
6262
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
63+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
6364
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
6465
GGML_METAL_KERNEL_TYPE_RMS_NORM,
6566
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -81,6 +82,7 @@
8182
GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
8283
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
8384
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
85+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
8486
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
8587
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
8688
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
@@ -98,6 +100,7 @@
98100
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
99101
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
100102
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
103+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
101104
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
102105
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
103106
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
@@ -112,6 +115,7 @@
112115
GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
113116
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
114117
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
118+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
115119
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
116120
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
117121
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
@@ -126,6 +130,7 @@
126130
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
127131
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
128132
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
133+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
129134
GGML_METAL_KERNEL_TYPE_ROPE_F32,
130135
GGML_METAL_KERNEL_TYPE_ROPE_F16,
131136
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
@@ -422,6 +427,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
422427
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
423428
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
424429
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
430+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
425431
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
426432
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
427433
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
@@ -443,6 +449,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
443449
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
444450
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
445451
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
452+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
446453
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
447454
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
448455
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
@@ -460,6 +467,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
460467
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
461468
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
462469
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
470+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
463471
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
464472
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
465473
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
@@ -474,6 +482,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
474482
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
475483
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
476484
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
485+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
477486
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
478487
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
479488
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
@@ -488,6 +497,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
488497
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
489498
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
490499
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
500+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
491501
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
492502
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
493503
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
@@ -1260,6 +1270,7 @@ static bool ggml_metal_graph_compute(
12601270
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
12611271
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
12621272
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1273+
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
12631274
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
12641275
}
12651276

@@ -1388,6 +1399,12 @@ static bool ggml_metal_graph_compute(
13881399
nth1 = 16;
13891400
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
13901401
} break;
1402+
case GGML_TYPE_IQ3_XXS:
1403+
{
1404+
nth0 = 4;
1405+
nth1 = 16;
1406+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
1407+
} break;
13911408
default:
13921409
{
13931410
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1430,6 +1447,11 @@ static bool ggml_metal_graph_compute(
14301447
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
14311448
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
14321449
}
1450+
else if (src0t == GGML_TYPE_IQ3_XXS) {
1451+
const int mem_size = 256*4+128;
1452+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1453+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1454+
}
14331455
else if (src0t == GGML_TYPE_Q4_K) {
14341456
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
14351457
}
@@ -1524,6 +1546,7 @@ static bool ggml_metal_graph_compute(
15241546
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
15251547
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
15261548
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
1549+
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
15271550
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
15281551
}
15291552

@@ -1655,6 +1678,12 @@ static bool ggml_metal_graph_compute(
16551678
nth1 = 16;
16561679
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
16571680
} break;
1681+
case GGML_TYPE_IQ3_XXS:
1682+
{
1683+
nth0 = 4;
1684+
nth1 = 16;
1685+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
1686+
} break;
16581687
default:
16591688
{
16601689
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1713,6 +1742,11 @@ static bool ggml_metal_graph_compute(
17131742
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
17141743
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
17151744
}
1745+
else if (src2t == GGML_TYPE_IQ3_XXS) {
1746+
const int mem_size = 256*4+128;
1747+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1748+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1749+
}
17161750
else if (src2t == GGML_TYPE_Q4_K) {
17171751
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
17181752
}
@@ -1753,6 +1787,7 @@ static bool ggml_metal_graph_compute(
17531787
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
17541788
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
17551789
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
1790+
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
17561791
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
17571792
default: GGML_ASSERT(false && "not implemented");
17581793
}

0 commit comments

Comments
 (0)