Skip to content

Commit c28f087

Browse files
committed
Update
[ghstack-poisoned]
2 parents 3c30f09 + a088f9f commit c28f087

File tree

22 files changed

+287
-130
lines changed

22 files changed

+287
-130
lines changed

CODEOWNERS

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
/backends/vulkan @SS-JIA
1616
/backends/xnnpack @digantdesai @mcr229
1717

18-
/build @GregoryComer @dbort @kirklandsign
18+
/build @GregoryComer @kirklandsign
1919

2020
/codegen @larryliu0820 @lucylq
2121

@@ -47,32 +47,32 @@
4747
/extension/apple @shoumikhin
4848
/extension/aten_util @JacobSzwejbka
4949
/extension/benchmark @tarun292
50-
/extension/data_loader @JacobSzwejbka @lucylq @dbort
51-
/extension/evalue_util @GregoryComer @dbort
50+
/extension/data_loader @JacobSzwejbka @lucylq
51+
/extension/evalue_util @GregoryComer
5252
/extension/export_util @kimishpatel
5353
/extension/flat_tensor @lucylq
5454
/extension/gguf_util @larryliu0820
5555
/extension/kernel_util @kimishpatel @manuelcandales
5656
/extension/llm @jackzhxng @iseeyuan @larryliu0820
57-
/extension/memory_allocator @JacobSzwejbka @dbort
57+
/extension/memory_allocator @JacobSzwejbka
5858
/extension/module @shoumikhin
5959
/extension/parallel @kimishpatel
6060
/extension/pybindings @JacobSzwejbka @larryliu0820
6161
/extension/pytree @JacobSzwejbka
62-
/extension/runner_util @dbort
62+
# /extension/runner_util @dbort
6363
/extension/tensor @shoumikhin
64-
/extension/testing_util @dbort
64+
# /extension/testing_util @dbort
6565
/extension/threadpool @kimishpatel
6666
/extension/training @JacobSzwejbka
6767

6868
/kernels @manuelcandales
6969

7070
/profiler @tarun292 @Gasoonjia
7171

72-
/runtime @dbort @JacobSzwejbka @lucylq
72+
/runtime @JacobSzwejbka @lucylq
7373
/runtime/backend @cccclai
7474

75-
/schema @dbort @JacobSzwejbka @lucylq
75+
/schema @JacobSzwejbka @lucylq
7676

7777
/scripts @GregoryComer
7878

