Skip to content

Commit 735e365

Browse files
committed
[ET-VK] Modify quantized linear naive shader to linearly dispatch work to improve performance.
Pull Request resolved: #10116 This diff changes naive quantized linear mat mul op to use push constant instead of uniform buffers and change dispatch pattern to linear to improve performance. ghstack-source-id: 277933493 @exported-using-ghexport Differential Revision: [D72862490](https://our.internmc.facebook.com/intern/diff/D72862490/)
1 parent 38c4c77 commit 735e365

File tree

2 files changed

+44
-65
lines changed

2 files changed

+44
-65
lines changed

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,20 @@ ${layout_declare_tensor(2, "r", "t_qmat2", "int8", STORAGE)}
2929
${layout_declare_tensor(3, "r", "t_scales", DTYPE, STORAGE)}
3030

3131
$if STORAGE == "buffer":
32-
${layout_declare_ubo(4, "ivec4", "out_sizes")}
33-
${layout_declare_ubo(5, "ivec4", "out_strides")}
34-
${layout_declare_ubo(6, "int", "out_numel")}
35-
${layout_declare_ubo(7, "ivec4", "mat1_sizes")}
36-
${layout_declare_ubo(8, "ivec4", "mat1_strides")}
37-
${layout_declare_ubo(9, "ivec4", "qmat2_strides")}
38-
${layout_declare_ubo(10, "ivec4", "scales_strides")}
32+
layout(push_constant) uniform restrict Block {
33+
ivec4 out_sizes;
34+
ivec4 out_strides;
35+
ivec4 mat1_sizes;
36+
ivec4 mat1_strides;
37+
ivec4 qmat2_strides;
38+
ivec4 scales_strides;
39+
int out_numel;
40+
};
3941
$else:
40-
${layout_declare_ubo(4, "ivec3", "out_limits")}
41-
${layout_declare_ubo(5, "ivec4", "mat1_sizes")}
42+
layout(push_constant) uniform restrict Block {
43+
ivec3 out_limits;
44+
ivec4 mat1_sizes;
45+
};
4246

4347
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4448

@@ -83,42 +87,40 @@ void main() {
8387

8488
#else // USING_TEXTURE
8589

86-
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
87-
8890
void main() {
89-
const u16vec2 out_pos = u16vec2(
90-
gl_GlobalInvocationID.x,
91-
gl_GlobalInvocationID.y);
91+
const ivec2 out_pos = ivec2(
92+
gl_GlobalInvocationID.x % out_limits.x,
93+
gl_GlobalInvocationID.x / out_limits.x);
9294

93-
if (out_pos.x >= out_limits.x || out_pos.y >= out_limits.y) {
95+
if (out_pos.y >= out_limits.y) {
9496
return;
9597
}
9698

97-
const uint16_t qmat2_pos_x = out_pos.x;
99+
const int qmat2_pos_x = out_pos.x;
98100

99101
VEC4_T outtex = VEC4_T(0);
100102

101-
const VEC4_T scales = load_texel(t_scales, u16vec3(out_pos.x, 0, 0));
103+
const VEC4_T scales = load_texel(t_scales, ivec3(out_pos.x, 0, 0));
102104

103105
VEC4_T mat1_tex;
104106
VEC4_T mat2_tex[4];
105107
for (
106-
uint16_t i = uint16_t(0), x = uint16_t(0);
107-
i < uint16_t(mat1_sizes.x);
108-
i += uint16_t(4), x++)
108+
int i = 0, x = 0;
109+
i < mat1_sizes.x;
110+
i += 4, x++)
109111
{
110-
mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.y, 0));
112+
mat1_tex = load_texel(t_mat1, ivec3(x, out_pos.y, 0));
111113

112-
mat2_tex[0] = load_texel(t_qmat2, u16vec3(out_pos.x, i, 0));
113-
mat2_tex[1] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(1), 0));
114-
mat2_tex[2] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(2), 0));
115-
mat2_tex[3] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(3), 0));
114+
mat2_tex[0] = load_texel(t_qmat2, ivec3(out_pos.x, i, 0));
115+
mat2_tex[1] = load_texel(t_qmat2, ivec3(out_pos.x, i + 1, 0));
116+
mat2_tex[2] = load_texel(t_qmat2, ivec3(out_pos.x, i + 2, 0));
117+
mat2_tex[3] = load_texel(t_qmat2, ivec3(out_pos.x, i + 3, 0));
116118

