Skip to content

Commit c352672

Browse files
authored
Fixup op_slice negative start arguments
Differential Revision: D72728353 Pull Request resolved: #10122
1 parent a664d7b commit c352672

File tree

2 files changed

+52
-34
lines changed

2 files changed

+52
-34
lines changed

backends/arm/operators/op_slice.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,18 @@ class SliceVisitor(NodeVisitor):
2323
def __init__(self, *args):
2424
super().__init__(*args)
2525

26+
def _fixup_start(self, start, shape, dim):
27+
if start.number < 0:
28+
return start.number % shape[dim]
29+
else:
30+
return start.number
31+
32+
def _fixup_end(self, end, shape, dim):
33+
if end.number < 0:
34+
return end.number % shape[dim]
35+
else:
36+
return min(end.number, shape[dim])
37+
2638
def define_node(
2739
self,
2840
node: Node,
@@ -42,17 +54,21 @@ def define_node(
4254
# Translate and check parameters in Pytorch dim order.
4355
shape = input_node.shape
4456
dim = dim.number
45-
if end.number < 0:
46-
end_index = end.number % shape[dim]
47-
else:
48-
end_index = min(end.number, shape[dim])
49-
size = end_index - start.number
57+
58+
start_index = self._fixup_start(start, shape, dim)
59+
end_index = self._fixup_end(end, shape, dim)
60+
size = end_index - start_index
61+
5062
assert size > 0
5163
assert size <= shape[dim]
5264

5365
# Convert aten args to Tosa's start and size attributes and in TOSA dim order.
5466
attr = ts.TosaSerializerAttribute()
55-
start_attr = [start.number if i == dim else 0 for i in input_node.dim_order]
67+
68+
start_attr = [
69+
self._fixup_start(start, shape, dim) if i == dim else 0
70+
for i in input_node.dim_order
71+
]
5672
size_attr = [size if i == dim else shape[i] for i in input_node.dim_order]
5773
attr.SliceAttribute(start_attr, size_attr)
5874

backends/arm/test/ops/test_slice.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,21 @@
1616
from executorch.exir.backend.compile_spec_schema import CompileSpec
1717
from parameterized import parameterized
1818

19+
test_data_suite = [
20+
(torch.ones(10), [(3, -3)]),
21+
(torch.ones(10), [(-8, 3)]),
22+
(torch.ones(10, 10), [(1, 3), (3, None)]),
23+
(torch.ones(10, 10, 10), [(0, 7), (0, None), (0, 8)]),
24+
(torch.ones((1, 12, 10, 10)), [(None, None), (None, 5), (3, 5), (4, 10)]),
25+
]
26+
1927

2028
class TestSimpleSlice(unittest.TestCase):
2129

2230
class Slice(torch.nn.Module):
23-
24-
sizes = [(10), (10, 10), (10, 10, 10), ((1, 12, 10, 10))]
25-
test_tensors = [(torch.ones(n),) for n in sizes]
26-
27-
def forward(self, x: torch.Tensor):
28-
if x.dim() == 1:
29-
return x[3:-3]
30-
elif x.dim() == 2:
31-
return x[1:3, 3:]
32-
elif x.dim() == 3:
33-
return x[0:7, 0:, 0:8]
34-
elif x.dim() == 4:
35-
return x[:, :5, 3:5, 4:10]
31+
def forward(self, x: torch.Tensor, s: list[tuple[int, int]]):
32+
slices = [slice(*i) for i in s]
33+
return x[slices]
3634

3735
def _test_slice_tosa_MI_pipeline(
3836
self, module: torch.nn.Module, test_data: torch.Tensor
@@ -112,25 +110,29 @@ def _test_slice_u85_BI_pipeline(
112110
common.get_u85_compile_spec(), module, test_data
113111
)
114112

115-
@parameterized.expand(Slice.test_tensors)
113+
@parameterized.expand(test_data_suite)
116114
@pytest.mark.tosa_ref_model
117-
def test_slice_tosa_MI(self, tensor):
118-
self._test_slice_tosa_MI_pipeline(self.Slice(), (tensor,))
115+
def test_slice_tosa_MI(self, tensor: torch.Tensor, slices: list[tuple[int, int]]):
116+
self._test_slice_tosa_MI_pipeline(self.Slice(), (tensor, slices))
119117

120-
@parameterized.expand(Slice.test_tensors[:2])
118+
@parameterized.expand(test_data_suite)
121119
@pytest.mark.tosa_ref_model
122-
def test_slice_nchw_tosa_BI(self, test_tensor: torch.Tensor):
123-
self._test_slice_tosa_BI_pipeline(self.Slice(), (test_tensor,))
120+
def test_slice_nchw_tosa_BI(
121+
self, tensor: torch.Tensor, slices: list[tuple[int, int]]
122+
):
123+
self._test_slice_tosa_BI_pipeline(self.Slice(), (tensor, slices))
124124

125-
@parameterized.expand(Slice.test_tensors[2:])
125+
@parameterized.expand(test_data_suite)
126126
@pytest.mark.tosa_ref_model
127-
def test_slice_nhwc_tosa_BI(self, test_tensor: torch.Tensor):
128-
self._test_slice_tosa_BI_pipeline(self.Slice(), (test_tensor,))
127+
def test_slice_nhwc_tosa_BI(
128+
self, tensor: torch.Tensor, slices: list[tuple[int, int]]
129+
):
130+
self._test_slice_tosa_BI_pipeline(self.Slice(), (tensor, slices))
129131

130-
@parameterized.expand(Slice.test_tensors)
131-
def test_slice_u55_BI(self, test_tensor: torch.Tensor):
132-
self._test_slice_u55_BI_pipeline(self.Slice(), (test_tensor,))
132+
@parameterized.expand(test_data_suite)
133+
def test_slice_u55_BI(self, tensor: torch.Tensor, slices: list[tuple[int, int]]):
134+
self._test_slice_u55_BI_pipeline(self.Slice(), (tensor, slices))
133135

134-
@parameterized.expand(Slice.test_tensors)
135-
def test_slice_u85_BI(self, test_tensor: torch.Tensor):
136-
self._test_slice_u85_BI_pipeline(self.Slice(), (test_tensor,))
136+
@parameterized.expand(test_data_suite)
137+
def test_slice_u85_BI(self, tensor: torch.Tensor, slices: list[tuple[int, int]]):
138+
self._test_slice_u85_BI_pipeline(self.Slice(), (tensor, slices))

0 commit comments

Comments
 (0)