Skip to content

Commit f357169

Browse files
authored
merge q_8w_linear and main functions in q_8w_linear shader
Differential Revision: D70127663 Pull Request resolved: #8704
1 parent ef9c3aa commit f357169

File tree

1 file changed

+30
-37
lines changed

1 file changed

+30
-37
lines changed

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl

Lines changed: 30 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,26 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
5252
#define FLOAT_T float
5353
#endif
5454

55-
FLOAT_T q_8w_linear(const ivec4 out_idx, const int K) {
56-
const FLOAT_T scale = t_scales[out_idx.x];
55+
void main() {
56+
const int out_bufi = int(gl_GlobalInvocationID.x);
57+
if (out_bufi >= out_numel) {
58+
return;
59+
}
60+
61+
const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, 0);
62+
63+
const FLOAT_T scale = t_scales[out_tidx.x];
5764

5865
FLOAT_T outval = FLOAT_T(0.0);
5966

60-
// Initial mat1 tensor idx will be (0, out_idx.y, out_idx.z, 0)
61-
int mat1_offset = out_idx.y * mat1_strides.y + out_idx.z * qmat2_strides.z;
62-
// Initial qmat2 tensor idx wil be (0, out_idx.x, 0, 0); note that the qmat2
67+
// Initial mat1 tensor idx will be (0, out_tidx.y, out_tidx.z, 0)
68+
int mat1_offset = out_tidx.y * mat1_strides.y + out_tidx.z * qmat2_strides.z;
69+
// Initial qmat2 tensor idx wil be (0, out_tidx.x, 0, 0); note that the qmat2
6370
// tensor is transposed
64-
int qmat2_offset = out_idx.x * qmat2_strides.y;
71+
int qmat2_offset = out_tidx.x * qmat2_strides.y;
6572

66-
// TODO(ssjia): optimize memory access pattern by traversing K in inner loop
67-
for (int i = 0; i < K; i++) {
73+
// TODO(ssjia): optimize memory access pattern by traversing mat1 x in inner loop
74+
for (int i = 0; i < mat1_sizes.x; i++) {
6875
const FLOAT_T mat1_val = t_mat1[mat1_offset];
6976
const FLOAT_T mat2_val = t_qmat2[qmat2_offset] * scale;
7077

@@ -74,33 +81,32 @@ FLOAT_T q_8w_linear(const ivec4 out_idx, const int K) {
7481
qmat2_offset++;
7582
}
7683

77-
return outval;
78-
}
79-
80-
void main() {
81-
const int out_bufi = int(gl_GlobalInvocationID.x);
82-
if (out_bufi >= out_numel) {
83-
return;
84-
}
85-
86-
const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, 0);
87-
88-
t_out[out_bufi] = q_8w_linear(out_tidx, mat1_sizes.x);
84+
t_out[out_bufi] = outval;
8985
}
9086

9187
#else // USING_TEXTURE
9288

9389
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
9490

95-
VEC4_T q_8w_linear(const u16vec2 out_pos, const uint16_t K) {
91+
void main() {
92+
const u16vec2 out_pos = u16vec2(
93+
gl_GlobalInvocationID.x / out_limits.y,
94+
gl_GlobalInvocationID.x % out_limits.y);
95+
if (out_pos.x >= out_limits.x) {
96+
return;
97+
}
98+
9699
const uint16_t qmat2_pos_y = out_pos.x * uint16_t(4);
97100

98101
VEC4_T outtex = VEC4_T(0);
99102

100-
const u16vec3 scales_pos = u16vec3(out_pos.x, 0, 0);
101-
const VEC4_T scales = load_texel(t_scales, scales_pos);
103+
const VEC4_T scales = load_texel(t_scales, u16vec3(out_pos.x, 0, 0));
102104

103-
for (uint16_t i = uint16_t(0), x = uint16_t(0); i < K; i += uint16_t(4), x++) {
105+
for (
106+
uint16_t i = uint16_t(0), x = uint16_t(0);
107+
i < uint16_t(mat1_sizes.x);
108+
i += uint16_t(4), x++)
109+
{
104110
const VEC4_T mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.y, 0));
105111
const VEC4_T sums = VEC4_T(
106112
dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y, 0))),
@@ -112,19 +118,6 @@ VEC4_T q_8w_linear(const u16vec2 out_pos, const uint16_t K) {
112118
}
113119

114120
outtex *= scales;
115-
116-
return outtex;
117-
}
118-
119-
void main() {
120-
const u16vec2 out_pos = u16vec2(
121-
gl_GlobalInvocationID.x / out_limits.y,
122-
gl_GlobalInvocationID.x % out_limits.y);
123-
if (out_pos.x >= out_limits.x) {
124-
return;
125-
}
126-
127-
VEC4_T outtex = q_8w_linear(out_pos, uint16_t(mat1_sizes.x));
128121
write_texel(t_out, u16vec3(out_pos, 0), outtex);
129122
}
130123

0 commit comments

Comments
 (0)