Skip to content

Commit 667e472

Browse files
committed
Update on "[ET-VK] Minor performance improvements to native layer norm."
This diff introduces minor performance improvements to the native layer norm function in the Vulkan backend of Executorch. In this new approach: The mean and variance values are calculated in 2 separate passes. Shader is dispatched based on input texture size, and input texel is read and stored in shared memory. Input stored in shard memory is then summed up using a reduce function. This implementation better utilizes a GPUs parallel processing capabilities. Differential Revision: [D72430290](https://our.internmc.facebook.com/intern/diff/D72430290/) [ghstack-poisoned]
2 parents fc5cdde + be1f29a commit 667e472

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+1193
-144
lines changed

.ci/scripts/gather_benchmark_configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"samsung_galaxy_s22": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/e59f866a-30aa-4aa1-87b7-4510e5820dfa",
2424
"samsung_galaxy_s24": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/98f8788c-2e25-4a3c-8bb2-0d1e8897c0db",
2525
"google_pixel_8_pro": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/d65096ab-900b-4521-be8b-a3619b69236a",
26+
"google_pixel_3_private_rooted": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/98d23ca8-ea9e-4fb7-b725-d402017b198d",
2627
}
2728

2829
# Predefined benchmark configurations
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
name: android-perf (private devices)
2+
3+
on:
4+
schedule:
5+
- cron: 0 0,4,8,12,16,20 * * *
6+
pull_request:
7+
paths:
8+
- .github/workflows/android-perf-private-device-experiment.yml
9+
push:
10+
branches:
11+
- main
12+
paths:
13+
- .github/workflows/android-perf-private-device-experiment.yml
14+
# Note: GitHub has an upper limit of 10 inputs
15+
workflow_dispatch:
16+
inputs:
17+
models:
18+
description: Models to be benchmarked
19+
required: false
20+
type: string
21+
default: mv3,meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8,meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8
22+
devices:
23+
description: Target devices to run benchmark
24+
required: false
25+
type: string
26+
default: google_pixel_3_private_rooted
27+
benchmark_configs:
28+
description: The list of configs used the benchmark
29+
required: false
30+
type: string
31+
workflow_call:
32+
inputs:
33+
models:
34+
description: Models to be benchmarked
35+
required: false
36+
type: string
37+
default: mv3,meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8,meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8
38+
devices:
39+
description: Target devices to run benchmark
40+
required: false
41+
type: string
42+
default: google_pixel_3_private_rooted
43+
benchmark_configs:
44+
description: The list of configs used the benchmark
45+
required: false
46+
type: string
47+
48+
concurrency:
49+
group: android-perf-private-devices-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
50+
cancel-in-progress: true
51+
52+
jobs:
53+
android:
54+
uses: ./.github/workflows/android-perf.yml
55+
secrets: inherit
56+
permissions:
57+
id-token: write
58+
contents: read
59+
with:
60+
models: ${{ inputs.models }}
61+
devices: google_pixel_3_private_rooted
62+
benchmark_configs: ${{ inputs.benchmark_configs }}

.github/workflows/android-release-artifacts.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ jobs:
4949
contents: read
5050
with:
5151
secrets-env: EXECUTORCH_MAVEN_SIGNING_KEYID EXECUTORCH_MAVEN_SIGNING_PASSWORD EXECUTORCH_MAVEN_CENTRAL_PASSWORD EXECUTORCH_MAVEN_CENTRAL_USERNAME EXECUTORCH_MAVEN_SIGNING_GPG_KEY_CONTENTS
52-
runner: linux.2xlarge
52+
# As this job has access to Maven credential, run this on a fresh ephemeral runner
53+
runner: ephemeral.linux.2xlarge
5354
docker-image: executorch-ubuntu-22.04-clang12-android
5455
submodules: 'recursive'
5556
ref: ${{ github.sha }}

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .decompose_linear_pass import DecomposeLinearPass # noqa
2727
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
2828
from .decompose_select import DecomposeSelectPass # noqa
29+
from .decompose_silu_pass import DecomposeSiluPass # noqa
2930
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa
3031
from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa
3132
from .decompose_sqrt_pass import DecomposeSqrtPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
DecomposeLinearPass,
3232
DecomposeMeanDimPass,
3333
DecomposeSelectPass,
34+
DecomposeSiluPass,
3435
DecomposeSoftmaxPass,
3536
DecomposeSoftmaxUnstablePass,
3637
DecomposeSqrtPass,
@@ -196,6 +197,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
196197
self.add_pass(DecomposeDivPass())
197198
self.add_pass(DecomposeLeakyReLUPass())
198199
self.add_pass(DecomposeSqrtPass())
200+
self.add_pass(DecomposeSiluPass())
199201

