Skip to content

Commit 061280a

Browse files
authored
Merge branch 'main' into arm-passes-init
2 parents b92b8bd + 78ee0e6 commit 061280a

File tree

7 files changed

+506
-27
lines changed

7 files changed

+506
-27
lines changed

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,17 @@ def rescale_fake(
3838
"""Casts the input tensor to dtype `dtype` to produce the correct tensor meta for a _rescale op.
3939
Additionally validates TOSA constraints of a RESCALE op.
4040
"""
41-
if not (dtype == torch.int32 or dtype == torch.int8):
41+
if dtype not in (torch.int32, torch.int8, torch.int16):
4242
raise NotImplementedError(
43-
"tosa::rescale currently only supports int32 and int8."
43+
f"tosa::rescale currently only supports int32, int16 and int8, not {dtype}"
4444
)
45-
if dtype == torch.int32 and out_zp != 0:
45+
if dtype in (torch.int32, torch.int16) and out_zp != 0:
4646
raise ValueError(
47-
"TOSA requires output_zp to be zero when the output dtype is int32."
47+
f"TOSA requires output_zp to be zero when the output dtype is {dtype}."
4848
)
49-
if x.dtype == torch.int32 and in_zp != 0:
49+
if x.dtype in (torch.int32, torch.int16) and in_zp != 0:
5050
raise ValueError(
51-
"TOSA requires input_zp to be zero when the input dtype is int32."
51+
f"TOSA requires input_zp to be zero when the input dtype is {dtype}"
5252
)
5353
if x.dtype == torch.int8 and not -128 <= in_zp <= 127:
5454
raise ValueError(f"{in_zp=} outside valid range (-128,127) for int8.")

backends/arm/_passes/insert_table_ops.py

Lines changed: 106 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
@@ -18,6 +17,7 @@
1817

1918
from executorch.exir.pass_base import ExportPass, PassResult
2019
from torch.fx import GraphModule
20+
2121
from torch.library import impl, Library
2222

2323
lib = Library("tosa", "DEF")
@@ -26,7 +26,10 @@
2626

2727
@impl(lib, "_table")
2828
def _table_impl(*args, **kwargs): # pyre-ignore
29-
return args[0]
29+
in_dtype = args[0].dtype
30+
if in_dtype == torch.int8:
31+
return args[0]
32+
return args[0].to(dtype=torch.int32)
3033

3134

3235
class InsertTableOpsPass(ExportPass):
@@ -59,29 +62,105 @@ def register_buffer(self, buffer_name: str, buffer: torch.Tensor) -> None:
5962
"""
6063
self.exported_program.state_dict[buffer_name] = buffer
6164

62-
def generate_table_values(
65+
def generate_8bit_table_values(
6366
self,
6467
torch_op: Callable[[torch.Tensor], torch.Tensor],
6568
in_quantargs: QuantArgs,
6669
out_quantargs: QuantArgs,
67-
) -> torch.Tensor:
70+
) -> tuple[torch.Tensor, int]:
71+
"""Compute LUT values for a INT8 TOSA.TABLE. Also returns 0 since no shifting is required after 8bit table.
72+
The INT8 table is a simple 256 value 1-1 LUT.
73+
"""
74+
6875
def f(x: torch.Tensor) -> torch.Tensor:
6976
x = in_quantargs.dequantize_value(x)
7077
x = torch_op(x)
7178
return out_quantargs.quantize_value(x)
7279

73-
input_dtype = in_quantargs.dtype
74-
steps = in_quantargs.qmax - in_quantargs.qmin + 1
75-
return f(
80+
return (
81+
f(
82+
torch.linspace(
83+
start=in_quantargs.qmin,
84+
end=in_quantargs.qmax,
85+
steps=256,
86+
# use torch.int64 to avoid overflow when dequantizing (subtracting zp).
87+
# e.g. torch.tensor(-50, dtype=torch.int8) - 100 == torch.tensor(106, dtype=torch.int8)
88+
dtype=torch.int64,
89+
)
90+
).to(dtype=torch.int8),
91+
0,
92+
)
93+
94+
def generate_16_bit_table_values(
95+
self,
96+
torch_op: Callable[[torch.Tensor], torch.Tensor],
97+
in_quantargs: QuantArgs,
98+
out_quantargs: QuantArgs,
99+
) -> tuple[torch.Tensor, int]:
100+
"""Compute LUT values for a INT16 TOSA.TABLE with 32 bit output.
101+
In practice the output is 23 bits that should be interpreted as 16 'whole' bits and 7 fractional bits, see
102+
the specification: https://www.mlplatform.org/tosa/tosa_spec.html#_table. This means that the output
103+
will interpreted as 2**7=128 times too large unless accounted for by rescaling down the table output.
104+
105+
Quantization can be either int16 or int32 which means that the op output could be larger than the 23 bits from
106+
the TOSA.TABLE output. In that case, we need to rescale up the output.
107+
108+
To handle this we need to:
109+
1) Make sure that our table values fit within 16 bits.
110+
2) Insert a rescale after the table to handle the x128 from the fractional bits and match the quantization.
111+
112+
The function returns rescale_lshift which says how much to rescale after the table. This value can negative.
113+
"""
114+
115+
def f(x: torch.Tensor) -> torch.Tensor:
116+
# Dont use the 7 LSBs.
117+
x = in_quantargs.dequantize_value((x & ~0x7F))
118+
x = torch_op(x)
119+
return out_quantargs.quantize_value(x)
120+
121+
lut_values = f(
76122
torch.linspace(
77123
start=in_quantargs.qmin,
78-
end=in_quantargs.qmax,
79-
steps=steps,
124+
end=in_quantargs.qmax + 1,
125+
steps=513,
80126
# use torch.int64 to avoid overflow when dequantizing (subtracting zp).
81127
# e.g. torch.tensor(-50, dtype=torch.int8) - 100 == torch.tensor(106, dtype=torch.int8)
82128
dtype=torch.int64,
83129
)
84-
).to(dtype=input_dtype)
130+
)
131+
# Calculate how much we need to shift table values to fit in 16 signed bits
132+
# ceil(log2(max absolute table value)) + 1 bit for signedness - 16
133+
# Example:
134+
# Max value in the table is 70 000. We want to fit it in 16 signed bits.
135+
# 70 000=0b10001000101110000 (17 digits) has ceil(log2(70 000)) = ceil(16.095) = 17 bits.
136+
# If we shift it 17-16=1 bit, we do get 16 bits (0b1000100010111000),
137+
# but due to signedness this is a negative number! So we need to shift it one more bit.
138+
# Note: for out_quantargs.dtype=torch.int16, rshift == 0 and rescale_lshift = -7.
139+
rshift = int(torch.ceil(torch.log2(lut_values.abs().max()))) + 1 - 16
140+
# The 7 fractional bits are equivalent to a lshift of 7, so subtract 7 from the lshift we do.
141+
rescale_lshift = rshift - 7
142+
lut_values = lut_values >> rshift
143+
return lut_values.to(dtype=torch.int16), rescale_lshift
144+
145+
def generate_table_values(
146+
self,
147+
torch_op: Callable[[torch.Tensor], torch.Tensor],
148+
in_quantargs: QuantArgs,
149+
out_quantargs: QuantArgs,
150+
) -> tuple[torch.Tensor, int]:
151+
match out_quantargs.dtype:
152+
case torch.int8:
153+
return self.generate_8bit_table_values(
154+
torch_op, in_quantargs, out_quantargs
155+
)
156+
case torch.int16 | torch.int32:
157+
return self.generate_16_bit_table_values(
158+
torch_op, in_quantargs, out_quantargs
159+
)
160+
case _:
161+
raise ValueError(
162+
f"Unsupported output dtype for table: {out_quantargs.dtype}"
163+
)
85164

86165
def call(self, graph_module: GraphModule) -> PassResult:
87166
modified = False
@@ -100,10 +179,12 @@ def call(self, graph_module: GraphModule) -> PassResult:
100179
op_target=torch.ops.tosa._table.default,
101180
args=(node.args[0],),
102181
)
182+
output_node = table_node
103183
assert len(input_qparams) == 1
104184
assert len(output_qparams) == 1
105-
# Generate table buffer
106-
buffer = self.generate_table_values(
185+
186+
# Generate table buffer and how much to lshift the table output.
187+
buffer, lshift = self.generate_table_values(
107188
torch_op=self.table_ops[node.target],
108189
in_quantargs=input_qparams[0],
109190
out_quantargs=output_qparams[0],
@@ -114,10 +195,20 @@ def call(self, graph_module: GraphModule) -> PassResult:
114195
self.register_buffer(
115196
buffer_name=table_node.name.replace("_default", ""), buffer=buffer
116197
)
117-
node.replace_all_uses_with(table_node)
198+
199+
if lshift != 0:
200+
scale = 2.0**lshift
201+
rescale_node = create_node(
202+
graph=graph_module.graph,
203+
op_target=torch.ops.tosa._rescale.default,
204+
args=(table_node, output_qparams[0].dtype, scale, 0, 0),
205+
)
206+
output_node = rescale_node
207+
208+
node.replace_all_uses_with(output_node)
118209
graph_module.graph.erase_node(node)
119-
table_node.meta["input_qparams"] = input_qparams
120-
table_node.meta["output_qparams"] = output_qparams
210+
output_node.meta["input_qparams"] = input_qparams
211+
output_node.meta["output_qparams"] = output_qparams
121212
modified = True
122213

123214
if modified:

backends/arm/operators/node_visitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class NodeVisitor:
3030
]
3131

3232
def __init__(self, exported_program: ExportedProgram, tosa_spec: TosaSpecification):
33-
self._exported_program = exported_program or None
33+
self._exported_program = exported_program
3434
self.tosa_spec = tosa_spec
3535

3636
def define_node(

backends/arm/operators/op_rescale.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def define_node(
3838
input_zp = cast(int, node.args[3])
3939
output_zp = cast(int, node.args[4])
4040

41-
# Skip int16 cases for now.
4241
if input_dtype != map_dtype(torch.int8) and input_zp != 0:
4342
raise ValueError(
4443
f"If input dtype is not int8, input_zp must be 0. Got input_dtype{ts.DTypeNames[input_dtype]}, {input_zp=}"
@@ -48,7 +47,10 @@ def define_node(
4847
f"If output dtype is not int8, output_zp must be 0. Got {output_dtype=}, {output_zp=}"
4948
)
5049

51-
scale_width = 32 if output_dtype == torch.int32 else 16
50+
# scale32 gives higher accuracy but for a higher HW cost.
51+
# For now, always go for scale32.
52+
scale_32 = True
53+
scale_width = 32 if scale_32 else 16
5254
multiplier, shift = tosa_quant_utils.compute_multiplier_and_shift(
5355
[scale], scale_width
5456
)
@@ -58,7 +60,7 @@ def define_node(
5860
output_zp=output_zp,
5961
multiplier=multiplier,
6062
shift=shift,
61-
scale32=output_dtype == torch.int32,
63+
scale32=scale_32,
6264
double_round=False,
6365
per_channel=False,
6466
input_unsigned=False,

backends/arm/operators/op_table.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,24 @@ def define_node(
3030
inputs: List[TosaArg],
3131
output: TosaArg,
3232
) -> None:
33-
assert node.name in self._exported_program.state_dict.keys() # type: ignore[union-attr]
34-
assert inputs[0].dtype == output.dtype == ts.DType.INT8
33+
if node.name not in self._exported_program.state_dict.keys(): # type: ignore[union-attr]
34+
raise RuntimeError(
35+
f"Did not find key {node.name} in state_dict {self._exported_program.state_dict.keys()}."
36+
)
37+
if inputs[0].dtype == ts.DType.INT8 and output.dtype != ts.DType.INT8:
38+
raise ValueError(f"Int8 tables need int8 output, got {output.dtype=}.")
39+
if inputs[0].dtype == ts.DType.INT16 and output.dtype != ts.DType.INT32:
40+
raise ValueError(f"Int16 tables need int32 output, got {output.dtype=}.")
41+
42+
if inputs[0].dtype not in (ts.DType.INT8, ts.DType.INT16):
43+
raise ValueError(
44+
f"TOSA.TABLE only supports int8 or int16 inputs, got {ts.DTypeNames[inputs[0]]}"
45+
)
46+
3547
table = self._exported_program.state_dict[node.name] # type: ignore[union-attr]
3648
table_attr = ts.TosaSerializerAttribute()
3749
table_attr.TableAttribute(np.array(table))
50+
3851
tosa_graph.addOperator(
3952
TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr
4053
)

0 commit comments

Comments
 (0)