Skip to content

Commit 10b9ecd

Browse files
authored
Merge pull request #1329 from pytorch/python_tests
refactor(//tests) : Refactor the test suite
2 parents 99db0cd + af20761 commit 10b9ecd

35 files changed

+713
-335
lines changed

.circleci/config.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@ commands:
435435
mkdir -p /tmp/artifacts/test_results
436436
cd tests/py
437437
pytest --junitxml=/tmp/artifacts/test_results/api/api_test_results.xml api/
438+
pytest --junitxml=/tmp/artifacts/test_results/models/models_test_results.xml models/
438439
pytest --junitxml=/tmp/artifacts/test_results/integrations/integrations_test_results.xml integrations/
439440
cd ~/project
440441

.github/workflows/docgen.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
- name: Set up Python 3.9.4
3232
uses: actions/setup-python@v2
3333
with:
34-
python-version: 3.9.4
34+
python-version: 3.9.4
3535
- uses: actions/checkout@v2
3636
with:
3737
ref: ${{github.head_ref}}

.github/workflows/linter.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ jobs:
3939
pip3 install -r $GITHUB_WORKSPACE/.github/scripts/requirements.txt
4040
pip3 install -r $GITHUB_WORKSPACE/requirements-dev.txt
4141
- name: Lint C++
42-
run: |
42+
run: |
4343
cd $GITHUB_WORKSPACE
4444
python3 $GITHUB_WORKSPACE/.github/scripts/run_cpp_linter.py
4545
env:

noxfile.py

Lines changed: 39 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
if USE_HOST_DEPS:
3131
print("Using dependencies from host python")
3232

33+
# Set epochs to train VGG model for accuracy tests
34+
EPOCHS = 25
35+
3336
SUPPORTED_PYTHON_VERSIONS = ["3.7", "3.8", "3.9", "3.10"]
3437

3538
nox.options.sessions = [
@@ -63,31 +66,6 @@ def install_torch_trt(session):
6366
session.run("python", "setup.py", "develop")
6467

6568

66-
def download_datasets(session):
67-
print(
68-
"Downloading dataset to path",
69-
os.path.join(TOP_DIR, "examples/int8/training/vgg16"),
70-
)
71-
session.chdir(os.path.join(TOP_DIR, "examples/int8/training/vgg16"))
72-
session.run_always(
73-
"wget", "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz", external=True
74-
)
75-
session.run_always("tar", "-xvzf", "cifar-10-binary.tar.gz", external=True)
76-
session.run_always(
77-
"mkdir",
78-
"-p",
79-
os.path.join(TOP_DIR, "tests/accuracy/datasets/data"),
80-
external=True,
81-
)
82-
session.run_always(
83-
"cp",
84-
"-rpf",
85-
os.path.join(TOP_DIR, "examples/int8/training/vgg16/cifar-10-batches-bin"),
86-
os.path.join(TOP_DIR, "tests/accuracy/datasets/data/cidar-10-batches-bin"),
87-
external=True,
88-
)
89-
90-
9169
def train_model(session):
9270
session.chdir(os.path.join(TOP_DIR, "examples/int8/training/vgg16"))
9371
session.install("-r", "requirements.txt")
@@ -107,14 +85,14 @@ def train_model(session):
10785
"--ckpt-dir",
10886
"vgg16_ckpts",
10987
"--epochs",
110-
"25",
88+
str(EPOCHS),
11189
env={"PYTHONPATH": PYT_PATH},
11290
)
11391

11492
session.run_always(
11593
"python",
11694
"export_ckpt.py",
117-
"vgg16_ckpts/ckpt_epoch25.pth",
95+
"vgg16_ckpts/ckpt_epoch" + str(EPOCHS) + ".pth",
11896
env={"PYTHONPATH": PYT_PATH},
11997
)
12098
else:
@@ -130,10 +108,12 @@ def train_model(session):
130108
"--ckpt-dir",
131109
"vgg16_ckpts",
132110
"--epochs",
133-
"25",
111+
str(EPOCHS),
134112
)
135113

136-
session.run_always("python", "export_ckpt.py", "vgg16_ckpts/ckpt_epoch25.pth")
114+
session.run_always(
115+
"python", "export_ckpt.py", "vgg16_ckpts/ckpt_epoch" + str(EPOCHS) + ".pth"
116+
)
137117

