Skip to content

Commit 8ade738

Browse files
committed
Update
[ghstack-poisoned]
2 parents 14a55be + 5814a3b commit 8ade738

File tree

22 files changed

+459
-94
lines changed

22 files changed

+459
-94
lines changed

.ci/scripts/test_ane_static_llama.sh

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#!/bin/bash
2+
# Copyright (c) Qualcomm Innovation Center, Inc.
3+
# All rights reserved
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
set -exu
9+
10+
source "$(dirname "${BASH_SOURCE[0]}")/utils.sh"
11+
12+
export EXECUTORCH_ROOT="$(dirname "${BASH_SOURCE[0]}")/../.."
13+
14+
if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then
15+
PYTHON_EXECUTABLE=python3
16+
fi
17+
18+
which "${PYTHON_EXECUTABLE}"
19+
20+
pushd $EXECUTORCH_ROOT/examples/apple/coreml/llama
21+
22+
# Download stories llama110m artifacts
23+
download_stories_model_artifacts
24+
25+
python export.py -n model.pte -p params.json -c stories110M.pt --seq_length 32 --max_seq_length 64 --dtype fp16 --coreml-quantize c4w
26+
27+
popd

.github/workflows/trunk.yml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,28 @@ jobs:
229229
# see if we can import the module successfully
230230
${CONDA_RUN} python -c "from executorch.extension.pybindings import portable_lib; print('success!')"
231231
232+
test-static-llama-ane:
233+
name: test-static-llama-ane
234+
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
235+
with:
236+
runner: macos-m1-stable
237+
python-version: '3.11'
238+
submodules: 'true'
239+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
240+
script: |
241+
set -eux
242+
bash .ci/scripts/setup-conda.sh
243+
eval "$(conda shell.bash hook)"
244+
245+
# Install requirements
246+
sh install_requirements.sh
247+
sh backends/apple/coreml/scripts/install_requirements.sh
248+
python install_executorch.py --pybind coreml
249+
sh examples/models/llama/install_requirements.sh
250+
251+
# Test ANE llama
252+
sh .ci/scripts/test_ane_static_llama.sh
253+
232254
test-llama-runner-macos:
233255
name: test-llama-runner-mac
234256
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main

backends/cadence/aot/pass_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,16 @@ def count_node(graph_module: torch.fx.GraphModule, target: torch.fx.node.Target)
104104
return total
105105

106106

107+
def op_counts_match(
108+
graph_module: torch.fx.GraphModule,
109+
expected_op_counts: dict[EdgeOpOverload, int],
110+
) -> bool:
111+
for op, count in expected_op_counts.items():
112+
if count_node(graph_module, op) != count:
113+
return False
114+
return True
115+
116+
107117
# Testing utils
108118
# Return the compute/function nodes in the graph
109119
def get_compute_nodes_in_gm(graph_module: torch.fx.GraphModule) -> List[torch.fx.Node]:

