Skip to content

Commit b504ba2

Browse files
Arm backend: Add decomposition pass for aten.ne (#10475)
Add a DecomposeNotEqualPass that rewrites `aten.ne` and its variants (e.g. `aten.ne_`, `aten.ne.Scalar`) into a combination of supported TOSA ops: ne(x, y) → logical_not(eq(x, y)) This decomposition is necessary since TOSA does not define a NOT_EQUAL operator. The pass ensures compatibility with both MI and BI TOSA profiles by emitting only supported ops (`eq` and `logical_not`), and is integrated into the ArmPassManager pipeline accordingly. Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 8d43ac4 commit b504ba2

File tree

7 files changed

+275
-0
lines changed

7 files changed

+275
-0
lines changed

backends/arm/_passes/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
2626
from .decompose_linear_pass import DecomposeLinearPass # noqa
2727
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
28+
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
2829
from .decompose_select import DecomposeSelectPass # noqa
2930
from .decompose_silu_pass import DecomposeSiluPass # noqa
3031
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa

backends/arm/_passes/arm_pass_manager.py

+3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
DecomposeLeakyReLUPass,
3131
DecomposeLinearPass,
3232
DecomposeMeanDimPass,
33+
DecomposeNotEqualPass,
3334
DecomposeSelectPass,
3435
DecomposeSiluPass,
3536
DecomposeSoftmaxPass,
@@ -131,6 +132,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
131132
self.add_pass(DecomposeLayerNormPass())
132133
self.add_pass(DecomposeVarPass())
133134
self.add_pass(DecomposeMeanDimPass())
135+
self.add_pass(DecomposeNotEqualPass())
134136
self.add_pass(ConvertMeanDimToAveragePoolPass())
135137
self.add_pass(DecomposeDivPass())
136138
self.add_pass(DecomposeSoftmaxPass())
@@ -194,6 +196,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
194196
self.add_pass(DecomposeLayerNormPass())
195197
self.add_pass(DecomposeVarPass())
196198
self.add_pass(DecomposeMeanDimPass())
199+
self.add_pass(DecomposeNotEqualPass())
197200
self.add_pass(DecomposeDivPass())
198201
self.add_pass(DecomposeLeakyReLUPass())
199202
self.add_pass(DecomposeSqrtPass())
+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm._passes import ArmPass
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
10+
edge_ne_ops = (exir_ops.edge.aten.ne.Tensor,)
11+
aten_ne_ops = (torch.ops.aten.ne.Tensor, torch.ops.aten.ne_.Tensor)
12+
13+
14+
def get_ne_decomposition(op) -> tuple:
15+
"""
16+
Returns the decomposition of the given aten.ne operation into its equivalent
17+
TOSA-supported operations.
18+
19+
This handles both edge dialect ops and core PyTorch ops. The decomposition strategy
20+
is:
21+
ne(x, y) -> logical_not(eq(x, y))
22+
23+
Returns:
24+
A tuple (eq_op, logical_not_op) corresponding to the appropriate operator
25+
overloads for the input op.
26+
27+
Raises:
28+
RuntimeError: If the provided operator is not a supported ne variant.
29+
"""
30+
if op in edge_ne_ops:
31+
return (exir_ops.edge.aten.eq.Tensor, exir_ops.edge.aten.logical_not.default)
32+
if op in aten_ne_ops:
33+
return (torch.ops.aten.eq.Tensor, torch.ops.aten.logical_not.default)
34+
35+
raise RuntimeError(f"Can't get ne decomposition for op {op}")
36+
37+
38+
class DecomposeNotEqualPass(ArmPass):
39+
"""
40+
A transformation pass that decomposes unsupported `aten.ne` operations into a
41+
combination of supported TOSA-equivalent operations.
42+
43+
Since TOSA does not provide a native NOT_EQUAL operator, this pass rewrites:
44+
ne(x, y) → logical_not(eq(x, y))
45+
46+
Supported input ops:
47+
- aten.ne.Tensor(x, y)
48+
- aten.ne_.Tensor(x, y)
49+
- exir_ops.edge.aten.ne.Tensor(x, y)
50+
51+
These are replaced with:
52+
- aten.eq.Tensor or exir_ops.edge.aten.eq.Tensor
53+
- followed by aten.logical_not.default or its edge equivalent
54+
"""
55+
56+
def call_operator(self, op, args, kwargs, meta):
57+
if op not in (edge_ne_ops + aten_ne_ops):
58+
return super().call_operator(op, args, kwargs, meta)
59+
60+
lhs, rhs = args
61+
62+
eq_op, logical_not_op = get_ne_decomposition(op)
63+
64+
eq_node = super().call_operator(eq_op, (lhs, rhs), {}, meta, updated=True)
65+
not_node = super().call_operator(
66+
logical_not_op, (eq_node,), {}, meta, updated=True
67+
)
68+
69+
return not_node

backends/arm/_passes/replace_scalar_with_tensor_pass.py

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
exir_ops.edge.aten.gt.Scalar: exir_ops.edge.aten.gt.Tensor,
3030
exir_ops.edge.aten.ge.Scalar: exir_ops.edge.aten.ge.Tensor,
3131
exir_ops.edge.aten.lt.Scalar: exir_ops.edge.aten.lt.Tensor,
32+
exir_ops.edge.aten.ne.Scalar: exir_ops.edge.aten.ne.Tensor,
3233
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
3334
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
3435
torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor,
@@ -39,6 +40,7 @@
3940
torch.ops.aten.gt.Scalar: torch.ops.aten.gt.Tensor,
4041
torch.ops.aten.ge.Scalar: torch.ops.aten.ge.Tensor,
4142
torch.ops.aten.lt.Scalar: torch.ops.aten.lt.Tensor,
43+
torch.ops.aten.ne.Scalar: torch.ops.aten.ne.Tensor,
4244
}
4345