138118

139119
def finetune_model(session):
@@ -156,17 +136,17 @@ def finetune_model(session):
156136
"--ckpt-dir",
157137
"vgg16_ckpts",
158138
"--start-from",
159-
"25",
139+
str(EPOCHS),
160140
"--epochs",
161-
"26",
141+
str(EPOCHS + 1),
162142
env={"PYTHONPATH": PYT_PATH},
163143
)
164144

165145
# Export model
166146
session.run_always(
167147
"python",
168148
"export_qat.py",
169-
"vgg16_ckpts/ckpt_epoch26.pth",
149+
"vgg16_ckpts/ckpt_epoch" + str(EPOCHS + 1) + ".pth",
170150
env={"PYTHONPATH": PYT_PATH},
171151
)
172152
else:
@@ -182,13 +162,17 @@ def finetune_model(session):
182162
"--ckpt-dir",
183163
"vgg16_ckpts",
184164
"--start-from",
185-
"25",
165+
str(EPOCHS),
186166
"--epochs",
187-
"26",
167+
str(EPOCHS + 1),
188168
)
189169

190170
# Export model
191-
session.run_always("python", "export_qat.py", "vgg16_ckpts/ckpt_epoch26.pth")
171+
session.run_always(
172+
"python",
173+
"export_qat.py",
174+
"vgg16_ckpts/ckpt_epoch" + str(EPOCHS + 1) + ".pth",
175+
)
192176

193177

194178
def cleanup(session):
@@ -219,6 +203,19 @@ def run_base_tests(session):
219203
session.run_always("pytest", test)
220204

221205

206+
def run_model_tests(session):
207+
print("Running model tests")
208+
session.chdir(os.path.join(TOP_DIR, "tests/py"))
209+
tests = [
210+
"models",
211+
]
212+
for test in tests:
213+
if USE_HOST_DEPS:
214+
session.run_always("pytest", test, env={"PYTHONPATH": PYT_PATH})
215+
else:
216+
session.run_always("pytest", test)
217+
218+
222219
def run_accuracy_tests(session):
223220
print("Running accuracy tests")
224221
session.chdir(os.path.join(TOP_DIR, "tests/py"))
@@ -282,7 +279,7 @@ def run_dla_tests(session):
282279
print("Running DLA tests")
283280
session.chdir(os.path.join(TOP_DIR, "tests/py"))
284281
tests = [
285-
"test_api_dla.py",
282+
"hw/test_api_dla.py",
286283
]
287284
for test in tests:
288285
if USE_HOST_DEPS:
@@ -322,21 +319,19 @@ def run_l0_dla_tests(session):
322319
cleanup(session)
323320

324321

325-
def run_l1_accuracy_tests(session):
322+
def run_l1_model_tests(session):
326323
if not USE_HOST_DEPS:
327324
install_deps(session)
328325
install_torch_trt(session)
329-
download_datasets(session)
330-
train_model(session)
331-
run_accuracy_tests(session)
326+
download_models(session)
327+
run_model_tests(session)
332328
cleanup(session)
333329

334330

335331
def run_l1_int8_accuracy_tests(session):
336332
if not USE_HOST_DEPS:
337333
install_deps(session)
338334
install_torch_trt(session)
339-
download_datasets(session)
340335
train_model(session)
341336
finetune_model(session)
342337
run_int8_accuracy_tests(session)
@@ -347,9 +342,6 @@ def run_l2_trt_compatibility_tests(session):
347342
if not USE_HOST_DEPS:
348343
install_deps(session)
349344
install_torch_trt(session)
350-
download_models(session)
351-
download_datasets(session)
352-
train_model(session)
353345
run_trt_compatibility_tests(session)
354346
cleanup(session)
355347

@@ -376,9 +368,9 @@ def l0_dla_tests(session):
376368

377369

378370
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
379-
def l1_accuracy_tests(session):
380-
"""Checking accuracy performance on various usecases"""
381-
run_l1_accuracy_tests(session)
371+
def l1_model_tests(session):
372+
"""When a user needs to test the functionality of standard models compilation and results"""
373+
run_l1_model_tests(session)
382374

383375

