Skip to content

Commit 9d3a7be

Browse files
digantdesaifacebook-github-bot
authored andcommitted
Fix partition logic for force_fp32_dynamic_linear
Summary: As title Differential Revision: D69906370
1 parent 2fff01a commit 9d3a7be

File tree

3 files changed

+56
-4
lines changed

3 files changed

+56
-4
lines changed

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,11 @@ def _get_bias_deps(
210210
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
211211
) -> Tuple[bool, List[torch.fx.Node]]:
212212
gemm_deps = []
213+
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
214+
# if force force_fp32_dynamic_linear is enabled, then we
215+
# do not partition the weight node
216+
return (True, gemm_deps)
217+
213218
if len(node.all_input_nodes) > 2 and self.bias_idx is not None:
214219
bias_node = get_input_node(node, self.bias_idx)
215220
if bias_node:
@@ -477,7 +482,15 @@ def find_partition_args(input_node):
477482
node.args = old_args
478483
node.users = old_users
479484

480-
return valid_deps, list(set(deps) | set(src_partition.nodes))
485+
# When using force_fp32_dynamic_linear, we want to get_deps to overwrite the source partition nodes.
486+
# Else we want to be greedy.
487+
ret_deps = (
488+
list(set(deps) & set(src_partition.nodes))
489+
if self.force_fp32_dynamic_linear
490+
else list(set(deps) | set(src_partition.nodes))
491+
)
492+
493+
return valid_deps, ret_deps
481494

482495
def supported_precision_types(self):
483496
return [

backends/xnnpack/test/ops/test_linear.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@
3131
ToEdgeTransformAndLower,
3232
)
3333

34+
from torch.export.graph_signature import (
35+
ExportGraphSignature,
36+
InputKind,
37+
)
38+
3439
try:
3540
from torchao.quantization.quant_api import (
3641
int8_dynamic_activation_int4_weight,
@@ -871,3 +876,38 @@ def test_linear_qd8_as_fp32(self):
871876
"dequantize_per_channel.default": 1, # 1: weight
872877
},
873878
)
879+
880+
def test_linear_fp32_with_force_as_mm(self):
881+
def check_signature(signature:ExportGraphSignature, num_params:int=0):
882+
sign_params: int = 0
883+
input_specs = signature.input_specs
884+
for input_spec in input_specs:
885+
if input_spec.kind == InputKind.PARAMETER:
886+
sign_params += 1
887+
assert sign_params == num_params, f"Expected {num_params} params, got {sign_params} with force_flag={force_flag}, use_bias={use_bias}, legacy_mode={legacy_mode}"
888+
889+
for force_flag in (True, False):
890+
partitioner = XnnpackPartitioner(force_fp32_dynamic_linear=force_flag)
891+
weight_param = force_flag == True # weight param
892+
for use_bias in (True, False):
893+
bias_param = use_bias == True and force_flag == True # bias param
894+
for legacy_mode in (True, False):
895+
module = BaseLinear(in_size=8, input_channels=13, output_channels=17, use_bias=use_bias)
896+
inputs = module.get_inputs()
897+
tester = Tester(module, inputs).export()
898+
if legacy_mode:
899+
tester.to_edge()
900+
partitioner_stage = Partition(partitioner=partitioner)
901+
tester.partition(partition_stage=partitioner_stage)
902+
tester.check_not(["executorch_exir_dialects_edge__ops_aten_mm_default" if use_bias else "executorch_exir_dialects_edge__ops_aten_addmm_default"])
903+
else:
904+
to_edge_and_transform_stage = ToEdgeTransformAndLower(partitioners=[partitioner])
905+
tester.to_edge_transform_and_lower(to_edge_and_transform_stage=to_edge_and_transform_stage)
906+
tester.check_not(["executorch_exir_dialects_edge__ops_aten_linear_default"])
907+
908+
signature: ExportGraphSignature = tester.get_artifact().exported_program().graph_signature
909+
check_signature(signature, num_params=weight_param + bias_param)
910+
911+
tester.to_executorch()
912+
tester.serialize()
913+
tester.run_method_and_compare_outputs()

backends/xnnpack/test/ops/test_lstm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,8 @@ def test_fp32_lstm_force_dynamic_linear(self):
5454
)
5555
.check_not(["executorch_exir_dialects_edge__ops_aten_addmm_default"])
5656
# Weights are supplied as input to linears
57-
.check(["p_lstm_weight_hh_l0", "p_lstm_weight_ih_l0"])
58-
# Biases are owned by delegates
59-
.check_not(["p_lstm_bias"])
57+
# Biases are not owned by delegates when force_fp32_dynamic_linear is set
58+
.check(["p_lstm_weight_hh_l0", "p_lstm_weight_ih_l0", "p_lstm_bias"])
6059
.to_executorch()
6160
.serialize()
6261
.run_method_and_compare_outputs()

0 commit comments

Comments
 (0)