Skip to content

Commit ea215fd

Browse files
committed
[ET-VK] Add buffer support for binary ops
## Context As title; add an implementation for binary operators for buffer-backed tensors. Differential Revision: [D70810338](https://our.internmc.facebook.com/intern/diff/D70810338/) ghstack-source-id: 270464061 Pull Request resolved: #9063
1 parent 6099020 commit ea215fd

File tree

5 files changed

+132
-19
lines changed

5 files changed

+132
-19
lines changed

backends/vulkan/runtime/api/containers/Tensor.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,19 +245,19 @@ class vTensor final {
245245
TextureLimits logical_limits;
246246
// Contains the number of elements in the tensor according to the canonical
247247
// sizes.
248-
size_t numel;
248+
int32_t numel;
249249

250250
friend class vTensor;
251251

252252
UniformData(
253253
const std::vector<int64_t>& sizes,
254254
const std::vector<int64_t>& strides,
255255
const TextureLimits& logical_limits,
256-
const size_t numel)
256+
const size_t numel_ll)
257257
: sizes_v(utils::make_whcn_ivec4(sizes)),
258258
strides_v(utils::make_whcn_ivec4(strides)),
259259
logical_limits(logical_limits),
260-
numel(numel) {}
260+
numel(utils::safe_downcast<int32_t>(numel_ll)) {}
261261

