Skip to content

Commit 7a7fe89

Browse files
digantdesaifacebook-github-bot
authored andcommitted
Fix partition logic for force_fp32_dynamic_linear (#8596)
Summary: As title Reviewed By: mcr229 Differential Revision: D69906370
1 parent 9de9ed4 commit 7a7fe89

File tree

3 files changed

+89
-4
lines changed

3 files changed

+89
-4
lines changed

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,11 @@ def _get_bias_deps(
210210
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
211211
) -> Tuple[bool, List[torch.fx.Node]]:
212212
gemm_deps = []
213+
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
214+
# if force force_fp32_dynamic_linear is enabled, then we
215+
# do not partition the weight node
216+
return (True, gemm_deps)
217+
213218
if len(node.all_input_nodes) > 2 and self.bias_idx is not None:
214219
bias_node = get_input_node(node, self.bias_idx)
215220
if bias_node:
@@ -477,7 +482,15 @@ def find_partition_args(input_node):
477482
node.args = old_args
478483
node.users = old_users
479484

480-
return valid_deps, list(set(deps) | set(src_partition.nodes))
485+
# When using force_fp32_dynamic_linear, we want to get_deps to overwrite the source partition nodes.
486+
# Else we want to be greedy.
487+
ret_deps = (
488+
list(set(deps) & set(src_partition.nodes))
489+
if self.force_fp32_dynamic_linear
490+
else list(set(deps) | set(src_partition.nodes))
491+
)
492+
493+
return valid_deps, ret_deps
481494

482495
def supported_precision_types(self):
483496
return [

backends/xnnpack/test/ops/test_linear.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@
3131
ToEdgeTransformAndLower,
3232
)
3333

34+
from torch.export.graph_signature import (
35+
ExportGraphSignature,
36+
InputKind,
37+
)
38+
3439
try:
3540
from torchao.quantization.quant_api import (
3641
int8_dynamic_activation_int4_weight,
@@ -871,3 +876,71 @@ def test_linear_qd8_as_fp32(self):
871876
"dequantize_per_channel.default": 1, # 1: weight
872877
},
873878
)
879+
880+
def test_linear_fp32_with_force_as_mm(self):
881+
def check_signature(
882+
signature: ExportGraphSignature,
883+
force_flag: bool,
884+
use_bias: bool,
885+
legacy_mode: bool,
886+
):
887+
num_params = 0
888+
if force_flag:
889+
num_params = 1 # weight_param
890+
if use_bias:
891+
num_params += 1 # bias_param
892+
sign_params: int = 0
893+
input_specs = signature.input_specs
894+
for input_spec in input_specs:
895+
if input_spec.kind == InputKind.PARAMETER:
896+
sign_params += 1
897+
assert (
898+
sign_params == num_params
899+
), f"Expected {num_params} params, got {sign_params} with force_flag={force_flag}, use_bias={use_bias}, legacy_mode={legacy_mode}"
900+
901+
for force_flag in (True, False):
902+
for use_bias in (True, False):
903+
for legacy_mode in (True, False):
904+
module = BaseLinear(
905+
in_size=8,
906+
input_channels=13,
907+
output_channels=17,
908+
use_bias=use_bias,
909+
)
910+
inputs = module.get_inputs()
911+
tester = Tester(module, inputs).export()
912+
partitioner = XnnpackPartitioner(
913+
force_fp32_dynamic_linear=force_flag
914+
)
915+
if legacy_mode:
916+
tester.to_edge()
917+
partitioner_stage = Partition(partitioner=partitioner)
918+
tester.partition(partition_stage=partitioner_stage)
919+
tester.check_not(
920+
[
921+
(
922+
"executorch_exir_dialects_edge__ops_aten_mm_default"
923+
if use_bias
924+
else "executorch_exir_dialects_edge__ops_aten_addmm_default"
925+
)
926+
]
927+
)
928+
else:
929+
to_edge_and_transform_stage = ToEdgeTransformAndLower(
930+
partitioners=[partitioner]
931+
)
932+
tester.to_edge_transform_and_lower(
933+
to_edge_and_transform_stage=to_edge_and_transform_stage
934+
)
935+
tester.check_not(
936+
["executorch_exir_dialects_edge__ops_aten_linear_default"]
937+
)
938+
939+
signature: ExportGraphSignature = (
940+
tester.get_artifact().exported_program().graph_signature
941+
)
942+
check_signature(signature, force_flag, use_bias, legacy_mode)
943+
944+
tester.to_executorch()
945+
tester.serialize()
946+
tester.run_method_and_compare_outputs()

backends/xnnpack/test/ops/test_lstm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,8 @@ def test_fp32_lstm_force_dynamic_linear(self):
5454
)
5555
.check_not(["executorch_exir_dialects_edge__ops_aten_addmm_default"])
5656
# Weights are supplied as input to linears
57-
.check(["p_lstm_weight_hh_l0", "p_lstm_weight_ih_l0"])
58-
# Biases are owned by delegates
59-
.check_not(["p_lstm_bias"])
57+
# Biases are not owned by delegates when force_fp32_dynamic_linear is set
58+
.check(["p_lstm_weight_hh_l0", "p_lstm_weight_ih_l0", "p_lstm_bias"])
6059
.to_executorch()
6160
.serialize()
6261
.run_method_and_compare_outputs()

0 commit comments

Comments
 (0)