Skip to content

Commit 1172fd1

Browse files
committed
[ET-VK] Allow int4 linear to execute without 8bit buffer support
Pull Request resolved: #10030 ## Context Some Vulkan devices do not have support for 8-bit buffers, which is currently required to execute the int4 linear compute shader due to the prepacking shader requiring it. This diff bypasses that restriction by introducing a variant of the prepacking shader that does not need 8-bit buffers. ## Changes Introduce a variant of the int4 weight prepacking shader that interprets the tensor data as an array of `uint` instead of `uint8_t`. Each `uint` represents 4 `uint8_t` values. Differential Revision: [D72750897](https://our.internmc.facebook.com/intern/diff/D72750897/) ghstack-source-id: 277175676
1 parent 9e19ece commit 1172fd1

File tree

6 files changed

+56
-30
lines changed

6 files changed

+56
-30
lines changed

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ def buffer_gvec_type(dtype: str, n: int) -> str:
125125

126126
if dtype == "float":
127127
return f"vec{n}"
128+
if dtype == "uint":
129+
return f"uvec{n}"
128130
elif dtype == "half":
129131
return f"f16vec{n}"
130132
elif dtype == "int":

backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.glsl

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,32 +10,52 @@
1010

1111
#define PRECISION ${PRECISION}
1212

13-
${define_required_extensions("uint8")}
14-
${define_required_extensions("int8")}
13+
$if not NO_INT8_BUFFERS:
14+
${define_required_extensions("uint8")}
15+
$if STORAGE == "buffer":
16+
${define_required_extensions("int8")}
1517

1618
layout(std430) buffer;
1719

1820
${layout_declare_tensor(B, "w", "t_qmat2", "uint8", STORAGE, is_scalar_array=False)}
19-
${layout_declare_tensor(B, "r", "nchw_4x2", "uint8", "buffer")}
21+
$if NO_INT8_BUFFERS:
22+
${layout_declare_tensor(B, "r", "nchw_4x2", "uint", "buffer")}
23+
$else:
24+
${layout_declare_tensor(B, "r", "nchw_4x2", "uint8", "buffer")}
2025

2126
layout(push_constant) uniform restrict Block {
2227
ivec4 qmat2_sizes;
2328
};
2429

2530
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2631

