@@ -2546,8 +2546,10 @@ typedef struct {
2546
2546
uint8_t signs[QK_K/8 ];
2547
2547
uint8_t scales[IQ3S_N_SCALE];
2548
2548
} block_iq3_s;
2549
- #define IQ3S_MULTIPLIER 518559
2550
- constexpr constant static uint8_t iq3s_values[16 ] = {1 , 1 , 1 , 3 , 3 , 3 , 5 , 5 , 5 , 7 , 7 , 9 , 9 , 11 , 13 , 15 };
2549
+
2550
+ // When a shuffle is involved in the codebook, on Metal it is faster to use a lookup table
2551
+ // #define IQ3S_MULTIPLIER 518559
2552
+ // constexpr constant static uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15};
2551
2553
2552
2554
typedef struct {
2553
2555
half d;
@@ -4085,6 +4087,73 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
4085
4087
0x3e1c1c1c , 0x3e1c3404 , 0x3e24140c , 0x3e24240c , 0x3e2c0404 , 0x3e2c0414 , 0x3e2c1424 , 0x3e341c04 ,
4086
4088
};
4087
4089
4090
+ constexpr constant static uint32_t iq3s_grid[512 ] = {
4091
+ 0x01010101 , 0x0105070f , 0x010f030d , 0x0105090b , 0x010f0509 , 0x01050109 , 0x010f0707 , 0x01050307 ,
4092
+ 0x010f0905 , 0x01050505 , 0x010f0105 , 0x01050703 , 0x010d0303 , 0x01050b03 , 0x010d0501 , 0x01050101 ,
4093
+ 0x010d0701 , 0x0105030f , 0x010d0b0d , 0x0105050b , 0x010d0109 , 0x01050709 , 0x010d0307 , 0x01030b07 ,
4094
+ 0x010b0505 , 0x01030105 , 0x010b0705 , 0x01030303 , 0x010b0b03 , 0x01030503 , 0x010b0101 , 0x01030701 ,
4095
+ 0x010b0301 , 0x01030b0f , 0x010b050d , 0x0103010b , 0x01090709 , 0x01030309 , 0x01090b07 , 0x01030507 ,
4096
+ 0x01090105 , 0x01030705 , 0x01090305 , 0x01030b03 , 0x01090503 , 0x01030103 , 0x01090701 , 0x01030301 ,
4097
+ 0x01090b01 , 0x0103050f , 0x0109010d , 0x0103070b , 0x01090309 , 0x01030b09 , 0x01090507 , 0x01030107 ,
4098
+ 0x01090705 , 0x01030305 , 0x01070d05 , 0x01010503 , 0x01070103 , 0x01010703 , 0x01070301 , 0x01010d01 ,
4099
+ 0x01070501 , 0x0101010f , 0x0107070d , 0x0101030b , 0x01070d09 , 0x01010509 , 0x01070107 , 0x01010907 ,
4100
+ 0x01070305 , 0x01010d05 , 0x01070505 , 0x01010103 , 0x01070903 , 0x01010303 , 0x01070d01 , 0x01010501 ,
4101
+ 0x01070101 , 0x0101090f , 0x0105030d , 0x01010d0b , 0x01050509 , 0x01010109 , 0x01050907 , 0x01010307 ,
4102
+ 0x01050d05 , 0x01010505 , 0x01050105 , 0x01010903 , 0x01050303 , 0x010f0d03 , 0x01050501 , 0x010f0101 ,
4103
+ 0x01050901 , 0x010f030f , 0x03050d0d , 0x030f050b , 0x03050109 , 0x030f0909 , 0x03050307 , 0x030d0d07 ,
4104
+ 0x03050505 , 0x030d0105 , 0x03050905 , 0x030d0303 , 0x03050f03 , 0x030d0503 , 0x03050101 , 0x030d0901 ,
4105
+ 0x03050301 , 0x030d0f0f , 0x0305050d , 0x030b010b , 0x03030909 , 0x030b0309 , 0x03030f07 , 0x030b0507 ,
4106
+ 0x03030105 , 0x030b0905 , 0x03030305 , 0x030b0f03 , 0x03030703 , 0x030b0103 , 0x03030901 , 0x03090301 ,
4107
+ 0x03030f01 , 0x0309070f , 0x0303010d , 0x0309090b , 0x03030309 , 0x03090f09 , 0x03030707 , 0x03090107 ,
4108
+ 0x03030905 , 0x03090505 , 0x03030f05 , 0x03090703 , 0x03030103 , 0x03090903 , 0x03030501 , 0x03090f01 ,
4109
+ 0x03030701 , 0x0309030f , 0x0303090d , 0x0309050b , 0x03030f09 , 0x03070709 , 0x03010307 , 0x03070907 ,
4110
+ 0x03010505 , 0x03070105 , 0x03010705 , 0x03070303 , 0x03010903 , 0x03070503 , 0x03010101 , 0x03070701 ,
4111
+ 0x03010301 , 0x0307090f , 0x0301050d , 0x0307010b , 0x03010709 , 0x03070309 , 0x03010b07 , 0x03070507 ,
4112
+ 0x03010105 , 0x03070705 , 0x03010305 , 0x03070b03 , 0x03010503 , 0x03050103 , 0x03010701 , 0x03050301 ,
4113
+ 0x03010b01 , 0x0305050f , 0x0301010d , 0x0305070b , 0x03010309 , 0x03050b09 , 0x03010507 , 0x03050107 ,
4114
+ 0x030f0705 , 0x03050305 , 0x030f0b05 , 0x03050503 , 0x030f0103 , 0x03050703 , 0x030f0301 , 0x03050b01 ,
4115
+ 0x030f0501 , 0x0305010f , 0x030f070d , 0x0505030b , 0x050d0b09 , 0x05050509 , 0x050d0107 , 0x05050707 ,
4116
+ 0x050d0305 , 0x05050b05 , 0x050d0505 , 0x05050103 , 0x050d0703 , 0x05050303 , 0x050b0b01 , 0x05030501 ,
4117
+ 0x050b0101 , 0x0503070f , 0x050b030d , 0x05030d0b , 0x050b0509 , 0x05030109 , 0x050b0707 , 0x05030307 ,
4118
+ 0x050b0d05 , 0x05030505 , 0x05090105 , 0x05030903 , 0x05090303 , 0x05030d03 , 0x05090501 , 0x05030101 ,
4119
+ 0x05090901 , 0x0503030f , 0x05090d0d , 0x0503050b , 0x05090109 , 0x05030909 , 0x05090307 , 0x05030d07 ,
4120
+ 0x05090505 , 0x05030105 , 0x05090905 , 0x05030303 , 0x05090d03 , 0x05030503 , 0x05090101 , 0x05030901 ,
4121
+ 0x05090301 , 0x05010d0f , 0x0507050d , 0x0501010b , 0x05070909 , 0x05010309 , 0x05070d07 , 0x05010507 ,
4122
+ 0x05070105 , 0x05010905 , 0x05070305 , 0x05010d03 , 0x05070503 , 0x05010103 , 0x05070901 , 0x05010301 ,
4123
+ 0x05070f01 , 0x0501050f , 0x0507010d , 0x0501090b , 0x05070309 , 0x05010f09 , 0x05070507 , 0x05010107 ,
4124
+ 0x05050905 , 0x05010305 , 0x05050f05 , 0x05010503 , 0x05050103 , 0x05010903 , 0x05050301 , 0x05010f01 ,
4125
+ 0x05050501 , 0x0501010f , 0x0505090d , 0x050f030b , 0x05050f09 , 0x050f0709 , 0x05050107 , 0x050f0907 ,
4126
+ 0x05050305 , 0x050f0f05 , 0x05050705 , 0x050f0103 , 0x05050903 , 0x050f0503 , 0x05050f01 , 0x050d0701 ,
4127
+ 0x05050101 , 0x050d090f , 0x0505050d , 0x050d0f0b , 0x07050709 , 0x070d0109 , 0x07050907 , 0x070d0507 ,
4128
+ 0x07050f05 , 0x070d0705 , 0x07030305 , 0x070b0903 , 0x07030503 , 0x070b0f03 , 0x07030701 , 0x070b0301 ,
4129
+ 0x07030901 , 0x070b050f , 0x0703010d , 0x070b070b , 0x07030309 , 0x07090909 , 0x07030507 , 0x07090107 ,
4130
+ 0x07030705 , 0x07090305 , 0x07030b05 , 0x07090503 , 0x07030103 , 0x07090703 , 0x07030301 , 0x07090b01 ,
4131
+ 0x07030501 , 0x0709010f , 0x0703070d , 0x0709030b , 0x07030b09 , 0x07090509 , 0x07030107 , 0x07090707 ,
4132
+ 0x07030305 , 0x07090b05 , 0x07030505 , 0x07090103 , 0x07010703 , 0x07070303 , 0x07010b01 , 0x07070501 ,
4133
+ 0x07010101 , 0x0707070f , 0x0701030d , 0x07070b0b , 0x07010509 , 0x07070109 , 0x07010707 , 0x07070307 ,
4134
+ 0x07010b05 , 0x07070505 , 0x07010105 , 0x07070703 , 0x07010303 , 0x07070b03 , 0x07010501 , 0x07070101 ,
4135
+ 0x07010701 , 0x0707030f , 0x07010b0d , 0x0705050b , 0x09010109 , 0x09050709 , 0x09010307 , 0x09050b07 ,
4136
+ 0x09010505 , 0x09050105 , 0x09010705 , 0x09050303 , 0x09010d03 , 0x09050503 , 0x09010101 , 0x09050701 ,
4137
+ 0x090f0301 , 0x09050d0f , 0x090f050d , 0x0905010b , 0x090f0909 , 0x09050309 , 0x090f0d07 , 0x09050507 ,
4138
+ 0x090f0105 , 0x09050905 , 0x090d0305 , 0x09050d03 , 0x090d0503 , 0x09050103 , 0x090d0901 , 0x09050301 ,
4139
+ 0x090d0d01 , 0x0905050f , 0x090d010d , 0x0905090b , 0x090d0309 , 0x09030d09 , 0x090b0507 , 0x09030107 ,
4140
+ 0x090b0905 , 0x09030305 , 0x090b0d05 , 0x09030503 , 0x090b0103 , 0x09030903 , 0x090b0301 , 0x09030d01 ,
4141
+ 0x090b0501 , 0x0903010f , 0x0909090d , 0x0903030b , 0x09090d09 , 0x09030509 , 0x09090107 , 0x09030907 ,
4142
+ 0x09090305 , 0x09030f05 , 0x09090505 , 0x09030103 , 0x09090903 , 0x09030303 , 0x09090f01 , 0x09030501 ,
4143
+ 0x09090101 , 0x0903090f , 0x0909030d , 0x09030f0b , 0x09090509 , 0x0b030109 , 0x0b090907 , 0x0b030307 ,
4144
+ 0x0b070f05 , 0x0b010505 , 0x0b070105 , 0x0b010903 , 0x0b070303 , 0x0b010f03 , 0x0b070701 , 0x0b010101 ,
4145
+ 0x0b070901 , 0x0b01030f , 0x0b070f0d , 0x0b01070b , 0x0b070109 , 0x0b010909 , 0x0b070507 , 0x0b010f07 ,
4146
+ 0x0b070705 , 0x0b010105 , 0x0b070905 , 0x0b010503 , 0x0b070f03 , 0x0b010703 , 0x0b070301 , 0x0b010901 ,
4147
+ 0x0b050501 , 0x0b010f0f , 0x0b05070d , 0x0b01030b , 0x0b050909 , 0x0d010509 , 0x0d050f07 , 0x0d010707 ,
4148
+ 0x0d050305 , 0x0d010905 , 0x0d050505 , 0x0d0f0103 , 0x0d050703 , 0x0d0f0303 , 0x0d050901 , 0x0d0f0501 ,
4149
+ 0x0d050101 , 0x0d0f070f , 0x0d05030d , 0x0d0f0b0b , 0x0d050509 , 0x0d0f0109 , 0x0d050707 , 0x0d0d0307 ,
4150
+ 0x0d050b05 , 0x0d0d0505 , 0x0d050105 , 0x0d0d0703 , 0x0d050303 , 0x0d0d0b03 , 0x0d050501 , 0x0d0d0101 ,
4151
+ 0x0d050701 , 0x0d0b030f , 0x0d030b0d , 0x0d0b050b , 0x0d030109 , 0x0d0b0709 , 0x0f030307 , 0x0f0b0b07 ,
4152
+ 0x0f030505 , 0x0f0b0105 , 0x0f030705 , 0x0f0b0303 , 0x0f030b03 , 0x0f090503 , 0x0f030101 , 0x0f090701 ,
4153
+ 0x0f030301 , 0x0f090b0f , 0x0f03050d , 0x0f09010b , 0x0f030709 , 0x0f090309 , 0x0f030b07 , 0x0f090507 ,
4154
+ 0x0f030105 , 0x0f090705 , 0x0f030305 , 0x0f090b03 , 0x0f030503 , 0x0f090103 , 0x0f030701 , 0x0f090301 ,
4155
+ };
4156
+
4088
4157
#define NGRID_IQ1S 512
4089
4158
constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
4090
4159
0xffffffffffff0101 , 0xffffffffff01ff00 , 0xffffffffff010100 , 0xffffffff00000000 ,
@@ -4694,20 +4763,23 @@ void kernel_mul_mv_iq3_s_f32_impl(
4694
4763
{
4695
4764
int nval = 8 ;
4696
4765
int pos = (32 *sgitg + tiisg)*nval;
4697
- uint32_t aux32;
4698
- thread int8_t * q = (thread int8_t *)&aux32;
4699
4766
for (int i = 0 ; i < nval; ++i) {
4700
- aux32 = (IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f ;
4701
- for (int k = 0 ; k < 4 ; ++k) q[k] = iq3s_values[q[k]];
4702
- values[pos + i] = aux32;
4767
+ values[pos + i] = iq3s_grid[pos + i];
4703
4768
}
4769
+ // uint32_t aux32;
4770
+ // thread int8_t * q = (thread int8_t *)&aux32;
4771
+ // for (int i = 0; i < nval; ++i) {
4772
+ // aux32 = (IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f;
4773
+ // for (int k = 0; k < 4; ++k) q[k] = iq3s_values[q[k]];
4774
+ // values[pos + i] = aux32;
4775
+ // }
4704
4776
threadgroup_barrier (mem_flags::mem_threadgroup);
4705
4777
}
4706
4778
4707
4779
const int ix = tiisg;
4708
4780
4709
- uint32_t aux32[2 ];
4710
- thread const int8_t * grid = (thread const int8_t *)aux32;
4781
+ // uint32_t aux32[2];
4782
+ // thread const int8_t * grid = (thread const int8_t *)aux32;
4711
4783
4712
4784
device const float * y4 = y + 32 * ix;
4713
4785
@@ -4735,11 +4807,11 @@ void kernel_mul_mv_iq3_s_f32_impl(
4735
4807
float2 sum = {0 };
4736
4808
for (int l = 0 ; l < 4 ; ++l) {
4737
4809
// This is slower than pre-computing the grid in shared memory and loading from there
4738
- // aux32[0] = (( IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))) & 0x0f0f0f0f) | 0x01010101 ;
4739
- // aux32[1] = (( IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))) & 0x0f0f0f0f) | 0x01010101 ;
4810
+ // aux32[0] = (IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))) & 0x0f0f0f0f;
4811
+ // aux32[1] = (IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))) & 0x0f0f0f0f;
4740
4812
// for (int j = 0; j < 4; ++j) {
4741
- // sum[0] += yl[8*l + j + 0] * grid[j+0] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
4742
- // sum[1] += yl[8*l + j + 4] * grid[j+4] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
4813
+ // sum[0] += yl[8*l + j + 0] * iq3s_values[ grid[j+0] ] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
4814
+ // sum[1] += yl[8*l + j + 4] * iq3s_values[ grid[j+4] ] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
4743
4815
// }
4744
4816
threadgroup const uint8_t * grid1 = (threadgroup const uint8_t *)(values + (qs[2 *l+0 ] | ((qh[0 ] << (8 -2 *l)) & 256 )));
4745
4817
threadgroup const uint8_t * grid2 = (threadgroup const uint8_t *)(values + (qs[2 *l+1 ] | ((qh[0 ] << (7 -2 *l)) & 256 )));
@@ -5655,20 +5727,32 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 &
5655
5727
device const uint8_t * signs = xb->signs + 4 *ib32 + 2 *il;
5656
5728
const uint8_t qh = xb->qh [ib32] >> 4 *il;
5657
5729
const float dl = d * (1 + 2 *((xb->scales [ib32/2 ] >> 4 *(ib32%2 )) & 0xf ));
5658
- uint32_t aux32[2 ];
5659
- thread const int8_t * grid = (thread const int8_t *)aux32;
5660
- aux32[0 ] = (IQ3S_MULTIPLIER * (qs[4 *il+0 ] | ((qh << 8 ) & 256 ))) & 0x0f0f0f0f ;
5661
- aux32[1 ] = (IQ3S_MULTIPLIER * (qs[4 *il+1 ] | ((qh << 7 ) & 256 ))) & 0x0f0f0f0f ;
5730
+ constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4 *il+0 ] | ((qh << 8 ) & 256 )));
5731
+ constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4 *il+1 ] | ((qh << 7 ) & 256 )));
5662
5732
for (int i = 0 ; i < 4 ; ++i) {
5663
- reg[0 ][i] = dl * iq3s_values[grid[i+ 0 ] ] * select (1 , -1 , signs[0 ] & kmask_iq2xs[i+0 ]);
5664
- reg[1 ][i] = dl * iq3s_values[grid[i+ 4 ] ] * select (1 , -1 , signs[0 ] & kmask_iq2xs[i+4 ]);
5733
+ reg[0 ][i] = dl * grid1[i ] * select (1 , -1 , signs[0 ] & kmask_iq2xs[i+0 ]);
5734
+ reg[1 ][i] = dl * grid2[i ] * select (1 , -1 , signs[0 ] & kmask_iq2xs[i+4 ]);
5665
5735
}
5666
- aux32[ 0 ] = (IQ3S_MULTIPLIER * (qs[4 *il+2 ] | ((qh << 6 ) & 256 ))) & 0x0f0f0f0f ;
5667
- aux32[ 1 ] = (IQ3S_MULTIPLIER * (qs[4 *il+3 ] | ((qh << 5 ) & 256 ))) & 0x0f0f0f0f ;
5736
+ grid1 = (constant uint8_t *)(iq3s_grid + (qs[4 *il+2 ] | ((qh << 6 ) & 256 )));
5737
+ grid2 = (constant uint8_t *)(iq3s_grid + (qs[4 *il+3 ] | ((qh << 5 ) & 256 )));
5668
5738
for (int i = 0 ; i < 4 ; ++i) {
5669
- reg[2 ][i] = dl * iq3s_values[grid[i+0 ]] * select (1 , -1 , signs[1 ] & kmask_iq2xs[i+0 ]);
5670
- reg[3 ][i] = dl * iq3s_values[grid[i+4 ]] * select (1 , -1 , signs[1 ] & kmask_iq2xs[i+4 ]);
5671
- }
5739
+ reg[2 ][i] = dl * grid1[i] * select (1 , -1 , signs[1 ] & kmask_iq2xs[i+0 ]);
5740
+ reg[3 ][i] = dl * grid2[i] * select (1 , -1 , signs[1 ] & kmask_iq2xs[i+4 ]);
5741
+ }
5742
+ // uint32_t aux32[2];
5743
+ // thread const int8_t * grid = (thread const int8_t *)aux32;
5744
+ // aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f;
5745
+ // aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f;
5746
+ // for (int i = 0; i < 4; ++i) {
5747
+ // reg[0][i] = dl * iq3s_values[grid[i+0]] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
5748
+ // reg[1][i] = dl * iq3s_values[grid[i+4]] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
5749
+ // }
5750
+ // aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+2] | ((qh << 6) & 256))) & 0x0f0f0f0f;
5751
+ // aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+3] | ((qh << 5) & 256))) & 0x0f0f0f0f;
5752
+ // for (int i = 0; i < 4; ++i) {
5753
+ // reg[2][i] = dl * iq3s_values[grid[i+0]] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
5754
+ // reg[3][i] = dl * iq3s_values[grid[i+4]] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
5755
+ // }
5672
5756
}
5673
5757
5674
5758
template <typename type4x4>
0 commit comments