4446

backends/arm/operator_support/ethos_u55_support.py

+2
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ class EthosU55NotSupported(OperatorSupportBase):
140140
exir_ops.edge.aten.le.Tensor,
141141
exir_ops.edge.aten.lt.Tensor,
142142
exir_ops.edge.aten.lt.Scalar,
143+
exir_ops.edge.aten.ne.Tensor,
144+
exir_ops.edge.aten.ne.Scalar,
143145
exir_ops.edge.aten.flip.default, # REVERSE
144146
exir_ops.edge.aten.grid_sampler_2d, # GATHER
145147
exir_ops.edge.aten.scatter.src,

backends/arm/operator_support/tosa_supported_operators.py

+4
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@ def is_node_supported(
185185
exir_ops.edge.aten.lt.Tensor,
186186
exir_ops.edge.aten.lt.Scalar,
187187
exir_ops.edge.aten.mul.Tensor,
188+
exir_ops.edge.aten.ne.Tensor,
189+
exir_ops.edge.aten.ne.Scalar,
188190
exir_ops.edge.aten.add.Scalar,
189191
exir_ops.edge.aten.sub.Scalar,
190192
exir_ops.edge.aten.mul.Scalar,
@@ -269,6 +271,8 @@ def is_node_supported(
269271
exir_ops.edge.aten.sqrt.default,
270272
exir_ops.edge.aten.sub.Scalar,
271273
exir_ops.edge.aten.mul.Scalar,
274+
exir_ops.edge.aten.ne.Tensor,
275+
exir_ops.edge.aten.ne.Scalar,
272276
exir_ops.edge.aten.div.Scalar,
273277
exir_ops.edge.aten.leaky_relu.default,
274278
]

backends/arm/test/ops/test_ne.py

+194
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
from executorch.backends.arm.test import common
10+
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU85PipelineBI,
13+
OpNotSupportedPipeline,
14+
TosaPipelineBI,
15+
TosaPipelineMI,
16+
)
17+
18+
19+
input_t = Tuple[torch.Tensor]
20+
21+
22+
class NotEqual(torch.nn.Module):
23+
aten_op_Tensor = "torch.ops.aten.ne.Tensor"
24+
aten_op_Scalar = "torch.ops.aten.ne.Scalar"
25+
decomposed_ops = ["torch.ops.aten.eq.Tensor", "torch.ops.aten.logical_not.default"]
26+
decomposed_exir_ops = [
27+
"executorch_exir_dialects_edge__ops_aten_eq_Tensor",
28+
"executorch_exir_dialects_edge__ops_aten_logical_not_default",
29+
]
30+
exir_op = "executorch_exir_dialects_edge__ops_aten_ne_Tensor"
31+
32+
def __init__(self, input, other):
33+
super().__init__()
34+
self.input_ = input
35+
self.other_ = other
36+
37+
def forward(
38+
self,
39+
input_: torch.Tensor,
40+
other_: torch.Tensor | int | float,
41+
):
42+
return input_ != other_
43+
44+
def get_inputs(self):
45+
return (self.input_, self.other_)
46+
47+
48+
op_ne_tensor_rank1_ones = NotEqual(
49+
torch.ones(5),
50+
torch.ones(5),
51+
)
52+
op_ne_tensor_rank2_rand = NotEqual(
53+
torch.rand(4, 5),
54+
torch.rand(1, 5),
55+
)
56+
op_ne_tensor_rank3_randn = NotEqual(
57+
torch.randn(10, 5, 2),
58+
torch.randn(10, 5, 2),
59+
)
60+
op_ne_tensor_rank4_randn = NotEqual(
61+
torch.randn(3, 2, 2, 2),
62+
torch.randn(3, 2, 2, 2),
63+
)
64+
65+
op_ne_scalar_rank1_ones = NotEqual(torch.ones(5), 1.0)
66+
op_ne_scalar_rank2_rand = NotEqual(torch.rand(4, 5), 0.2)
67+
op_ne_scalar_rank3_randn = NotEqual(torch.randn(10, 5, 2), -0.1)
68+
op_ne_scalar_rank4_randn = NotEqual(torch.randn(3, 2, 2, 2), 0.3)
69+
op_ne_scalar_rank4_randn_1batch = NotEqual(torch.randn(1, 2, 2, 2), 0.3)
70+
71+
test_data_tensor = {
72+
"ne_tensor_rank1_ones": op_ne_tensor_rank1_ones,
73+
"ne_tensor_rank2_rand": op_ne_tensor_rank2_rand,
74+
"ne_tensor_rank3_randn": op_ne_tensor_rank3_randn,
75+
"ne_tensor_rank4_randn": op_ne_tensor_rank4_randn,
76+
}
77+
78+
test_data_scalar = {
79+
"ne_scalar_rank1_ones": op_ne_scalar_rank1_ones,
80+
"ne_scalar_rank2_rand": op_ne_scalar_rank2_rand,
81+
"ne_scalar_rank3_randn": op_ne_scalar_rank3_randn,
82+
"ne_scalar_rank4_randn": op_ne_scalar_rank4_randn,
83+
"ne_scalar_rank4_randn_1batch": op_ne_scalar_rank4_randn_1batch,
84+
}
85+
86+
87+
@common.parametrize("test_module", test_data_tensor)
88+
def test_ne_tensor_tosa_MI(test_module):
89+
pipeline = TosaPipelineMI[input_t](
90+
test_module, test_module.get_inputs(), NotEqual.aten_op_Tensor, NotEqual.exir_op
91+
)
92+
pipeline.run()
93+
94+
95+
@common.parametrize("test_module", test_data_scalar)
96+
def test_ne_scalar_tosa_MI(test_module):
97+
pipeline = TosaPipelineMI[input_t](
98+
test_module,
99+
test_module.get_inputs(),
100+
NotEqual.aten_op_Scalar,
101+
NotEqual.exir_op,
102+
)
103+
pipeline.run()
104+
105+
106+
@common.parametrize("test_module", test_data_tensor)
107+
def test_ne_tensor_tosa_BI(test_module):
108+
pipeline = TosaPipelineBI[input_t](
109+
test_module, test_module.get_inputs(), NotEqual.decomposed_ops, NotEqual.exir_op
110+
)
111+
pipeline.run()
112+
113+
114+
@common.parametrize("test_module", test_data_scalar)
115+
def test_ne_scalar_tosa_BI(test_module):
116+
pipeline = TosaPipelineBI[input_t](
117+
test_module, test_module.get_inputs(), NotEqual.decomposed_ops, NotEqual.exir_op
118+
)
119+
pipeline.run()
120+
121+
122+
@common.parametrize("test_module", test_data_tensor)
123+
@common.XfailIfNoCorstone300
124+
def test_ne_tensor_u55_BI(test_module):
125+
# EQUAL is not supported on U55.
126+
pipeline = OpNotSupportedPipeline[input_t](
127+
test_module,
128+
test_module.get_inputs(),
129+
"TOSA-0.80+BI+u55",
130+
{
131+
NotEqual.decomposed_exir_ops[0]: 1,
132+
NotEqual.decomposed_exir_ops[1]: 1,
133+
},
134+
)
135+
pipeline.run()
136+
137+
138+
@common.parametrize("test_module", test_data_scalar)
139+
@common.XfailIfNoCorstone300
140+
def test_ne_scalar_u55_BI(test_module):
141+
# Not equal (ne) is decomposed into the TOSA ops EQUAL and LOGICAL_NOT, both of
142+
# which are unsupported on U55.
143+
pipeline = OpNotSupportedPipeline[input_t](
144+
test_module,
145+
test_module.get_inputs(),
146+
"TOSA-0.80+BI+u55",
147+
{
148+
NotEqual.decomposed_exir_ops[0]: 1,
149+
NotEqual.decomposed_exir_ops[1]: 1,
150+
},
151+
n_expected_delegates=1,
152+
)
153+
pipeline.run()
154+
155+
156+
@common.parametrize(
157+
"test_module",
158+
test_data_tensor,
159+
xfails={
160+
"ne_tensor_rank4_randn": "MLETORCH-517: Batch size > 1 not fully supported",
161+
},
162+
strict=False,
163+
)
164+
@common.XfailIfNoCorstone320
165+
def test_ne_tensor_u85_BI(test_module):
166+
pipeline = EthosU85PipelineBI[input_t](
167+
test_module,
168+
test_module.get_inputs(),
169+
NotEqual.decomposed_ops,
170+
NotEqual.decomposed_exir_ops,
171+
run_on_fvp=True,
172+
)
173+
pipeline.run()
174+
175+
176+
@common.parametrize(
177+
"test_module",
178+
test_data_scalar,
179+
xfails={
180+
"ne_scalar_rank4_randn": "MLETORCH-517: Batch size > 1 not fully supported",
181+
"ne_scalar_rank4_randn_1batch": "MLETORCH-847: Boolean ne result unstable on U85",
182+
},
183+
strict=False,
184+
)
185+
@common.XfailIfNoCorstone320
186+
def test_ne_scalar_u85_BI(test_module):
187+
pipeline = EthosU85PipelineBI[input_t](
188+
test_module,
189+
test_module.get_inputs(),
190+
NotEqual.decomposed_ops,
191+
NotEqual.decomposed_exir_ops,
192+
run_on_fvp=True,
193+
)
194+
pipeline.run()

0 commit comments

Comments
 (0)