Skip to content

Arm backend: Add decomposition pass for aten.ne #10475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
from .decompose_linear_pass import DecomposeLinearPass # noqa
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
from .decompose_select import DecomposeSelectPass # noqa
from .decompose_silu_pass import DecomposeSiluPass # noqa
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa
Expand Down
3 changes: 3 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
DecomposeLeakyReLUPass,
DecomposeLinearPass,
DecomposeMeanDimPass,
DecomposeNotEqualPass,
DecomposeSelectPass,
DecomposeSiluPass,
DecomposeSoftmaxPass,
Expand Down Expand Up @@ -131,6 +132,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(DecomposeLayerNormPass())
self.add_pass(DecomposeVarPass())
self.add_pass(DecomposeMeanDimPass())
self.add_pass(DecomposeNotEqualPass())
self.add_pass(ConvertMeanDimToAveragePoolPass())
self.add_pass(DecomposeDivPass())
self.add_pass(DecomposeSoftmaxPass())
Expand Down Expand Up @@ -194,6 +196,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(DecomposeLayerNormPass())
self.add_pass(DecomposeVarPass())
self.add_pass(DecomposeMeanDimPass())
self.add_pass(DecomposeNotEqualPass())
self.add_pass(DecomposeDivPass())
self.add_pass(DecomposeLeakyReLUPass())
self.add_pass(DecomposeSqrtPass())
Expand Down
69 changes: 69 additions & 0 deletions backends/arm/_passes/decompose_ne_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.exir.dialects._ops import ops as exir_ops

edge_ne_ops = (exir_ops.edge.aten.ne.Tensor,)
aten_ne_ops = (torch.ops.aten.ne.Tensor, torch.ops.aten.ne_.Tensor)


def get_ne_decomposition(op) -> tuple:
"""
Returns the decomposition of the given aten.ne operation into its equivalent
TOSA-supported operations.

This handles both edge dialect ops and core PyTorch ops. The decomposition strategy
is:
ne(x, y) -> logical_not(eq(x, y))

Returns:
A tuple (eq_op, logical_not_op) corresponding to the appropriate operator
overloads for the input op.

Raises:
RuntimeError: If the provided operator is not a supported ne variant.
"""
if op in edge_ne_ops:
return (exir_ops.edge.aten.eq.Tensor, exir_ops.edge.aten.logical_not.default)
if op in aten_ne_ops:
return (torch.ops.aten.eq.Tensor, torch.ops.aten.logical_not.default)

raise RuntimeError(f"Can't get ne decomposition for op {op}")


class DecomposeNotEqualPass(ArmPass):
"""
A transformation pass that decomposes unsupported `aten.ne` operations into a
combination of supported TOSA-equivalent operations.

Since TOSA does not provide a native NOT_EQUAL operator, this pass rewrites:
ne(x, y) → logical_not(eq(x, y))

Supported input ops:
- aten.ne.Tensor(x, y)
- aten.ne_.Tensor(x, y)
- exir_ops.edge.aten.ne.Tensor(x, y)

These are replaced with:
- aten.eq.Tensor or exir_ops.edge.aten.eq.Tensor
- followed by aten.logical_not.default or its edge equivalent
"""

def call_operator(self, op, args, kwargs, meta):
if op not in (edge_ne_ops + aten_ne_ops):
return super().call_operator(op, args, kwargs, meta)

lhs, rhs = args

eq_op, logical_not_op = get_ne_decomposition(op)

eq_node = super().call_operator(eq_op, (lhs, rhs), {}, meta, updated=True)
not_node = super().call_operator(
logical_not_op, (eq_node,), {}, meta, updated=True
)

return not_node
2 changes: 2 additions & 0 deletions backends/arm/_passes/replace_scalar_with_tensor_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
exir_ops.edge.aten.gt.Scalar: exir_ops.edge.aten.gt.Tensor,
exir_ops.edge.aten.ge.Scalar: exir_ops.edge.aten.ge.Tensor,
exir_ops.edge.aten.lt.Scalar: exir_ops.edge.aten.lt.Tensor,
exir_ops.edge.aten.ne.Scalar: exir_ops.edge.aten.ne.Tensor,
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor,
Expand All @@ -39,6 +40,7 @@
torch.ops.aten.gt.Scalar: torch.ops.aten.gt.Tensor,
torch.ops.aten.ge.Scalar: torch.ops.aten.ge.Tensor,
torch.ops.aten.lt.Scalar: torch.ops.aten.lt.Tensor,
torch.ops.aten.ne.Scalar: torch.ops.aten.ne.Tensor,
}


