Skip to content

Commit 4c3c6ea

Browse files
committed
[ET-VK] Minor improvement to permute op.
Pull Request resolved: #10117 This change reduces the complexity of boundary comparison in permute op to improve speed. ghstack-source-id: 277933492 @exported-using-ghexport Differential Revision: [D72866962](https://our.internmc.facebook.com/intern/diff/D72866962/)
1 parent 735e365 commit 4c3c6ea

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

backends/vulkan/runtime/graph/ops/glsl/permute.glsl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ layout(push_constant) uniform PRECISION restrict Block {
3131
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3232
layout(constant_id = 3) const int packed_dim = C_DIM;
3333

34+
#extension GL_EXT_control_flow_attributes : require
35+
3436
void main() {
3537
ivec3 pos = ivec3(gl_GlobalInvocationID);
3638

@@ -54,11 +56,16 @@ void main() {
5456
in_bchw_pos[out_ndims[2]] = pos.y;
5557
in_bchw_pos[out_ndims[3]] = pos.x;
5658

57-
for (int j = 0; j < 4; ++j) {
59+
const int in_packed_dim_size = in_sizes[3 - out_ndims[in_packed_dim_bchw_index]];
60+
61+
[[unroll]] for (int j = 0, bchw_index = in_bchw_pos[out_ndims[in_packed_dim_bchw_index]]; j < 4; ++j, ++bchw_index) {
5862
// terminate the loop if trying to access input texture out of bounds
59-
if (any(greaterThanEqual(in_bchw_pos.wzyx, in_sizes.xyzw))) {
63+
if (bchw_index >= in_packed_dim_size) {
6064
break;
6165
}
66+
// go to position in the input, that is mapped to the packed dim in the output
67+
in_bchw_pos[out_ndims[in_packed_dim_bchw_index]] = bchw_index;
68+
6269
ivec3 fetch_pos;
6370

6471
fetch_pos.xy = in_bchw_pos.wz;
@@ -74,9 +81,6 @@ void main() {
7481
// fetch input texel
7582
VEC4_T inval = VEC4_T(load_texel(t_in, fetch_pos));
7683
outval[j] = inval[in_packed_dim_lane_index];
77-
78-
// go to next position in the input, that is mapped to the packed dim in the output
79-
in_bchw_pos[out_ndims[in_packed_dim_bchw_index]]++;
8084
}
8185

8286
pos[packed_dim] = int(gl_GlobalInvocationID[packed_dim]);

0 commit comments

Comments
 (0)