@@ -29,16 +29,20 @@ ${layout_declare_tensor(2, "r", "t_qmat2", "int8", STORAGE)}
29
29
${layout_declare_tensor(3 , "r", "t_scales", DTYPE, STORAGE)}
30
30
31
31
$if STORAGE == "buffer ":
32
- ${layout_declare_ubo(4 , "ivec4 ", "out_sizes")}
33
- ${layout_declare_ubo(5 , "ivec4 ", "out_strides")}
34
- ${layout_declare_ubo(6 , "int ", "out_numel")}
35
- ${layout_declare_ubo(7 , "ivec4 ", "mat1_sizes")}
36
- ${layout_declare_ubo(8 , "ivec4 ", "mat1_strides")}
37
- ${layout_declare_ubo(9 , "ivec4 ", "qmat2_strides")}
38
- ${layout_declare_ubo(10 , "ivec4 ", "scales_strides")}
32
+ layout (push_constant) uniform restrict Block {
33
+ ivec4 out_sizes;
34
+ ivec4 out_strides;
35
+ ivec4 mat1_sizes;
36
+ ivec4 mat1_strides;
37
+ ivec4 qmat2_strides;
38
+ ivec4 scales_strides;
39
+ int out_numel;
40
+ };
39
41
$else :
40
- ${layout_declare_ubo(4 , "ivec3 ", "out_limits")}
41
- ${layout_declare_ubo(5 , "ivec4 ", "mat1_sizes")}
42
+ layout (push_constant) uniform restrict Block {
43
+ ivec3 out_limits;
44
+ ivec4 mat1_sizes;
45
+ };
42
46
43
47
layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
44
48
@@ -83,42 +87,40 @@ void main() {
83
87
84
88
#else // USING_TEXTURE
85
89
86
- #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
87
-
88
90
void main() {
89
- const u16vec2 out_pos = u16vec2 (
90
- gl_GlobalInvocationID.x,
91
- gl_GlobalInvocationID.y );
91
+ const ivec2 out_pos = ivec2 (
92
+ gl_GlobalInvocationID.x % out_limits.x ,
93
+ gl_GlobalInvocationID.x / out_limits.x );
92
94
93
- if (out_pos.x >= out_limits.x || out_pos. y >= out_limits.y) {
95
+ if (out_pos.y >= out_limits.y) {
94
96
return ;
95
97
}
96
98
97
- const uint16_t qmat2_pos_x = out_pos.x;
99
+ const int qmat2_pos_x = out_pos.x;
98
100
99
101
VEC4_T outtex = VEC4_T(0 );
100
102
101
- const VEC4_T scales = load_texel(t_scales, u16vec3 (out_pos.x, 0 , 0 ));
103
+ const VEC4_T scales = load_texel(t_scales, ivec3 (out_pos.x, 0 , 0 ));
102
104
103
105
VEC4_T mat1_tex;
104
106
VEC4_T mat2_tex[4 ];
105
107
for (
106
- uint16_t i = uint16_t( 0 ) , x = uint16_t( 0 ) ;
107
- i < uint16_t( mat1_sizes.x) ;
108
- i += uint16_t( 4 ) , x++ )
108
+ int i = 0 , x = 0 ;
109
+ i < mat1_sizes.x;
110
+ i += 4 , x++ )
109
111
{
110
- mat1_tex = load_texel(t_mat1, u16vec3 (x, out_pos.y, 0 ));
112
+ mat1_tex = load_texel(t_mat1, ivec3 (x, out_pos.y, 0 ));
111
113
112
- mat2_tex[0 ] = load_texel(t_qmat2, u16vec3 (out_pos.x, i, 0 ));
113
- mat2_tex[1 ] = load_texel(t_qmat2, u16vec3 (out_pos.x, i + uint16_t( 1 ) , 0 ));
114
- mat2_tex[2 ] = load_texel(t_qmat2, u16vec3 (out_pos.x, i + uint16_t( 2 ) , 0 ));
115
- mat2_tex[3 ] = load_texel(t_qmat2, u16vec3 (out_pos.x, i + uint16_t( 3 ) , 0 ));
114
+ mat2_tex[0 ] = load_texel(t_qmat2, ivec3 (out_pos.x, i, 0 ));
115
+ mat2_tex[1 ] = load_texel(t_qmat2, ivec3 (out_pos.x, i + 1 , 0 ));
116
+ mat2_tex[2 ] = load_texel(t_qmat2, ivec3 (out_pos.x, i + 2 , 0 ));
117
+ mat2_tex[3 ] = load_texel(t_qmat2, ivec3 (out_pos.x, i + 3 , 0 ));
116
118
117
119
outtex += mat1_tex.x * mat2_tex[0 ] + mat1_tex.y * mat2_tex[1 ] + mat1_tex.z * mat2_tex[2 ] + mat1_tex.w * mat2_tex[3 ];
118
120
}
119
121
120
122
outtex *= scales;
121
- write_texel(t_out, u16vec3 (out_pos, 0 ), outtex);
123
+ write_texel(t_out, ivec3 (out_pos, 0 ), outtex);
122
124
}
123
125
124
126
#endif
0 commit comments