Skip to content

Commit e40bba5

Browse files
committed
[ET-VK] Modify quantized linear naive shader to linearly dispatch work to improve performance.
This diff changes naive quantized linear mat mul op to use push constant instead of uniform buffers and change dispatch pattern to linear to improve performance. Differential Revision: [D72862490](https://our.internmc.facebook.com/intern/diff/D72862490/) ghstack-source-id: 277628836 Pull Request resolved: #10116
1 parent 0e6a7f6 commit e40bba5

File tree

2 files changed

+44
-65
lines changed

2 files changed

+44
-65
lines changed

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,20 @@ ${layout_declare_tensor(2, "r", "t_qmat2", "int8", STORAGE)}
2929
${layout_declare_tensor(3, "r", "t_scales", DTYPE, STORAGE)}
3030

3131
$if STORAGE == "buffer":
32-
${layout_declare_ubo(4, "ivec4", "out_sizes")}
33-
${layout_declare_ubo(5, "ivec4", "out_strides")}
34-
${layout_declare_ubo(6, "int", "out_numel")}
35-
${layout_declare_ubo(7, "ivec4", "mat1_sizes")}
36-
${layout_declare_ubo(8, "ivec4", "mat1_strides")}
37-
${layout_declare_ubo(9, "ivec4", "qmat2_strides")}
38-
${layout_declare_ubo(10, "ivec4", "scales_strides")}
32+
layout(push_constant) uniform restrict Block {
33+
ivec4 out_sizes;
34+
ivec4 out_strides;
35+
ivec4 mat1_sizes;
36+
ivec4 mat1_strides;
37+
ivec4 qmat2_strides;
38+
ivec4 scales_strides;
39+
int out_numel;
40+
};
3941
$else:
40-
${layout_declare_ubo(4, "ivec3", "out_limits")}
41-
${layout_declare_ubo(5, "ivec4", "mat1_sizes")}
42+
layout(push_constant) uniform restrict Block {
43+
ivec3 out_limits;
44+
ivec4 mat1_sizes;
45+
};
4246

4347
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4448

@@ -83,42 +87,40 @@ void main() {
8387

8488
#else // USING_TEXTURE
8589

86-
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
87-
8890
void main() {
89-
const u16vec2 out_pos = u16vec2(
90-
gl_GlobalInvocationID.x,
91-
gl_GlobalInvocationID.y);
91+
const ivec2 out_pos = ivec2(
92+
gl_GlobalInvocationID.x % out_limits.x,
93+
gl_GlobalInvocationID.x / out_limits.x);
9294

93-
if (out_pos.x >= out_limits.x || out_pos.y >= out_limits.y) {
95+
if (out_pos.y >= out_limits.y) {
9496
return;
9597
}
9698

97-
const uint16_t qmat2_pos_x = out_pos.x;
99+
const int qmat2_pos_x = out_pos.x;
98100

99101
VEC4_T outtex = VEC4_T(0);
100102

101-
const VEC4_T scales = load_texel(t_scales, u16vec3(out_pos.x, 0, 0));
103+
const VEC4_T scales = load_texel(t_scales, ivec3(out_pos.x, 0, 0));
102104

103105
VEC4_T mat1_tex;
104106
VEC4_T mat2_tex[4];
105107
for (
106-
uint16_t i = uint16_t(0), x = uint16_t(0);
107-
i < uint16_t(mat1_sizes.x);
108-
i += uint16_t(4), x++)
108+
int i = 0, x = 0;
109+
i < mat1_sizes.x;
110+
i += 4, x++)
109111
{
110-
mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.y, 0));
112+
mat1_tex = load_texel(t_mat1, ivec3(x, out_pos.y, 0));
111113

112-
mat2_tex[0] = load_texel(t_qmat2, u16vec3(out_pos.x, i, 0));
113-
mat2_tex[1] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(1), 0));
114-
mat2_tex[2] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(2), 0));
115-
mat2_tex[3] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(3), 0));
114+
mat2_tex[0] = load_texel(t_qmat2, ivec3(out_pos.x, i, 0));
115+
mat2_tex[1] = load_texel(t_qmat2, ivec3(out_pos.x, i + 1, 0));
116+
mat2_tex[2] = load_texel(t_qmat2, ivec3(out_pos.x, i + 2, 0));
117+
mat2_tex[3] = load_texel(t_qmat2, ivec3(out_pos.x, i + 3, 0));
116118