backends/cadence/aot/remove_ops.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,72 @@ def remove_branched(
807807
user.replace_all_uses_with(node.args[0])
808808

809809

810+
class RemoveCatFromSliceCopyPass(ExportPass):
811+
def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
812+
slice_copy_nodes = [
813+
node
814+
for node in graph_module.graph.nodes
815+
if node.target == exir_ops.edge.aten.slice_copy.Tensor
816+
]
817+
for slice_copy_node in slice_copy_nodes:
818+
slice_dim, start_idx, end_idx, step = 0, 0, float("inf"), 1
819+
input_node, *other_args = slice_copy_node.args
820+
if len(other_args) >= 1:
821+
slice_dim = other_args[0]
822+
if len(other_args) >= 2:
823+
start_idx = other_args[1]
824+
if len(other_args) >= 3:
825+
end_idx = other_args[2]
826+
if len(other_args) >= 4:
827+
step = other_args[3]
828+
if step != 1:
829+
continue
830+
slice_copy_dtype = slice_copy_node.meta["val"].dtype
831+
if input_node.target != exir_ops.edge.aten.cat.default:
832+
continue
833+
cat_dtype = input_node.meta["val"].dtype
834+
if slice_copy_dtype != cat_dtype:
835+
continue
836+
cat_dim = input_node.args[1:]
837+
if len(cat_dim) == 0:
838+
cat_dim = 0
839+
if cat_dim != slice_dim:
840+
continue
841+
cat_output_shape = input_node.meta["val"].shape
842+
start_idx = (
843+
cat_output_shape[cat_dim] + start_idx if start_idx < 0 else start_idx
844+
)
845+
end_idx = (
846+
cat_output_shape[cat_dim]
847+
if end_idx > cat_output_shape[cat_dim]
848+
else end_idx
849+
)
850+
base_idx = 0
851+
cat_input_to_keep = None
852+
for cat_input_node in input_node.args[0]:
853+
cat_input_dtype = cat_input_node.meta["val"].dtype
854+
if slice_copy_dtype != cat_input_dtype:
855+
continue
856+
cat_input_shape = cat_input_node.meta["val"].shape
857+
858+
# check if the slice range overlaps with the cat range
859+
if (
860+
base_idx <= start_idx
861+
and end_idx <= list(cat_input_shape)[cat_dim] + base_idx
862+
):
863+
cat_input_to_keep = cat_input_node
864+
break
865+
base_idx += list(cat_input_shape)[cat_dim]
866+
if cat_input_to_keep is not None:
867+
slice_copy_node.replace_input_with(input_node, cat_input_to_keep)
868+
869+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
870+
self._remove_unused_cat(graph_module)
871+
graph_module.recompile()
872+
graph_module.graph.eliminate_dead_code()
873+
return super().call(graph_module)
874+
875+
810876
# The following class consolidates functions to remove ops that are redundant
811877
# in Jarvis. Currently, each function in this class iterates over each node of
812878
# the graph module once. In future, we could consolidate them into a monolithic

backends/cadence/aot/tests/test_remove_ops_passes.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from executorch.backends.cadence.aot.remove_ops import (
2323
RemoveAliasCopyOpPass,
2424
RemoveBranchedQuantDequant,
25+
RemoveCatFromSliceCopyPass,
2526
RemoveCloneOpPass,
2627
RemoveContiguousOpPass,
2728
RemoveDetachCopyPass,
@@ -741,3 +742,54 @@ def forward(self, x):
741742
},
742743
)
743744
)
745+
746+
def test_remove_cat_from_slice_copy_all_removal(self) -> None:
747+
class M(torch.nn.Module):
748+
def __init__(self):
749+
super().__init__()
750+
751+
def forward(self, x, y):
752+
x1 = torch.cat((x, y), 0) # (2, 4)
753+
return torch.slice_copy(x1, dim=0, start=0, end=1)
754+
755+
inputs = tuple(torch.randn(2, 4) for _ in range(2))
756+
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
757+
p = RemoveCatFromSliceCopyPass()
758+
graph_module = cast(PassResult, p(graph_module)).graph_module
759+
760+
# Ensure both cat nodes were removed
761+
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)
762+
763+
def test_remove_cat_from_slice_copy_no_removal(self) -> None:
764+
class M(torch.nn.Module):
765+
def __init__(self):
766+
super().__init__()
767+
768+
def forward(self, x, y):
769+
x1 = torch.cat((x, y), 0) # (2, 4)
770+
return torch.slice_copy(x1, dim=0, start=0, end=3)
771+
772+
inputs = tuple(torch.randn(2, 4) for _ in range(2))
773+
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
774+
p = RemoveCatFromSliceCopyPass()
775+
graph_module = cast(PassResult, p(graph_module)).graph_module
776+
777+
# Ensure both cat nodes were removed
778+
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 1)
779+
780+
def test_remove_cat_from_slice_copy_zero_range(self) -> None:
781+
class M(torch.nn.Module):
782+
def __init__(self):
783+
super().__init__()
784+
785+
def forward(self, x, y):
786+
x1 = torch.cat((x, y), 0) # (2, 4)
787+
return torch.slice_copy(x1, dim=0, start=0, end=0)
788+
789+
inputs = tuple(torch.randn(2, 4) for _ in range(2))
790+
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
791+
p = RemoveCatFromSliceCopyPass()
792+
graph_module = cast(PassResult, p(graph_module)).graph_module
793+
794+
# Ensure both cat nodes were removed
795+
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)

backends/vulkan/op_registry.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,6 @@ def register_view_op(features: OpFeatures):
530530
exir_ops.edge.aten.flip.default,
531531
exir_ops.edge.aten.index_select.default,
532532
exir_ops.edge.aten.select_copy.int,
533-
exir_ops.edge.aten.slice_copy.Tensor,
534533
# Tensor combination
535534
exir_ops.edge.aten.cat.default,
536535
exir_ops.edge.aten.split_with_sizes_copy.default,
@@ -557,6 +556,19 @@ def register_ported_op(features: OpFeatures):
557556
return features
558557

559558