200202
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
201203
# Numerically stable softmax uses amax which is not supported on Ethos-U55
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
import torch
9+
from executorch.exir.pass_base import ExportPass
10+
11+
aten_silu_ops = (torch.ops.aten.silu.default, torch.ops.aten.silu_.default)
12+
13+
14+
class DecomposeSiluPass(ExportPass):
15+
"""
16+
This pass decomposes silu into a mul and a sigmoid node.
17+
18+
Example:
19+
y = silu(a)
20+
Becomes:
21+
x = sigmoid(a)
22+
y = mul(a,x)
23+
"""
24+
25+
def call_operator(self, op, args, kwargs, meta):
26+
if op not in (aten_silu_ops):
27+
return super().call_operator(op, args, kwargs, meta)
28+
sigmoid_op = torch.ops.aten.sigmoid.default
29+
mul_op = torch.ops.aten.mul.Tensor
30+
31+
original = args[0]
32+
sigmoid = super().call_operator(sigmoid_op, (original,), {}, meta)
33+
34+
return super().call_operator(mul_op, (original, sigmoid), {}, meta)

backends/arm/_passes/match_arg_ranks_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(self, exported_program):
4949
exir_ops.edge.aten.bitwise_left_shift.Tensor,
5050
exir_ops.edge.aten.eq.Tensor,
5151
exir_ops.edge.aten.gt.Tensor,
52+
exir_ops.edge.aten.ge.Tensor,
5253
exir_ops.edge.aten.lt.Tensor,
5354
exir_ops.edge.aten.pow.Tensor_Tensor,
5455
exir_ops.edge.aten.where.self,

backends/arm/_passes/replace_scalar_with_tensor_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
exir_ops.edge.aten.__lshift__.Scalar: exir_ops.edge.aten.bitwise_left_shift.Tensor,
2828
exir_ops.edge.aten.eq.Scalar: exir_ops.edge.aten.eq.Tensor,
2929
exir_ops.edge.aten.gt.Scalar: exir_ops.edge.aten.gt.Tensor,
30+
exir_ops.edge.aten.ge.Scalar: exir_ops.edge.aten.ge.Tensor,
3031
exir_ops.edge.aten.lt.Scalar: exir_ops.edge.aten.lt.Tensor,
3132
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
3233
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
@@ -36,6 +37,7 @@
3637
torch.ops.aten.__lshift__.Scalar: torch.ops.aten.bitwise_left_shift.Tensor,
3738
torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor,
3839
torch.ops.aten.gt.Scalar: torch.ops.aten.gt.Tensor,
40+
torch.ops.aten.ge.Scalar: torch.ops.aten.ge.Tensor,
3941
torch.ops.aten.lt.Scalar: torch.ops.aten.lt.Tensor,
4042
}
4143

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ class EthosU55NotSupported(OperatorSupportBase):
134134
exir_ops.edge.aten.eq.Tensor,
135135
exir_ops.edge.aten.eq.Scalar,
136136
exir_ops.edge.aten.ge.Tensor,
137+
exir_ops.edge.aten.ge.Scalar,
137138
exir_ops.edge.aten.gt.Tensor,
138139
exir_ops.edge.aten.gt.Scalar,
139140
exir_ops.edge.aten.le.Tensor,

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def is_node_supported(
178178
exir_ops.edge.aten.full.default,
179179
exir_ops.edge.aten.full_like.default,
180180
exir_ops.edge.aten.ge.Tensor,
181+
exir_ops.edge.aten.ge.Scalar,
181182
exir_ops.edge.aten.gt.Tensor,
182183
exir_ops.edge.aten.gt.Scalar,
183184
exir_ops.edge.aten.le.Tensor,
@@ -228,6 +229,7 @@ def is_node_supported(
228229
exir_ops.edge.aten.__lshift__.Scalar,
229230
torch.ops.aten.scalar_tensor.default,
230231
exir_ops.edge.aten.gelu.default,
232+
exir_ops.edge.aten.alias_copy.default,
231233
]
232234

233235
return supported

backends/arm/operators/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
op_erf,
2323
op_exp,
2424
op_ge,
25-
op_get_item,
2625
op_gt,
2726
op_le,
2827
op_log,
@@ -51,5 +50,6 @@
5150
op_view,
5251
op_where,
5352
ops_binary,
53+
ops_identity,
5454
ops_unary,
5555
)

