Skip to content

Commit 4655097

Browse files
authored
Merge branch 'release/0.6' into cherry-pick-9799-by-pytorch_bot_bot_
2 parents 7845699 + ff9fcaa commit 4655097

11 files changed

+381
-394
lines changed

backends/xnnpack/operators/op_static_constant_pad.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import cast, Dict, List
88

99
import torch
10+
1011
from executorch.backends.xnnpack.operators.node_visitor import (
1112
get_tensor_value,
1213
NodeVisitor,
@@ -17,7 +18,11 @@
1718
XNNStaticConstantPad,
1819
XNode,
1920
)
20-
from executorch.backends.xnnpack.utils.utils import check_or_raise, get_input_node
21+
from executorch.backends.xnnpack.utils.utils import (
22+
check_or_raise,
23+
get_input_node,
24+
PERM_NCHW_TO_NHWC,
25+
)
2126

2227

2328
@register_node_visitor
@@ -113,8 +118,15 @@ def define_node(
113118
# b)
114119
# tuple[0] = prepadding dim[-1]
115120
# tuple[1] = postpadding dim[-1]
121+
is_channels_last = node.meta.get("XNN_NHWC_NODE", False)
116122
pre_paddings = all_paddings[-2::-2] # even index elements in reverse order
117123
post_paddings = all_paddings[::-2] # odd index elements in reverse order
124+
if is_channels_last:
125+
check_or_raise(len(pre_paddings) == 4, "Expecting prepaddings to be 4D")
126+
check_or_raise(len(post_paddings) == 4, "Expecting postpaddings to be 4D")
127+
128+
pre_paddings = [pre_paddings[i] for i in PERM_NCHW_TO_NHWC]
129+
post_paddings = [post_paddings[i] for i in PERM_NCHW_TO_NHWC]
118130

119131
# the padding value, which defaults to 0.0
120132
padding_value = cast(float, node.args[2]) if len(node.args) > 2 else 0.0

backends/xnnpack/test/ops/test_static_constant_pad.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,30 @@ class TestStaticConstantPad(unittest.TestCase):
1414
def setUp(self):
1515
torch._dynamo.reset()
1616

17+
class NHWCStaticConstantPad(torch.nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
self.conv1 = torch.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=1)
21+
self.conv2 = torch.nn.Conv2d(in_channels=13, out_channels=13, kernel_size=1)
22+
23+
def forward(self, x):
24+
a = self.conv1(x)
25+
pad_6 = (1, 2, 3, 4, 5, 6)
26+
a = torch.nn.functional.pad(
27+
input=a,
28+
pad=pad_6,
29+
mode="constant",
30+
value=3.1,
31+
)
32+
# tensorshape = [1, 13, 10, 7]
33+
a = self.conv2(a)
34+
35+
return a
36+
37+
def sample_inputs(self):
38+
# NCHW
39+
return (torch.randn(1, 2, 3, 4),)
40+
1741
class StaticConstantPadFunctional(torch.nn.Module):
1842
def __init__(self):
1943
super().__init__()
@@ -205,3 +229,24 @@ def test_qs8_static_constant_pad_2d(self):
205229
.serialize()
206230
.run_method_and_compare_outputs()
207231
)
232+
233+
def test_fp32_static_constant_pad_nhwc(self):
234+
conv = self.NHWCStaticConstantPad()
235+
inputs = conv.sample_inputs()
236+
(
237+
Tester(conv, inputs)
238+
.export()
239+
.check_count({"torch.ops.aten.pad.default": 1})
240+
.dump_artifact()
241+
.to_edge_transform_and_lower()
242+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
243+
.check_not(
244+
[
245+
"executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default",
246+
"executorch_exir_dialects_edge__ops_aten_convolution_default",
247+
]
248+
)
249+
.to_executorch()
250+
.serialize()
251+
.run_method_and_compare_outputs()
252+
)
139 KB
Loading
1.92 MB
Binary file not shown.