117119
outtex += mat1_tex.x * mat2_tex[0] + mat1_tex.y * mat2_tex[1] + mat1_tex.z * mat2_tex[2] + mat1_tex.w * mat2_tex[3];
118120
}
119121

120122
outtex *= scales;
121-
write_texel(t_out, u16vec3(out_pos, 0), outtex);
123+
write_texel(t_out, ivec3(out_pos, 0), outtex);
122124
}
123125

124126
#endif

backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp

Lines changed: 16 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -98,47 +98,22 @@ void add_q_8w_linear_node(
9898
add_dtype_suffix(kernel_name, graph.dtype_of(out_W_packed));
9999
add_storage_type_suffix(kernel_name, graph.storage_type_of(out_W_packed));
100100

101-
vkapi::ParamsBindList ubos({});
101+
std::vector<PushConstantDataInfo> pcs;
102102
if (graph.is_buffer_storage(out_W_packed)) {
103-
ubos.append(
104-
{graph.sizes_ubo(out_W_packed),
105-
graph.strides_ubo(out_W_packed),
106-
graph.numel_ubo(out_W_packed),
107-
graph.sizes_ubo(mat1_W_packed),
108-
graph.strides_ubo(mat1),
109-
graph.strides_ubo(q_mat2),
110-
graph.strides_ubo(scales)});
103+
pcs = {graph.sizes_pc_of(out_W_packed),
104+
graph.strides_pc_of(out_W_packed),
105+
graph.sizes_pc_of(mat1_W_packed),
106+
graph.strides_pc_of(mat1),
107+
graph.strides_pc_of(q_mat2),
108+
graph.strides_pc_of(scales),
109+
graph.numel_pc_of(out_W_packed)};
111110
} else {
112-
ubos.append(
113-
{graph.logical_limits_ubo(out_W_packed),
114-
graph.sizes_ubo(mat1_W_packed)});
111+
pcs = {graph.logical_limits_pc_of(out_W_packed),
112+
graph.sizes_pc_of(mat1_W_packed)};
115113
}
116114

117-
utils::uvec3 global_wg;
118-
if (graph.is_buffer_storage(out)) {
119-
global_wg = {static_cast<uint32_t>(graph.numel_of(out_W_packed)), 1, 1};
120-
} else {
121-
global_wg = graph.logical_limits_of(out_W_packed);
122-
}
123-
124-
utils::uvec3 local_wg{8, 8, 1};
125-
int32_t out_W = graph.size_at<int32_t>(-1, out_W_packed);
126-
127-
if (graph.is_buffer_storage(out_W_packed)) {
128-
local_wg[0] = 64;
129-
local_wg[1] = 1;
130-
local_wg[2] = 1;
131-
} else {
132-
if (out_W % 8 != 0) {
133-
if (out_W % 4 == 0) {
134-
local_wg[0] = 4;
135-
local_wg[1] = 16;
136-
} else {
137-
local_wg[0] = 2;
138-
local_wg[1] = 32;
139-
}
140-
}
141-
}
115+
const utils::uvec3 global_wg = {static_cast<uint32_t>(graph.numel_of(out_W_packed)), 1, 1};
116+
const utils::uvec3 local_wg{64, 1, 1};
142117

143118
graph.execute_nodes().emplace_back(new DispatchNode(
144119
graph,
@@ -149,11 +124,13 @@ void add_q_8w_linear_node(
149124
{{out_W_packed, vkapi::MemoryAccessType::WRITE},
150125
{{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}},
151126
// Shader params buffers
152-
ubos,
127+
{},
153128
// Specialization Constants
154129
{},
155130
// Resizing Logic
156-
resize_q_8w_linear_node));
131+
resize_q_8w_linear_node,
132+
{},
133+
pcs));
157134
if (!graph.is_buffer_storage(out) &&
158135
graph.packed_dim_of(out) != WHCN::kWidthDim) {
159136
viewFn(graph, {out_W_packed, graph.add_none(), out});

0 commit comments

Comments
 (0)