117119
outtex += mat1_tex.x * mat2_tex[0] + mat1_tex.y * mat2_tex[1] + mat1_tex.z * mat2_tex[2] + mat1_tex.w * mat2_tex[3];
118120
}
119121

120122
outtex *= scales;
121-
write_texel(t_out, u16vec3(out_pos, 0), outtex);
123+
write_texel(t_out, ivec3(out_pos, 0), outtex);
122124
}
123125

124126
#endif

backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp

Lines changed: 16 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -98,47 +98,22 @@ void add_q_8w_linear_node(
9898
add_dtype_suffix(kernel_name, graph.dtype_of(out_W_packed));
9999
add_storage_type_suffix(kernel_name, graph.storage_type_of(out_W_packed));
100100

101-
vkapi::ParamsBindList ubos({});
101+
std::vector<PushConstantDataInfo> pcs;
102102
if (graph.is_buffer_storage(out_W_packed)) {
103-
ubos.append(
104-
{graph.sizes_ubo(out_W_packed),
105-
graph.strides_ubo(out_W_packed),
106-
graph.numel_ubo(out_W_packed),
107-
graph.sizes_ubo(mat1_W_packed),
108-
graph.strides_ubo(mat1),
109-
graph.strides_ubo(q_mat2),
110-
graph.strides_ubo(scales)});
103+
pcs = {graph.sizes_pc_of(out_W_packed),
104+
graph.strides_pc_of(out_W_packed),
105+
graph.sizes_pc_of(mat1_W_packed),
106+
graph.strides_pc_of(mat1),
107+
graph.strides_pc_of(q_mat2),
108+
graph.strides_pc_of(scales),
109+
graph.numel_pc_of(out_W_packed)};
111110
} else {
112-
ubos.append(
113-
{graph.logical_limits_ubo(out_W_packed),
114-
graph.sizes_ubo(mat1_W_packed)});
111+
pcs = {graph.logical_limits_pc_of(out_W_packed),
112+
graph.sizes_pc_of(mat1_W_packed)};
115113
}
116114

117-
utils::uvec3 global_wg;
118-
if (graph.is_buffer_storage(out)) {
119-
global_wg = {static_cast<uint32_t>(graph.numel_of(out_W_packed)), 1, 1};
120-
} else {
121-
global_wg = graph.logical_limits_of(out_W_packed);
122-
}
123-
124-
utils::uvec3 local_wg{8, 8, 1};
125-
int32_t out_W = graph.size_at<int32_t>(-1, out_W_packed);
126-
127-
if (graph.is_buffer_storage(out_W_packed)) {
128-
local_wg[0] = 64;
129-
local_wg[1] = 1;
130-
local_wg[2] = 1;
131-
} else {
132-
if (out_W % 8 != 0) {
133-
if (out_W % 4 == 0) {
134-
local_wg[0] = 4;
135-
local_wg[1] = 16;
136-
} else {
137-
local_wg[0] = 2;
138-
local_wg[1] = 32;
139-
}
140-
}
141-
}
115+
const utils::uvec3 global_wg = {static_cast<uint32_t>(graph.numel_of(out_W_packed)), 1, 1};
116+
const utils::uvec3 local_wg{64, 1, 1};
142117

143118
graph.execute_nodes().emplace_back(new DispatchNode(
144119
graph,
@@ -149,11 +124,13 @@ void add_q_8w_linear_node(
149124
{{out_W_packed, vkapi::MemoryAccessType::WRITE},
150125
{{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}},
151126
// Shader params buffers
152-
ubos,
127+
{},
153128
// Specialization Constants
154129
{},
155130
// Resizing Logic
156-
resize_q_8w_linear_node));
131+
resize_q_8w_linear_node,
132+
{},
133+
pcs));
157134
if (!graph.is_buffer_storage(out) &&
158135
graph.packed_dim_of(out) != WHCN::kWidthDim) {
159136
viewFn(graph, {out_W_packed, graph.add_none(), out});

0 commit comments

Comments
 (0)