Skip to content

Commit e07687d

Browse files
authored
Merge pull request #1148 from pytorch/fix_parsing
fix: fix the parsing related model loading bug
2 parents 5cb5947 + e7c359d commit e07687d

File tree

3 files changed

+61
-2
lines changed

3 files changed

+61
-2
lines changed

core/compiler.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,11 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
428428
auto graph_and_mapping =
429429
ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params, fallback_nodes);
430430
new_g = graph_and_mapping.first;
431-
LOG_INFO("Segmented Graph: " << *new_g);
431+
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
432+
for (size_t i = 0; i < new_g->inputs().size(); ++i) {
433+
new_g->inputs()[i]->setDebugName(std::string("input_") + std::to_string(i));
434+
}
435+
LOG_INFO(*new_g << "(GraphAfterFallback)");
432436

433437
// if there is no tensorrt engine self in fallback graph, there is no conversion, we just return the initial
434438
// module

tests/core/partitioning/BUILD

+17-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,21 @@ partitioning_test(
3737
name = "test_resolve_nontensor_inputs",
3838
)
3939

40+
cc_test(
41+
name = "test_loading_model",
42+
srcs = ["test_loading_model.cpp"],
43+
deps = [
44+
"//tests/util",
45+
"@googletest//:gtest_main",
46+
] + select({
47+
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
48+
"//conditions:default": ["@libtorch//:libtorch"],
49+
}),
50+
data = [
51+
":jit_models"
52+
]
53+
)
54+
4055
cc_test(
4156
name = "test_fallback_graph_output",
4257
srcs = ["test_fallback_graph_output.cpp"],
@@ -92,6 +107,7 @@ test_suite(
92107
":test_fallback_graph_output",
93108
":test_loop_fallback",
94109
":test_conditionals",
95-
":test_resolve_nontensor_inputs"
110+
":test_resolve_nontensor_inputs",
111+
":test_loading_model"
96112
]
97113
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include <string>
2+
#include <unordered_set>
3+
#include "core/compiler.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/script.h"
7+
8+
#ifndef DISABLE_TEST_IN_CI
9+
10+
TEST(Partitioning, ComputeResNet50FallbackGraphCorrectly) {
11+
torch::jit::script::Module mod;
12+
try {
13+
mod = torch::jit::load("tests/modules/conditional_scripted.jit.pt");
14+
} catch (const c10::Error& e) {
15+
std::cerr << "error loading the model\n";
16+
return;
17+
}
18+
19+
const std::vector<std::vector<int64_t>> input_shapes = {{1, 3, 224, 224}};
20+
std::vector<torch::jit::IValue> jit_inputs_ivalues;
21+
std::vector<torch::jit::IValue> trt_inputs_ivalues;
22+
for (auto in_shape : input_shapes) {
23+
auto in = at::randint(5, in_shape, {at::kCUDA});
24+
jit_inputs_ivalues.push_back(in.clone());
25+
trt_inputs_ivalues.push_back(in.clone());
26+
}
27+
28+
std::vector<torch_tensorrt::core::ir::Input> input_ranges{torch_tensorrt::core::ir::Input({1, 3, 224, 224})};
29+
30+
torch_tensorrt::core::CompileSpec cfg(input_ranges);
31+
cfg.partition_info.enabled = true;
32+
33+
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
34+
auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
35+
trt_mod.save("loading_model.ts");
36+
auto loaded_model = torch::jit::load("loading_model.ts");
37+
}
38+
39+
#endif

0 commit comments

Comments
 (0)