Skip to content

Commit fd77339

Browse files
committed
[ExecuTorch] Arm Ethos: Do not depend on torch.testing._internal
This can cuase issues with `disable_global_flags` and internal state of the library, this is something which is set when importing this Differential Revision: [D70402061](https://our.internmc.facebook.com/intern/diff/D70402061/) ghstack-source-id: 269356031 Pull Request resolved: #8893
1 parent 542480c commit fd77339

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

backends/arm/test/passes/test_rescale_pass.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from executorch.backends.arm.test import common, conftest
1414
from executorch.backends.arm.test.tester.arm_tester import ArmTester
1515
from parameterized import parameterized
16-
from torch.testing._internal import optests
1716

1817

1918
def test_rescale_op():
@@ -64,7 +63,7 @@ def test_nonzero_zp_for_int32():
6463
),
6564
]
6665
for sample_input in sample_inputs:
67-
with pytest.raises(optests.generate_tests.OpCheckError):
66+
with pytest.raises(Exception, match="opcheck"):
6867
torch.library.opcheck(torch.ops.tosa._rescale, sample_input)
6968

7069

@@ -87,7 +86,7 @@ def test_zp_outside_range():
8786
),
8887
]
8988
for sample_input in sample_inputs:
90-
with pytest.raises(optests.generate_tests.OpCheckError):
89+
with pytest.raises(Exception, match="opcheck"):
9190
torch.library.opcheck(torch.ops.tosa._rescale, sample_input)
9291

9392

backends/arm/test/runner_utils.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,33 @@
3434
from torch.fx.node import Node
3535

3636
from torch.overrides import TorchFunctionMode
37-
from torch.testing._internal.common_utils import torch_to_numpy_dtype_dict
3837
from tosa import TosaGraph
3938

4039
logger = logging.getLogger(__name__)
4140
logger.setLevel(logging.CRITICAL)
4241

42+
# Copied from PyTorch.
43+
# From torch/testing/_internal/common_utils.py:torch_to_numpy_dtype_dict
44+
# To avoid a dependency on _internal stuff.
45+
_torch_to_numpy_dtype_dict = {
46+
torch.bool: np.bool_,
47+
torch.uint8: np.uint8,
48+
torch.uint16: np.uint16,
49+
torch.uint32: np.uint32,
50+
torch.uint64: np.uint64,
51+
torch.int8: np.int8,
52+
torch.int16: np.int16,
53+
torch.int32: np.int32,
54+
torch.int64: np.int64,
55+
torch.float16: np.float16,
56+
torch.float32: np.float32,
57+
torch.float64: np.float64,
58+
torch.bfloat16: np.float32,
59+
torch.complex32: np.complex64,
60+
torch.complex64: np.complex64,
61+
torch.complex128: np.complex128,
62+
}
63+
4364

4465
class QuantizationParams:
4566
__slots__ = ["node_name", "zp", "scale", "qmin", "qmax", "dtype"]
@@ -335,7 +356,7 @@ def run_corstone(
335356
output_dtype = node.meta["val"].dtype
336357
tosa_ref_output = np.fromfile(
337358
os.path.join(intermediate_path, f"out-{i}.bin"),
338-
torch_to_numpy_dtype_dict[output_dtype],
359+
_torch_to_numpy_dtype_dict[output_dtype],
339360
)
340361

341362
output_np.append(torch.from_numpy(tosa_ref_output).reshape(output_shape))
@@ -349,7 +370,7 @@ def prep_data_for_save(
349370
):
350371
if isinstance(data, torch.Tensor):
351372
data_np = np.array(data.detach(), order="C").astype(
352-
torch_to_numpy_dtype_dict[data.dtype]
373+
_torch_to_numpy_dtype_dict[data.dtype]
353374
)
354375
else:
355376
data_np = np.array(data)

0 commit comments

Comments
 (0)