Skip to content

Commit 31cecc8

Browse files
committed
iq3_s_mult_shuffle: use lookup table on Metal
~4% faster TG and ~2% faster PP that way.
1 parent 93034df commit 31cecc8

File tree

1 file changed

+108
-24
lines changed

1 file changed

+108
-24
lines changed

ggml-metal.metal

Lines changed: 108 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2546,8 +2546,10 @@ typedef struct {
25462546
uint8_t signs[QK_K/8];
25472547
uint8_t scales[IQ3S_N_SCALE];
25482548
} block_iq3_s;
2549-
#define IQ3S_MULTIPLIER 518559
2550-
constexpr constant static uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15};
2549+
2550+
// When a shuffle is involved in the codebook, on Metal it is faster to use a lookup table
2551+
//#define IQ3S_MULTIPLIER 518559
2552+
//constexpr constant static uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15};
25512553

25522554
typedef struct {
25532555
half d;
@@ -4085,6 +4087,73 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
40854087
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
40864088
};
40874089

4090+
constexpr constant static uint32_t iq3s_grid[512] = {
4091+
0x01010101, 0x0105070f, 0x010f030d, 0x0105090b, 0x010f0509, 0x01050109, 0x010f0707, 0x01050307,
4092+
0x010f0905, 0x01050505, 0x010f0105, 0x01050703, 0x010d0303, 0x01050b03, 0x010d0501, 0x01050101,
4093+
0x010d0701, 0x0105030f, 0x010d0b0d, 0x0105050b, 0x010d0109, 0x01050709, 0x010d0307, 0x01030b07,
4094+
0x010b0505, 0x01030105, 0x010b0705, 0x01030303, 0x010b0b03, 0x01030503, 0x010b0101, 0x01030701,
4095+
0x010b0301, 0x01030b0f, 0x010b050d, 0x0103010b, 0x01090709, 0x01030309, 0x01090b07, 0x01030507,
4096+
0x01090105, 0x01030705, 0x01090305, 0x01030b03, 0x01090503, 0x01030103, 0x01090701, 0x01030301,
4097+
0x01090b01, 0x0103050f, 0x0109010d, 0x0103070b, 0x01090309, 0x01030b09, 0x01090507, 0x01030107,
4098+
0x01090705, 0x01030305, 0x01070d05, 0x01010503, 0x01070103, 0x01010703, 0x01070301, 0x01010d01,
4099+
0x01070501, 0x0101010f, 0x0107070d, 0x0101030b, 0x01070d09, 0x01010509, 0x01070107, 0x01010907,
4100+
0x01070305, 0x01010d05, 0x01070505, 0x01010103, 0x01070903, 0x01010303, 0x01070d01, 0x01010501,
4101+
0x01070101, 0x0101090f, 0x0105030d, 0x01010d0b, 0x01050509, 0x01010109, 0x01050907, 0x01010307,
4102+
0x01050d05, 0x01010505, 0x01050105, 0x01010903, 0x01050303, 0x010f0d03, 0x01050501, 0x010f0101,
4103+
0x01050901, 0x010f030f, 0x03050d0d, 0x030f050b, 0x03050109, 0x030f0909, 0x03050307, 0x030d0d07,
4104+
0x03050505, 0x030d0105, 0x03050905, 0x030d0303, 0x03050f03, 0x030d0503, 0x03050101, 0x030d0901,
4105+
0x03050301, 0x030d0f0f, 0x0305050d, 0x030b010b, 0x03030909, 0x030b0309, 0x03030f07, 0x030b0507,
4106+
0x03030105, 0x030b0905, 0x03030305, 0x030b0f03, 0x03030703, 0x030b0103, 0x03030901, 0x03090301,
4107+
0x03030f01, 0x0309070f, 0x0303010d, 0x0309090b, 0x03030309, 0x03090f09, 0x03030707, 0x03090107,
4108+
0x03030905, 0x03090505, 0x03030f05, 0x03090703, 0x03030103, 0x03090903, 0x03030501, 0x03090f01,
4109+
0x03030701, 0x0309030f, 0x0303090d, 0x0309050b, 0x03030f09, 0x03070709, 0x03010307, 0x03070907,
4110+
0x03010505, 0x03070105, 0x03010705, 0x03070303, 0x03010903, 0x03070503, 0x03010101, 0x03070701,
4111+
0x03010301, 0x0307090f, 0x0301050d, 0x0307010b, 0x03010709, 0x03070309, 0x03010b07, 0x03070507,
4112+
0x03010105, 0x03070705, 0x03010305, 0x03070b03, 0x03010503, 0x03050103, 0x03010701, 0x03050301,
4113+
0x03010b01, 0x0305050f, 0x0301010d, 0x0305070b, 0x03010309, 0x03050b09, 0x03010507, 0x03050107,
4114+
0x030f0705, 0x03050305, 0x030f0b05, 0x03050503, 0x030f0103, 0x03050703, 0x030f0301, 0x03050b01,
4115+
0x030f0501, 0x0305010f, 0x030f070d, 0x0505030b, 0x050d0b09, 0x05050509, 0x050d0107, 0x05050707,
4116+
0x050d0305, 0x05050b05, 0x050d0505, 0x05050103, 0x050d0703, 0x05050303, 0x050b0b01, 0x05030501,
4117+
0x050b0101, 0x0503070f, 0x050b030d, 0x05030d0b, 0x050b0509, 0x05030109, 0x050b0707, 0x05030307,
4118+
0x050b0d05, 0x05030505, 0x05090105, 0x05030903, 0x05090303, 0x05030d03, 0x05090501, 0x05030101,
4119+
0x05090901, 0x0503030f, 0x05090d0d, 0x0503050b, 0x05090109, 0x05030909, 0x05090307, 0x05030d07,
4120+
0x05090505, 0x05030105, 0x05090905, 0x05030303, 0x05090d03, 0x05030503, 0x05090101, 0x05030901,
4121+
0x05090301, 0x05010d0f, 0x0507050d, 0x0501010b, 0x05070909, 0x05010309, 0x05070d07, 0x05010507,
4122+
0x05070105, 0x05010905, 0x05070305, 0x05010d03, 0x05070503, 0x05010103, 0x05070901, 0x05010301,
4123+
0x05070f01, 0x0501050f, 0x0507010d, 0x0501090b, 0x05070309, 0x05010f09, 0x05070507, 0x05010107,
4124+
0x05050905, 0x05010305, 0x05050f05, 0x05010503, 0x05050103, 0x05010903, 0x05050301, 0x05010f01,
4125+
0x05050501, 0x0501010f, 0x0505090d, 0x050f030b, 0x05050f09, 0x050f0709, 0x05050107, 0x050f0907,
4126+
0x05050305, 0x050f0f05, 0x05050705, 0x050f0103, 0x05050903, 0x050f0503, 0x05050f01, 0x050d0701,
4127+
0x05050101, 0x050d090f, 0x0505050d, 0x050d0f0b, 0x07050709, 0x070d0109, 0x07050907, 0x070d0507,
4128+
0x07050f05, 0x070d0705, 0x07030305, 0x070b0903, 0x07030503, 0x070b0f03, 0x07030701, 0x070b0301,
4129+
0x07030901, 0x070b050f, 0x0703010d, 0x070b070b, 0x07030309, 0x07090909, 0x07030507, 0x07090107,
4130+
0x07030705, 0x07090305, 0x07030b05, 0x07090503, 0x07030103, 0x07090703, 0x07030301, 0x07090b01,
4131+
0x07030501, 0x0709010f, 0x0703070d, 0x0709030b, 0x07030b09, 0x07090509, 0x07030107, 0x07090707,
4132+
0x07030305, 0x07090b05, 0x07030505, 0x07090103, 0x07010703, 0x07070303, 0x07010b01, 0x07070501,
4133+
0x07010101, 0x0707070f, 0x0701030d, 0x07070b0b, 0x07010509, 0x07070109, 0x07010707, 0x07070307,
4134+
0x07010b05, 0x07070505, 0x07010105, 0x07070703, 0x07010303, 0x07070b03, 0x07010501, 0x07070101,
4135+
0x07010701, 0x0707030f, 0x07010b0d, 0x0705050b, 0x09010109, 0x09050709, 0x09010307, 0x09050b07,
4136+
0x09010505, 0x09050105, 0x09010705, 0x09050303, 0x09010d03, 0x09050503, 0x09010101, 0x09050701,
4137+
0x090f0301, 0x09050d0f, 0x090f050d, 0x0905010b, 0x090f0909, 0x09050309, 0x090f0d07, 0x09050507,
4138+
0x090f0105, 0x09050905, 0x090d0305, 0x09050d03, 0x090d0503, 0x09050103, 0x090d0901, 0x09050301,
4139+
0x090d0d01, 0x0905050f, 0x090d010d, 0x0905090b, 0x090d0309, 0x09030d09, 0x090b0507, 0x09030107,
4140+
0x090b0905, 0x09030305, 0x090b0d05, 0x09030503, 0x090b0103, 0x09030903, 0x090b0301, 0x09030d01,
4141+
0x090b0501, 0x0903010f, 0x0909090d, 0x0903030b, 0x09090d09, 0x09030509, 0x09090107, 0x09030907,
4142+
0x09090305, 0x09030f05, 0x09090505, 0x09030103, 0x09090903, 0x09030303, 0x09090f01, 0x09030501,
4143+
0x09090101, 0x0903090f, 0x0909030d, 0x09030f0b, 0x09090509, 0x0b030109, 0x0b090907, 0x0b030307,
4144+
0x0b070f05, 0x0b010505, 0x0b070105, 0x0b010903, 0x0b070303, 0x0b010f03, 0x0b070701, 0x0b010101,
4145+
0x0b070901, 0x0b01030f, 0x0b070f0d, 0x0b01070b, 0x0b070109, 0x0b010909, 0x0b070507, 0x0b010f07,
4146+
0x0b070705, 0x0b010105, 0x0b070905, 0x0b010503, 0x0b070f03, 0x0b010703, 0x0b070301, 0x0b010901,
4147+
0x0b050501, 0x0b010f0f, 0x0b05070d, 0x0b01030b, 0x0b050909, 0x0d010509, 0x0d050f07, 0x0d010707,
4148+
0x0d050305, 0x0d010905, 0x0d050505, 0x0d0f0103, 0x0d050703, 0x0d0f0303, 0x0d050901, 0x0d0f0501,
4149+
0x0d050101, 0x0d0f070f, 0x0d05030d, 0x0d0f0b0b, 0x0d050509, 0x0d0f0109, 0x0d050707, 0x0d0d0307,
4150+
0x0d050b05, 0x0d0d0505, 0x0d050105, 0x0d0d0703, 0x0d050303, 0x0d0d0b03, 0x0d050501, 0x0d0d0101,
4151+
0x0d050701, 0x0d0b030f, 0x0d030b0d, 0x0d0b050b, 0x0d030109, 0x0d0b0709, 0x0f030307, 0x0f0b0b07,
4152+
0x0f030505, 0x0f0b0105, 0x0f030705, 0x0f0b0303, 0x0f030b03, 0x0f090503, 0x0f030101, 0x0f090701,
4153+
0x0f030301, 0x0f090b0f, 0x0f03050d, 0x0f09010b, 0x0f030709, 0x0f090309, 0x0f030b07, 0x0f090507,
4154+
0x0f030105, 0x0f090705, 0x0f030305, 0x0f090b03, 0x0f030503, 0x0f090103, 0x0f030701, 0x0f090301,
4155+
};
4156+
40884157
#define NGRID_IQ1S 512
40894158
constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
40904159
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
@@ -4694,20 +4763,23 @@ void kernel_mul_mv_iq3_s_f32_impl(
46944763
{
46954764
int nval = 8;
46964765
int pos = (32*sgitg + tiisg)*nval;
4697-
uint32_t aux32;
4698-
thread int8_t * q = (thread int8_t *)&aux32;
46994766
for (int i = 0; i < nval; ++i) {
4700-
aux32 = (IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f;
4701-
for (int k = 0; k < 4; ++k) q[k] = iq3s_values[q[k]];
4702-
values[pos + i] = aux32;
4767+
values[pos + i] = iq3s_grid[pos + i];
47034768
}
4769+
//uint32_t aux32;
4770+
//thread int8_t * q = (thread int8_t *)&aux32;
4771+
//for (int i = 0; i < nval; ++i) {
4772+
// aux32 = (IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f;
4773+
// for (int k = 0; k < 4; ++k) q[k] = iq3s_values[q[k]];
4774+
// values[pos + i] = aux32;
4775+
//}
47044776
threadgroup_barrier(mem_flags::mem_threadgroup);
47054777
}
47064778

47074779
const int ix = tiisg;
47084780

4709-
uint32_t aux32[2];
4710-
thread const int8_t * grid = (thread const int8_t *)aux32;
4781+
//uint32_t aux32[2];
4782+
//thread const int8_t * grid = (thread const int8_t *)aux32;
47114783

47124784
device const float * y4 = y + 32 * ix;
47134785

@@ -4735,11 +4807,11 @@ void kernel_mul_mv_iq3_s_f32_impl(
47354807
float2 sum = {0};
47364808
for (int l = 0; l < 4; ++l) {
47374809
// This is slower than pre-computing the grid in shared memory and loading from there
4738-
//aux32[0] = ((IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))) & 0x0f0f0f0f) | 0x01010101;
4739-
//aux32[1] = ((IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))) & 0x0f0f0f0f) | 0x01010101;
4810+
//aux32[0] = (IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))) & 0x0f0f0f0f;
4811+
//aux32[1] = (IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))) & 0x0f0f0f0f;
47404812
//for (int j = 0; j < 4; ++j) {
4741-
// sum[0] += yl[8*l + j + 0] * grid[j+0] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
4742-
// sum[1] += yl[8*l + j + 4] * grid[j+4] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
4813+
// sum[0] += yl[8*l + j + 0] * iq3s_values[grid[j+0]] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
4814+
// sum[1] += yl[8*l + j + 4] * iq3s_values[grid[j+4]] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
47434815
//}
47444816
threadgroup const uint8_t * grid1 = (threadgroup const uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
47454817
threadgroup const uint8_t * grid2 = (threadgroup const uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
@@ -5655,20 +5727,32 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 &
56555727
device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
56565728
const uint8_t qh = xb->qh[ib32] >> 4*il;
56575729
const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
5658-
uint32_t aux32[2];
5659-
thread const int8_t * grid = (thread const int8_t *)aux32;
5660-
aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f;
5661-
aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f;
5730+
constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
5731+
constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
56625732
for (int i = 0; i < 4; ++i) {
5663-
reg[0][i] = dl * iq3s_values[grid[i+0]] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
5664-
reg[1][i] = dl * iq3s_values[grid[i+4]] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
5733+
reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
5734+
reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
56655735
}
5666-
aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+2] | ((qh << 6) & 256))) & 0x0f0f0f0f;
5667-
aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+3] | ((qh << 5) & 256))) & 0x0f0f0f0f;
5736+
grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
5737+
grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
56685738
for (int i = 0; i < 4; ++i) {
5669-
reg[2][i] = dl * iq3s_values[grid[i+0]] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
5670-
reg[3][i] = dl * iq3s_values[grid[i+4]] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
5671-
}
5739+
reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
5740+
reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
5741+
}
5742+
//uint32_t aux32[2];
5743+
//thread const int8_t * grid = (thread const int8_t *)aux32;
5744+
//aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f;
5745+
//aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f;
5746+
//for (int i = 0; i < 4; ++i) {
5747+
// reg[0][i] = dl * iq3s_values[grid[i+0]] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
5748+
// reg[1][i] = dl * iq3s_values[grid[i+4]] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
5749+
//}
5750+
//aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+2] | ((qh << 6) & 256))) & 0x0f0f0f0f;
5751+
//aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+3] | ((qh << 5) & 256))) & 0x0f0f0f0f;
5752+
//for (int i = 0; i < 4; ++i) {
5753+
// reg[2][i] = dl * iq3s_values[grid[i+0]] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
5754+
// reg[3][i] = dl * iq3s_values[grid[i+4]] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
5755+
//}
56725756
}
56735757

56745758
template <typename type4x4>

0 commit comments

Comments
 (0)