Skip to content

Commit 766d42c

Browse files
authored
Merge branch 'main' into set_TAs_correctly_depending_on_mem_mode
2 parents bc433a3 + e37129d commit 766d42c

File tree

39 files changed

+401
-1329
lines changed

39 files changed

+401
-1329
lines changed

.ci/scripts/test_model.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,14 @@ test_model() {
100100
rm "./${MODEL_NAME}.pte"
101101
return # Skip running with portable executor runnner since portable doesn't support Qwen's biased linears.
102102
fi
103+
if [[ "${MODEL_NAME}" == "phi4_mini" ]]; then
104+
# Install requirements for export_llama
105+
bash examples/models/llama/install_requirements.sh
106+
# Test export_llama script: python3 -m examples.models.llama.export_llama.
107+
"${PYTHON_EXECUTABLE}" -m examples.models.llama.export_llama --model "${MODEL_NAME}" -c examples/models/llama/params/demo_rand_params.pth -p examples/models/phi-4-mini/config.json
108+
run_portable_executor_runner
109+
rm "./${MODEL_NAME}.pte"
110+
fi
103111

104112
# Export a basic .pte and run the model.
105113
"${PYTHON_EXECUTABLE}" -m examples.portable.scripts.export --model_name="${MODEL_NAME}" "${STRICT}"

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ jobs:
229229
# see if we can import the module successfully
230230
${CONDA_RUN} python -c "from executorch.extension.pybindings import portable_lib; print('success!')"
231231
232-
test-static-llama-ane:
232+
test-static-llama-ane:
233233
name: test-static-llama-ane
234234
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
235235
with:

backends/cadence/aot/pass_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,16 @@ def count_node(graph_module: torch.fx.GraphModule, target: torch.fx.node.Target)
104104
return total
105105

106106

107+
def op_counts_match(
108+
graph_module: torch.fx.GraphModule,
109+
expected_op_counts: dict[EdgeOpOverload, int],
110+
) -> bool:
111+
for op, count in expected_op_counts.items():
112+
if count_node(graph_module, op) != count:
113+
return False
114+
return True
115+
116+
107117
# Testing utils
108118
# Return the compute/function nodes in the graph
109119
def get_compute_nodes_in_gm(graph_module: torch.fx.GraphModule) -> List[torch.fx.Node]:

backends/cadence/aot/remove_ops.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
3434
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
3535
from executorch.exir.dialects._ops import ops as exir_ops
36-
from executorch.exir.dialects.edge._ops import EdgeOpOverload
36+
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
3737
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
3838
from executorch.exir.pass_manager import PassManager, PassType
3939
from executorch.exir.passes import dead_code_elimination_pass
@@ -745,6 +745,68 @@ def permute_shape(
745745
return [shape[p] for p in permute_dims]
746746

747747

748+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
749+
class RemoveBranchedQuantDequant(ExportPass):
750+
"""
751+
This pass looks for adjacent quant and dequant nodes with identical
752+
parameters, where the quant node has other users in addition to the
753+
dequant. The quant and dequant pair would be removed by the
754+
FuseQuantDequantToRequantizePass if not for the multiple users. This pass
755+
removes just the dequant node by connecting it to the quant's parent node
756+
"""
757+
758+
quantize_op_packets: set[EdgeOpOverloadPacket] = {
759+
exir_ops.edge.cadence.quantize_per_tensor,
760+
exir_ops.edge.quantized_decomposed.quantize_per_tensor,
761+
}
762+
dequantize_op_packets: set[EdgeOpOverloadPacket] = {
763+
exir_ops.edge.cadence.dequantize_per_tensor,
764+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor,
765+
}
766+
767+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
768+
self.remove_branched(
769+
graph_module, self.quantize_op_packets, self.dequantize_op_packets
770+
)
771+
self.remove_branched(
772+
graph_module, self.dequantize_op_packets, self.quantize_op_packets
773+
)
774+
775+
graph_module.graph.eliminate_dead_code()
776+
result = super().call(graph_module)
777+
return result
778+
779+
def remove_branched(
780+
self,
781+
graph_module: torch.fx.GraphModule,
782+
producer_pkts: set[EdgeOpOverloadPacket],
783+
consumer_pkts: set[EdgeOpOverloadPacket],
784+
) -> None:
785+
for node in graph_module.graph.nodes:
786+
if (
787+
node.op != "call_function"
788+
or not isinstance(node.target, EdgeOpOverload)
789+
or get_edge_overload_packet(node.target) not in producer_pkts
790+
):
791+
continue
792+
793+
if len(node.users) < 2:
794+
continue
795+
796+
for user in node.users:
797+
if (
798+
not isinstance(user.target, EdgeOpOverload)
799+
or get_edge_overload_packet(user.target) not in consumer_pkts
800+
):
801+
continue
802+
803+
# check qparams match
804+
if node.args[1:] != user.args[1:]:
805+
continue
806+
807+
user.replace_all_uses_with(node.args[0])
808+
809+
748810
# The following class consolidates functions to remove ops that are redundant
749811
# in Jarvis. Currently, each function in this class iterates over each node of
750812
# the graph module once. In future, we could consolidate them into a monolithic
@@ -765,4 +827,5 @@ class CadenceRemoveNops:
765827
RemoveNopMulOpPass,
766828
RemoveNopAddOpPass,
767829
RemoveNopLinalgVectorNormOpPass,
830+
RemoveBranchedQuantDequant,
768831
]

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
FuseTransposeOpPairsPass,
2121
)
2222
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
23-
from executorch.backends.cadence.aot.pass_utils import count_node
23+
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
2424
from executorch.exir.dialects._ops import ops as exir_ops
2525
from executorch.exir.dialects.edge._ops import EdgeOpOverload
2626
from torch import nn
@@ -32,8 +32,7 @@ def check_op_counts(
3232
graph_module: torch.fx.GraphModule,
3333
expected_op_counts: dict[EdgeOpOverload, int],
3434
) -> None:
35-
for op, count in expected_op_counts.items():
36-
self.assertEqual(count_node(graph_module, op), count)
35+
self.assertTrue(op_counts_match(graph_module, expected_op_counts))
3736

