Skip to content

Commit 5a225b5

Browse files
committed
Arm backend: Add support to ge.Scalar
- Convert ge.Scalar to ge.Tensor - Add new scalar test cases for ge Signed-off-by: Fang-Ching <[email protected]> Change-Id: I140044db2e2bff625e7aab565d491838b17741a3
1 parent a073668 commit 5a225b5

File tree

5 files changed

+94
-44
lines changed

5 files changed

+94
-44
lines changed

backends/arm/_passes/match_arg_ranks_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(self, exported_program):
4949
exir_ops.edge.aten.bitwise_left_shift.Tensor,
5050
exir_ops.edge.aten.eq.Tensor,
5151
exir_ops.edge.aten.gt.Tensor,
52+
exir_ops.edge.aten.ge.Tensor,
5253
exir_ops.edge.aten.lt.Tensor,
5354
exir_ops.edge.aten.pow.Tensor_Tensor,
5455
exir_ops.edge.aten.where.self,

backends/arm/_passes/replace_scalar_with_tensor_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
exir_ops.edge.aten.__lshift__.Scalar: exir_ops.edge.aten.bitwise_left_shift.Tensor,
2828
exir_ops.edge.aten.eq.Scalar: exir_ops.edge.aten.eq.Tensor,
2929
exir_ops.edge.aten.gt.Scalar: exir_ops.edge.aten.gt.Tensor,
30+
exir_ops.edge.aten.ge.Scalar: exir_ops.edge.aten.ge.Tensor,
3031
exir_ops.edge.aten.lt.Scalar: exir_ops.edge.aten.lt.Tensor,
3132
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
3233
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
@@ -36,6 +37,7 @@
3637
torch.ops.aten.__lshift__.Scalar: torch.ops.aten.bitwise_left_shift.Tensor,
3738
torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor,
3839
torch.ops.aten.gt.Scalar: torch.ops.aten.gt.Tensor,
40+
torch.ops.aten.ge.Scalar: torch.ops.aten.ge.Tensor,
3941
torch.ops.aten.lt.Scalar: torch.ops.aten.lt.Tensor,
4042
}
4143

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ class EthosU55NotSupported(OperatorSupportBase):
134134
exir_ops.edge.aten.eq.Tensor,
135135
exir_ops.edge.aten.eq.Scalar,
136136
exir_ops.edge.aten.ge.Tensor,
137+
exir_ops.edge.aten.ge.Scalar,
137138
exir_ops.edge.aten.gt.Tensor,
138139
exir_ops.edge.aten.gt.Scalar,
139140
exir_ops.edge.aten.le.Tensor,

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def is_node_supported(
176176
exir_ops.edge.aten.full.default,
177177
exir_ops.edge.aten.full_like.default,
178178
exir_ops.edge.aten.ge.Tensor,
179+
exir_ops.edge.aten.ge.Scalar,
179180
exir_ops.edge.aten.gt.Tensor,
180181
exir_ops.edge.aten.gt.Scalar,
181182
exir_ops.edge.aten.le.Tensor,

backends/arm/test/ops/test_ge.py

Lines changed: 89 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from typing import Tuple
77

8-
import pytest
98
import torch
109
from executorch.backends.arm.test import common
1110

@@ -16,13 +15,14 @@
1615
TosaPipelineMI,
1716
)
1817

19-
aten_op = "torch.ops.aten.ge.Tensor"
20-
exir_op = "executorch_exir_dialects_edge__ops_aten_ge_Tensor"
21-
2218
input_t = Tuple[torch.Tensor]
2319

2420

2521
class GreaterEqual(torch.nn.Module):
22+
aten_op_tensor = "torch.ops.aten.ge.Tensor"
23+
aten_op_scalar = "torch.ops.aten.ge.Scalar"
24+
exir_op = "executorch_exir_dialects_edge__ops_aten_ge_Tensor"
25+
2626
def __init__(self, input, other):
2727
super().__init__()
2828
self.input_ = input
@@ -31,106 +31,151 @@ def __init__(self, input, other):
3131
def forward(
3232
self,
3333
input_: torch.Tensor,
34-
other_: torch.Tensor,
34+
other_: torch.Tensor | int | float,
3535
):
3636
return input_ >= other_
3737

