Skip to content

Commit f863fe4

Browse files
authored
Sub, Sqrt, Pad
Differential Revision: D60492340 Pull Request resolved: #4533
1 parent 6607c7d commit f863fe4

File tree

5 files changed

+36
-38
lines changed

5 files changed

+36
-38
lines changed

backends/xnnpack/partition/config/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
CatConfig,
2121
CeilConfig,
2222
ClampConfig,
23+
ConstantPadConfig,
2324
DeQuantizedPerTensorConfig,
2425
DivConfig,
2526
FloorConfig,
@@ -40,6 +41,8 @@
4041
SigmoidConfig,
4142
SliceCopyConfig,
4243
SoftmaxConfig,
44+
SquareRootConfig,
45+
SubConfig,
4346
UpsampleBilinear2dConfig,
4447
)
4548
from executorch.backends.xnnpack.partition.config.node_configs import (
@@ -55,6 +58,7 @@
5558
# GEMM-like Configs
5659
AddmmConfig,
5760
LinearConfig,
61+
ConstantPadConfig,
5862
ConvolutionConfig,
5963
# BatchNorm Config
6064
BatchNormConfig,
@@ -82,6 +86,8 @@
8286
SoftmaxConfig,
8387
SigmoidConfig,
8488
SliceCopyConfig,
89+
SquareRootConfig,
90+
SubConfig,
8591
PermuteConfig,
8692
# EluConfig, # Waiting for PyTorch Pin Update
8793
ReLUConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,3 +382,24 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
382382

383383
def supported_precision_types(self) -> List[ConfigPrecisionType]:
384384
return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]
385+
386+
387+
class SquareRootConfig(GenericNodePartitionerConfig):
388+
target_name = "sqrt.default"
389+
390+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
391+
return [ConfigPrecisionType.FP32]
392+
393+
394+
class ConstantPadConfig(GenericNodePartitionerConfig):
395+
target_name = "constant_pad_nd.default"
396+
397+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
398+
return [ConfigPrecisionType.FP32]
399+
400+
401+
class SubConfig(GenericNodePartitionerConfig):
402+
target_name = "sub.Tensor"
403+
404+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
405+
return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]

backends/xnnpack/test/ops/sqrt.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@ def _test_sqrt(self, inputs):
2525
Tester(self.Sqrt(), inputs)
2626
.export()
2727
.check_count({"torch.ops.aten.sqrt.default": 1})
28-
.to_edge()
29-
.check_count({"executorch_exir_dialects_edge__ops_aten_sqrt_default": 1})
30-
.partition()
28+
.to_edge_transform_and_lower()
3129
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
3230
.check_not(["executorch_exir_dialects_edge__ops_aten_sqrt_default"])
3331
.to_executorch()

backends/xnnpack/test/ops/static_constant_pad.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,7 @@ def _test_static_constant_pad_functional(self, inputs):
8888
Tester(self.StaticConstantPadFunctional(), inputs)
8989
.export()
9090
.check_count({"torch.ops.aten.pad.default": 8})
91-
.to_edge()
92-
.check_count(
93-
{"executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default": 8}
94-
)
95-
.partition()
91+
.to_edge_transform_and_lower()
9692
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
9793
.check_not(
9894
["executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default"]
@@ -139,11 +135,7 @@ def forward(self, x):
139135
.export()
140136
.check_count({"torch.ops.aten.pad.default": 1})
141137
.check(["torch.ops.quantized_decomposed"])
142-
.to_edge()
143-
.check_count(
144-
{"executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default": 1}
145-
)
146-
.partition()
138+
.to_edge_transform_and_lower()
147139
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
148140
.check_not(
149141
[
@@ -164,11 +156,7 @@ def test_qs8_static_constant_pad_2d(self):
164156
.export()
165157
.check_count({"torch.ops.aten.pad.default": 1})
166158
.check(["torch.ops.quantized_decomposed"])
167-
.to_edge()
168-
.check_count(
169-
{"executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default": 1}
170-
)
171-
.partition()
159+
.to_edge_transform_and_lower()
172160
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
173161
.check_not(
174162
[

backends/xnnpack/test/ops/sub.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,7 @@ def _test_sub(self, inputs):
3232
Tester(self.Sub(), inputs)
3333
.export()
3434
.check_count({"torch.ops.aten.sub.Tensor": 1})
35-
.to_edge()
36-
.check_count({"executorch_exir_dialects_edge__ops_aten_sub_Tensor": 1})
37-
.partition()
35+
.to_edge_transform_and_lower()
3836
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
3937
.check_not(["executorch_exir_dialects_edge__ops_aten_sub_Tensor"])
4038
.to_executorch()
@@ -62,9 +60,7 @@ def _test_qs8_sub(self):
6260
.export()
6361
.check_count({"torch.ops.aten.sub.Tensor": 1})
6462
.check(["torch.ops.quantized_decomposed"])
65-
.to_edge()
66-
.check_count({"executorch_exir_dialects_edge__ops_aten_sub_Tensor": 1})
67-
.partition()
63+
.to_edge_transform_and_lower()
6864
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
6965
.check_not(
7066
[
@@ -86,9 +82,7 @@ def _test_qs8_sub2(self):
8682
.export()
8783
.check_count({"torch.ops.aten.sub.Tensor": 1})
8884
.check(["torch.ops.quantized_decomposed"])
89-
.to_edge()
90-
.check_count({"executorch_exir_dialects_edge__ops_aten_sub_Tensor": 1})
91-
.partition()
85+
.to_edge_transform_and_lower()
9286
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
9387
.check_not(
9488
[
@@ -110,9 +104,7 @@ def _test_qs8_sub3(self):
110104
.export()
111105
.check_count({"torch.ops.aten.sub.Tensor": 1})
112106
.check(["torch.ops.quantized_decomposed"])
113-
.to_edge()
114-
.check_count({"executorch_exir_dialects_edge__ops_aten_sub_Tensor": 1})
115-
.partition()
107+
.to_edge_transform_and_lower()
116108
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
117109
.check_not(
118110
[
@@ -144,14 +136,7 @@ def forward(self, x, y):
144136
}
145137
)
146138
.check(["torch.ops.quantized_decomposed"])
147-
.to_edge()
148-
.check_count(
149-
{
150-
"executorch_exir_dialects_edge__ops_aten_sub_Tensor": 1,
151-
"executorch_exir_dialects_edge__ops_aten_relu_default": 1,
152-
}
153-
)
154-
.partition()
139+
.to_edge_transform_and_lower()
155140
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
156141
.check_not(
157142
[

0 commit comments

Comments
 (0)