262262
public:
263263
/*

backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,37 +11,83 @@
1111
#define PRECISION ${PRECISION}
1212

1313
#define VEC4_T ${texel_type(DTYPE)}
14+
#define T ${buffer_scalar_type(DTYPE)}
1415

1516
#define op(X, Y, A) ${OPERATOR}
1617

18+
${define_active_storage_type(STORAGE)}
19+
${define_required_extensions(DTYPE)}
20+
1721
layout(std430) buffer;
1822

1923
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
2024
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
2125
${layout_declare_tensor(B, "r", "t_other", DTYPE, STORAGE)}
2226

27+
$if STORAGE == "buffer":
28+
layout(push_constant) uniform restrict Block {
29+
ivec4 in_sizes;
30+
ivec4 other_sizes;
31+
ivec4 out_strides;
32+
ivec4 in_strides;
33+
ivec4 other_strides;
34+
int out_numel;
35+
float alpha;
36+
};
37+
$else:
38+
layout(push_constant) uniform restrict Block {
39+
ivec4 out_sizes;
40+
ivec4 in_sizes;
41+
ivec4 other_sizes;
42+
ivec2 broadcast_params;
43+
float alpha;
44+
};
45+
2346
#include "broadcasting_utils.h"
2447
#include "indexing_utils.h"
2548

2649
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2750

28-
${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
29-
const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
30-
const lowp int packed_dim = unhash_packed_dim(out_layout);
51+
$if STORAGE == "buffer":
52+
${layout_declare_spec_const(C, "int", "out_packed_dim", "DEFAULT_LAYOUT")}
53+
${layout_declare_spec_const(C, "int", "in_packed_dim", "DEFAULT_LAYOUT")}
54+
${layout_declare_spec_const(C, "int", "other_packed_dim", "DEFAULT_LAYOUT")}
55+
$else:
56+
${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
57+
const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
58+
const lowp int packed_dim = unhash_packed_dim(out_layout);
3159

32-
${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")}
33-
const lowp ivec4 in_axis_map = unhash_axis_map(in_layout);
60+
${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")}
61+
const lowp ivec4 in_axis_map = unhash_axis_map(in_layout);
3462

35-
${layout_declare_spec_const(C, "int", "other_layout", "DEFAULT_LAYOUT")}
36-
const lowp ivec4 other_axis_map = unhash_axis_map(other_layout);
63+
${layout_declare_spec_const(C, "int", "other_layout", "DEFAULT_LAYOUT")}
64+
const lowp ivec4 other_axis_map = unhash_axis_map(other_layout);
3765

38-
layout(push_constant) uniform restrict Block {
39-
ivec4 out_sizes;
40-
ivec4 in_sizes;
41-
ivec4 other_sizes;
42-
ivec2 broadcast_params;
43-
float alpha;
44-
};
66+
#ifdef USING_BUFFER
67+
68+
void main() {
69+
const int out_bufi = ivec3(gl_GlobalInvocationID).x;
70+
if (out_bufi >= out_numel) {
71+
return;
72+
}
73+
74+
// Simple case; no broadcasting
75+
if (in_sizes == other_sizes) {
76+
t_out[out_bufi] = T(op(t_in[out_bufi], t_other[out_bufi], T(alpha)));
77+
return;
78+
}
79+
80+
const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_packed_dim);
81+
const ivec4 in_tidx = min(out_tidx, in_sizes - 1);
82+
const ivec4 other_tidx = min(out_tidx, other_sizes - 1);
83+
84+
const int in_bufi = tidx_to_bufi(in_tidx, in_strides);
85+
const int other_bufi = tidx_to_bufi(other_tidx, other_strides);
86+
87+
t_out[out_bufi] = T(op(t_in[in_bufi], t_other[other_bufi], T(alpha)));
88+
}
89+
90+
#else // USING_TEXTURE
4591

4692
void main() {
4793
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
@@ -79,3 +125,5 @@ void main() {
79125
VEC4_T(op(in_texel, other_texel, alpha)),
80126
out_axis_map);
81127
}
128+
129+
#endif

backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@ binary_op:
1010
NDIM: 3
1111
DTYPE: float
1212
PACKING: C_packed
13-
STORAGE: texture3d
1413
generate_variant_forall:
14+
STORAGE:
15+
- VALUE: texture3d
16+
- VALUE: buffer
1517
DTYPE:
1618
- VALUE: half
1719
- VALUE: float

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ void resize_binary_op_node(
4444
out->virtual_resize(new_out_sizes);
4545
}
4646

47-
void add_binary_op_node(
47+
void add_binary_op_texture_node(
4848
ComputeGraph& graph,
4949
const ValueRef in1,
5050
const ValueRef in2,
@@ -75,6 +75,7 @@ void add_binary_op_node(
7575
std::string kernel_name("binary_");
7676
kernel_name.reserve(kShaderNameReserve);
7777
kernel_name += op_name;
78+
add_storage_type_suffix(kernel_name, *t_out);
7879
add_dtype_suffix(kernel_name, *t_out);
7980

8081
graph.execute_nodes().emplace_back(new DispatchNode(
@@ -98,6 +99,67 @@ void add_binary_op_node(
9899
PushConstantDataInfo(&binary_ops_params, sizeof(binary_ops_params))}}));
99100
}
100101

102+
void add_binary_op_buffer_node(
103+
ComputeGraph& graph,
104+
const ValueRef in1,
105+
const ValueRef in2,
106+
const ValueRef alpha,
107+
const ValueRef out,
108+
const std::string& op_name) {
109+
// check_binary_op_args(*t_in1, *t_in2, *t_out);
110+
111+
float alpha_val = 1.0f;
112+
// String is checked since floor_div passes in an unused string argument in
113+
// place of alpha
114+
if (is_valid(alpha) && !graph.val_is_string(alpha)) {
115+
alpha_val = graph.extract_scalar<float>(alpha);
116+
}
117+
118+
std::string kernel_name("binary_");
119+
kernel_name.reserve(kShaderNameReserve);
120+
kernel_name += op_name;
121+
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
122+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
123+
124+
graph.execute_nodes().emplace_back(new DispatchNode(
125+
graph,
126+
VK_KERNEL_FROM_STR(kernel_name),
127+
graph.create_global_wg_size(out),
128+
graph.create_local_wg_size(out),
129+
// Inputs and Outputs
130+
{{out, vkapi::MemoryAccessType::WRITE},
131+
{{in1, in2}, vkapi::MemoryAccessType::READ}},
132+
// Shader params buffers
133+
{},
134+
// Specialization Constants
135+
{graph.packed_dim_of(out), graph.packed_dim_of(in1), graph.packed_dim_of(in2)},
136+
// Resizing Logic
137+
resize_binary_op_node,
138+
{},
139+
{{graph.sizes_pc_of(in1),
140+
graph.sizes_pc_of(in2),
141+
graph.strides_pc_of(out),
142+
graph.strides_pc_of(in1),
143+
graph.strides_pc_of(in2),
144+
graph.numel_pc_of(out),
145+
PushConstantDataInfo(&alpha_val, sizeof(float)),
146+
}}));
147+
}
148+
149+
void add_binary_op_node(
150+
ComputeGraph& graph,
151+
const ValueRef in1,
152+
const ValueRef in2,
153+
const ValueRef alpha,
154+
const ValueRef out,
155+
const std::string& op_name) {
156+
if (graph.is_buffer_storage(out)) {
157+
add_binary_op_buffer_node(graph, in1, in2, alpha, out, op_name);
158+
} else {
159+
add_binary_op_texture_node(graph, in1, in2, alpha, out, op_name);
160+
}
161+
}
162+
101163
#define DEFINE_BINARY_OP_WITH_ALPHA_FN(op_name) \
102164
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
103165
return add_binary_op_node( \

backends/vulkan/test/op_tests/cases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def get_binary_elementwise_inputs():
5656
"utils::kWidthPacked",
5757
"utils::kChannelsPacked",
5858
]
59+
test_suite.storage_types = ["utils::kBuffer", "utils::kTexture3D"]
5960
return test_suite
6061

6162

0 commit comments

Comments
 (0)