backends/cadence/aot/remove_ops.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
3434
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
3535
from executorch.exir.dialects._ops import ops as exir_ops
36-
from executorch.exir.dialects.edge._ops import EdgeOpOverload
36+
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
3737
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
3838
from executorch.exir.pass_manager import PassManager, PassType
3939
from executorch.exir.passes import dead_code_elimination_pass
@@ -745,6 +745,68 @@ def permute_shape(
745745
return [shape[p] for p in permute_dims]
746746

747747

748+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
749+
class RemoveBranchedQuantDequant(ExportPass):
750+
"""
751+
This pass looks for adjacent quant and dequant nodes with identical
752+
parameters, where the quant node has other users in addition to the
753+
dequant. The quant and dequant pair would be removed by the
754+
FuseQuantDequantToRequantizePass if not for the multiple users. This pass
755+
removes just the dequant node by connecting it to the quant's parent node
756+
"""
757+
758+
quantize_op_packets: set[EdgeOpOverloadPacket] = {
759+
exir_ops.edge.cadence.quantize_per_tensor,
760+
exir_ops.edge.quantized_decomposed.quantize_per_tensor,
761+
}
762+
dequantize_op_packets: set[EdgeOpOverloadPacket] = {
763+
exir_ops.edge.cadence.dequantize_per_tensor,
764+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor,
765+
}
766+
767+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
768+
self.remove_branched(
769+
graph_module, self.quantize_op_packets, self.dequantize_op_packets
770+
)
771+
self.remove_branched(
772+
graph_module, self.dequantize_op_packets, self.quantize_op_packets
773+
)
774+
775+
graph_module.graph.eliminate_dead_code()
776+
result = super().call(graph_module)
777+
return result
778+
779+
def remove_branched(
780+
self,
781+
graph_module: torch.fx.GraphModule,
782+
producer_pkts: set[EdgeOpOverloadPacket],
783+
consumer_pkts: set[EdgeOpOverloadPacket],
784+
) -> None:
785+
for node in graph_module.graph.nodes:
786+
if (
787+
node.op != "call_function"
788+
or not isinstance(node.target, EdgeOpOverload)
789+
or get_edge_overload_packet(node.target) not in producer_pkts
790+
):
791+
continue
792+
793+
if len(node.users) < 2:
794+
continue
795+
796+
for user in node.users:
797+
if (
798+
not isinstance(user.target, EdgeOpOverload)
799+
or get_edge_overload_packet(user.target) not in consumer_pkts
800+
):
801+
continue
802+
803+
# check qparams match
804+
if node.args[1:] != user.args[1:]:
805+
continue
806+
807+
user.replace_all_uses_with(node.args[0])
808+
809+
748810
# The following class consolidates functions to remove ops that are redundant
749811
# in Jarvis. Currently, each function in this class iterates over each node of
750812
# the graph module once. In future, we could consolidate them into a monolithic
@@ -765,4 +827,5 @@ class CadenceRemoveNops:
765827
RemoveNopMulOpPass,
766828
RemoveNopAddOpPass,
767829
RemoveNopLinalgVectorNormOpPass,
830+
RemoveBranchedQuantDequant,
768831
]

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
FuseTransposeOpPairsPass,
2121
)
2222
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
23-
from executorch.backends.cadence.aot.pass_utils import count_node
23+
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
2424
from executorch.exir.dialects._ops import ops as exir_ops
2525
from executorch.exir.dialects.edge._ops import EdgeOpOverload
2626
from torch import nn
@@ -32,8 +32,7 @@ def check_op_counts(
3232
graph_module: torch.fx.GraphModule,
3333
expected_op_counts: dict[EdgeOpOverload, int],
3434
) -> None:
35-
for op, count in expected_op_counts.items():
36-
self.assertEqual(count_node(graph_module, op), count)
35+
self.assertTrue(op_counts_match(graph_module, expected_op_counts))
3736

3837

3938
class TestFusionPasses(TestFusionPassesBase):

backends/cadence/aot/tests/test_remove_ops_passes.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717
from executorch.backends.cadence.aot import compiler
1818
from executorch.backends.cadence.aot.compiler import export_to_edge
1919