3837

3938
class TestFusionPasses(TestFusionPassesBase):

backends/cadence/aot/tests/test_remove_ops_passes.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717
from executorch.backends.cadence.aot import compiler
1818
from executorch.backends.cadence.aot.compiler import export_to_edge
1919

20-
from executorch.backends.cadence.aot.pass_utils import count_node
20+
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
2121
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer
2222
from executorch.backends.cadence.aot.remove_ops import (
2323
RemoveAliasCopyOpPass,
24+
RemoveBranchedQuantDequant,
2425
RemoveCloneOpPass,
2526
RemoveContiguousOpPass,
2627
RemoveDetachCopyPass,
@@ -709,3 +710,34 @@ def forward(self, x):
709710
self.assertEqual(
710711
count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 2
711712
)
713+
714+
def test_remove_dequant_on_branch(self):
715+
class M(torch.nn.Module):
716+
def forward(self, x):
717+
x = torch.abs(x)
718+
x0 = torch.ops.quantized_decomposed.quantize_per_tensor(
719+
x, 1.2, 3, 0, 127, torch.int8
720+
)
721+
x1 = torch.abs(x0)
722+
y0 = torch.ops.quantized_decomposed.dequantize_per_tensor(
723+
x0, 1.2, 3, 0, 127, torch.int8
724+
)
725+
y1 = y0.view(-1)
726+
return x1, y1
727+
728+
inputs = torch.rand(1, 8, 4, 6)
729+
model = M()
730+
graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
731+
732+
graph_module = RemoveBranchedQuantDequant()(graph_module).graph_module
733+
self.assertTrue(
734+
op_counts_match(
735+
graph_module,
736+
expected_op_counts={
737+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
738+
# we expect the pass to remove the dequantize node
739+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
740+
exir_ops.edge.aten.abs.default: 2,
741+
},
742+
)
743+
)

backends/xnnpack/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ if(NOT CMAKE_TOOLCHAIN_FILE MATCHES ".*(iOS|ios\.toolchain)\.cmake$")
148148

149149
target_link_libraries(xnn_executor_runner gflags ${xnn_executor_runner_libs})
150150
target_compile_options(xnn_executor_runner PUBLIC ${_common_compile_options})
151+
if(EXECUTORCH_BUILD_PTHREADPOOL)
152+
target_link_libraries(xnn_executor_runner extension_threadpool pthreadpool)
153+
target_compile_definitions(xnn_executor_runner PRIVATE ET_USE_THREADPOOL)
154+
endif()
151155
endif()
152156

153157
install(

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ def _detect_precision(self, node: torch.fx.Node) -> ConfigPrecisionType:
9696
def _overwrite_precision(self, node: torch.fx.Node):
9797
precision = self._detect_precision(node)
9898
if precision not in self.enabled_precision_types:
99-
# detected precision is not enabled, lets try to partition it as fp32
99+
# detected precision is not enabled, try to partition it as fp32
100100
if self.enabled_precision_types == [ConfigPrecisionType.FP32]:
101-
# if only fp32 is enabled, then we can still partition fp32 gemms
101+
# when only fp32 is enabled, then we can still partition fp32 gemms
102102
# even with in a quantized graph
103103
if precision in [
104104
ConfigPrecisionType.STATIC_QUANT,
@@ -107,6 +107,7 @@ def _overwrite_precision(self, node: torch.fx.Node):
107107
precision = ConfigPrecisionType.FP32
108108
logging.info(f"Overwriting precision, partitioning {node} as FP32")
109109
return True, precision
110+
110111
return False, precision
111112

112113
def get_deps(
@@ -226,8 +227,11 @@ def _get_bias_deps(
226227
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
227228
) -> Tuple[bool, List[torch.fx.Node]]:
228229
gemm_deps = []
229-
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
230-
# if force force_fp32_dynamic_linear is enabled, then we
230+
if (
231+
precision == ConfigPrecisionType.FP32
232+
and self.force_non_static_weights_for_f32_linear
233+
):
234+
# if force_non_static_weights_for_f32_linear is enabled, then we
231235
# do not partition the weight node
232236
return (True, gemm_deps)
233237

@@ -305,8 +309,11 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
305309
def _get_weight_deps(
306310
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
307311
) -> Tuple[bool, List[torch.fx.Node]]:
308-
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
309-
# if force fp32_dynamic_linear is enabled, then we
312+
if (
313+
precision == ConfigPrecisionType.FP32
314+
and self.force_non_static_weights_for_f32_linear
315+
):
316+
# if force_non_static_weights_for_f32_linear is enabled, then we
310317
# do not partition the weight node
311318
return (True, [])
312319

@@ -412,9 +419,11 @@ def __init__(self, **kwargs):
412419
def _get_weight_deps(
413420
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
414421
) -> Tuple[bool, List[torch.fx.Node]]:
415-
# TODO(maxren, T210537195):
416-
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
417-
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
422+
if (
423+
precision == ConfigPrecisionType.FP32
424+
and self.force_non_static_weights_for_f32_linear
425+
):
426+
# if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we
418427
# do not partition the weight node
419428
return (True, [])
420429

@@ -501,11 +510,11 @@ def find_partition_args(input_node):
501510
node.args = old_args
502511
node.users = old_users
503512

504-
# When using force_fp32_dynamic_linear, we want to get_deps to overwrite the source partition nodes.
513+
# When using force_non_static_weights_for_f32_linear, we want to get_deps to overwrite the source partition nodes.
505514
# Else we want to be greedy.
506515
ret_deps = (
507516
list(set(deps) & set(src_partition.nodes))
508-
if self.force_fp32_dynamic_linear
517+
if self.force_non_static_weights_for_f32_linear
509518
else list(set(deps) | set(src_partition.nodes))
510519
)
511520

@@ -531,8 +540,11 @@ def __init__(self, **kwargs):
531540
def _get_weight_deps(
532541
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
533542
) -> Tuple[bool, List[torch.fx.Node]]:
534-
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
535-
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
543+
if (
544+
precision == ConfigPrecisionType.FP32
545+
and self.force_non_static_weights_for_f32_linear
546+
):
547+
# if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we
536548
# do not partition the weight node
537549
return (True, [])
538550

backends/xnnpack/partition/config/xnnpack_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def __init__(self, **kwargs):
4141
super().__init__()
4242
self.enabled_precision_types = self.supported_precision_types()
4343
# Flag used in GEMMConfig()
44-
self.force_fp32_dynamic_linear = kwargs.get("force_fp32_dynamic_linear", False)
44+
self.force_non_static_weights_for_f32_linear = kwargs.get(
45+
"force_non_static_weights_for_f32_linear", False
46+
)
4547

4648
def get_partition(
4749
self, node: torch.fx.Node, ep: ExportedProgram

backends/xnnpack/test/ops/test_linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,7 @@ def test_linear_qd8_as_fp32(self):
948948
},
949949
)
950950

951-
def test_linear_fp32_with_force_as_mm(self):
951+
def test_linear_with_force_non_static_weights_for_f32_linear(self):
952952
def check_signature(
953953
signature: ExportGraphSignature,
954954
force_flag: bool,
@@ -981,7 +981,7 @@ def check_signature(
981981
inputs = module.get_inputs()
982982
tester = Tester(module, inputs).export()
983983
partitioner = XnnpackPartitioner(
984-
force_fp32_dynamic_linear=force_flag
984+
force_non_static_weights_for_f32_linear=force_flag
985985
)
986986
if legacy_mode:
987987
tester.to_edge()

backends/xnnpack/test/ops/test_lstm.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,20 @@ def test_fp32_lstm(self):
4343
.run_method_and_compare_outputs()
4444
)
4545

46-
def test_fp32_lstm_force_dynamic_linear(self):
46+
def test_lstm_with_force_non_static_weights_for_f32_linear(self):
4747
(
4848
Tester(self.LSTMLinear(32, 32, 10), (torch.rand(1, 32, 32),))
4949
.export()
5050
.to_edge_transform_and_lower(
5151
ToEdgeTransformAndLower(
52-
partitioners=[XnnpackPartitioner(force_fp32_dynamic_linear=True)]
52+
partitioners=[
53+
XnnpackPartitioner(force_non_static_weights_for_f32_linear=True)
54+
]
5355
)
5456
)
5557
.check_not(["executorch_exir_dialects_edge__ops_aten_addmm_default"])
5658
# Weights are supplied as input to linears
57-
# Biases are not owned by delegates when force_fp32_dynamic_linear is set
59+
# Biases are not owned by delegates when force_non_static_weights_for_f32_linear is set
5860
.check(["p_lstm_weight_hh_l0", "p_lstm_weight_ih_l0", "p_lstm_bias"])
5961
.to_executorch()
6062
.serialize()

build/build_android_library.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,9 @@ collect_artifacts_to_be_uploaded() {
178178
}
179179

180180
main() {
181-
BUILD_AAR_DIR="$(mktemp -d)"
181+
if [[ -z "${BUILD_AAR_DIR:-}" ]]; then
182+
BUILD_AAR_DIR="$(mktemp -d)"
183+
fi
182184
export BUILD_AAR_DIR
183185
if [ -z "$ANDROID_ABIS" ]; then
184186
ANDROID_ABIS=("arm64-v8a" "x86_64")

0 commit comments

Comments
 (0)