Skip to content

Commit 8b4500b

Browse files
authored
[ET-VK] Add co-op algorithm for 4 bit weight only quantized linear (#10235)
## Context As title. Add an alternative compute shader for int4 weight-only quantized linear that utilizes a co-operative algorithm. This shader is more performant than standard tiled algorithms for `gemv` cases, i.e. when `mat1` is a vector rather than a matrix. ## Changes * Add the cooperative shader * Use the cooperative shader when the height of `mat1` is 1 Differential Revision: [D73044650](https://our.internmc.facebook.com/intern/diff/D73044650/)
1 parent 614dde0 commit 8b4500b

File tree

4 files changed

+252
-3
lines changed

4 files changed

+252
-3
lines changed
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define T ${buffer_scalar_type(DTYPE)}
14+
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}
15+
16+
#define TILE_ROWS ${TILE_ROWS}
17+
18+
#define NGROUPS 8
19+
#define NWORKERS 8
20+
21+
${define_required_extensions(DTYPE)}
22+
$if WEIGHT_STORAGE == "buffer":
23+
${define_required_extensions("uint8")}
24+
25+
#extension GL_EXT_control_flow_attributes : require
26+
27+
layout(std430) buffer;
28+
29+
${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
30+
${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array=False)}
31+
${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)}
32+
${layout_declare_tensor(B, "r", "t_qparams", DTYPE, "buffer", is_scalar_array=False)}
33+
34+
layout(push_constant) uniform restrict Block {
35+
ivec4 out_sizes;
36+
ivec4 mat1_sizes;
37+
ivec4 qmat2_sizes;
38+
};
39+
40+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
41+
42+
layout(constant_id = 3) const int group_size = 64;
43+
44+
shared VEC4_T partial_sums[NGROUPS][NWORKERS][TILE_ROWS][2];
45+
46+
/*
47+
* This shader computes a linear operator between a floating point input matrix
48+
* x and a weights matrix that is quantized to 4 bits. Please refer to the
49+
* q_4w_linear shader for more details.
50+
*
51+
* This shader implements a co-operative algorithm to compute the output. The
52+
* work group size is {NGROUP, 1, NWORKERS}, and each group of NWORKERS threads
53+
* cooperative to compute TILE_ROWS * 2 output texels. Therefore,
54+
* NGROUP * TILE_ROWS * 2 output texels are computed across one work group.
55+
*
56+
* The threads co-operate by each thread computing a partial reduction along the
57+
* K dimension. To illustrate the computation, consider a scalar variant of the
58+
* algorithm that computes the dot product of 2 vectors. Also assume that
59+
* NWORKERS is 8.
60+
*
61+
* Thread 1 in each group will compute:
62+
* (mat1[0] * mat2[0]) + (mat1[8] * mat2[8]) + (mat1[16] * mat2[16]) + ...
63+
*
64+
* Thread 2 in each group will compute:
65+
* (mat1[1] * mat2[1]) + (mat2[9] * mat2[9]) + (mat1[17] * mat2[17]) + ...
66+
*
67+
* Thread 3 in each group will compute:
68+
* (mat1[2] * mat2[2]) + (mat2[10] * mat2[10]) + (mat1[18] * mat2[18]) + ...
69+
*
70+
* The partial accumulations is structured such that memory accesses in each
71+
* loop iteration can be coalesced.
72+
*
73+
* Then, at the end first thread in each group will accumulate the partial
74+
* accumulations computed by each thread to obtain the final result.
75+
*
76+
* Note that this shader assumes that all tensors are width packed.
77+
*/
78+
void main() {
79+
const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS;
80+
// Each thread writes out 2 texels along the width axis, equivalent to 8
81+
// scalar elements. Therefore multiply the thread_idx.x by 8.
82+
const uint out_col = gl_GlobalInvocationID.x << 3;
83+
// Similar reasoning to the above, each thread works on 2 texels along the
84+
// width axis so multiply thread_idx.x by 2.
85+
const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1;
86+
87+
const uint gid = gl_LocalInvocationID.x; // group id
88+
const uint wid = gl_LocalInvocationID.z; // worker id
89+
90+
if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
91+
return;
92+
}
93+
94+
const int num_blocks = mat1_sizes.x / group_size;
95+
96+
VEC4_T mat1[TILE_ROWS];
97+
VEC4_T qmat2[4][2];
98+
VEC4_T local_sums[TILE_ROWS][2];
99+
100+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
101+
local_sums[r][0] = VEC4_T(0);
102+
local_sums[r][1] = VEC4_T(0);
103+
}
104+
105+
VEC4_T scales[2];
106+
VEC4_T zeros[2];
107+
108+
$if WEIGHT_STORAGE == "buffer":
109+
const int qmat2_stride = qmat2_sizes.x >> 2;
110+
$if PARAMS_STORAGE == "buffer":
111+
const int qparams_y_stride = out_sizes.x >> 2;
112+
const int qparams_z_stride = qparams_y_stride * 2;
113+
114+
for (int block_idx = 0; block_idx < num_blocks; ++block_idx) {
115+
$if PARAMS_STORAGE == "buffer":
116+
scales[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx];
117+
zeros[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + qparams_y_stride];
118+
119+
scales[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1];
120+
zeros[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1 + qparams_y_stride];
121+
$else:
122+
scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0);
123+
zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0);
124+
125+
scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0);
126+
zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0);
127+
128+
for (uint g_idx = 4 * wid; g_idx < group_size; g_idx += (4 * NWORKERS)) {
129+
const uint k = block_idx * group_size + g_idx;
130+
131+
// Preload B
132+
[[unroll]] for (int r = 0; r < 4; ++r) {
133+
$if WEIGHT_STORAGE == "buffer":
134+
const u8vec4 packed_weight_tex = t_qmat2[(k + r) * qmat2_stride + gl_GlobalInvocationID.x];
135+
$else:
136+
const uvec4 packed_weight_tex = texelFetch(
137+
t_qmat2,
138+
ivec2(gl_GlobalInvocationID.x, k + r),
139+
0);
140+
141+
qmat2[r][0] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0) * scales[0] + zeros[0];
142+
qmat2[r][1] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0) * scales[1] + zeros[1];
143+
}
144+
145+
// Preload A
146+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
147+
$if IN_STORAGE == "buffer":
148+
mat1[r] = t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2];
149+
$else:
150+
mat1[r] = texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0);
151+
}
152+
153+
// Accumulate local output tile
154+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
155+
local_sums[r][0] += mat1[r].x * qmat2[0][0]
156+
+ mat1[r].y * qmat2[1][0]
157+
+ mat1[r].z * qmat2[2][0]
158+
+ mat1[r].w * qmat2[3][0];
159+
160+
local_sums[r][1] += mat1[r].x * qmat2[0][1]
161+
+ mat1[r].y * qmat2[1][1]
162+
+ mat1[r].z * qmat2[2][1]
163+
+ mat1[r].w * qmat2[3][1];
164+
}
165+
}
166+
}
167+
168+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
169+
partial_sums[gid][wid][r][0] = local_sums[r][0];
170+
partial_sums[gid][wid][r][1] = local_sums[r][1];
171+
}
172+
173+
memoryBarrierShared();
174+
barrier();
175+
176+
if (wid != 0) {
177+
return;
178+
}
179+
180+
VEC4_T sums[TILE_ROWS][2];
181+
182+
for (int r = 0; r < TILE_ROWS; ++r) {
183+
sums[r][0] = VEC4_T(0);
184+
sums[r][1] = VEC4_T(0);
185+
[[unroll]] for (int worker = 0; worker < NWORKERS; ++ worker) {
186+
sums[r][0] += partial_sums[gid][worker][r][0];
187+
sums[r][1] += partial_sums[gid][worker][r][1];
188+
}
189+
}
190+
191+
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
192+
$if OUT_STORAGE == "buffer":
193+
t_out[((out_row + r) * out_sizes.x + out_col) >> 2] = sums[r][0];
194+
t_out[((out_row + r) * out_sizes.x + out_col + 4) >> 2] = sums[r][1];
195+
$else:
196+
imageStore(t_out, ivec3(out_col_texel_idx, out_row + r, 0), sums[r][0]);
197+
imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row + r, 0), sums[r][1]);
198+
}
199+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
q_4w_linear_coop:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
OUT_STORAGE: texture3d
11+
IN_STORAGE: texture3d
12+
WEIGHT_STORAGE: texture2d
13+
PARAMS_STORAGE: buffer
14+
TILE_ROWS: 1
15+
shader_variants:
16+
- NAME: q_4w_linear_coop_texture3d_texture3d_texture2d_float
17+
- NAME: q_4w_linear_coop_buffer_buffer_texture2d_float
18+
OUT_STORAGE: buffer
19+
IN_STORAGE: buffer
20+
- NAME: q_4w_linear_coop_buffer_buffer_buffer_float
21+
OUT_STORAGE: buffer
22+
IN_STORAGE: buffer
23+
WEIGHT_STORAGE: buffer

