|
| 1 | +# Copyright 2025 Arm Limited and/or its affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD-style license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +# pyre-unsafe |
| 7 | + |
| 8 | +import operator |
| 9 | + |
| 10 | +import torch |
| 11 | +from executorch.backends.arm._passes import ArmPass |
| 12 | +from executorch.backends.arm._passes.arm_pass_utils import create_node |
| 13 | +from executorch.exir.dialects._ops import ops as exir_ops |
| 14 | +from executorch.exir.pass_base import PassResult |
| 15 | + |
| 16 | + |
| 17 | +def get_group_norm_decomposition(op) -> tuple: |
| 18 | + if op == exir_ops.edge.aten.native_group_norm.default: |
| 19 | + return ( |
| 20 | + exir_ops.edge.aten.mean.dim, |
| 21 | + exir_ops.edge.aten.sub.Tensor, |
| 22 | + exir_ops.edge.aten.var.correction, |
| 23 | + exir_ops.edge.aten.full.default, |
| 24 | + exir_ops.edge.aten.add.Tensor, |
| 25 | + exir_ops.edge.aten.rsqrt.default, |
| 26 | + exir_ops.edge.aten.mul.Tensor, |
| 27 | + exir_ops.edge.aten.view_copy.default, |
| 28 | + ) |
| 29 | + if op == torch.ops.aten.group_norm.default: |
| 30 | + return ( |
| 31 | + torch.ops.aten.mean.dim, |
| 32 | + torch.ops.aten.sub.Tensor, |
| 33 | + torch.ops.aten.var.correction, |
| 34 | + torch.ops.aten.full.default, |
| 35 | + torch.ops.aten.add.Tensor, |
| 36 | + torch.ops.aten.rsqrt.default, |
| 37 | + torch.ops.aten.mul.Tensor, |
| 38 | + torch.ops.aten.view_copy.default, |
| 39 | + ) |
| 40 | + raise RuntimeError(f"Can't get group_norm composition for op {op}") |
| 41 | + |
| 42 | + |
| 43 | +class DecomposeGroupNormPass(ArmPass): |
| 44 | + """ |
| 45 | + groupnorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias |
| 46 | + Decompose groupnorm(x, weight, bias, N, C, HxW, group, eps) to a sequence of: |
| 47 | + mean = op_mean(x, dims) # E[x] |
| 48 | + var = op_var(x, dims) # Var[x] |
| 49 | + numerator = op_sub(x, mean) # (x - E[x]) |
| 50 | + add = op_add(var, eps) # Var[x] + eps |
| 51 | + rsqrt = op_rsqrt(add) # 1 / sqrt(Var[x] + eps) |
| 52 | + mul = op_mul(numerator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps)) |
| 53 | + weigths = op_mul(mul, weigths) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths |
| 54 | + bias = op_add(weigths, bias) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias |
| 55 | + where x can viewed with shape [N, group, C//group, HxW] dims=[C//group, HxW] |
| 56 | +
|
| 57 | + Source: https://pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html |
| 58 | + """ |
| 59 | + |
| 60 | + def call(self, graph_module: torch.fx.GraphModule): |
| 61 | + modified = False |
| 62 | + for node in graph_module.graph.nodes: |
| 63 | + if node.op != "call_function" or node.target not in ( |
| 64 | + exir_ops.edge.aten.native_group_norm.default, |
| 65 | + torch.ops.aten.group_norm.default, |
| 66 | + ): |
| 67 | + continue |
| 68 | + |
| 69 | + # epsilon default value |
| 70 | + eps = torch.finfo().eps |
| 71 | + weights = None |
| 72 | + bias = None |
| 73 | + args = node.args |
| 74 | + meta = node.meta |
| 75 | + if isinstance(meta["val"], tuple): |
| 76 | + shape = meta["val"][0].size() |
| 77 | + dtype = meta["val"][0].dtype |
| 78 | + else: |
| 79 | + shape = meta["val"].size() |
| 80 | + dtype = meta["val"].dtype |
| 81 | + match len(args): |
| 82 | + # MI profile always provides all the args: x, weight, bias, N, C, HxW, group, eps |
| 83 | + case 8: |
| 84 | + x, weights, bias, N, C, HxW, group, eps = args |
| 85 | + # BI profile: affine=[True|False], eps!=1e-5 |
| 86 | + case 5: |
| 87 | + x, group, weights, bias, eps = args |
| 88 | + # BI profile: affine=True, eps=1e-5 |
| 89 | + case 4: |
| 90 | + x, group, weights, bias = args |
| 91 | + # BI profile: affine=False, eps=1e=5 |
| 92 | + case 2: |
| 93 | + x, group = args |
| 94 | + # Unsupported args |
| 95 | + case _: |
| 96 | + raise ValueError( |
| 97 | + f"Unsupported group_norm argument pattern with {len(args)} args" |
| 98 | + ) |
| 99 | + N = shape[0] |
| 100 | + C = shape[1] |
| 101 | + HxW = 1 |
| 102 | + for dim in shape[2:]: |
| 103 | + HxW *= dim |
| 104 | + channels_per_group = C // group |
| 105 | + grouped_shape = torch.Size([N, group, channels_per_group, HxW]) |
| 106 | + dims = [2, 3] |
| 107 | + epsilon_reshaped_shape = torch.Size([1] * len(grouped_shape)) |
| 108 | + weights_reshaped_shape = torch.Size([1, group, channels_per_group, 1]) |
| 109 | + ( |
| 110 | + mean_op, |
| 111 | + sub_op, |
| 112 | + var_op, |
| 113 | + full_op, |
| 114 | + add_op, |
| 115 | + rsqrt_op, |
| 116 | + mul_op, |
| 117 | + view_op, |
| 118 | + ) = get_group_norm_decomposition(node.target) |
| 119 | + with graph_module.graph.inserting_before(node): |
| 120 | + keepdim = True |
| 121 | + x_reshaped = create_node( |
| 122 | + graph_module.graph, |
| 123 | + view_op, |
| 124 | + args=(x, grouped_shape), |
| 125 | + from_node=node, |
| 126 | + ) |
| 127 | + mean = create_node( |
| 128 | + graph_module.graph, mean_op, args=(x_reshaped, dims, keepdim) |
| 129 | + ) |
| 130 | + sub = create_node(graph_module.graph, sub_op, args=(x_reshaped, mean)) |
| 131 | + var = create_node( |
| 132 | + graph_module.graph, |
| 133 | + var_op, |
| 134 | + args=(x_reshaped, dims), |
| 135 | + kwargs={"correction": 0, "keepdim": keepdim}, |
| 136 | + from_node=node, |
| 137 | + ) |
| 138 | + full = create_node( |
| 139 | + graph_module.graph, |
| 140 | + full_op, |
| 141 | + args=(epsilon_reshaped_shape, eps), |
| 142 | + kwargs={"dtype": dtype}, |
| 143 | + from_node=node, |
| 144 | + ) |
| 145 | + add0 = create_node( |
| 146 | + graph_module.graph, add_op, args=(var, full), from_node=node |
| 147 | + ) |
| 148 | + rsqrt = create_node( |
| 149 | + graph_module.graph, rsqrt_op, args=(add0,), from_node=node |
| 150 | + ) |
| 151 | + mul0 = create_node( |
| 152 | + graph_module.graph, mul_op, args=(sub, rsqrt), from_node=node |
| 153 | + ) |
| 154 | + if weights is not None: |
| 155 | + weights_reshaped = create_node( |
| 156 | + graph_module.graph, |
| 157 | + view_op, |
| 158 | + args=(weights, weights_reshaped_shape), |
| 159 | + from_node=node, |
| 160 | + ) |
| 161 | + mul1 = create_node( |
| 162 | + graph_module.graph, |
| 163 | + mul_op, |
| 164 | + args=( |
| 165 | + mul0, |
| 166 | + weights_reshaped, |
| 167 | + ), |
| 168 | + from_node=node, |
| 169 | + ) |
| 170 | + else: |
| 171 | + mul1 = mul0 |
| 172 | + if bias is not None: |
| 173 | + bias_reshaped_shape = weights_reshaped_shape |
| 174 | + bias_reshaped = create_node( |
| 175 | + graph_module.graph, |
| 176 | + view_op, |
| 177 | + args=(bias, bias_reshaped_shape), |
| 178 | + from_node=node, |
| 179 | + ) |
| 180 | + output = create_node( |
| 181 | + graph_module.graph, |
| 182 | + add_op, |
| 183 | + args=(mul1, bias_reshaped), |
| 184 | + from_node=node, |
| 185 | + ) |
| 186 | + else: |
| 187 | + output = mul1 |
| 188 | + |
| 189 | + output_reshaped = create_node( |
| 190 | + graph_module.graph, |
| 191 | + view_op, |
| 192 | + args=(output, shape), |
| 193 | + from_node=node, |
| 194 | + ) |
| 195 | + |
| 196 | + users = [user for user in node.users if node != user] |
| 197 | + node.replace_all_uses_with(output_reshaped) |
| 198 | + for user in users: |
| 199 | + if user.target == operator.getitem: |
| 200 | + user.replace_all_uses_with(output_reshaped) |
| 201 | + graph_module.graph.erase_node(node) |
| 202 | + graph_module.graph.eliminate_dead_code() |
| 203 | + modified = True |
| 204 | + if modified: |
| 205 | + graph_module.recompile() |
| 206 | + graph_module = super().call(graph_module).graph_module |
| 207 | + |
| 208 | + return PassResult(graph_module, modified) |
0 commit comments