Skip to content

Commit ec0fcee

Browse files
benkli01facebook-github-bot
authored andcommitted
Minor fixes around the Arm testing framework (#5976)
Summary: Pull Request resolved: #5976 Reviewed By: mergennachin Differential Revision: D64047755 Pulled By: digantdesai fbshipit-source-id: 91d32e771526eff2cb9863a04b0e18a2b40173e0
1 parent e904e56 commit ec0fcee

File tree

3 files changed

+39
-34
lines changed

3 files changed

+39
-34
lines changed

backends/arm/test/runner_utils.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -457,9 +457,10 @@ def prep_data_for_save(
457457
data_np = np.array(data.detach(), order="C").astype(np.float32)
458458

459459
if is_quantized:
460-
assert (
461-
quant_param.node_name in input_name
462-
), "These quantization params do not match the input tensor name"
460+
assert quant_param.node_name in input_name, (
461+
f"The quantization params name '{quant_param.node_name}' does not "
462+
f"match the input tensor name '{input_name}'."
463+
)
463464
data_np = (
464465
((data_np / np.float32(quant_param.scale)) + quant_param.zp)
465466
.round()

backends/arm/test/tester/arm_tester.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def __init__(
150150
model: torch.nn.Module,
151151
example_inputs: Tuple[torch.Tensor],
152152
compile_spec: List[CompileSpec] = None,
153+
tosa_ref_model_path: str | None = None,
153154
):
154155
"""
155156
Args:
@@ -160,7 +161,10 @@ def __init__(
160161

161162
# Initiate runner_util
162163
intermediate_path = get_intermediate_path(compile_spec)
163-
self.runner_util = RunnerUtil(intermediate_path=intermediate_path)
164+
self.runner_util = RunnerUtil(
165+
intermediate_path=intermediate_path,
166+
tosa_ref_model_path=tosa_ref_model_path,
167+
)
164168

165169
self.compile_spec = compile_spec
166170
super().__init__(model, example_inputs)

backends/arm/tosa_mapping.py

+30-30
Original file line numberDiff line numberDiff line change
@@ -15,37 +15,37 @@
1515
import torch
1616

1717

18+
UNSUPPORTED_DTYPES = (
19+
torch.float64,
20+
torch.double,
21+
torch.complex64,
22+
torch.cfloat,
23+
torch.complex128,
24+
torch.cdouble,
25+
torch.uint8,
26+
torch.int64,
27+
torch.long,
28+
)
29+
30+
DTYPE_MAP = {
31+
torch.float32: ts.DType.FP32,
32+
torch.float: ts.DType.FP32,
33+
torch.float16: ts.DType.FP16,
34+
torch.half: ts.DType.FP16,
35+
torch.bfloat16: ts.DType.BF16,
36+
torch.int8: ts.DType.INT8,
37+
torch.int16: ts.DType.INT16,
38+
torch.short: ts.DType.INT16,
39+
torch.int32: ts.DType.INT32,
40+
torch.int: ts.DType.INT32,
41+
torch.bool: ts.DType.BOOL,
42+
}
43+
44+
1845
def map_dtype(data_type):
19-
unsupported = (
20-
torch.float64,
21-
torch.double,
22-
torch.complex64,
23-
torch.cfloat,
24-
torch.complex128,
25-
torch.cdouble,
26-
torch.uint8,
27-
torch.int64,
28-
torch.long,
29-
)
30-
31-
dmap = {
32-
torch.float32: ts.DType.FP32,
33-
torch.float: ts.DType.FP32,
34-
torch.float16: ts.DType.FP16,
35-
torch.half: ts.DType.FP16,
36-
torch.bfloat16: ts.DType.BF16,
37-
torch.int8: ts.DType.INT8,
38-
torch.int16: ts.DType.INT16,
39-
torch.short: ts.DType.INT16,
40-
torch.int32: ts.DType.INT32,
41-
torch.int: ts.DType.INT32,
42-
torch.bool: ts.DType.BOOL,
43-
}
44-
45-
assert unsupported.count(data_type) == 0, "Unsupported type"
46-
rtype = dmap.get(data_type)
47-
assert rtype is not None, "Unknown type"
48-
return rtype
46+
assert data_type not in UNSUPPORTED_DTYPES, f"Unsupported type: {data_type}"
47+
assert data_type in DTYPE_MAP, f"Unknown type: {data_type}"
48+
return DTYPE_MAP[data_type]
4949

5050

5151
# Returns the shape and type of a node

0 commit comments

Comments
 (0)