Skip to content

Commit 8e89094

Browse files
Arm backend: Added 8 new unit tests for testing various passes. (#9037)
Refactored tests cases to use TestPassPipeline instead of ArmTester. Signed-off-by: Michiel Olieslagers <[email protected]>
1 parent 4022ff1 commit 8e89094

8 files changed

+555
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
from executorch.backends.arm._passes.convert_expand_copy_to_repeat import (
10+
ConvertExpandCopyToRepeatPass,
11+
)
12+
13+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
14+
15+
input_t = Tuple[torch.Tensor] # Input x
16+
17+
18+
class Expand(torch.nn.Module):
19+
"""
20+
Basic expand model using torch.Tensor.expand function
21+
"""
22+
23+
def __init__(self):
24+
super(Expand, self).__init__()
25+
26+
def forward(self, x):
27+
return x.expand(3, 4)
28+
29+
def get_inputs(self) -> input_t:
30+
return (torch.rand(3, 1),)
31+
32+
33+
def test_expand_to_repeat_tosa_BI():
34+
module = Expand()
35+
pipeline = PassPipeline[input_t](
36+
module,
37+
module.get_inputs(),
38+
tosa_version="TOSA-0.80+BI",
39+
ops_before_pass={
40+
"executorch_exir_dialects_edge__ops_aten_expand_copy_default": 1,
41+
},
42+
ops_not_before_pass=["executorch_exir_dialects_edge__ops_aten_repeat_default"],
43+
ops_after_pass={
44+
"executorch_exir_dialects_edge__ops_aten_repeat_default": 1,
45+
},
46+
ops_not_after_pass=[
47+
"executorch_exir_dialects_edge__ops_aten_expand_copy_default"
48+
],
49+
pass_list=[ConvertExpandCopyToRepeatPass],
50+
)
51+
pipeline.run()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
from executorch.backends.arm._passes.convert_split_to_slice import (
10+
ConvertSplitToSlicePass,
11+
)
12+
13+
from executorch.backends.arm.test import common
14+
15+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
16+
17+
input_t = Tuple[torch.Tensor] # Input x
18+
19+
20+
class Split(torch.nn.Module):
21+
"""
22+
Basic split model using torch.split function
23+
"""
24+
25+
def get_inputs(self) -> input_t:
26+
return (torch.rand(10),)
27+
28+
def forward(self, x):
29+
return torch.split(x, 2)
30+
31+
32+
class SplitTensor(torch.nn.Module):
33+
"""
34+
Basic split model using torch.Tensor.split function
35+
"""
36+
37+
def get_inputs(self) -> input_t:
38+
return (torch.rand(10),)
39+
40+
def forward(self, x):
41+
return x.split(2)
42+
43+
44+
modules = {"split_basic": Split(), "split_tensor": SplitTensor()}
45+
46+
47+
@common.parametrize("module", modules)
48+
def test_split_to_slice_tosa_BI(module):
49+
pipeline = PassPipeline[input_t](
50+
module,
51+
module.get_inputs(),
52+
tosa_version="TOSA-0.80+BI",
53+
ops_before_pass={
54+
"executorch_exir_dialects_edge__ops_aten_split_with_sizes_copy_default": 1,
55+
},
56+
ops_not_before_pass=[
57+
"executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor"
58+
],
59+
ops_after_pass={
60+
"executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor": 5,
61+
},
62+
ops_not_after_pass=[
63+
"executorch_exir_dialects_edge__ops_aten_split_with_sizes_copy_default"
64+
],
65+
pass_list=[ConvertSplitToSlicePass],
66+
)
67+
pipeline.run()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
10+
11+
from executorch.backends.arm.test import common
12+
13+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
14+
15+
input_t = Tuple[torch.Tensor] # Input x
16+
17+
18+
class Div(torch.nn.Module):
19+
"""
20+
Basic div model using torch.div
21+
"""
22+
23+
def get_inputs(self) -> input_t:
24+
return (torch.rand(10),)
25+
26+
def forward(self, x):
27+
return torch.div(x, 2)
28+
29+
30+
class DivTensor(torch.nn.Module):
31+
"""
32+
Basic div model using torch.Tensor.div
33+
"""
34+
35+
def get_inputs(self) -> input_t:
36+
return (torch.rand(10),)
37+
38+
def forward(self, x):
39+
return x.div(2)
40+
41+
42+
modules = {"div_basic": Div(), "div_tensor": DivTensor()}
43+
44+
45+
@common.parametrize("module", modules)
46+
def test_decompose_div_tosa_MI(module):
47+
pipeline = PassPipeline[input_t](
48+
module,
49+
module.get_inputs(),
50+
tosa_version="TOSA-0.80+MI",
51+
ops_before_pass={
52+
"executorch_exir_dialects_edge__ops_aten_div_Tensor": 1,
53+
},
54+
ops_not_before_pass=[
55+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor",
56+
"executorch_exir_dialects_edge__ops_aten_reciprocal_default",
57+
],
58+
ops_after_pass={
59+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1,
60+
"executorch_exir_dialects_edge__ops_aten_reciprocal_default": 1,
61+
},
62+
ops_not_after_pass=["executorch_exir_dialects_edge__ops_aten_div_Tensor"],
63+
pass_list=[DecomposeDivPass],
64+
)
65+
pipeline.run()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
from executorch.backends.arm._passes.decompose_layernorm_pass import (
10+
DecomposeLayerNormPass,
11+
)
12+
13+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
14+
15+
input_t = Tuple[torch.Tensor] # Input x
16+
17+
18+
class LayerNorm(torch.nn.Module):
19+
"""
20+
Basic layer_norm model using torch.nn.layer_norm layer
21+
"""
22+
23+
def __init__(self):
24+
super(LayerNorm, self).__init__()
25+
self.layer_norm = torch.nn.LayerNorm(10)
26+
27+
def forward(self, x):
28+
x = self.layer_norm(x)
29+
return x
30+
31+
def get_inputs(self) -> input_t:
32+
return (torch.rand(10),)
33+
34+
35+
def test_decompose_layernorm_tosa_MI():
36+
module = LayerNorm()
37+
pipeline = PassPipeline[input_t](
38+
module,
39+
module.get_inputs(),
40+
tosa_version="TOSA-0.80+MI",
41+
ops_before_pass={
42+
"executorch_exir_dialects_edge__ops_aten_native_layer_norm_default": 1,
43+
},
44+
ops_not_before_pass=[
45+
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
46+
"executorch_exir_dialects_edge__ops_aten_view_copy_default",
47+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor",
48+
"executorch_exir_dialects_edge__ops_aten_full_default",
49+
"executorch_exir_dialects_edge__ops_aten_rsqrt_default",
50+
"executorch_exir_dialects_edge__ops_aten_var_correction",
51+
"executorch_exir_dialects_edge__ops_aten_sub_Tensor",
52+
"executorch_exir_dialects_edge__ops_aten_mean_dim",
53+
],
54+
ops_after_pass={
55+
"executorch_exir_dialects_edge__ops_aten_add_Tensor": 2,
56+
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 2,
57+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 2,
58+
"executorch_exir_dialects_edge__ops_aten_full_default": 1,
59+
"executorch_exir_dialects_edge__ops_aten_rsqrt_default": 1,
60+
"executorch_exir_dialects_edge__ops_aten_var_correction": 1,
61+
"executorch_exir_dialects_edge__ops_aten_sub_Tensor": 1,
62+
"executorch_exir_dialects_edge__ops_aten_mean_dim": 1,
63+
},
64+
ops_not_after_pass=[
65+
"executorch_exir_dialects_edge__ops_aten_expand_copy_default"
66+
],
67+
pass_list=[DecomposeLayerNormPass],
68+
)
69+
pipeline.run()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
10+
11+
from executorch.backends.arm.test import common
12+
13+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
14+
15+
input_t = Tuple[torch.Tensor] # Input x
16+
17+
18+
class MeanDim(torch.nn.Module):
19+
"""
20+
Basic mean model using torch.mean function making sure keepdim=True (keepdim=False doesnt work for this pass for some reason)
21+
"""
22+
23+
def __init__(self):
24+
super(MeanDim, self).__init__()
25+
26+
def forward(self, x):
27+
return torch.mean(x, 1, True)
28+
29+
def get_inputs(self) -> input_t:
30+
return (torch.rand(4, 4),)
31+
32+
33+
class MeanDimTensor(torch.nn.Module):
34+
"""
35+
Basic mean model using torch.Tensor.mean function making sure keepdim=True (keepdim=False doesnt work for this pass for some reason)
36+
"""
37+
38+
def __init__(self):
39+
super(MeanDimTensor, self).__init__()
40+
41+
def forward(self, x):
42+
return x.mean(1, True)
43+
44+
def get_inputs(self) -> input_t:
45+
return (torch.rand(4, 4),)
46+
47+
48+
modules = {"meandim_basic": MeanDim(), "meandim_tensor": MeanDimTensor()}
49+
50+
51+
@common.parametrize("module", modules)
52+
def test_decompose_meandim_tosa_MI(module):
53+
pipeline = PassPipeline[input_t](
54+
module,
55+
module.get_inputs(),
56+
tosa_version="TOSA-0.80+MI",
57+
ops_before_pass={
58+
"executorch_exir_dialects_edge__ops_aten_mean_dim": 1,
59+
},
60+
ops_not_before_pass=[
61+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor",
62+
"executorch_exir_dialects_edge__ops_aten_full_default",
63+
"executorch_exir_dialects_edge__ops_aten_sum_dim_IntList",
64+
],
65+
ops_after_pass={
66+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1,
67+
"executorch_exir_dialects_edge__ops_aten_full_default": 1,
68+
"executorch_exir_dialects_edge__ops_aten_sum_dim_IntList": 1,
69+
},
70+
ops_not_after_pass=["executorch_exir_dialects_edge__ops_aten_mean_dim"],
71+
pass_list=[DecomposeMeanDimPass],
72+
)
73+
pipeline.run()

0 commit comments

Comments
 (0)