3838
def get_inputs(self):
3939
return (self.input_, self.other_)
4040

4141

42-
op_ge_rank1_ones = GreaterEqual(
42+
op_ge_tensor_rank1_ones = GreaterEqual(
4343
torch.ones(5),
4444
torch.ones(5),
4545
)
46-
op_ge_rank2_rand = GreaterEqual(
46+
op_ge_tensor_rank2_rand = GreaterEqual(
4747
torch.rand(4, 5),
4848
torch.rand(1, 5),
4949
)
50-
op_ge_rank3_randn = GreaterEqual(
50+
op_ge_tensor_rank3_randn = GreaterEqual(
5151
torch.randn(10, 5, 2),
5252
torch.randn(10, 5, 2),
5353
)
54-
op_ge_rank4_randn = GreaterEqual(
54+
op_ge_tensor_rank4_randn = GreaterEqual(
5555
torch.randn(3, 2, 2, 2),
5656
torch.randn(3, 2, 2, 2),
5757
)
5858

59-
test_data_common = {
60-
"ge_rank1_ones": op_ge_rank1_ones,
61-
"ge_rank2_rand": op_ge_rank2_rand,
62-
"ge_rank3_randn": op_ge_rank3_randn,
63-
"ge_rank4_randn": op_ge_rank4_randn,
59+
op_ge_scalar_rank1_ones = GreaterEqual(torch.ones(5), 1.0)
60+
op_ge_scalar_rank2_rand = GreaterEqual(torch.rand(4, 5), 0.2)
61+
op_ge_scalar_rank3_randn = GreaterEqual(torch.randn(10, 5, 2), -0.1)
62+
op_ge_scalar_rank4_randn = GreaterEqual(torch.randn(3, 2, 2, 2), 0.3)
63+
64+
test_data_tensor = {
65+
"ge_tensor_rank1_ones": op_ge_tensor_rank1_ones,
66+
"ge_tensor_rank2_rand": op_ge_tensor_rank2_rand,
67+
"ge_tensor_rank3_randn": op_ge_tensor_rank3_randn,
68+
"ge_tensor_rank4_randn": op_ge_tensor_rank4_randn,
69+
}
70+
71+
test_data_scalar = {
72+
"ge_scalar_rank1_ones": op_ge_scalar_rank1_ones,
73+
"ge_scalar_rank2_rand": op_ge_scalar_rank2_rand,
74+
"ge_scalar_rank3_randn": op_ge_scalar_rank3_randn,
75+
"ge_scalar_rank4_randn": op_ge_scalar_rank4_randn,
6476
}
6577

6678

67-
@common.parametrize("test_module", test_data_common)
68-
def test_ge_tosa_MI(test_module):
79+
@common.parametrize("test_module", test_data_tensor)
80+
def test_ge_tensor_tosa_MI(test_module):
81+
pipeline = TosaPipelineMI[input_t](
82+
test_module,
83+
test_module.get_inputs(),
84+
GreaterEqual.aten_op_tensor,
85+
GreaterEqual.exir_op,
86+
)
87+
pipeline.run()
88+
89+
90+
@common.parametrize("test_module", test_data_scalar)
91+
def test_ge_scalar_tosa_MI(test_module):
6992
pipeline = TosaPipelineMI[input_t](
70-
test_module, test_module.get_inputs(), aten_op, exir_op
93+
test_module,
94+
test_module.get_inputs(),
95+
GreaterEqual.aten_op_scalar,
96+
GreaterEqual.exir_op,
7197
)
7298
pipeline.run()
7399

74100

75-
@common.parametrize("test_module", test_data_common)
76-
def test_ge_tosa_BI(test_module):
101+
@common.parametrize("test_module", test_data_tensor)
102+
def test_ge_tensor_tosa_BI(test_module):
77103
pipeline = TosaPipelineBI[input_t](
78-
test_module, test_module.get_inputs(), aten_op, exir_op
104+
test_module,
105+
test_module.get_inputs(),
106+
GreaterEqual.aten_op_tensor,
107+
GreaterEqual.exir_op,
79108
)
80109
pipeline.run()
81110

82111

83-
@common.parametrize("test_module", test_data_common)
84-
def test_ge_u55_BI(test_module):
85-
# GREATER_EQUAL is not supported on U55.
86-
pipeline = OpNotSupportedPipeline[input_t](
112+
@common.parametrize("test_module", test_data_scalar)
113+
def test_ge_scalar_tosa_BI(test_module):
114+
pipeline = TosaPipelineBI[input_t](
87115
test_module,
88116
test_module.get_inputs(),
89-
"TOSA-0.80+BI+u55",
90-
{exir_op: 1},
117+
GreaterEqual.aten_op_tensor,
118+
GreaterEqual.exir_op,
91119
)
92120
pipeline.run()
93121

94122

95-
@common.parametrize("test_module", test_data_common)
96-
def test_ge_u85_BI(test_module):
97-
pipeline = EthosU85PipelineBI[input_t](
123+
@common.parametrize("test_module", test_data_tensor)
124+
@common.XfailIfNoCorstone300
125+
def test_ge_tensor_u55_BI(test_module):
126+
# GREATER_EQUAL is not supported on U55.
127+
pipeline = OpNotSupportedPipeline[input_t](
98128
test_module,
99129
test_module.get_inputs(),
100-
aten_op,
101-
exir_op,
102-
run_on_fvp=False,
103-
use_to_edge_transform_and_lower=True,
130+
"TOSA-0.80+BI+u55",
131+
{GreaterEqual.exir_op: 1},
104132
)
105133
pipeline.run()
106134

107135

108-
@common.parametrize("test_module", test_data_common)
109-
@pytest.mark.skip(reason="The same as test_ge_u55_BI")
110-
def test_ge_u55_BI_on_fvp(test_module):
136+
@common.parametrize("test_module", test_data_scalar)
137+
@common.XfailIfNoCorstone300
138+
def test_ge_scalar_u55_BI(test_module):
111139
# GREATER_EQUAL is not supported on U55.
112140
pipeline = OpNotSupportedPipeline[input_t](
113141
test_module,
114142
test_module.get_inputs(),
115143
"TOSA-0.80+BI+u55",
116-
{exir_op: 1},
144+
{GreaterEqual.exir_op: 1},
145+
n_expected_delegates=1,
146+
)
147+
pipeline.run()
148+
149+
150+
@common.parametrize(
151+
"test_module",
152+
test_data_tensor,
153+
xfails={"ge_tensor_rank4_randn": "MLETORCH-847: Boolean eq result unstable on U85"},
154+
)
155+
@common.XfailIfNoCorstone320
156+
def test_ge_tensor_u85_BI(test_module):
157+
pipeline = EthosU85PipelineBI[input_t](
158+
test_module,
159+
test_module.get_inputs(),
160+
GreaterEqual.aten_op_tensor,
161+
GreaterEqual.exir_op,
162+
run_on_fvp=True,
117163
)
118164
pipeline.run()
119165

120166

121167
@common.parametrize(
122168
"test_module",
123-
test_data_common,
124-
xfails={"ge_rank4_randn": "4D fails because boolean Tensors can't be subtracted"},
169+
test_data_scalar,
170+
xfails={"ge_scalar_rank4_randn": "MLETORCH-847: Boolean eq result unstable on U85"},
125171
)
126-
@common.SkipIfNoCorstone320
127-
def test_ge_u85_BI_on_fvp(test_module):
172+
@common.XfailIfNoCorstone320
173+
def test_ge_scalar_u85_BI(test_module):
128174
pipeline = EthosU85PipelineBI[input_t](
129175
test_module,
130176
test_module.get_inputs(),
131-
aten_op,
132-
exir_op,
177+
GreaterEqual.aten_op_tensor,
178+
GreaterEqual.exir_op,
133179
run_on_fvp=True,
134-
use_to_edge_transform_and_lower=True,
135180
)
136181
pipeline.run()

0 commit comments

Comments
 (0)