|
16 | 16 | from executorch.exir.backend.compile_spec_schema import CompileSpec
|
17 | 17 | from parameterized import parameterized
|
18 | 18 |
|
| 19 | +test_data_suite = [ |
| 20 | + (torch.ones(10), [(3, -3)]), |
| 21 | + (torch.ones(10), [(-8, 3)]), |
| 22 | + (torch.ones(10, 10), [(1, 3), (3, None)]), |
| 23 | + (torch.ones(10, 10, 10), [(0, 7), (0, None), (0, 8)]), |
| 24 | + (torch.ones((1, 12, 10, 10)), [(None, None), (None, 5), (3, 5), (4, 10)]), |
| 25 | +] |
| 26 | + |
19 | 27 |
|
20 | 28 | class TestSimpleSlice(unittest.TestCase):
|
21 | 29 |
|
22 | 30 | class Slice(torch.nn.Module):
|
23 |
| - |
24 |
| - sizes = [(10), (10, 10), (10, 10, 10), ((1, 12, 10, 10))] |
25 |
| - test_tensors = [(torch.ones(n),) for n in sizes] |
26 |
| - |
27 |
| - def forward(self, x: torch.Tensor): |
28 |
| - if x.dim() == 1: |
29 |
| - return x[3:-3] |
30 |
| - elif x.dim() == 2: |
31 |
| - return x[1:3, 3:] |
32 |
| - elif x.dim() == 3: |
33 |
| - return x[0:7, 0:, 0:8] |
34 |
| - elif x.dim() == 4: |
35 |
| - return x[:, :5, 3:5, 4:10] |
| 31 | + def forward(self, x: torch.Tensor, s: list[tuple[int, int]]): |
| 32 | + slices = [slice(*i) for i in s] |
| 33 | + return x[slices] |
36 | 34 |
|
37 | 35 | def _test_slice_tosa_MI_pipeline(
|
38 | 36 | self, module: torch.nn.Module, test_data: torch.Tensor
|
@@ -112,25 +110,29 @@ def _test_slice_u85_BI_pipeline(
|
112 | 110 | common.get_u85_compile_spec(), module, test_data
|
113 | 111 | )
|
114 | 112 |
|
115 |
| - @parameterized.expand(Slice.test_tensors) |
| 113 | + @parameterized.expand(test_data_suite) |
116 | 114 | @pytest.mark.tosa_ref_model
|
117 |
| - def test_slice_tosa_MI(self, tensor): |
118 |
| - self._test_slice_tosa_MI_pipeline(self.Slice(), (tensor,)) |
| 115 | + def test_slice_tosa_MI(self, tensor: torch.Tensor, slices: list[tuple[int, int]]): |
| 116 | + self._test_slice_tosa_MI_pipeline(self.Slice(), (tensor, slices)) |
119 | 117 |
|
120 |
| - @parameterized.expand(Slice.test_tensors[:2]) |
| 118 | + @parameterized.expand(test_data_suite) |
121 | 119 | @pytest.mark.tosa_ref_model
|
122 |
| - def test_slice_nchw_tosa_BI(self, test_tensor: torch.Tensor): |
123 |
| - self._test_slice_tosa_BI_pipeline(self.Slice(), (test_tensor,)) |
| 120 | + def test_slice_nchw_tosa_BI( |
| 121 | + self, tensor: torch.Tensor, slices: list[tuple[int, int]] |
| 122 | + ): |
| 123 | + self._test_slice_tosa_BI_pipeline(self.Slice(), (tensor, slices)) |
124 | 124 |
|
125 |
| - @parameterized.expand(Slice.test_tensors[2:]) |
| 125 | + @parameterized.expand(test_data_suite) |
126 | 126 | @pytest.mark.tosa_ref_model
|
127 |
| - def test_slice_nhwc_tosa_BI(self, test_tensor: torch.Tensor): |
128 |
| - self._test_slice_tosa_BI_pipeline(self.Slice(), (test_tensor,)) |
| 127 | + def test_slice_nhwc_tosa_BI( |
| 128 | + self, tensor: torch.Tensor, slices: list[tuple[int, int]] |
| 129 | + ): |
| 130 | + self._test_slice_tosa_BI_pipeline(self.Slice(), (tensor, slices)) |
129 | 131 |
|
130 |
| - @parameterized.expand(Slice.test_tensors) |
131 |
| - def test_slice_u55_BI(self, test_tensor: torch.Tensor): |
132 |
| - self._test_slice_u55_BI_pipeline(self.Slice(), (test_tensor,)) |
| 132 | + @parameterized.expand(test_data_suite) |
| 133 | + def test_slice_u55_BI(self, tensor: torch.Tensor, slices: list[tuple[int, int]]): |
| 134 | + self._test_slice_u55_BI_pipeline(self.Slice(), (tensor, slices)) |
133 | 135 |
|
134 |
| - @parameterized.expand(Slice.test_tensors) |
135 |
| - def test_slice_u85_BI(self, test_tensor: torch.Tensor): |
136 |
| - self._test_slice_u85_BI_pipeline(self.Slice(), (test_tensor,)) |
| 136 | + @parameterized.expand(test_data_suite) |
| 137 | + def test_slice_u85_BI(self, tensor: torch.Tensor, slices: list[tuple[int, int]]): |
| 138 | + self._test_slice_u85_BI_pipeline(self.Slice(), (tensor, slices)) |
0 commit comments