Skip to content

Commit 3ef5cb2

Browse files
Arm backend: Add TOSA support for GroupNorm
- Decompose groupnorm into a sequence of supported operators - Have some numerical issues with BI profile - Fix docstring in decompose_layernorm_pass - Add "native_group_norm.default" to CUSTOM_EDGE_OPS Change-Id: I3f70388c12b8d9afd52876840b6c008a1b0bec4e Signed-off-by: Yufeng Shi <[email protected]>
1 parent 8e89094 commit 3ef5cb2

File tree

7 files changed

+351
-5
lines changed

7 files changed

+351
-5
lines changed

backends/arm/_passes/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .decompose_batchnorm_pass import DecomposeBatchNormPass # noqa
2222
from .decompose_div_pass import DecomposeDivPass # noqa
2323
from .decompose_gelu_pass import DecomposeGeluPass # noqa
24+
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
2425
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
2526
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
2627
from .decompose_linear_pass import DecomposeLinearPass # noqa

backends/arm/_passes/arm_pass_manager.py

+3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
DecomposeBatchNormPass,
2727
DecomposeDivPass,
2828
DecomposeGeluPass,
29+
DecomposeGroupNormPass,
2930
DecomposeLayerNormPass,
3031
DecomposeLeakyReLUPass,
3132
DecomposeLinearPass,
@@ -127,6 +128,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
127128
self.add_pass(DecomposeLinearPass())
128129
self.add_pass(DecomposeLeakyReLUPass())
129130
self.add_pass(DecomposeBatchNormPass())
131+
self.add_pass(DecomposeGroupNormPass())
130132
self.add_pass(DecomposeLayerNormPass())
131133
self.add_pass(DecomposeVarPass())
132134
self.add_pass(DecomposeMeanDimPass())
@@ -180,6 +182,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
180182
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
181183
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
182184
self.add_pass(ScalarsToAttributePass())
185+
self.add_pass(DecomposeGroupNormPass())
183186
self.add_pass(DecomposeLayerNormPass())
184187
self.add_pass(DecomposeVarPass())
185188
self.add_pass(DecomposeMeanDimPass())
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
import operator
9+
10+
import torch
11+
from executorch.backends.arm._passes import ArmPass
12+
from executorch.backends.arm._passes.arm_pass_utils import create_node
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import PassResult
15+
16+
17+
def get_group_norm_decomposition(op) -> tuple:
18+
if op == exir_ops.edge.aten.native_group_norm.default:
19+
return (
20+
exir_ops.edge.aten.mean.dim,
21+
exir_ops.edge.aten.sub.Tensor,
22+
exir_ops.edge.aten.var.correction,
23+
exir_ops.edge.aten.full.default,
24+
exir_ops.edge.aten.add.Tensor,
25+
exir_ops.edge.aten.rsqrt.default,
26+
exir_ops.edge.aten.mul.Tensor,
27+
exir_ops.edge.aten.view_copy.default,
28+
)
29+
if op == torch.ops.aten.group_norm.default:
30+
return (
31+
torch.ops.aten.mean.dim,
32+
torch.ops.aten.sub.Tensor,
33+
torch.ops.aten.var.correction,
34+
torch.ops.aten.full.default,
35+
torch.ops.aten.add.Tensor,
36+
torch.ops.aten.rsqrt.default,
37+
torch.ops.aten.mul.Tensor,
38+
torch.ops.aten.view_copy.default,
39+
)
40+
raise RuntimeError(f"Can't get group_norm composition for op {op}")
41+
42+
43+
class DecomposeGroupNormPass(ArmPass):
44+
"""
45+
groupnorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias
46+
Decompose groupnorm(x, weight, bias, N, C, HxW, group, eps) to a sequence of:
47+
mean = op_mean(x, dims) # E[x]
48+
var = op_var(x, dims) # Var[x]
49+
numerator = op_sub(x, mean) # (x - E[x])
50+
add = op_add(var, eps) # Var[x] + eps
51+
rsqrt = op_rsqrt(add) # 1 / sqrt(Var[x] + eps)
52+
mul = op_mul(numerator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps))
53+
weigths = op_mul(mul, weigths) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths
54+
bias = op_add(weigths, bias) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias
55+
where x can viewed with shape [N, group, C//group, HxW] dims=[C//group, HxW]
56+
57+
Source: https://pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html
58+
"""
59+
60+
def call(self, graph_module: torch.fx.GraphModule):
61+
modified = False
62+
for node in graph_module.graph.nodes:
63+
if node.op != "call_function" or node.target not in (
64+
exir_ops.edge.aten.native_group_norm.default,
65+
torch.ops.aten.group_norm.default,
66+
):
67+
continue
68+
69+
# epsilon default value
70+
eps = torch.finfo().eps
71+
weights = None
72+
bias = None
73+
args = node.args
74+
meta = node.meta
75+
if isinstance(meta["val"], tuple):
76+
shape = meta["val"][0].size()
77+
dtype = meta["val"][0].dtype
78+
else:
79+
shape = meta["val"].size()
80+
dtype = meta["val"].dtype
81+
match len(args):
82+
# MI profile always provides all the args: x, weight, bias, N, C, HxW, group, eps
83+
case 8:
84+
x, weights, bias, N, C, HxW, group, eps = args
85+
# BI profile: affine=[True|False], eps!=1e-5
86+
case 5:
87+
x, group, weights, bias, eps = args
88+
# BI profile: affine=True, eps=1e-5
89+
case 4:
90+
x, group, weights, bias = args
91+
# BI profile: affine=False, eps=1e=5
92+
case 2:
93+
x, group = args
94+
# Unsupported args
95+
case _:
96+
raise ValueError(
97+
f"Unsupported group_norm argument pattern with {len(args)} args"
98+
)
99+
N = shape[0]
100+
C = shape[1]
101+
HxW = 1
102+
for dim in shape[2:]:
103+
HxW *= dim
104+
channels_per_group = C // group
105+
grouped_shape = torch.Size([N, group, channels_per_group, HxW])
106+
dims = [2, 3]
107+
epsilon_reshaped_shape = torch.Size([1] * len(grouped_shape))
108+
weights_reshaped_shape = torch.Size([1, group, channels_per_group, 1])
109+
(
110+
mean_op,
111+
sub_op,
112+
var_op,
113+
full_op,
114+
add_op,
115+
rsqrt_op,
116+
mul_op,
117+
view_op,
118+
) = get_group_norm_decomposition(node.target)
119+
with graph_module.graph.inserting_before(node):
120+
keepdim = True
121+
x_reshaped = create_node(
122+
graph_module.graph,
123+
view_op,
124+
args=(x, grouped_shape),
125+
from_node=node,
126+
)
127+
mean = create_node(
128+
graph_module.graph, mean_op, args=(x_reshaped, dims, keepdim)
129+
)
130+
sub = create_node(graph_module.graph, sub_op, args=(x_reshaped, mean))
131+
var = create_node(
132+
graph_module.graph,
133+
var_op,
134+
args=(x_reshaped, dims),
135+
kwargs={"correction": 0, "keepdim": keepdim},
136+
from_node=node,
137+
)
138+
full = create_node(
139+
graph_module.graph,
140+
full_op,
141+
args=(epsilon_reshaped_shape, eps),
142+
kwargs={"dtype": dtype},
143+
from_node=node,
144+
)
145+
add0 = create_node(
146+
graph_module.graph, add_op, args=(var, full), from_node=node
147+
)
148+
rsqrt = create_node(
149+
graph_module.graph, rsqrt_op, args=(add0,), from_node=node
150+
)
151+
mul0 = create_node(
152+
graph_module.graph, mul_op, args=(sub, rsqrt), from_node=node
153+
)
154+
if weights is not None:
155+
weights_reshaped = create_node(
156+
graph_module.graph,
157+
view_op,
158+
args=(weights, weights_reshaped_shape),
159+
from_node=node,
160+
)
161+
mul1 = create_node(
162+
graph_module.graph,
163+
mul_op,
164+
args=(
165+
mul0,
166+
weights_reshaped,
167+
),
168+
from_node=node,
169+
)
170+
else:
171+
mul1 = mul0
172+
if bias is not None:
173+
bias_reshaped_shape = weights_reshaped_shape
174+
bias_reshaped = create_node(
175+
graph_module.graph,
176+
view_op,
177+
args=(bias, bias_reshaped_shape),
178+
from_node=node,
179+
)
180+
output = create_node(
181+
graph_module.graph,
182+
add_op,
183+
args=(mul1, bias_reshaped),
184+
from_node=node,
185+
)
186+
else:
187+
output = mul1
188+
189+
output_reshaped = create_node(
190+
graph_module.graph,
191+
view_op,
192+
args=(output, shape),
193+
from_node=node,
194+
)
195+
196+
users = [user for user in node.users if node != user]
197+
node.replace_all_uses_with(output_reshaped)
198+
for user in users:
199+
if user.target == operator.getitem:
200+
user.replace_all_uses_with(output_reshaped)
201+
graph_module.graph.erase_node(node)
202+
graph_module.graph.eliminate_dead_code()
203+
modified = True
204+
if modified:
205+
graph_module.recompile()
206+
graph_module = super().call(graph_module).graph_module
207+
208+
return PassResult(graph_module, modified)