docs/source/backend-delegates-xnnpack-reference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,5 +142,5 @@ def _qdq_quantized_linear(
142142
You can read more indepth explanations on PyTorch 2 quantization [here](https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html).
143143

144144
## See Also
145-
- [Integrating XNNPACK Delegate Android App](demo-apps-android.md)
145+
- [Integrating XNNPACK Delegate in Android AAR](using-executorch-android.md)
146146
- [Complete the Lowering to XNNPACK Tutorial](tutorial-xnnpack-delegate-lowering.md)

docs/source/backends-qualcomm.md

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -351,11 +351,6 @@ The command-line arguments are written in [utils.py](https://github.com/pytorch/
351351
The model, inputs, and output location are passed to `qnn_executorch_runner` by `--model_path`, `--input_list_path`, and `--output_folder_path`.
352352

353353

354-
### Running a model via ExecuTorch's android demo-app
355-
356-
An Android demo-app using Qualcomm AI Engine Direct Backend can be found in
357-
`examples`. Please refer to android demo app [tutorial](demo-apps-android.md).
358-
359354
## Supported model list
360355

361356
Please refer to `$EXECUTORCH_ROOT/examples/qualcomm/scripts/` and `EXECUTORCH_ROOT/examples/qualcomm/oss_scripts/` to the list of supported models.

docs/source/index.md

Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
(home)=
2+
# Welcome to the ExecuTorch Documentation
3+
4+
**ExecuTorch** is PyTorch's solution to training and inference on the
5+
Edge.
6+
7+
## Key Value Propositions
8+
9+
- **Portability:** Compatibility with a wide variety of computing
10+
platforms, from high-end mobile phones to highly constrained
11+
embedded systems and microcontrollers.
12+
- **Productivity:** Enabling developers to use the same toolchains and
13+
Developer Tools from PyTorch model authoring and conversion, to
14+
debugging and deployment to a wide variety of platforms.
15+
- **Performance:** Providing end users with a seamless and
16+
high-performance experience due to a lightweight runtime and
17+
utilizing full hardware capabilities such as CPUs, NPUs, and DSPs.
18+
19+
ExecuTorch provides support for:
20+
21+
* **Strong Model Support** LLMs (Large Language Models),
22+
CV (Computer Vision), ASR (Automatic Speech Recognition), TTS (Text To Speech)
23+
* **All Major Platforms** Android, Mac, Linux, Windows
24+
* **Rich Acceleration Support** Apple, Arm, Cadence, MediaTek,
25+
Qualcomm, Vulkan, XNNPACK
26+
27+
### Documentation Navigation
28+
#### Introduction
29+
- [Overview](intro-overview)
30+
- [How it Works](intro-how-it-works)
31+
- [Getting Started with Architecture](getting-started-architecture)
32+
- [Concepts](concepts)
33+
#### Usage
34+
- [Getting Started](getting-started)
35+
- [Using Executorch Export](using-executorch-export)
36+
- [Using Executorch on Android](using-executorch-android)
37+
- [Using Executorch on iOS](using-executorch-ios)
38+
- [Using Executorch with C++](using-executorch-cpp)
39+
- [Runtime Integration](using-executorch-runtime-integration)
40+
- [Troubleshooting](using-executorch-troubleshooting)
41+
- [Building from Source](using-executorch-building-from-source)
42+
- [FAQs](using-executorch-faqs)
43+
#### Examples
44+
- [Android Demo Apps](https://github.com/pytorch-labs/executorch-examples/tree/main/dl3/android/DeepLabV3Demo#executorch-android-demo-app)
45+
- [iOS Demo Apps](demo-apps-ios.md)
46+
#### Backends
47+
- [Overview](backends-overview)
48+
- [XNNPACK](backends-xnnpack)
49+
- [Core ML](backends-coreml)
50+
- [MPS](backends-mps)
51+
- [Vulkan](backends-vulkan)
52+
- [ARM Ethos-U](backends-arm-ethos-u)
53+
- [Qualcomm](backends-qualcomm)
54+
- [MediaTek](backends-mediatek)
55+
- [Cadence](backends-cadence)
56+
#### Developer Tools
57+
- [Overview](devtools-overview)
58+
- [Bundled IO](bundled-io)
59+
- [ETRecord](etrecord)
60+
- [ETDump](etdump)
61+
- [Runtime Profiling](runtime-profiling)
62+
- [Model Debugging](model-debugging)
63+
- [Model Inspector](model-inspector)
64+
- [Memory Planning Inspection](memory-planning-inspection)
65+
- [Delegate Debugging](delegate-debugging)
66+
- [Tutorial](devtools-tutorial)
67+
#### Runtime
68+
- [Overview](runtime-overview)
69+
- [Extension Module](extension-module)
70+
- [Extension Tensor](extension-tensor)
71+
- [Running a Model (C++ Tutorial)](running-a-model-cpp-tutorial)
72+
- [Backend Delegate Implementation and Linking](runtime-backend-delegate-implementation-and-linking)
73+
- [Platform Abstraction Layer](runtime-platform-abstraction-layer)
74+
#### Portable C++ Programming
75+
- [PTE File Format](pte-file-format)
76+
#### API Reference
77+
- [Export to Executorch API Reference](export-to-executorch-api-reference)
78+
- [Executorch Runtime API Reference](executorch-runtime-api-reference)
79+
- [Runtime Python API Reference](runtime-python-api-reference)
80+
- [API Life Cycle](api-life-cycle)
81+
- [Javadoc](https://pytorch.org/executorch/main/javadoc/)
82+
#### Quantization
83+
- [Overview](quantization-overview)
84+
#### Kernel Library
85+
- [Overview](kernel-library-overview)
86+
- [Custom ATen Kernel](kernel-library-custom-aten-kernel)
87+
- [Selective Build](kernel-library-selective-build)
88+
#### Working with LLMs
89+
- [Llama](llm/llama)
90+
- [Llama on Android](llm/llama-demo-android)
91+
- [Llama on iOS](llm/llama-demo-ios)
92+
- [Llama on Android via Qualcomm backend](llm/build-run-llama3-qualcomm-ai-engine-direct-backend)
93+
- [Intro to LLMs in Executorch](llm/getting-started)
94+
#### Backend Development
95+
- [Delegates Integration](backend-delegates-integration)
96+
- [XNNPACK Reference](backend-delegates-xnnpack-reference)
97+
- [Dependencies](backend-delegates-dependencies)
98+
- [Compiler Delegate and Partitioner](compiler-delegate-and-partitioner)
99+
- [Debug Backend Delegate](debug-backend-delegate)
100+
#### IR Specification
101+
- [EXIR](ir-exir)
102+
- [Ops Set Definition](ir-ops-set-definition)
103+
#### Compiler Entry Points
104+
- [Backend Dialect](compiler-backend-dialect)
105+
- [Custom Compiler Passes](compiler-custom-compiler-passes)
106+
- [Memory Planning](compiler-memory-planning)
107+
#### Contributing
108+
- [Contributing](contributing)
109+
110+
```{toctree}
111+
:glob:
112+
:maxdepth: 1
113+
:caption: Introduction
114+
:hidden:
115+
116+
intro-overview
117+
intro-how-it-works
118+
getting-started-architecture
119+
concepts
120+
```
121+
122+
```{toctree}
123+
:glob:
124+
:maxdepth: 1
125+
:caption: Usage
126+
:hidden:
127+
128+
getting-started
129+
using-executorch-export
130+
using-executorch-android
131+
using-executorch-ios
132+
using-executorch-cpp
133+
using-executorch-runtime-integration
134+
using-executorch-troubleshooting
135+
using-executorch-building-from-source
136+
using-executorch-faqs
137+
```
138+
139+
```{toctree}
140+
:glob:
141+
:maxdepth: 1
142+
:caption: Examples
143+
:hidden:
144+
145+
Building an ExecuTorch Android Demo App <https://github.com/pytorch-labs/executorch-examples/tree/main/dl3/android/DeepLabV3Demo#executorch-android-demo-app>
146+
demo-apps-ios.md
147+
```
148+
149+
```{toctree}
150+
:glob:
151+
:maxdepth: 1
152+
:caption: Backends
153+
:hidden:
154+
155+
backends-overview
156+
backends-xnnpack
157+
backends-coreml
158+
backends-mps
159+
backends-vulkan
160+
backends-arm-ethos-u
161+
backends-qualcomm
162+
backends-mediatek
163+
backends-cadence
164+
```
165+
166+
```{toctree}
167+
:glob:
168+
:maxdepth: 1
169+
:caption: Developer Tools
170+
:hidden:
171+
172+
devtools-overview
173+
bundled-io
174+
etrecord
175+
etdump
176+
runtime-profiling
177+
model-debugging
178+
model-inspector
179+
memory-planning-inspection
180+
delegate-debugging
181+
devtools-tutorial
182+
```
183+
184+
```{toctree}
185+
:glob:
186+
:maxdepth: 1
187+
:caption: Runtime
188+
:hidden:
189+
190+
runtime-overview
191+
extension-module
192+
extension-tensor
193+
running-a-model-cpp-tutorial
194+
runtime-backend-delegate-implementation-and-linking
195+
runtime-platform-abstraction-layer
196+
portable-cpp-programming
197+
pte-file-format
198+
```
199+
200+
```{toctree}
201+
:glob:
202+
:maxdepth: 1
203+
:caption: API Reference
204+
:hidden:
205+
206+
export-to-executorch-api-reference
207+
executorch-runtime-api-reference
208+
runtime-python-api-reference
209+
api-life-cycle
210+
Javadoc <https://pytorch.org/executorch/main/javadoc/>
211+
```
212+
213+
```{toctree}
214+
:glob:
215+
:maxdepth: 1
216+
:caption: Quantization
217+
:hidden:
218+
219+
quantization-overview
220+
```
221+
222+
```{toctree}
223+
:glob:
224+
:maxdepth: 1
225+
:caption: Kernel Library
226+
:hidden:
227+
228+
kernel-library-overview
229+
kernel-library-custom-aten-kernel
230+
kernel-library-selective-build
231+
```
232+
233+
```{toctree}
234+
:glob:
235+
:maxdepth: 2
236+
:caption: Working with LLMs
237+
:hidden:
238+
239+
Llama <llm/llama>
240+
Llama on Android <llm/llama-demo-android>
241+
Llama on iOS <llm/llama-demo-ios>
242+
Llama on Android via Qualcomm backend <llm/build-run-llama3-qualcomm-ai-engine-direct-backend>
243+
Intro to LLMs in Executorch <llm/getting-started>
244+
```
245+
246+
```{toctree}
247+
:glob:
248+
:maxdepth: 1
249+
:caption: Backend Development
250+
:hidden:
251+
252+
backend-delegates-integration
253+
backend-delegates-xnnpack-reference
254+
backend-delegates-dependencies
255+
compiler-delegate-and-partitioner
256+
debug-backend-delegate
257+
```
258+
259+
```{toctree}
260+
:glob:
261+
:maxdepth: 1
262+
:caption: IR Specification
263+
:hidden:
264+
265+
ir-exir
266+
ir-ops-set-definition
267+
```
268+
269+
```{toctree}
270+
:glob:
271+
:maxdepth: 1
272+
:caption: Compiler Entry Points
273+
:hidden:
274+
275+
compiler-backend-dialect
276+
compiler-custom-compiler-passes
277+
compiler-memory-planning
278+
```
279+
280+
```{toctree}
281+
:glob:
282+
:maxdepth: 1
283+
:caption: Contributing
284+
:hidden:
285+
286+
contributing
287+
```

0 commit comments

Comments
 (0)