Skip to content

Commit 80d5e5a

Browse files
authored
Fix partition logic for force_fp32_dynamic_linear
Differential Revision: D69906370 Pull Request resolved: #8596
1 parent 5e4b75b commit 80d5e5a

File tree

3 files changed

+86
-4
lines changed

3 files changed

+86
-4
lines changed

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,11 @@ def _get_bias_deps(
210210
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
211211
) -> Tuple[bool, List[torch.fx.Node]]:
212212
gemm_deps = []
213+
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
214+
# if force force_fp32_dynamic_linear is enabled, then we
215+
# do not partition the weight node
216+
return (True, gemm_deps)
217+
213218
if len(node.all_input_nodes) > 2 and self.bias_idx is not None:
214219
bias_node = get_input_node(node, self.bias_idx)
215220
if bias_node:
@@ -477,7 +482,15 @@ def find_partition_args(input_node):
477482
node.args = old_args
478483
node.users = old_users
479484

480-
return valid_deps, list(set(deps) | set(src_partition.nodes))
485+
# When using force_fp32_dynamic_linear, we want to get_deps to overwrite the source partition nodes.
486+
# Else we want to be greedy.
487+
ret_deps = (
488+
list(set(deps) & set(src_partition.nodes))
489+
if self.force_fp32_dynamic_linear
490+
else list(set(deps) | set(src_partition.nodes))
491+
)
492+
493+
return valid_deps, ret_deps
481494

482495
def supported_precision_types(self):
483496
return [

backends/xnnpack/test/ops/test_linear.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
ToEdgeTransformAndLower,
3232
)
3333

34+
from torch.export.graph_signature import ExportGraphSignature, InputKind
35+
3436
try:
3537
from torchao.quantization.quant_api import (
3638
int8_dynamic_activation_int4_weight,
@@ -871,3 +873,71 @@ def test_linear_qd8_as_fp32(self):
871873
"dequantize_per_channel.default": 1, # 1: weight
872874
},
873875
)
876+
877+
def test_linear_fp32_with_force_as_mm(self):
878+
def check_signature(
879+
signature: ExportGraphSignature,
880+
force_flag: bool,
881+
use_bias: bool,
882+
legacy_mode: bool,
883+
):
884+
num_params = 0
885+
if force_flag:
886+
num_params = 1 # weight_param
887+
if use_bias:
888+
num_params += 1 # bias_param
889+
sign_params: int = 0
890+
input_specs = signature.input_specs
891+
for input_spec in input_specs:
892+
if input_spec.kind == InputKind.PARAMETER:
893+
sign_params += 1
894+
assert (
895+
sign_params == num_params
896+
), f"Expected {num_params} params, got {sign_params} with force_flag={force_flag}, use_bias={use_bias}, legacy_mode={legacy_mode}"
897+
898+
for force_flag in (True, False):
899+
for use_bias in (True, False):
900+
for legacy_mode in (True, False):
901+
module = BaseLinear(
902+
in_size=8,
903+
input_channels=13,
904+
output_channels=17,
905+
use_bias=use_bias,
906+
)
907+
inputs = module.get_inputs()
908+
tester = Tester(module, inputs).export()
909+
partitioner = XnnpackPartitioner(
910+
force_fp32_dynamic_linear=force_flag
911+
)
912+
if legacy_mode:
913+
tester.to_edge()
914+
partitioner_stage = Partition(partitioner=partitioner)
915+
tester.partition(partition_stage=partitioner_stage)
916+
tester.check_not(
917+
[
918+
(
919+
"executorch_exir_dialects_edge__ops_aten_mm_default"
920+
if use_bias
921+
else "executorch_exir_dialects_edge__ops_aten_addmm_default"
922+
)
923+
]
924+
)
925+
else:
926+
to_edge_and_transform_stage = ToEdgeTransformAndLower(
927+
partitioners=[partitioner]
928+
)
929+
tester.to_edge_transform_and_lower(
930+
to_edge_and_transform_stage=to_edge_and_transform_stage
931+
)
932+
tester.check_not(
933+
["executorch_exir_dialects_edge__ops_aten_linear_default"]
934+
)
935+
936+
signature: ExportGraphSignature = (
937+
tester.get_artifact().exported_program().graph_signature
938+
)
939+
check_signature(signature, force_flag, use_bias, legacy_mode)
940+
941+
tester.to_executorch()
942+
tester.serialize()
943+
tester.run_method_and_compare_outputs()

backends/xnnpack/test/ops/test_lstm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,8 @@ def test_fp32_lstm_force_dynamic_linear(self):
5454
)
5555
.check_not(["executorch_exir_dialects_edge__ops_aten_addmm_default"])
5656
# Weights are supplied as input to linears
57-
.check(["p_lstm_weight_hh_l0", "p_lstm_weight_ih_l0"])
58-
# Biases are owned by delegates
59-
.check_not(["p_lstm_bias"])
57+
# Biases are not owned by delegates when force_fp32_dynamic_linear is set
58+
.check(["p_lstm_weight_hh_l0", "p_lstm_weight_ih_l0", "p_lstm_bias"])
6059
.to_executorch()
6160
.serialize()
6261
.run_method_and_compare_outputs()

0 commit comments

Comments
 (0)