backends/vulkan/runtime/graph/ops/impl/QuantizedLinearGroupwiseInt4.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,24 +128,37 @@ void add_q_4w_linear_node(
128128
check_q_4w_linear_args(
129129
graph, mat1, mat2_data, group_size, scales_and_zeros_data, out);
130130

131+
const uint32_t group_size_val = graph.extract_scalar<uint32_t>(group_size);
132+
133+
bool use_coop_algorithm = false;
134+
// Apply the coop algorithm for gemv cases, i.e. mat1 is a vector as opposed
135+
// to a matrix.
136+
if (graph.size_at<uint32_t>(-2, mat1) == 1) {
137+
use_coop_algorithm = true;
138+
}
139+
131140
ValueRef mat2 =
132141
prepack_int4_linear_weight_transposed_interleaved(graph, mat2_data);
133142

134143
ValueRef scales_and_zeros = prepack_standard_hw_transposed(
135144
graph, scales_and_zeros_data, utils::kBuffer, utils::kWidthPacked);
136145

137146
std::string kernel_name = "q_4w_linear";
147+
if (use_coop_algorithm) {
148+
kernel_name += "_coop";
149+
}
138150
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
139151
add_storage_type_suffix(kernel_name, graph.storage_type_of(mat1));
140152
add_storage_type_suffix(kernel_name, graph.storage_type_of(mat2));
141153
add_dtype_suffix(kernel_name, graph.dtype_of(out));
142154

143-
const uint32_t group_size_val = graph.extract_scalar<uint32_t>(group_size);
144-
145155
utils::uvec3 global_wg_size = graph.logical_limits_of(out);
146156
global_wg_size[0] = utils::div_up(global_wg_size[0], uint32_t(2));
147157

148158
utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size);
159+
if (use_coop_algorithm) {
160+
local_wg_size = {8, 1, 8};
161+
}
149162

150163
graph.execute_nodes().emplace_back(new DispatchNode(
151164
graph,

backends/vulkan/test/op_tests/linear_weight_int4_test.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,24 @@ TEST(VulkanInt4LinearTest, test_reference_impl) {
273273
/*N = */ 32);
274274
}
275275

276-
TEST(VulkanInt4LinearTest, test_vulkan_impl) {
276+
TEST(VulkanInt4LinearTest, test_vulkan_impl_small_m) {
277277
test_vulkan_linear_int4(
278278
/*B = */ 1,
279279
/*M = */ 4,
280280
/*K = */ 128,
281281
/*N = */ 32);
282+
283+
test_vulkan_linear_int4(
284+
/*B = */ 1,
285+
/*M = */ 1,
286+
/*K = */ 256,
287+
/*N = */ 256);
288+
}
289+
290+
TEST(VulkanInt4LinearTest, test_vulkan_impl_gemm) {
291+
test_vulkan_linear_int4(
292+
/*B = */ 1,
293+
/*M = */ 256,
294+
/*K = */ 256,
295+
/*N = */ 256);
282296
}

0 commit comments

Comments
 (0)