20-
from executorch.backends.cadence.aot.pass_utils import count_node
20+
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
2121
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer
2222
from executorch.backends.cadence.aot.remove_ops import (
2323
RemoveAliasCopyOpPass,
24+
RemoveBranchedQuantDequant,
2425
RemoveCloneOpPass,
2526
RemoveContiguousOpPass,
2627
RemoveDetachCopyPass,
@@ -709,3 +710,34 @@ def forward(self, x):
709710
self.assertEqual(
710711
count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 2
711712
)
713+
714+
def test_remove_dequant_on_branch(self):
715+
class M(torch.nn.Module):
716+
def forward(self, x):
717+
x = torch.abs(x)
718+
x0 = torch.ops.quantized_decomposed.quantize_per_tensor(
719+
x, 1.2, 3, 0, 127, torch.int8
720+
)
721+
x1 = torch.abs(x0)
722+
y0 = torch.ops.quantized_decomposed.dequantize_per_tensor(
723+
x0, 1.2, 3, 0, 127, torch.int8
724+
)
725+
y1 = y0.view(-1)
726+
return x1, y1
727+
728+
inputs = torch.rand(1, 8, 4, 6)
729+
model = M()
730+
graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
731+
732+
graph_module = RemoveBranchedQuantDequant()(graph_module).graph_module
733+
self.assertTrue(
734+
op_counts_match(
735+
graph_module,
736+
expected_op_counts={
737+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
738+
# we expect the pass to remove the dequantize node
739+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
740+
exir_ops.edge.aten.abs.default: 2,
741+
},
742+
)
743+
)

backends/vulkan/runtime/api/containers/Tensor.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <executorch/backends/vulkan/runtime/api/containers/Tensor.h>
10+
#include <algorithm>
1011
#include <cassert>
1112
#include <cstring>
1213

backends/xnnpack/CMakeLists.txt

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,14 @@ if(NOT PYTHON_EXECUTABLE)
3333
resolve_python_executable()
3434
endif()
3535

36-
# NB: Enabling this will serialize execution of delegate instances
37-
# Keeping this OFF by default to maintain existing behavior, to be revisited.
36+
# NB: Enabling this will serialize execution of delegate instances Keeping this
37+
# OFF by default to maintain existing behavior, to be revisited.
3838
option(EXECUTORCH_XNNPACK_SHARED_WORKSPACE
39-
"Enable workspace sharing across different delegate instances" ON)
40-
# Keeping this OFF by default due to regressions in decode
41-
# and model load with kleidi kernels
42-
option(EXECUTORCH_XNNPACK_ENABLE_KLEIDI
43-
"Enable Arm Kleidi kernels" OFF)
39+
"Enable workspace sharing across different delegate instances" ON
40+
)
41+
# Keeping this OFF by default due to regressions in decode and model load with
42+
# kleidi kernels
43+
option(EXECUTORCH_XNNPACK_ENABLE_KLEIDI "Enable Arm Kleidi kernels" OFF)
4444
if(EXECUTORCH_XNNPACK_SHARED_WORKSPACE)
4545
add_definitions(-DENABLE_XNNPACK_SHARED_WORKSPACE)
4646
endif()
@@ -100,8 +100,7 @@ include(cmake/Dependencies.cmake)
100100
list(TRANSFORM _xnnpack_backend__srcs PREPEND "${EXECUTORCH_ROOT}/")
101101
add_library(xnnpack_backend STATIC ${_xnnpack_backend__srcs})
102102
target_link_libraries(
103-
xnnpack_backend PRIVATE ${xnnpack_third_party} executorch_core
104-
xnnpack_schema
103+
xnnpack_backend PRIVATE ${xnnpack_third_party} executorch_core xnnpack_schema
105104
)
106105

107106
target_include_directories(
@@ -119,6 +118,12 @@ target_include_directories(
119118
target_compile_options(xnnpack_backend PUBLIC ${_common_compile_options})
120119
target_link_options_shared_lib(xnnpack_backend)
121120

121+
if(EXECUTORCH_BUILD_KERNELS_OPTIMIZED)
122+
list(APPEND xnn_executor_runner_libs optimized_native_cpu_ops_lib)
123+
else()
124+
list(APPEND xnn_executor_runner_libs portable_ops_lib)
125+
endif()
126+
122127
list(APPEND xnn_executor_runner_libs xnnpack_backend executorch)
123128

124129
# ios can only build library but not binary
@@ -134,14 +139,19 @@ if(NOT CMAKE_TOOLCHAIN_FILE MATCHES ".*(iOS|ios\.toolchain)\.cmake$")
134139
if(EXECUTORCH_BUILD_DEVTOOLS)
135140
list(APPEND xnn_executor_runner_libs etdump)
136141
else()
137-
message(SEND_ERROR "Use of 'EXECUTORCH_ENABLE_EVENT_TRACER' requires 'EXECUTORCH_BUILD_DEVTOOLS' to be enabled.")
142+
message(
143+
SEND_ERROR
144+
"Use of 'EXECUTORCH_ENABLE_EVENT_TRACER' requires 'EXECUTORCH_BUILD_DEVTOOLS' to be enabled."
145+
)
138146
endif()
139147
endif()
140148

141-
target_link_libraries(
142-
xnn_executor_runner gflags portable_ops_lib ${xnn_executor_runner_libs}
143-
)
149+
target_link_libraries(xnn_executor_runner gflags ${xnn_executor_runner_libs})
144150
target_compile_options(xnn_executor_runner PUBLIC ${_common_compile_options})
151+
if(EXECUTORCH_BUILD_PTHREADPOOL)
152+
target_link_libraries(xnn_executor_runner extension_threadpool pthreadpool)
153+
target_compile_definitions(xnn_executor_runner PRIVATE ET_USE_THREADPOOL)
154+
endif()
145155
endif()
146156

147157
install(

0 commit comments

Comments
 (0)