27-
uint8_t get_first(const uint8_t packed) {
28-
return uint8_t((packed & 0xF0) >> 4);
32+
$if NO_INT8_BUFFERS:
33+
#define BUF_T uint
34+
$else:
35+
#define BUF_T uint8_t
36+
37+
$if STORAGE == "buffer":
38+
#define UVEC4_T u8vec4
39+
$else:
40+
#define UVEC4_T uvec4
41+
42+
uint get_first(const BUF_T packed) {
43+
return (packed & 0xF0) >> 4;
2944
}
3045

31-
uint8_t get_second(const uint8_t packed) {
32-
return uint8_t(packed & 0x0F);
46+
uint get_second(const BUF_T packed) {
47+
return packed & 0x0F;
3348
}
3449

35-
uint8_t combine(const uint8_t first, const uint8_t second) {
36-
return uint8_t(first << 4 | second);
50+
uint combine(const uint first, const uint second) {
51+
return (first << 4 | second);
3752
}
3853

54+
$if NO_INT8_BUFFERS:
55+
uint extract_comp(const uint packed4, const uint idx) {
56+
return (packed4 >> (idx * 8)) & 0xFF;
57+
}
58+
3959
/*
4060
* This shader packs the weight tensor into a texture.
4161
*
@@ -102,25 +122,32 @@ void main() {
102122
int in_numcols = qmat2_sizes.y;
103123
int in_num_int8_cols = qmat2_sizes.y >> 1;
104124

105-
uint8_t in_vals[8][2];
125+
uint in_vals[8][2];
106126
for (int r = 0; r < 8; ++r) {
107127
if (in_row + r < in_numrows) {
108-
uint8_t in_val_packed = nchw_4x2[(in_row + r) * in_num_int8_cols + in_int8_col];
128+
uint scalar_idx = (in_row + r) * in_num_int8_cols + in_int8_col;
129+
$if NO_INT8_BUFFERS:
130+
BUF_T in_val_packed_texel = nchw_4x2[scalar_idx >> 2];
131+
const uint packed_idx = scalar_idx % 4;
132+
uint in_val_packed = extract_comp(in_val_packed_texel, packed_idx);
133+
$else:
134+
BUF_T in_val_packed = nchw_4x2[scalar_idx];
135+
109136
in_vals[r][0] = get_first(in_val_packed);
110137
in_vals[r][1] = get_second(in_val_packed);
111138
} else {
112-
in_vals[r][0] = uint8_t(0);
113-
in_vals[r][1] = uint8_t(0);
139+
in_vals[r][0] = uint(0);
140+
in_vals[r][1] = uint(0);
114141
}
115142
}
116143

117-
u8vec4 out_tex_1 = u8vec4(
144+
UVEC4_T out_tex_1 = UVEC4_T(
118145
combine(in_vals[0][0], in_vals[4][0]),
119146
combine(in_vals[1][0], in_vals[5][0]),
120147
combine(in_vals[2][0], in_vals[6][0]),
121148
combine(in_vals[3][0], in_vals[7][0]));
122149

123-
u8vec4 out_tex_2 = u8vec4(
150+
UVEC4_T out_tex_2 = UVEC4_T(
124151
combine(in_vals[0][1], in_vals[4][1]),
125152
combine(in_vals[1][1], in_vals[5][1]),
126153
combine(in_vals[2][1], in_vals[6][1]),

backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.yaml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
pack_int4_linear_weight_transposed_interleaved:
88
parameter_names_with_default_values:
99
STORAGE: texture2d
10-
generate_variant_forall:
11-
STORAGE:
12-
- VALUE: texture2d
13-
- VALUE: buffer
10+
NO_INT8_BUFFERS: false
1411
shader_variants:
15-
- NAME: pack_int4_linear_weight_transposed_interleaved
12+
- NAME: pack_int4_linear_weight_transposed_interleaved_texture2d
13+
- NAME: pack_int4_linear_weight_transposed_interleaved_buffer
14+
STORAGE: buffer
15+
- NAME: pack_int4_linear_weight_transposed_interleaved_nobitw8buffer_texture2d
16+
NO_INT8_BUFFERS: true

backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}
1515

1616
${define_required_extensions(DTYPE)}
17-
${define_required_extensions("int8")}
17+
$if WEIGHT_STORAGE == "buffer":
18+
${define_required_extensions("uint8")}
1819

1920
layout(std430) buffer;
2021

backends/vulkan/runtime/graph/ops/impl/QuantizedLinearGroupwiseInt4.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,6 @@ void check_q_4w_linear_args(
2222
const ValueRef group_size,
2323
const ValueRef scales_and_zeros,
2424
const ValueRef out) {
25-
VK_CHECK_COND(graph.int16_shader_types_enabled());
26-
VK_CHECK_COND(graph.int8_buffers_enabled());
27-
2825
VK_CHECK_COND(graph.val_is_tensor(mat1));
2926
VK_CHECK_COND(graph.val_is_tref(mat2_data));
3027
VK_CHECK_COND(graph.val_is_tref(scales_and_zeros));
@@ -97,7 +94,10 @@ ValueRef prepack_int4_linear_weight_transposed_interleaved(
9794
global_wg_size = graph.logical_limits_of(qmat2);
9895
global_wg_size[1] = utils::div_up(global_wg_size[1], uint32_t(2));
9996

100-
std::string kernel_name = "pack_int4_linear_weight_transposed_interleaved";
97+
std::string kernel_name =
98+
graph.context()->adapter_ptr()->has_full_int8_buffers_support()
99+
? "pack_int4_linear_weight_transposed_interleaved"
100+
: "pack_int4_linear_weight_transposed_interleaved_nobitw8buffer";
101101
add_storage_type_suffix(kernel_name, storage_type);
102102

103103
graph.prepack_nodes().emplace_back(new PrepackNode(

backends/vulkan/test/op_tests/linear_weight_int4_test.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,11 +274,6 @@ TEST(VulkanInt4LinearTest, test_reference_impl) {
274274
}
275275

276276
TEST(VulkanInt4LinearTest, test_vulkan_impl) {
277-
if (!vkcompute::api::context()
278-
->adapter_ptr()
279-
->has_full_int8_buffers_support()) {
280-
GTEST_SKIP();
281-
}
282277
test_vulkan_linear_int4(
283278
/*B = */ 1,
284279
/*M = */ 4,

0 commit comments

Comments
 (0)