559+
@update_features(
560+
[
561+
# Indexing and lookup
562+
exir_ops.edge.aten.slice_copy.Tensor,
563+
]
564+
)
565+
def register_ported_op_all_packed_dims(features: OpFeatures):
566+
features.texture_impl = TextureImplFeatures(
567+
valid_packed_dims=all_packed_dims,
568+
)
569+
return features
570+
571+
560572
# Ported ops that support their own prepacking.
561573
@update_features(
562574
[

backends/vulkan/runtime/graph/ops/glsl/slice_batch_height_width.glsl

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ layout(set = 0, binding = 3) uniform PRECISION restrict SliceArg {
2727
int dim;
2828
int offset;
2929
int step;
30-
// Used when dim=batch. Stride is the # of plances for each batch value.
31-
int stride;
30+
int image_in_channel_size;
3231
}
3332
slice_arg;
3433

@@ -45,11 +44,24 @@ void main() {
4544

4645
ivec3 in_pos = pos;
4746

48-
int index = pos[slice_arg.dim] / slice_arg.stride;
49-
int within_stride = pos[slice_arg.dim] % slice_arg.stride;
50-
51-
in_pos[slice_arg.dim] = slice_arg.offset * slice_arg.stride + index * slice_arg.step *
52-
slice_arg.stride + within_stride;
47+
// slice along batch axis
48+
if (slice_arg.dim == 3) {
49+
// index of the channel inside a batch
50+
const int chanl_index = pos.z % slice_arg.image_in_channel_size;
51+
// index of batch
52+
const int batch_index = pos.z / slice_arg.image_in_channel_size;
53+
in_pos.z = (slice_arg.offset + batch_index * slice_arg.step) * slice_arg.image_in_channel_size + chanl_index;
54+
} else if (slice_arg.dim == C_DIM) {
55+
// index of the channel inside a batch
56+
const int chanl_index = pos.z % sizes.z;
57+
// index of batch
58+
const int batch_index = pos.z / sizes.z;
59+
in_pos.z = slice_arg.offset + batch_index * slice_arg.image_in_channel_size + chanl_index * slice_arg.step;
60+
} else if (slice_arg.dim == H_DIM) {
61+
in_pos.y = slice_arg.offset + pos.y * slice_arg.step;
62+
} else {
63+
in_pos.x = slice_arg.offset + pos.x * slice_arg.step;
64+
}
5365

5466
imageStore(image_out, pos, texelFetch(image_in, in_pos, 0));
5567

backends/vulkan/runtime/graph/ops/glsl/slice_channel.glsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ void main() {
4949
for (int i=0;i<4;i++) {
5050
ivec4 user_coor = nchwi_to_tidx(buf_indices[i], out_sizes);
5151

52-
int in_channel = user_coor.z;
52+
int in_dim = user_coor[packed_dim];
5353

5454
ivec4 in_user_coor = user_coor;
55-
in_user_coor.z = slice_arg.offset + in_channel * slice_arg.step;
55+
in_user_coor[packed_dim] = slice_arg.offset + in_dim * slice_arg.step;
5656

5757
ivec4 in_pow_elem = to_texture_elem_pos(
5858
in_user_coor,

backends/vulkan/runtime/graph/ops/impl/Slice.cpp

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ void add_slice_tensor_copy_node(
4444
vTensorPtr t_in = graph.get_tensor(in);
4545
vTensorPtr t_out = graph.get_tensor(out);
4646

47-
VK_CHECK_COND(check_packed_dim_is(*t_in, WHCN::kChannelsDim));
48-
VK_CHECK_COND(check_packed_dim_is(*t_out, WHCN::kChannelsDim));
47+
VK_CHECK_COND(check_same_packed_dim(*t_in, *t_out));
4948

5049
// Need normalize the dim
5150
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
@@ -76,7 +75,13 @@ void add_slice_tensor_copy_node(
7675
start = normalize_idx(start, in_sizes[dim], 0);
7776
end = normalize_idx(end, in_sizes[dim], in_sizes[dim]);
7877

79-
if (dim_index == kChannel4D) {
78+
const vkapi::SpecVarList spec_vars = {t_in->packed_dim()};
79+
80+
const auto packed_dim_idx =
81+
static_cast<DimIndex>(DimIndex::DIM_LAST - t_in->packed_dim());
82+
83+
// if slice dim is the same as the packed dim, we can use the channel slice
84+
if (dim_index == packed_dim_idx) {
8085
// slice by channel
8186
std::string kernel_name = "slice_channel";
8287
kernel_name.reserve(kShaderNameReserve);
@@ -99,26 +104,31 @@ void add_slice_tensor_copy_node(
99104
{in, vkapi::MemoryAccessType::READ}},
100105
{t_out->sizes_ubo(),
101106
t_in->sizes_ubo(),
102-
graph.create_params_buffer(params)}));
107+
graph.create_params_buffer(params)},
108+
spec_vars));
103109

104110
} else {
105111
// GPU's coordinate is in x, y, z
106112
int64_t gpu_dim = -1;
107-
int64_t stride = 1;
113+
int64_t in_channel_stride = 1;
108114
if (dim_index == kWidth4D) {
109115
gpu_dim = 0; // width: x dimension in gpu
110116
VK_CHECK_COND(out_sizes[dim] == (1 + (end - start - 1) / step));
111117
} else if (dim_index == kHeight4D) {
112118
gpu_dim = 1; // height: y dimension
113119
VK_CHECK_COND(out_sizes[dim] == (1 + (end - start - 1) / step));
114-
} else if (dim_index == kBatch4D) {
115-
gpu_dim = 2; // batch: z dimension
116-
117-
// Due to channel packing, each batch value is span over stride planes
118-
int64_t n_channels = dim_at(in_sizes, kChannel4D);
119-
stride = utils::div_up_4(n_channels);
120+
} else if (dim_index == kChannel4D) {
121+
gpu_dim = 2; // channel: z dimension
122+
VK_CHECK_COND(out_sizes[dim] == (1 + (end - start - 1) / step));
123+
in_channel_stride = dim_at(in_sizes, kChannel4D);
120124
} else {
121-
VK_THROW("Unexpected ncwh_dim!");
125+
gpu_dim = 3; // batch: w dimension
126+
127+
in_channel_stride = dim_at(in_sizes, kChannel4D);
128+
if (packed_dim_idx == kChannel4D) {
129+
// Due to channel packing, each batch value is span over stride planes
130+
in_channel_stride = utils::div_up_4(in_channel_stride);
131+
}
122132
}
123133

124134
std::string kernel_name = "slice_batch_height_width";
@@ -137,7 +147,7 @@ void add_slice_tensor_copy_node(
137147
static_cast<int32_t>(gpu_dim),
138148
static_cast<int32_t>(start),
139149
static_cast<int32_t>(step),
140-
static_cast<int32_t>(stride),
150+
static_cast<int32_t>(in_channel_stride),
141151
};
142152

143153
graph.execute_nodes().emplace_back(new DispatchNode(
@@ -147,7 +157,8 @@ void add_slice_tensor_copy_node(
147157
local_size,
148158
{{out, vkapi::MemoryAccessType::WRITE},
149159
{in, vkapi::MemoryAccessType::READ}},
150-
{t_out->sizes_ubo(), graph.create_params_buffer(params)}));
160+
{t_out->sizes_ubo(), graph.create_params_buffer(params)},
161+
spec_vars));
151162
}
152163
}
153164

backends/vulkan/test/op_tests/cases.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,11 @@ def get_slice_out_inputs():
585585
test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
586586

587587
test_suite.dtypes = ["at::kFloat", "at::kHalf"]
588-
test_suite.layouts = ["utils::kChannelsPacked"]
588+
test_suite.layouts = [
589+
"utils::kWidthPacked",
590+
"utils::kHeightPacked",
591+
"utils::kChannelsPacked",
592+
]
589593
test_suite.data_gen = "make_seq_tensor"
590594
return test_suite
591595

build/cmake_deps.toml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,21 @@ deps = [
5858
"executorch_core",
5959
]
6060

61+
# HACK: prevent reduce_util from also showing up in custom_ops. The
62+
# actual medium-term fix is to stop using Buck to drive our CMake
63+
# builds.
64+
[targets.reduce_util]
65+
buck_targets = [
66+
"//kernels/portable/cpu/util:reduce_util",
67+
]
68+
filters = [
69+
".cpp$",
70+
]
71+
deps = [
72+
"executorch",
73+
"executorch_core",
74+
]
75+
6176
[targets.optimized_kernels]
6277
buck_targets = [
6378
"//kernels/optimized:generated_lib",
@@ -414,6 +429,7 @@ deps = [
414429
"optimized_kernels",
415430
"extension_parallel",
416431
"extension_threadpool",
432+
"reduce_util",
417433
"xnnpack_backend",
418434
]
419435

0 commit comments

Comments
 (0)