Skip to content

Commit e88aafc

Browse files
authored
Arm backend: Add support to neg.default (#10653)
This patch adds support to aten.neg.default via TOSA NEGATE and added test cases. Signed-off-by: Fang-Ching <[email protected]>
1 parent cd3b53d commit e88aafc

File tree

5 files changed

+150
-0
lines changed

5 files changed

+150
-0
lines changed

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def is_node_supported(
194194
exir_ops.edge.aten.mul.Tensor,
195195
exir_ops.edge.aten.ne.Tensor,
196196
exir_ops.edge.aten.ne.Scalar,
197+
exir_ops.edge.aten.neg.default,
197198
exir_ops.edge.aten.add.Scalar,
198199
exir_ops.edge.aten.sub.Scalar,
199200
exir_ops.edge.aten.mul.Scalar,
@@ -311,6 +312,7 @@ class CheckProperQuantization(OperatorSupportBase):
311312
exir_ops.edge.aten.max_pool2d_with_indices.default,
312313
exir_ops.edge.aten.mm.default,
313314
exir_ops.edge.aten.mul.Tensor,
315+
exir_ops.edge.aten.neg.default,
314316
exir_ops.edge.aten.relu.default,
315317
exir_ops.edge.aten.sub.Tensor,
316318
exir_ops.edge.aten.upsample_bilinear2d.vec,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
op_maximum,
3232
op_minimum,
3333
op_mul,
34+
op_neg,
3435
op_permute,
3536
op_pow,
3637
op_reciprocal,

backends/arm/operators/op_neg.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
from typing import List
8+
9+
import torch.fx
10+
11+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
12+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
13+
get_input_qparams,
14+
get_output_qparams,
15+
)
16+
from executorch.backends.arm.operators.node_visitor import (
17+
NodeVisitor,
18+
register_node_visitor,
19+
)
20+
21+
from executorch.backends.arm.tosa_mapping import TosaArg
22+
23+
24+
def get_negate_zero_points(node: torch.fx.Node, dtype: ts.DType) -> tuple[int, int]:
25+
"""
26+
Returns (input1_zp, output_zp) for TOSA NEGATE.
27+
Must be zero for non-int8 types.
28+
"""
29+
if dtype == ts.DType.INT8:
30+
return (
31+
get_input_qparams(node)[0].zp,
32+
get_output_qparams(node)[0].zp,
33+
)
34+
return (0, 0)
35+
36+
37+
@register_node_visitor
38+
class NegVisitor(NodeVisitor):
39+
target = "aten.neg.default"
40+
41+
supported_dtypes = {
42+
ts.DType.INT8,
43+
ts.DType.INT16,
44+
ts.DType.INT32,
45+
ts.DType.FP16,
46+
ts.DType.BF16,
47+
ts.DType.FP32,
48+
}
49+
50+
def __init__(self, *args):
51+
super().__init__(*args)
52+
53+
def define_node(
54+
self,
55+
node: torch.fx.Node,
56+
tosa_graph: ts.TosaSerializer,
57+
inputs: List[TosaArg],
58+
output: TosaArg,
59+
) -> None:
60+
61+
if inputs[0].dtype not in self.supported_dtypes:
62+
raise ValueError(f"Unsupported dtype for NEGATE: {inputs[0].dtype}")
63+
64+
if inputs[0].dtype != output.dtype:
65+
raise ValueError(
66+
"All inputs and output need same dtype."
67+
f"Got {inputs[0].dtype=}, {output.dtype=}"
68+
)
69+
input_zp, output_zp = get_negate_zero_points(node, inputs[0].dtype)
70+
71+
attr = ts.TosaSerializerAttribute()
72+
attr.NegateAttribute(input1_zp=input_zp, output_zp=output_zp)
73+
tosa_graph.addOperator(
74+
ts.TosaOp.Op().NEGATE,
75+
[inputs[0].name],
76+
[output.name],
77+
attributes=attr,
78+
)

backends/arm/quantizer/quantization_annotator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,9 @@ def any_or_hardtanh_min_zero(n: Node):
375375
)
376376
]
377377
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
378+
elif node.target in (torch.ops.aten.neg.default,):
379+
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
380+
quant_properties.quant_output = _QuantProperty(0, input_act_qspec)
378381
elif node.target in _one_to_one:
379382
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
380383
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)

backends/arm/test/ops/test_neg.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from typing import Dict, Tuple
8+
9+
import torch
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU55PipelineBI,
13+
EthosU85PipelineBI,
14+
TosaPipelineBI,
15+
TosaPipelineMI,
16+
)
17+
18+
input_t1 = Tuple[torch.Tensor]
19+
20+
21+
class Neg(torch.nn.Module):
22+
23+
aten_op = "torch.ops.aten.neg.default"
24+
exir_op = "executorch_exir_dialects_edge__ops_aten_neg_default"
25+
26+
test_data: Dict[str, input_t1] = {
27+
"rank_1_ramp": (torch.arange(-16, 16, 0.2),),
28+
"rank_2_rand_uniform": (torch.rand(10, 10) - 0.5,),
29+
"rank_3_all_ones": (torch.ones(10, 10, 10),),
30+
"rank_4_all_zeros": (torch.zeros(1, 10, 10, 10),),
31+
"rank_4_randn_pos": (torch.randn(1, 4, 4, 4) + 10,),
32+
"rank_4_randn_neg": (torch.randn(1, 4, 4, 4) - 10,),
33+
}
34+
35+
def forward(self, x: torch.Tensor):
36+
return torch.neg(x)
37+
38+
39+
@common.parametrize("test_data", Neg.test_data)
40+
def test_neg_tosa_MI(test_data: input_t1):
41+
pipeline = TosaPipelineMI[input_t1](Neg(), test_data, Neg.aten_op, Neg.exir_op)
42+
pipeline.run()
43+
44+
45+
@common.parametrize("test_data", Neg.test_data)
46+
def test_neg_tosa_BI(test_data: input_t1):
47+
pipeline = TosaPipelineBI[input_t1](Neg(), test_data, Neg.aten_op, Neg.exir_op)
48+
pipeline.run()
49+
50+
51+
@common.parametrize("test_data", Neg.test_data)
52+
@common.XfailIfNoCorstone300
53+
def test_neg_u55_BI(test_data: input_t1):
54+
pipeline = EthosU55PipelineBI[input_t1](
55+
Neg(), test_data, Neg.aten_op, Neg.exir_op, run_on_fvp=True
56+
)
57+
pipeline.run()
58+
59+
60+
@common.parametrize("test_data", Neg.test_data)
61+
@common.XfailIfNoCorstone320
62+
def test_neg_u85_BI(test_data: input_t1):
63+
pipeline = EthosU85PipelineBI[input_t1](
64+
Neg(), test_data, Neg.aten_op, Neg.exir_op, run_on_fvp=True
65+
)
66+
pipeline.run()

0 commit comments

Comments
 (0)