backends/arm/operators/op_get_item.py

Lines changed: 0 additions & 35 deletions
This file was deleted.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import List
9+
10+
import torch
11+
import torch.fx
12+
13+
import tosa_tools.v0_80.serializer.tosa_serializer as ts
14+
15+
from executorch.backends.arm.operators.node_visitor import (
16+
NodeVisitor,
17+
register_node_visitor,
18+
)
19+
from executorch.backends.arm.tosa_mapping import TosaArg
20+
21+
22+
def identity_operator_factory(identity_target: str):
23+
"""
24+
Creates and registers NodeVisitors for operators that map directly
25+
to a TOSA IDENTITY op.
26+
"""
27+
28+
class IdentityOperatorVisitor(NodeVisitor):
29+
target = identity_target
30+
31+
def define_node(
32+
self,
33+
node: torch.fx.Node,
34+
tosa_graph: ts.TosaSerializer,
35+
inputs: List[TosaArg],
36+
output: TosaArg,
37+
) -> None:
38+
# Simply add an identityOp
39+
tosa_graph.addOperator(
40+
ts.TosaOp.Op().IDENTITY, [inputs[0].name], [output.name]
41+
)
42+
43+
register_node_visitor(IdentityOperatorVisitor)
44+
45+
46+
identity_operator_factory("getitem")
47+
identity_operator_factory("aten.alias_copy.default")

backends/arm/quantizer/arm_quantizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,10 +286,10 @@ def _annotate_all_static_patterns(
286286
quantization_config: Optional[QuantizationConfig],
287287
filter_fn: Optional[Callable[[Node], bool]] = None,
288288
) -> GraphModule:
289-
"""Loops over all STATIC_OPS and runs the corresponding registred annotator.
289+
"""Loops over all STATIC_OPS and runs the corresponding registered annotator.
290290
Args:
291291
model: The model to annotate statically.
292-
quantization_config: Specifices the QuantizationSpecs for the model's
292+
quantization_config: Specifies the QuantizationSpecs for the model's
293293
input activations, output activations, weights and biases.
294294
filter_fn: An optional filter function that takes a node and returns whether the node should be annotated.
295295
Returns:

backends/arm/quantizer/quantization_annotator.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,11 @@ def _match_pattern(
244244
operator.getitem,
245245
]
246246

247+
_one_to_one_shared_input_or_input_act_qspec = [
248+
torch.ops.aten.adaptive_avg_pool2d.default,
249+
torch.ops.aten.alias_copy.default,
250+
]
251+
247252

248253
def get_quant_properties( # noqa: C901
249254
node: Node, gm: torch.fx.GraphModule, quantization_config
@@ -332,7 +337,7 @@ def any_or_hardtanh_min_zero(n: Node):
332337
_QuantProperty(2, shared_qspec), # type: ignore[arg-type]
333338
]
334339
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
335-
elif node.target == torch.ops.aten.adaptive_avg_pool2d.default:
340+
elif node.target in _one_to_one_shared_input_or_input_act_qspec:
336341
input_qspec = (
337342
SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
338343
if arm_quantizer_utils.is_output_annotated(node.args[0]) # type: ignore

0 commit comments

Comments
 (0)