Skip to content

Commit 012f120

Browse files
authored
Arm backend: Add ComputeConstantOpsAOT pass (#9504)
Operators that output tensors based on constant args are pre-computed and added as buffers. - The pass currently supports full, arange, linspace, and eye. - Remove some logic for full now handled by the pass - Rename FuseConstantOpsPass to FuseConstantArgsPass and do minor improvements Fix retracing in FuseViewCopyTransform Since the pass can change shapes of ops, the graph needs to be retraced to show this in node.meta["val"]. Signed-off-by: Erik Lundell <[email protected]>
1 parent 766bbdc commit 012f120

13 files changed

+338
-136
lines changed

backends/arm/_passes/arm_pass_manager.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@
5555
RetraceFoldedDtypesPass,
5656
)
5757
from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass
58-
from executorch.backends.arm._passes.fuse_constant_ops_pass import FuseConstantOpsPass
58+
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
59+
ComputeConstantOpsAOT,
60+
FuseConstantArgsPass,
61+
)
5962
from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found]
6063
FuseQuantizedActivationPass,
6164
)
@@ -121,21 +124,23 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
121124
self.add_pass(QuantizeOperatorArguments())
122125
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
123126
self.add_pass(RetraceFoldedDtypesPass())
127+
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
128+
self.add_pass(MatchArgRanksPass(exported_program))
129+
self.add_pass(ComputeConstantOpsAOT(exported_program))
124130

125131
self.add_pass(RemoveClonePass())
126132
self.add_pass(SizeAdjustConv2DPass())
127133
self.add_pass(ConvertExpandCopyToRepeatPass())
128134
self.add_pass(UnsqueezeBeforeRepeatPass())
129-
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
130135
self.add_pass(CastInt64ToInt32Pass(exported_program))
131-
self.add_pass(MatchArgRanksPass(exported_program))
132136
self.add_pass(KeepDimsFalseToSqueezePass())
133137
self.add_pass(Conv1dUnsqueezePass(exported_program))
134138
self.add_pass(DecomposeSelectPass())
135139
self.add_pass(ConvertSqueezesToViewPass())
136140

137141
self.add_pass(FuseViewCopyTransform())
138-
self.add_pass(FuseConstantOpsPass(exported_program))
142+
self.add_pass(FuseConstantArgsPass(exported_program))
143+
139144
self.add_pass(InsertTableOpsPass(exported_program))
140145
self.add_pass(AnnotateChannelsLastDimOrder())
141146
self.add_pass(InsertRescalePass())
@@ -166,21 +171,22 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
166171
self.add_pass(QuantizeOperatorArguments())
167172
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
168173
self.add_pass(RetraceFoldedDtypesPass())
174+
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
175+
self.add_pass(MatchArgRanksPass(exported_program))
176+
self.add_pass(ComputeConstantOpsAOT(exported_program))
169177

170178
self.add_pass(RemoveClonePass())
171179
self.add_pass(SizeAdjustConv2DPass())
172180
self.add_pass(ConvertExpandCopyToRepeatPass())
173181
self.add_pass(UnsqueezeBeforeRepeatPass())
174-
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
175182
self.add_pass(CastInt64ToInt32Pass(exported_program))
176-
self.add_pass(MatchArgRanksPass(exported_program))
177183
self.add_pass(KeepDimsFalseToSqueezePass())
178184
self.add_pass(Conv1dUnsqueezePass(exported_program))
179185
self.add_pass(DecomposeSelectPass())
180186
self.add_pass(ConvertSqueezesToViewPass())
181187

182188
self.add_pass(FuseViewCopyTransform())
183-
self.add_pass(FuseConstantOpsPass(exported_program))
189+
self.add_pass(FuseConstantArgsPass(exported_program))
184190
self.add_pass(InsertTableOpsPass(exported_program))
185191
self.add_pass(AnnotateChannelsLastDimOrder())
186192
self.add_pass(InsertRescalePass())

backends/arm/_passes/cast_int64_pass.py

+28-27
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import logging
99

1010
import torch
11-
from executorch.backends.arm._passes.arm_pass_utils import is_param_node
1211
from executorch.exir.pass_base import ExportPass, PassResult
1312
from torch._export.utils import is_buffer
1413

@@ -25,35 +24,37 @@ def __init__(self, exported_program: torch.export.ExportedProgram):
2524
super(CastInt64ToInt32Pass, self).__init__()
2625
self.exported_program = exported_program
2726

