Skip to content

Arm backend: Update node visitors to support TOSA 1.0 #10390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 23, 2025
1 change: 0 additions & 1 deletion backends/arm/operator_support/to_copy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def is_node_tosa_supported(
) -> bool:
assert node.target in self.targets

assert tosa_spec.support_integer()
supported_dtypes = (
self.ALL_SUPPORTED_TYPES
if tosa_spec.support_float()
Expand Down
49 changes: 45 additions & 4 deletions backends/arm/operators/op_amax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
from typing import Any, List

import tosa_tools.v0_80.serializer.tosa_serializer as ts
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
Expand All @@ -15,19 +14,22 @@


@register_node_visitor
class MaxVisitor(NodeVisitor):
class MaxVisitor_0_80(NodeVisitor):
target = "aten.amax.default"

tosa_specs = NodeVisitor.tosa_specs_0_80

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
import tosa_tools.v0_80.serializer.tosa_serializer as ts

input = inputs[0]
dim = inputs[1].number
Expand All @@ -49,3 +51,42 @@ def define_node(
tosa_graph.addOperator(
ts.TosaOp.Op().REDUCE_MAX, [input.name], [output.name], attr
)


@register_node_visitor
class MaxVisitor(NodeVisitor):
target = "aten.amax.default"

tosa_specs = NodeVisitor.tosa_specs_1_00

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
import serializer.tosa_serializer as ts

input = inputs[0]
dim = inputs[1].number

if dim < 0:
tensor = get_first_fake_tensor(node)
rank = len(tensor.size())
dim = rank + dim

keep_dims = inputs[2].number
if not keep_dims:
raise RuntimeError(
"TOSA only supports keepdims == True; Did you run the convert_minmax pass?"
)

attr = ts.TosaSerializerAttribute()
attr.ReduceMaxAttribute(axis=input.dim_order.index(dim), nan_mode=1)
tosa_graph.addOperator(
ts.TosaOp.Op().REDUCE_MAX, [input.name], [output.name], attr
)
49 changes: 45 additions & 4 deletions backends/arm/operators/op_amin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
from typing import Any, List

import tosa_tools.v0_80.serializer.tosa_serializer as ts
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
Expand All @@ -15,19 +14,22 @@


@register_node_visitor
class MinVisitor(NodeVisitor):
class MinVisitor_0_80(NodeVisitor):
target = "aten.amin.default"

tosa_specs = NodeVisitor.tosa_specs_0_80

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
import tosa_tools.v0_80.serializer.tosa_serializer as ts

input = inputs[0]
dim = inputs[1].number
Expand All @@ -49,3 +51,42 @@ def define_node(
tosa_graph.addOperator(
ts.TosaOp.Op().REDUCE_MIN, [input.name], [output.name], attr
)


@register_node_visitor
class MinVisitor(NodeVisitor):
target = "aten.amin.default"

tosa_specs = NodeVisitor.tosa_specs_1_00

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
import serializer.tosa_serializer as ts

input = inputs[0]
dim = inputs[1].number

if dim < 0:
tensor = get_first_fake_tensor(node)
rank = len(tensor.size())
dim = rank + dim

keep_dims = inputs[2].number
if not keep_dims:
raise RuntimeError(
"TOSA only supports keepdims == True; Did you run the convert_minmax pass?"
)

attr = ts.TosaSerializerAttribute()
attr.ReduceMinAttribute(axis=input.dim_order.index(dim), nan_mode=1)
tosa_graph.addOperator(
ts.TosaOp.Op().REDUCE_MIN, [input.name], [output.name], attr
)
127 changes: 123 additions & 4 deletions backends/arm/operators/op_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

from typing import Any, List, Tuple

import numpy as np
import torch

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
Expand All @@ -34,14 +34,16 @@ def __init__(self, *args):

def _create_clamp_node(
self,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
input_name: str,
output_name: str,
min_int: int,
max_int: int,
min_fp32: float,
max_fp32: float,
) -> None:
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

attr = ts.TosaSerializerAttribute()
attr.ClampAttribute(
tosa_graph.builder,
Expand Down Expand Up @@ -81,7 +83,7 @@ def cast_type(value: Any) -> int | float:
def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
Expand Down Expand Up @@ -122,10 +124,12 @@ def __init__(self, *args):
def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

if len(node.all_input_nodes) != 1:
raise ValueError(
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
Expand All @@ -150,3 +154,118 @@ def define_node(
min_fp32,
max_fp32,
)


@register_node_visitor
class ClampVisitor_INT(NodeVisitor):
target = "aten.clamp.default"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+INT"),
]

def __init__(self, *args):
super().__init__(*args)

def _get_min_max_arguments(
self, node: Node, dtype_min: int | float, dtype_max: int | float
) -> Tuple[int | float, int | float]:

def cast_type(value: Any) -> int | float:
if isinstance(value, int):
return value
else:
# Attempt to cast to float
return float(value)

if len(node.args) != 2 and len(node.args) != 3:
raise ValueError(f"Expected len(node.args) to be 2 or 3, got {node.args}")

min_arg = dtype_min
max_arg = dtype_max

if node.args[1] is not None:
min_arg = cast_type(node.args[1])

if len(node.args) > 2:
if node.args[2] is not None:
max_arg = cast_type(node.args[2])

return min_arg, max_arg

def define_node(
self,
node: Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
import serializer.tosa_serializer as ts # type: ignore

if len(node.all_input_nodes) != 1:
raise ValueError(
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
)

# NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
min_int8, max_int8 = self._get_min_max_arguments(
node,
torch.iinfo(torch.int8).min,
torch.iinfo(torch.int8).max,
)

attr = ts.TosaSerializerAttribute()
attr.ClampAttribute(
tosa_graph.builder,
np.int8(min_int8).tobytes(),
np.int8(max_int8).tobytes(),
nan_mode=1,
)

tosa_graph.addOperator(
ts.TosaOp.Op().CLAMP, [inputs[0].name], [output.name], attr
)


@register_node_visitor
class ClampVisitor_FP(ClampVisitor_INT):
# inheriting 'target' from INT class

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
import serializer.tosa_serializer as ts # type: ignore

if len(node.all_input_nodes) != 1:
raise ValueError(
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
)

min_fp32, max_fp32 = self._get_min_max_arguments(
node,
torch.finfo(torch.float32).min,
torch.finfo(torch.float32).max,
)

attr = ts.TosaSerializerAttribute()
attr.ClampAttribute(
tosa_graph.builder,
np.float32(min_fp32).tobytes(),
np.float32(max_fp32).tobytes(),
nan_mode=1,
)

tosa_graph.addOperator(
ts.TosaOp.Op().CLAMP, [inputs[0].name], [output.name], attr
)
Loading
Loading