Skip to content

Commit 2fff01a

Browse files
martinlsmmansnilsfreddan80
authored
Arm backend: Add Arm model test for Wav2letter (#8594)
Add Arm model test for Wav2letter Signed-off-by: Martin Lindström <[email protected]> Co-authored-by: Måns Nilsson <[email protected]> Co-authored-by: Fredrik Knutsson <[email protected]>
1 parent c35df8b commit 2fff01a

File tree

1 file changed

+150
-0
lines changed

1 file changed

+150
-0
lines changed
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import logging
9+
import unittest
10+
from typing import Tuple
11+
12+
import pytest
13+
14+
import torch
15+
from executorch.backends.arm.test import common, conftest
16+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
17+
18+
from executorch.exir.backend.compile_spec_schema import CompileSpec
19+
from torchaudio import models
20+
21+
22+
logger = logging.getLogger(__name__)
23+
logger.setLevel(logging.INFO)
24+
25+
26+
def get_test_inputs(batch_size, num_features, input_frames):
27+
return (torch.randn(batch_size, num_features, input_frames),)
28+
29+
30+
class TestW2L(unittest.TestCase):
31+
"""Tests Wav2Letter."""
32+
33+
batch_size = 10
34+
input_frames = 400
35+
num_features = 1
36+
37+
w2l = models.Wav2Letter(num_features=num_features).eval()
38+
model_example_inputs = get_test_inputs(batch_size, num_features, input_frames)
39+
40+
all_operators = {
41+
"executorch_exir_dialects_edge__ops_aten_convolution_default",
42+
"executorch_exir_dialects_edge__ops_aten__log_softmax_default",
43+
"executorch_exir_dialects_edge__ops_aten_relu_default",
44+
}
45+
46+
operators_after_quantization = all_operators - {
47+
"executorch_exir_dialects_edge__ops_aten__log_softmax_default",
48+
}
49+
50+
@pytest.mark.slow # about 3min on std laptop
51+
def test_w2l_tosa_MI(self):
52+
(
53+
ArmTester(
54+
self.w2l,
55+
example_inputs=self.model_example_inputs,
56+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
57+
)
58+
.export()
59+
.dump_operator_distribution()
60+
.to_edge_transform_and_lower()
61+
.dump_operator_distribution()
62+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
63+
.to_executorch()
64+
.run_method_and_compare_outputs(
65+
inputs=get_test_inputs(
66+
self.batch_size, self.num_features, self.input_frames
67+
)
68+
)
69+
)
70+
71+
@pytest.mark.slow # about 1min on std laptop
72+
def test_w2l_tosa_BI(self):
73+
(
74+
ArmTester(
75+
self.w2l,
76+
example_inputs=self.model_example_inputs,
77+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
78+
)
79+
.quantize()
80+
.export()
81+
.dump_operator_distribution()
82+
.to_edge_transform_and_lower()
83+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
84+
.to_executorch()
85+
.run_method_and_compare_outputs(
86+
atol=0.1,
87+
qtol=1,
88+
inputs=get_test_inputs(
89+
self.batch_size, self.num_features, self.input_frames
90+
),
91+
)
92+
)
93+
94+
def _test_w2l_ethos_BI_pipeline(
95+
self,
96+
module: torch.nn.Module,
97+
test_data: Tuple[torch.Tensor],
98+
compile_spec: CompileSpec,
99+
):
100+
tester = (
101+
ArmTester(module, example_inputs=test_data, compile_spec=compile_spec)
102+
.quantize()
103+
.export()
104+
.to_edge()
105+
.check(list(self.operators_after_quantization))
106+
.partition()
107+
.to_executorch()
108+
.serialize()
109+
)
110+
return tester
111+
112+
# TODO: expected fail as TOSA.Transpose is not supported by Ethos-U55
113+
@pytest.mark.slow
114+
@pytest.mark.corstone_fvp
115+
@conftest.expectedFailureOnFVP
116+
def test_w2l_u55_BI(self):
117+
tester = self._test_w2l_ethos_BI_pipeline(
118+
self.w2l,
119+
self.model_example_inputs,
120+
common.get_u55_compile_spec(),
121+
)
122+
123+
if conftest.is_option_enabled("corstone_fvp"):
124+
tester.run_method_and_compare_outputs(
125+
atol=1.0,
126+
qtol=1,
127+
inputs=get_test_inputs(
128+
self.batch_size, self.num_features, self.input_frames
129+
),
130+
)
131+
132+
@pytest.mark.slow
133+
@pytest.mark.corstone_fvp
134+
@unittest.skip("Blocked by MLBEDSW-10420")
135+
@conftest.expectedFailureOnFVP # TODO: MLBEDSW-10093
136+
def test_w2l_u85_BI(self):
137+
tester = self._test_w2l_ethos_BI_pipeline(
138+
self.w2l,
139+
self.model_example_inputs,
140+
common.get_u85_compile_spec(),
141+
)
142+
143+
if conftest.is_option_enabled("corstone_fvp"):
144+
tester.run_method_and_compare_outputs(
145+
atol=1.0,
146+
qtol=1,
147+
inputs=get_test_inputs(
148+
self.batch_size, self.num_features, self.input_frames
149+
),
150+
)

0 commit comments

Comments
 (0)