Expand Down
2 changes: 2 additions & 0 deletions backends/arm/operator_support/ethos_u55_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ class EthosU55NotSupported(OperatorSupportBase):
exir_ops.edge.aten.le.Tensor,
exir_ops.edge.aten.lt.Tensor,
exir_ops.edge.aten.lt.Scalar,
exir_ops.edge.aten.ne.Tensor,
exir_ops.edge.aten.ne.Scalar,
exir_ops.edge.aten.flip.default, # REVERSE
exir_ops.edge.aten.grid_sampler_2d, # GATHER
exir_ops.edge.aten.scatter.src,
Expand Down
4 changes: 4 additions & 0 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ def is_node_supported(
exir_ops.edge.aten.lt.Tensor,
exir_ops.edge.aten.lt.Scalar,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.ne.Tensor,
exir_ops.edge.aten.ne.Scalar,
exir_ops.edge.aten.add.Scalar,
exir_ops.edge.aten.sub.Scalar,
exir_ops.edge.aten.mul.Scalar,
Expand Down Expand Up @@ -269,6 +271,8 @@ def is_node_supported(
exir_ops.edge.aten.sqrt.default,
exir_ops.edge.aten.sub.Scalar,
exir_ops.edge.aten.mul.Scalar,
exir_ops.edge.aten.ne.Tensor,
exir_ops.edge.aten.ne.Scalar,
exir_ops.edge.aten.div.Scalar,
exir_ops.edge.aten.leaky_relu.default,
]
Expand Down
194 changes: 194 additions & 0 deletions backends/arm/test/ops/test_ne.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch
from executorch.backends.arm.test import common

from executorch.backends.arm.test.tester.test_pipeline import (
EthosU85PipelineBI,
OpNotSupportedPipeline,
TosaPipelineBI,
TosaPipelineMI,
)


input_t = Tuple[torch.Tensor]


class NotEqual(torch.nn.Module):
aten_op_Tensor = "torch.ops.aten.ne.Tensor"
aten_op_Scalar = "torch.ops.aten.ne.Scalar"
decomposed_ops = ["torch.ops.aten.eq.Tensor", "torch.ops.aten.logical_not.default"]
decomposed_exir_ops = [
"executorch_exir_dialects_edge__ops_aten_eq_Tensor",
"executorch_exir_dialects_edge__ops_aten_logical_not_default",
]
exir_op = "executorch_exir_dialects_edge__ops_aten_ne_Tensor"

def __init__(self, input, other):
super().__init__()
self.input_ = input
self.other_ = other

def forward(
self,
input_: torch.Tensor,
other_: torch.Tensor | int | float,
):
return input_ != other_

def get_inputs(self):
return (self.input_, self.other_)


op_ne_tensor_rank1_ones = NotEqual(
torch.ones(5),
torch.ones(5),
)
op_ne_tensor_rank2_rand = NotEqual(
torch.rand(4, 5),
torch.rand(1, 5),
)
op_ne_tensor_rank3_randn = NotEqual(
torch.randn(10, 5, 2),
torch.randn(10, 5, 2),
)
op_ne_tensor_rank4_randn = NotEqual(
torch.randn(3, 2, 2, 2),
torch.randn(3, 2, 2, 2),
)

op_ne_scalar_rank1_ones = NotEqual(torch.ones(5), 1.0)
op_ne_scalar_rank2_rand = NotEqual(torch.rand(4, 5), 0.2)
op_ne_scalar_rank3_randn = NotEqual(torch.randn(10, 5, 2), -0.1)
op_ne_scalar_rank4_randn = NotEqual(torch.randn(3, 2, 2, 2), 0.3)
op_ne_scalar_rank4_randn_1batch = NotEqual(torch.randn(1, 2, 2, 2), 0.3)

test_data_tensor = {
"ne_tensor_rank1_ones": op_ne_tensor_rank1_ones,
"ne_tensor_rank2_rand": op_ne_tensor_rank2_rand,
"ne_tensor_rank3_randn": op_ne_tensor_rank3_randn,
"ne_tensor_rank4_randn": op_ne_tensor_rank4_randn,
}

test_data_scalar = {
"ne_scalar_rank1_ones": op_ne_scalar_rank1_ones,
"ne_scalar_rank2_rand": op_ne_scalar_rank2_rand,
"ne_scalar_rank3_randn": op_ne_scalar_rank3_randn,
"ne_scalar_rank4_randn": op_ne_scalar_rank4_randn,
"ne_scalar_rank4_randn_1batch": op_ne_scalar_rank4_randn_1batch,
}


@common.parametrize("test_module", test_data_tensor)
def test_ne_tensor_tosa_MI(test_module):
pipeline = TosaPipelineMI[input_t](
test_module, test_module.get_inputs(), NotEqual.aten_op_Tensor, NotEqual.exir_op
)
pipeline.run()


@common.parametrize("test_module", test_data_scalar)
def test_ne_scalar_tosa_MI(test_module):
pipeline = TosaPipelineMI[input_t](
test_module,
test_module.get_inputs(),
NotEqual.aten_op_Scalar,
NotEqual.exir_op,
)
pipeline.run()


@common.parametrize("test_module", test_data_tensor)
def test_ne_tensor_tosa_BI(test_module):
pipeline = TosaPipelineBI[input_t](
test_module, test_module.get_inputs(), NotEqual.decomposed_ops, NotEqual.exir_op
)
pipeline.run()


@common.parametrize("test_module", test_data_scalar)
def test_ne_scalar_tosa_BI(test_module):
pipeline = TosaPipelineBI[input_t](
test_module, test_module.get_inputs(), NotEqual.decomposed_ops, NotEqual.exir_op
)
pipeline.run()


@common.parametrize("test_module", test_data_tensor)
@common.XfailIfNoCorstone300
def test_ne_tensor_u55_BI(test_module):
# EQUAL is not supported on U55.
pipeline = OpNotSupportedPipeline[input_t](
test_module,
test_module.get_inputs(),
"TOSA-0.80+BI+u55",
{
NotEqual.decomposed_exir_ops[0]: 1,
NotEqual.decomposed_exir_ops[1]: 1,
},
)
pipeline.run()


@common.parametrize("test_module", test_data_scalar)
@common.XfailIfNoCorstone300
def test_ne_scalar_u55_BI(test_module):
# Not equal (ne) is decomposed into the TOSA ops EQUAL and LOGICAL_NOT, both of
# which are unsupported on U55.
pipeline = OpNotSupportedPipeline[input_t](
test_module,
test_module.get_inputs(),
"TOSA-0.80+BI+u55",
{
NotEqual.decomposed_exir_ops[0]: 1,
NotEqual.decomposed_exir_ops[1]: 1,
},
n_expected_delegates=1,
)
pipeline.run()


@common.parametrize(
"test_module",
test_data_tensor,
xfails={
"ne_tensor_rank4_randn": "MLETORCH-517: Batch size > 1 not fully supported",
},
strict=False,
)
@common.XfailIfNoCorstone320
def test_ne_tensor_u85_BI(test_module):
pipeline = EthosU85PipelineBI[input_t](
test_module,
test_module.get_inputs(),
NotEqual.decomposed_ops,
NotEqual.decomposed_exir_ops,
run_on_fvp=True,
)
pipeline.run()


@common.parametrize(
"test_module",
test_data_scalar,
xfails={
"ne_scalar_rank4_randn": "MLETORCH-517: Batch size > 1 not fully supported",
"ne_scalar_rank4_randn_1batch": "MLETORCH-847: Boolean ne result unstable on U85",
},
strict=False,
)
@common.XfailIfNoCorstone320
def test_ne_scalar_u85_BI(test_module):
pipeline = EthosU85PipelineBI[input_t](
test_module,
test_module.get_inputs(),
NotEqual.decomposed_ops,
NotEqual.decomposed_exir_ops,
run_on_fvp=True,
)
pipeline.run()
Loading