27+
def _assert_within_int32(self, tensor: torch.Tensor, node: torch.fx.Node):
28+
if torch.min(tensor) < torch.iinfo(torch.int32).min:
29+
raise RuntimeError(
30+
f"Node {node.name} has value < {torch.iinfo(torch.int32).min}"
31+
)
32+
if torch.max(tensor) > torch.iinfo(torch.int32).max:
33+
raise RuntimeError(
34+
f"Node {node.name} has value > {torch.iinfo(torch.int32).max}"
35+
)
36+
2837
def _to_int32(self, graph_module: torch.fx.GraphModule):
2938
for node in graph_module.graph.nodes:
3039
fake_tensor = node.meta["val"]
31-
if isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor):
32-
if node.meta["val"].dtype == torch.int64 and is_param_node(
33-
self.exported_program, node
34-
):
35-
if is_buffer(self.exported_program, node):
36-
node.meta["val"] = node.meta["val"].to(torch.int32)
37-
buffer_name = (
38-
self.exported_program.graph_signature.inputs_to_buffers[
39-
node.name
40-
]
41-
)
42-
buffer = self.exported_program.state_dict[node.name]
43-
logger.warning(
44-
f"Casting buffer {node.name} from torch.int64 to torch.int32"
45-
f" defined in {node.meta['stack_trace']}"
46-
)
47-
if torch.min(buffer) < torch.iinfo(torch.int32).min:
48-
raise RuntimeError(
49-
f"Buffer {node.name} has value < {torch.iinfo(torch.int32).min}"
50-
)
51-
if torch.max(buffer) > torch.iinfo(torch.int32).max:
52-
raise RuntimeError(
53-
f"Buffer {node.name} has value > {torch.iinfo(torch.int32).max}"
54-
)
55-
buffer_int32 = buffer.to(torch.int32)
56-
self.exported_program.state_dict[buffer_name] = buffer_int32
40+
if not isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor):
41+
continue
42+
if fake_tensor.dtype != torch.int64:
43+
continue
44+
if is_buffer(self.exported_program, node):
45+
node.meta["val"] = fake_tensor.to(torch.int32)
46+
buffer_name = self.exported_program.graph_signature.inputs_to_buffers[
47+
node.name
48+
]
49+
buffer = self.exported_program.state_dict[node.name]
50+
self._assert_within_int32(buffer, node)
51+
logger.warning(
52+
f"Casting buffer {node.name} from torch.int64 to torch.int32"
53+
f" defined in {node.meta.get('stack_trace','[no stack trace found]')}"
54+
)
55+
buffer_int32 = buffer.to(torch.int32)
56+
self.exported_program.state_dict[buffer_name] = buffer_int32
57+
continue
5758

5859
def call(self, graph_module: torch.fx.GraphModule):
5960
self._to_int32(graph_module)

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

+2-15
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,8 @@ def call(self, graph_module: GraphModule) -> PassResult:
174174

175175
class QuantizeOperatorArguments(ExportPass):
176176
"""
177-
This pass makes sure that the arguments to full.default and clamp.default are quantized correctly.
177+
This pass makes sure that the arguments to clamp.default are quantized correctly.
178178
More specifically, this pass:
179-
- Makes sure the fill_value for full.default is quantized. This pass needs to be run before
180-
the folding pass above to make sure that the retraced output of the full.default op is
181-
the right dtype.
182179
- Makes sure the min and max values to clamp.default are quantized, if it's a quantized operator.
183180
"""
184181

@@ -189,7 +186,6 @@ def call(self, graph_module: GraphModule) -> PassResult:
189186
n = cast(Node, n)
190187
if n.target not in {
191188
exir_ops.edge.aten.clamp.default,
192-
exir_ops.edge.aten.full.default,
193189
}:
194190
continue
195191

@@ -200,16 +196,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
200196

201197
qargs = QuantArgs.from_operator(user.target, user.args)
202198

203-
if n.target == exir_ops.edge.aten.full.default:
204-
if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype:
205-
# replace the node arg with a quantized dito and also set dtype
206-
# to get the right output according to the Edge IR specification:
207-
# exir/dialects/edge/edge.yaml:3596
208-
quantized_full_value = qargs.quantize_value(n.args[1]).item()
209-
n.update_arg(1, quantized_full_value)
210-
n.update_kwarg("dtype", qargs.dtype)
211-
modified = True
212-
elif n.target == exir_ops.edge.aten.clamp.default:
199+
if n.target == exir_ops.edge.aten.clamp.default:
213200
# Quantize the min and max arguments of clamp, if they are not None
214201
min_val = n.args[1]
215202
max_val = None if len(n.args) <= 2 else n.args[2]

0 commit comments

Comments
 (0)