1
1
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2
- # All rights reserved.
3
2
#
4
3
# This source code is licensed under the BSD-style license found in the
5
4
# LICENSE file in the root directory of this source tree.
18
17
19
18
from executorch .exir .pass_base import ExportPass , PassResult
20
19
from torch .fx import GraphModule
20
+
21
21
from torch .library import impl , Library
22
22
23
23
lib = Library ("tosa" , "DEF" )
26
26
27
27
@impl (lib , "_table" )
28
28
def _table_impl (* args , ** kwargs ): # pyre-ignore
29
- return args [0 ]
29
+ in_dtype = args [0 ].dtype
30
+ if in_dtype == torch .int8 :
31
+ return args [0 ]
32
+ return args [0 ].to (dtype = torch .int32 )
30
33
31
34
32
35
class InsertTableOpsPass (ExportPass ):
@@ -59,29 +62,105 @@ def register_buffer(self, buffer_name: str, buffer: torch.Tensor) -> None:
59
62
"""
60
63
self .exported_program .state_dict [buffer_name ] = buffer
61
64
62
- def generate_table_values (
65
+ def generate_8bit_table_values (
63
66
self ,
64
67
torch_op : Callable [[torch .Tensor ], torch .Tensor ],
65
68
in_quantargs : QuantArgs ,
66
69
out_quantargs : QuantArgs ,
67
- ) -> torch .Tensor :
70
+ ) -> tuple [torch .Tensor , int ]:
71
+ """Compute LUT values for a INT8 TOSA.TABLE. Also returns 0 since no shifting is required after 8bit table.
72
+ The INT8 table is a simple 256 value 1-1 LUT.
73
+ """
74
+
68
75
def f (x : torch .Tensor ) -> torch .Tensor :
69
76
x = in_quantargs .dequantize_value (x )
70
77
x = torch_op (x )
71
78
return out_quantargs .quantize_value (x )
72
79
73
- input_dtype = in_quantargs .dtype
74
- steps = in_quantargs .qmax - in_quantargs .qmin + 1
75
- return f (
80
+ return (
81
+ f (
82
+ torch .linspace (
83
+ start = in_quantargs .qmin ,
84
+ end = in_quantargs .qmax ,
85
+ steps = 256 ,
86
+ # use torch.int64 to avoid overflow when dequantizing (subtracting zp).
87
+ # e.g. torch.tensor(-50, dtype=torch.int8) - 100 == torch.tensor(106, dtype=torch.int8)
88
+ dtype = torch .int64 ,
89
+ )
90
+ ).to (dtype = torch .int8 ),
91
+ 0 ,
92
+ )
93
+
94
+ def generate_16_bit_table_values (
95
+ self ,
96
+ torch_op : Callable [[torch .Tensor ], torch .Tensor ],
97
+ in_quantargs : QuantArgs ,
98
+ out_quantargs : QuantArgs ,
99
+ ) -> tuple [torch .Tensor , int ]:
100
+ """Compute LUT values for a INT16 TOSA.TABLE with 32 bit output.
101
+ In practice the output is 23 bits that should be interpreted as 16 'whole' bits and 7 fractional bits, see
102
+ the specification: https://www.mlplatform.org/tosa/tosa_spec.html#_table. This means that the output
103
+ will interpreted as 2**7=128 times too large unless accounted for by rescaling down the table output.
104
+
105
+ Quantization can be either int16 or int32 which means that the op output could be larger than the 23 bits from
106
+ the TOSA.TABLE output. In that case, we need to rescale up the output.
107
+
108
+ To handle this we need to:
109
+ 1) Make sure that our table values fit within 16 bits.
110
+ 2) Insert a rescale after the table to handle the x128 from the fractional bits and match the quantization.
111
+
112
+ The function returns rescale_lshift which says how much to rescale after the table. This value can negative.
113
+ """
114
+
115
+ def f (x : torch .Tensor ) -> torch .Tensor :
116
+ # Dont use the 7 LSBs.
117
+ x = in_quantargs .dequantize_value ((x & ~ 0x7F ))
118
+ x = torch_op (x )
119
+ return out_quantargs .quantize_value (x )
120
+
121
+ lut_values = f (
76
122
torch .linspace (
77
123
start = in_quantargs .qmin ,
78
- end = in_quantargs .qmax ,
79
- steps = steps ,
124
+ end = in_quantargs .qmax + 1 ,
125
+ steps = 513 ,
80
126
# use torch.int64 to avoid overflow when dequantizing (subtracting zp).
81
127
# e.g. torch.tensor(-50, dtype=torch.int8) - 100 == torch.tensor(106, dtype=torch.int8)
82
128
dtype = torch .int64 ,
83
129
)
84
- ).to (dtype = input_dtype )
130
+ )
131
+ # Calculate how much we need to shift table values to fit in 16 signed bits
132
+ # ceil(log2(max absolute table value)) + 1 bit for signedness - 16
133
+ # Example:
134
+ # Max value in the table is 70 000. We want to fit it in 16 signed bits.
135
+ # 70 000=0b10001000101110000 (17 digits) has ceil(log2(70 000)) = ceil(16.095) = 17 bits.
136
+ # If we shift it 17-16=1 bit, we do get 16 bits (0b1000100010111000),
137
+ # but due to signedness this is a negative number! So we need to shift it one more bit.
138
+ # Note: for out_quantargs.dtype=torch.int16, rshift == 0 and rescale_lshift = -7.
139
+ rshift = int (torch .ceil (torch .log2 (lut_values .abs ().max ()))) + 1 - 16
140
+ # The 7 fractional bits are equivalent to a lshift of 7, so subtract 7 from the lshift we do.
141
+ rescale_lshift = rshift - 7
142
+ lut_values = lut_values >> rshift
143
+ return lut_values .to (dtype = torch .int16 ), rescale_lshift
144
+
145
+ def generate_table_values (
146
+ self ,
147
+ torch_op : Callable [[torch .Tensor ], torch .Tensor ],
148
+ in_quantargs : QuantArgs ,
149
+ out_quantargs : QuantArgs ,
150
+ ) -> tuple [torch .Tensor , int ]:
151
+ match out_quantargs .dtype :
152
+ case torch .int8 :
153
+ return self .generate_8bit_table_values (
154
+ torch_op , in_quantargs , out_quantargs
155
+ )
156
+ case torch .int16 | torch .int32 :
157
+ return self .generate_16_bit_table_values (
158
+ torch_op , in_quantargs , out_quantargs
159
+ )
160
+ case _:
161
+ raise ValueError (
162
+ f"Unsupported output dtype for table: { out_quantargs .dtype } "
163
+ )
85
164
86
165
def call (self , graph_module : GraphModule ) -> PassResult :
87
166
modified = False
@@ -100,10 +179,12 @@ def call(self, graph_module: GraphModule) -> PassResult:
100
179
op_target = torch .ops .tosa ._table .default ,
101
180
args = (node .args [0 ],),
102
181
)
182
+ output_node = table_node
103
183
assert len (input_qparams ) == 1
104
184
assert len (output_qparams ) == 1
105
- # Generate table buffer
106
- buffer = self .generate_table_values (
185
+
186
+ # Generate table buffer and how much to lshift the table output.
187
+ buffer , lshift = self .generate_table_values (
107
188
torch_op = self .table_ops [node .target ],
108
189
in_quantargs = input_qparams [0 ],
109
190
out_quantargs = output_qparams [0 ],
@@ -114,10 +195,20 @@ def call(self, graph_module: GraphModule) -> PassResult:
114
195
self .register_buffer (
115
196
buffer_name = table_node .name .replace ("_default" , "" ), buffer = buffer
116
197
)
117
- node .replace_all_uses_with (table_node )
198
+
199
+ if lshift != 0 :
200
+ scale = 2.0 ** lshift
201
+ rescale_node = create_node (
202
+ graph = graph_module .graph ,
203
+ op_target = torch .ops .tosa ._rescale .default ,
204
+ args = (table_node , output_qparams [0 ].dtype , scale , 0 , 0 ),
205
+ )
206
+ output_node = rescale_node
207
+
208
+ node .replace_all_uses_with (output_node )
118
209
graph_module .graph .erase_node (node )
119
- table_node .meta ["input_qparams" ] = input_qparams
120
- table_node .meta ["output_qparams" ] = output_qparams
210
+ output_node .meta ["input_qparams" ] = input_qparams
211
+ output_node .meta ["output_qparams" ] = output_qparams
121
212
modified = True
122
213
123
214
if modified :
0 commit comments