384376
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
@@ -397,13 +389,3 @@ def l2_trt_compatibility_tests(session):
397389
def l2_multi_gpu_tests(session):
398390
"""Makes sure that Torch-TensorRT can operate on multi-gpu systems"""
399391
run_l2_multi_gpu_tests(session)
400-
401-
402-
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
403-
def download_test_models(session):
404-
"""Grab all the models needed for testing"""
405-
try:
406-
import torch
407-
except ModuleNotFoundError:
408-
install_deps(session)
409-
download_models(session)

py/torch_tensorrt/ptq.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ def write_calibration_cache(self, cache):
5656
return b""
5757

5858

59+
# deepcopy (which involves pickling) is performed on the compile_spec internally during compilation.
60+
# We register this __reduce__ function for pickler to identity the calibrator object returned by DataLoaderCalibrator during deepcopy.
61+
# This should be the object's local name relative to the module https://docs.python.org/3/library/pickle.html#object.__reduce__
62+
def __reduce__(self):
63+
return self.__class__.__name__
64+
65+
5966
class DataLoaderCalibrator(object):
6067
"""
6168
Constructs a calibrator class in TensorRT and uses pytorch dataloader to load/preproces
@@ -114,24 +121,27 @@ def __new__(cls, *args, **kwargs):
114121
"get_batch": get_cache_mode_batch if use_cache else get_batch,
115122
"read_calibration_cache": read_calibration_cache,
116123
"write_calibration_cache": write_calibration_cache,
124+
"__reduce__": __reduce__, # used when you deepcopy the DataLoaderCalibrator object
117125
}
118126

119127
# Using type metaclass to construct calibrator class based on algorithm type
120128
if algo_type == CalibrationAlgo.ENTROPY_CALIBRATION:
121129
return type(
122-
"DataLoaderCalibrator", (_C.IInt8EntropyCalibrator,), attribute_mapping
130+
"Int8EntropyCalibrator", (_C.IInt8EntropyCalibrator,), attribute_mapping
123131
)()
124132
elif algo_type == CalibrationAlgo.ENTROPY_CALIBRATION_2:
125133
return type(
126-
"DataLoaderCalibrator", (_C.IInt8MinMaxCalibrator,), attribute_mapping
134+
"Int8EntropyCalibrator2",
135+
(_C.IInt8EntropyCalibrator2,),
136+
attribute_mapping,
127137
)()
128138
elif algo_type == CalibrationAlgo.LEGACY_CALIBRATION:
129139
return type(
130-
"DataLoaderCalibrator", (_C.IInt8LegacyCalibrator,), attribute_mapping
140+
"Int8LegacyCalibrator", (_C.IInt8LegacyCalibrator,), attribute_mapping
131141
)()
132142
elif algo_type == CalibrationAlgo.MINMAX_CALIBRATION:
133143
return type(
134-
"DataLoaderCalibrator", (_C.IInt8MinMaxCalibrator,), attribute_mapping
144+
"Int8MinMaxCalibrator", (_C.IInt8MinMaxCalibrator,), attribute_mapping
135145
)()
136146
else:
137147
log(

py/torch_tensorrt/ts/_compile_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def _parse_input_signature(input_signature: Any):
225225

226226

227227
def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec:
228-
# TODO: Remove deep copy once collections does not need partial compilation
228+
# TODO: Use deepcopy to support partial compilation of collections
229229
compile_spec = deepcopy(compile_spec_)
230230
info = _ts_C.CompileSpec()
231231

@@ -301,7 +301,7 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec:
301301
compile_spec["enabled_precisions"]
302302
)
303303

304-
if "calibrator" in compile_spec:
304+
if "calibrator" in compile_spec and compile_spec["calibrator"]:
305305
info.ptq_calibrator = compile_spec["calibrator"]
306306

307307
if "sparse_weights" in compile_spec:

tests/core/lowering/test_module_fallback_passes.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,5 +124,5 @@ TEST(Lowering, LowerAndPartitionSimpleModuleFallbackCorrectly) {
124124
}
125125

126126
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
127-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results, trt_results, 2e-6));
127+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results, trt_results, 0.99));
128128
}

tests/core/partitioning/test_fallback_graph_output.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ TEST(Partitioning, ComputeResNet50FallbackGraphCorrectly) {
3434
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
3535
auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
3636
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
37-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results, trt_results, 2e-6));
37+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results, trt_results, 0.99));
3838
}
3939

4040
TEST(Partitioning, ComputeMobileNetFallbackGraphCorrectly) {
@@ -64,6 +64,6 @@ TEST(Partitioning, ComputeMobileNetFallbackGraphCorrectly) {
6464
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
6565
auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
6666
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
67-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results, trt_results, 2e-6));
67+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results, trt_results, 0.99));
6868
}
6969
#endif

tests/cpp/test_collections.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ TEST(CppAPITests, TestCollectionStandardTensorInput) {
4242
auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings);
4343
auto trt_out = trt_mod.forward(inputs_);
4444

45-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTensor(), trt_out.toTensor(), 1e-5));
45+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(out.toTensor(), trt_out.toTensor(), 0.99));
4646
}
4747

4848
TEST(CppAPITests, TestCollectionTupleInput) {
@@ -85,7 +85,7 @@ TEST(CppAPITests, TestCollectionTupleInput) {
8585
auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings);
8686
auto trt_out = trt_mod.forward(complex_inputs);
8787

88-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTensor(), trt_out.toTensor(), 1e-5));
88+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(out.toTensor(), trt_out.toTensor(), 0.99));
8989
}
9090

9191
TEST(CppAPITests, TestCollectionListInput) {
@@ -144,7 +144,7 @@ TEST(CppAPITests, TestCollectionListInput) {
144144
LOG_DEBUG("Finish compile");
145145
auto trt_out = trt_mod.forward(complex_inputs);
146146

147-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTensor(), trt_out.toTensor(), 1e-5));
147+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(out.toTensor(), trt_out.toTensor(), 0.99));
148148
}
149149

150150
TEST(CppAPITests, TestCollectionTupleInputOutput) {
@@ -317,4 +317,4 @@ TEST(CppAPITests, TestCollectionComplexModel) {
317317
out.toTuple()->elements()[0].toTensor(), trt_out.toTuple()->elements()[0].toTensor(), 1e-5));
318318
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(
319319
out.toTuple()->elements()[1].toTensor(), trt_out.toTuple()->elements()[1].toTensor(), 1e-5));
320-
}
320+
}

tests/cpp/test_compiled_modules.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ TEST_P(CppAPITests, CompiledModuleIsClose) {
4242

4343
for (size_t i = 0; i < trt_results.size(); i++) {
4444
ASSERT_TRUE(
45-
torch_tensorrt::tests::util::almostEqual(jit_results[i], trt_results[i].reshape_as(jit_results[i]), threshold));
45+
torch_tensorrt::tests::util::cosineSimEqual(jit_results[i], trt_results[i].reshape_as(jit_results[i]), 0.99));
4646
}
4747
}
4848

@@ -52,11 +52,7 @@ INSTANTIATE_TEST_SUITE_P(
5252
CompiledModuleForwardIsCloseSuite,
5353
CppAPITests,
5454
testing::Values(
55-
PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
56-
PathAndInput({"tests/modules/resnet50_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
57-
PathAndInput({"tests/modules/mobilenet_v2_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
5855
PathAndInput({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
59-
PathAndInput({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
6056
PathAndInput({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}),
6157
PathAndInput({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 8e-3}),
6258
PathAndInput({"tests/modules/bert_base_uncased_traced.jit.pt", {{1, 14}, {1, 14}}, {at::kInt, at::kInt}, 8e-2}),

tests/cpp/test_module_fallback.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ TEST(CppAPITest, ResNetModuleFallbacksCorrectly) {
3030
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
3131
auto trt_mod = torch_tensorrt::ts::compile(mod, cfg);
3232
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
33-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results, trt_results, 2e-6));
33+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results, trt_results, 0.99));
3434
}
3535

3636
TEST(CppAPITest, MobileNetModuleFallbacksCorrectlyWithOneEngine) {
@@ -69,6 +69,6 @@ TEST(CppAPITest, MobileNetModuleFallbacksCorrectlyWithOneEngine) {
6969
ASSERT_TRUE(trt_count == 1);
7070

7171
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
72-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results, trt_results, 2e-6));
72+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results, trt_results, 0.99));
7373
}
7474
#endif

0 commit comments

Comments
 (0)