backends/arm/_passes/decompose_layernorm_pass.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
@@ -47,11 +46,12 @@ class DecomposeLayerNormPass(ArmPass):
4746
Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of:
4847
mean = op_mean(x, dims) # E[x]
4948
var = op_var(x, dims) # Var[x]
50-
denominator = op_sub(x, mean) # (x - E[x])
49+
numerator = op_sub(x, mean) # (x - E[x])
5150
add = op_add(var, eps) # Var[x] + eps
5251
rsqrt = op_rsqrt(add) # 1 / sqrt(Var[x] + eps)
53-
mul = op_mul(denominator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths
54-
bias = op_add(mul, bias) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias
52+
mul = op_mul(numerator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps))
53+
weigths = op_mul(mul, weigths) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths
54+
bias = op_add(weigths, bias) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias
5555
5656
Source: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
5757
"""

backends/arm/operator_support/tosa_supported_operators.py

+2
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def is_node_supported(
188188
exir_ops.edge.aten.div.Scalar,
189189
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
190190
exir_ops.edge.aten.native_layer_norm.default,
191+
exir_ops.edge.aten.native_group_norm.default,
191192
exir_ops.edge.aten.sigmoid.default,
192193
exir_ops.edge.aten.mean.dim,
193194
exir_ops.edge.aten.mm.default,
@@ -255,6 +256,7 @@ def is_node_supported(
255256
exir_ops.edge.aten.div.Tensor,
256257
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
257258
exir_ops.edge.aten.native_layer_norm.default,
259+
exir_ops.edge.aten.native_group_norm.default,
258260
exir_ops.edge.aten.mean.dim,
259261
exir_ops.edge.aten._softmax.default,
260262
exir_ops.edge.aten._log_softmax.default,

backends/arm/scripts/parse_test_names.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from executorch.exir.dialects.edge.spec.utils import SAMPLE_INPUT
66

77
# Add edge ops which we lower but which are not included in exir/dialects/edge/edge.yaml here.
8-
CUSTOM_EDGE_OPS = ["linspace.default", "eye.default"]
8+
CUSTOM_EDGE_OPS = ["linspace.default", "eye.default", "native_group_norm.default"]
99
ALL_EDGE_OPS = SAMPLE_INPUT.keys() | CUSTOM_EDGE_OPS
1010

1111
# Add all targets and TOSA profiles we support here.

0 commit comments

Comments
 (0)