Skip to content

Commit edab231

Browse files
authored
Arm backend: Add upsample_bilinear2d op (#10349)
1 parent 95c663e commit edab231

File tree

6 files changed

+352
-0
lines changed

6 files changed

+352
-0
lines changed

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def is_node_supported(
207207
exir_ops.edge.aten._log_softmax.default,
208208
exir_ops.edge.aten.sub.Tensor,
209209
exir_ops.edge.aten.tanh.default,
210+
exir_ops.edge.aten.upsample_bilinear2d.vec,
210211
exir_ops.edge.aten.upsample_nearest2d.vec,
211212
exir_ops.edge.aten.var.correction,
212213
exir_ops.edge.aten.var.dim,
@@ -365,6 +366,7 @@ def is_node_supported(
365366
exir_ops.edge.aten.sigmoid.default,
366367
exir_ops.edge.aten.sub.Tensor,
367368
exir_ops.edge.aten.tanh.default,
369+
exir_ops.edge.aten.upsample_bilinear2d.vec,
368370
exir_ops.edge.aten.upsample_nearest2d.vec,
369371
exir_ops.edge.aten.gelu.default,
370372
):

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
op_to_copy,
4747
op_to_dim_order_copy,
4848
op_transpose,
49+
op_upsample_bilinear2d,
4950
op_upsample_nearest2d,
5051
op_view,
5152
op_where,
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
from typing import List
8+
9+
import torch
10+
11+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
12+
13+
from executorch.backends.arm.operators.node_visitor import (
14+
NodeVisitor,
15+
register_node_visitor,
16+
)
17+
from executorch.backends.arm.tosa_mapping import TosaArg
18+
from executorch.backends.arm.tosa_quant_utils import build_rescale
19+
from executorch.backends.arm.tosa_utils import get_resize_parameters, tosa_shape
20+
from tosa_tools.v0_80.tosa.ResizeMode import ResizeMode # type: ignore
21+
22+
23+
@register_node_visitor
24+
class UpsampleBilinear2dVisitor_0_80(NodeVisitor):
25+
target = "aten.upsample_bilinear2d.vec"
26+
27+
def __init__(self, *args):
28+
super().__init__(*args)
29+
30+
def define_node(
31+
self,
32+
node: torch.fx.Node,
33+
tosa_graph: ts.TosaSerializer,
34+
inputs: List[TosaArg],
35+
output: TosaArg,
36+
) -> None:
37+
assert (
38+
inputs[0].shape is not None and output.shape is not None
39+
), "Only static shapes are supported"
40+
41+
input_dtype = inputs[0].dtype
42+
43+
# tosa_shape output is NHWC, take HW
44+
input_size_yx = torch.tensor(
45+
tosa_shape(inputs[0].shape, inputs[0].dim_order)[1:3]
46+
)
47+
# Ignore scale and size parameters, directly use the output size as
48+
# we only support static shapes currently
49+
output_size_yx = torch.tensor(tosa_shape(output.shape, output.dim_order)[1:3])
50+
51+
scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters(
52+
input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=True
53+
)
54+
55+
def in_int16_range(x):
56+
return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1)
57+
58+
assert in_int16_range(scale_n_yx)
59+
assert in_int16_range(scale_d_yx)
60+
assert in_int16_range(border_yx)
61+
62+
attr = ts.TosaSerializerAttribute()
63+
attr.ResizeAttribute(
64+
scale=[scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]],
65+
offset=offset_yx.tolist(),
66+
border=border_yx.tolist(),
67+
mode=ResizeMode.BILINEAR,
68+
)
69+
70+
if input_dtype == output.dtype == ts.DType.FP32:
71+
tosa_graph.addOperator(
72+
ts.TosaOp.Op().RESIZE, [inputs[0].name], [output.name], attr
73+
)
74+
return
75+
elif input_dtype == output.dtype == ts.DType.INT8:
76+
intermediate = tosa_graph.addIntermediate(
77+
tosa_shape(output.shape, output.dim_order), ts.DType.INT32
78+
)
79+
80+
tosa_graph.addOperator(
81+
ts.TosaOp.Op().RESIZE, [inputs[0].name], [intermediate.name], attr
82+
)
83+
84+
final_output_scale = float(1 / (scale_n_yx[0] * scale_n_yx[1]))
85+
86+
build_rescale(
87+
tosa_fb=tosa_graph,
88+
scale=[final_output_scale],
89+
input_node=intermediate,
90+
output_name=output.name,
91+
output_type=ts.DType.INT8,
92+
output_shape=output.shape,
93+
input_zp=0,
94+
output_zp=0,
95+
is_double_round=False,
96+
)
97+
else:
98+
raise ValueError(
99+
"Input/output dtype not in {float32, int8}: {input_dtype=} {output.dtype=}"
100+
)

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def _match_pattern(
215215
torch.ops.aten.flip.default,
216216
torch.ops.aten.chunk.default,
217217
torch.ops.aten.contiguous.default,
218+
torch.ops.aten.upsample_bilinear2d.vec,
218219
torch.ops.aten.upsample_nearest2d.vec,
219220
torch.ops.aten.pad.default,
220221
torch.ops.aten.amax.default,
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Optional, Tuple
7+
8+
import torch
9+
from executorch.backends.arm.test import common
10+
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU85PipelineBI,
13+
TosaPipelineBI,
14+
TosaPipelineMI,
15+
)
16+
17+
aten_op = "torch.ops.aten.upsample_bilinear2d.vec"
18+
input_t1 = Tuple[torch.Tensor] # Input x
19+
20+
test_data_suite_tosa = {
21+
# (test_name, test_data, size, scale_factor, compare_outputs)
22+
"rand_double_scale": (torch.rand(2, 4, 8, 3), None, 2.0, True),
23+
"rand_double_scale_one_dim": (torch.rand(2, 4, 8, 3), None, (1.0, 2.0), True),
24+
"rand_double_size": (torch.rand(2, 4, 8, 3), (16, 6), None, True),
25+
"rand_one_double_scale": (torch.rand(2, 4, 1, 1), None, 2.0, True),
26+
"rand_one_double_size": (torch.rand(2, 4, 1, 1), (2, 2), None, True),
27+
"rand_one_same_scale": (torch.rand(2, 4, 1, 1), None, 1.0, True),
28+
"rand_one_same_size": (torch.rand(2, 4, 1, 1), (1, 1), None, True),
29+
# Can't compare outputs as the rounding when selecting the nearest pixel is
30+
# different between PyTorch and TOSA. Just check the legalization went well.
31+
# TODO Improve the test infrastructure to support more in depth verification
32+
# of the TOSA legalization results.
33+
"rand_half_scale": (torch.rand(2, 4, 8, 6), None, 0.5, False),
34+
"rand_half_size": (torch.rand(2, 4, 8, 6), (4, 3), None, False),
35+
"rand_one_and_half_scale": (torch.rand(2, 4, 8, 3), None, 1.5, False),
36+
"rand_one_and_half_size": (torch.rand(2, 4, 8, 3), (12, 4), None, False),
37+
# Use randn for a bunch of tests to get random numbers from the
38+
# normal distribution where negative is also a possibilty
39+
"randn_double_scale_negative": (torch.randn(2, 4, 8, 3), None, 2.0, True),
40+
"randn_double_scale_one_dim_negative": (
41+
torch.randn(2, 4, 8, 3),
42+
None,
43+
(1.0, 2.0),
44+
True,
45+
),
46+
"randn_double_size_negative": (torch.randn(2, 4, 8, 3), (16, 6), None, True),
47+
"randn_one_double_scale_negative": (torch.randn(2, 4, 1, 1), None, 2.0, True),
48+
"randn_one_double_size_negative": (torch.randn(2, 4, 1, 1), (2, 2), None, True),
49+
"randn_one_same_scale_negative": (torch.randn(2, 4, 1, 1), None, 1.0, True),
50+
"randn_one_same_size_negative": (torch.randn(2, 4, 1, 1), (1, 1), None, True),
51+
}
52+
53+
test_data_suite_Uxx = {
54+
"rand_half_scale": (torch.rand(2, 4, 8, 6), None, 0.5, False),
55+
"rand_half_size": (torch.rand(2, 4, 8, 6), (4, 3), None, False),
56+
"rand_one_and_half_scale": (torch.rand(2, 4, 8, 3), None, 1.5, False),
57+
"rand_one_and_half_size": (torch.rand(2, 4, 8, 3), (12, 4), None, False),
58+
}
59+
60+
61+
class UpsamplingBilinear2d(torch.nn.Module):
62+
def __init__(
63+
self,
64+
size: Optional[Tuple[int]],
65+
scale_factor: Optional[float | Tuple[float]],
66+
):
67+
super().__init__()
68+
self.upsample = torch.nn.UpsamplingBilinear2d( # noqa: TOR101
69+
size=size, scale_factor=scale_factor
70+
)
71+
72+
def forward(self, x):
73+
return self.upsample(x)
74+
75+
76+
class Upsample(torch.nn.Module):
77+
def __init__(
78+
self,
79+
size: Optional[Tuple[int]],
80+
scale_factor: Optional[float | Tuple[float]],
81+
):
82+
super().__init__()
83+
self.upsample = torch.nn.Upsample(
84+
size=size, scale_factor=scale_factor, mode="bilinear", align_corners=True
85+
)
86+
87+
def forward(self, x):
88+
return self.upsample(x)
89+
90+
91+
class Interpolate(torch.nn.Module):
92+
def __init__(
93+
self,
94+
size: Optional[Tuple[int]],
95+
scale_factor: Optional[float | Tuple[float]],
96+
):
97+
super().__init__()
98+
self.upsample = lambda x: torch.nn.functional.interpolate(
99+
x, size=size, scale_factor=scale_factor, mode="bilinear", align_corners=True
100+
)
101+
102+
def forward(self, x):
103+
return self.upsample(x)
104+
105+
106+
@common.parametrize("test_data", test_data_suite_tosa)
107+
def test_upsample_bilinear2d_vec_tosa_MI_UpsamplingBilinear2d(
108+
test_data: torch.Tensor,
109+
):
110+
test_data, size, scale_factor, compare_outputs = test_data
111+
112+
pipeline = TosaPipelineMI[input_t1](
113+
UpsamplingBilinear2d(size, scale_factor),
114+
(test_data,),
115+
aten_op,
116+
exir_op=[],
117+
)
118+
if not compare_outputs:
119+
pipeline.pop_stage(-1)
120+
pipeline.run()
121+
122+
123+
@common.parametrize("test_data", test_data_suite_tosa)
124+
def test_upsample_bilinear2d_vec_tosa_MI_Upsample(
125+
test_data: torch.Tensor,
126+
):
127+
test_data, size, scale_factor, compare_outputs = test_data
128+
129+
pipeline = TosaPipelineMI[input_t1](
130+
Upsample(size, scale_factor),
131+
(test_data,),
132+
aten_op,
133+
exir_op=[],
134+
)
135+
if not compare_outputs:
136+
pipeline.pop_stage(-1)
137+
138+
pipeline.run()
139+
140+
141+
@common.parametrize("test_data", test_data_suite_tosa)
142+
def test_upsample_bilinear2d_vec_tosa_MI_Interpolate(
143+
test_data: torch.Tensor,
144+
):
145+
test_data, size, scale_factor, compare_outputs = test_data
146+
147+
pipeline = TosaPipelineMI[input_t1](
148+
Interpolate(size, scale_factor),
149+
(test_data,),
150+
aten_op,
151+
exir_op=[],
152+
)
153+
if not compare_outputs:
154+
pipeline.pop_stage(-1)
155+
pipeline.run()
156+
157+
158+
@common.parametrize("test_data", test_data_suite_tosa)
159+
def test_upsample_bilinear2d_vec_tosa_BI_intropolate(
160+
test_data: torch.Tensor,
161+
):
162+
test_data, size, scale_factor, compare_outputs = test_data
163+
164+
pipeline = TosaPipelineBI[input_t1](
165+
UpsamplingBilinear2d(size, scale_factor),
166+
(test_data,),
167+
aten_op,
168+
exir_op=[],
169+
)
170+
if not compare_outputs:
171+
pipeline.pop_stage(-1)
172+
pipeline.run()
173+
174+
175+
@common.parametrize("test_data", test_data_suite_tosa)
176+
def test_upsample_bilinear2d_vec_tosa_BI_Upsample(
177+
test_data: torch.Tensor,
178+
):
179+
test_data, size, scale_factor, compare_outputs = test_data
180+
181+
pipeline = TosaPipelineBI[input_t1](
182+
Upsample(size, scale_factor),
183+
(test_data,),
184+
aten_op,
185+
exir_op=[],
186+
)
187+
if not compare_outputs:
188+
pipeline.pop_stage(-1)
189+
pipeline.run()
190+
191+
192+
@common.parametrize("test_data", test_data_suite_Uxx)
193+
@common.XfailIfNoCorstone320
194+
def test_upsample_bilinear2d_vec_U85_BI_Upsample(test_data: input_t1):
195+
test_data, size, scale_factor, compare_outputs = test_data
196+
197+
pipeline = EthosU85PipelineBI[input_t1](
198+
Upsample(size, scale_factor),
199+
(test_data,),
200+
aten_op,
201+
run_on_fvp=True,
202+
qtol=1,
203+
use_to_edge_transform_and_lower=True,
204+
)
205+
if not compare_outputs:
206+
pipeline.pop_stage(-1)
207+
pipeline.run()
208+
209+
210+
@common.parametrize("test_data", test_data_suite_Uxx)
211+
@common.XfailIfNoCorstone320
212+
def test_upsample_bilinear2d_vec_U85_BI_Interpolate(
213+
test_data: torch.Tensor,
214+
):
215+
test_data, size, scale_factor, compare_outputs = test_data
216+
217+
pipeline = EthosU85PipelineBI[input_t1](
218+
Interpolate(size, scale_factor),
219+
(test_data,),
220+
aten_op,
221+
run_on_fvp=True,
222+
qtol=1,
223+
use_to_edge_transform_and_lower=True,
224+
)
225+
if not compare_outputs:
226+
pipeline.pop_stage(-1)
227+
pipeline.run()
228+
229+
230+
@common.parametrize("test_data", test_data_suite_Uxx)
231+
@common.XfailIfNoCorstone320
232+
def test_upsample_bilinear2d_vec_U85_BI_UpsamplingBilinear2d(
233+
test_data: torch.Tensor,
234+
):
235+
test_data, size, scale_factor, compare_outputs = test_data
236+
237+
pipeline = EthosU85PipelineBI[input_t1](
238+
UpsamplingBilinear2d(size, scale_factor),
239+
(test_data,),
240+
aten_op,
241+
run_on_fvp=True,
242+
qtol=1,
243+
use_to_edge_transform_and_lower=True,
244+
)
245+
if not compare_outputs:
246+
pipeline.pop_stage(-1)
247+
pipeline.run()

backends/arm/tosa_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ def filter_fn(node: torch.fx.Node) -> bool:
170170

171171
ops_to_not_decompose = [
172172
torch.ops.aten.linear.default,
173+
torch.ops.aten.upsample_bilinear2d.vec,
173174
torch.ops.aten.upsample_nearest2d.vec,
174175
torch.ops.aten.eye.default,
175176
torch.ops.aten.linspace.default,

0 commit comments

Comments
 (0)