Skip to content

Commit 6f7627f

Browse files
fix: dependency order of inserted long input casts (#1833)
1 parent a245b86 commit 6f7627f

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed

core/lowering/lowering.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ int AutocastLongInputs(
3131
ir::TypeMap input_type_map,
3232
std::string target_device_name) {
3333
int num_autocasts = 0;
34+
auto old_insert_point = g->insertPoint();
35+
g->setInsertPoint(g->nodes().front());
3436
// For each graph input, determine if it can be autocasted
3537
for (size_t i = 0; i < g->inputs().size(); i++) {
3638
auto input = g->inputs()[i];
@@ -71,7 +73,7 @@ int AutocastLongInputs(
7173
auto cast_node = g->create(torch::jit::aten::to, {input, cuda, const_type, const_false, const_false, none_val});
7274

7375
// Replace all uses of the original tensor with that of the casted tensor
74-
g->prependNode(cast_node);
76+
g->insertNode(cast_node);
7577
input->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);
7678

7779
// Mark the cast node to run in PyTorch for ease of casting
@@ -80,7 +82,7 @@ int AutocastLongInputs(
8082
num_autocasts++;
8183
}
8284
}
83-
85+
g->setInsertPoint(old_insert_point);
8486
LOG_GRAPH("Inserted " << num_autocasts << " autocasts");
8587

8688
if (num_autocasts > 0) {

tests/core/lowering/BUILD

+5
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ cc_test(
2727
}),
2828
)
2929

30+
lowering_test(
31+
name = "test_autocast_long_inputs",
32+
)
33+
3034
lowering_test(
3135
name = "test_conv_pass",
3236
)
@@ -102,6 +106,7 @@ lowering_test(
102106
test_suite(
103107
name = "lowering_tests",
104108
tests = [
109+
":test_autocast_long_inputs",
105110
":test_conv_pass",
106111
":test_device_casting",
107112
":test_exception_elimination_pass",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/csrc/jit/ir/subgraph_matcher.h"
8+
9+
TEST(LoweringPasses, AutocastLongInputs) {
10+
std::string source_graph = R"IR(
11+
graph(%long_0 : Tensor, %long_1 : Tensor):
12+
%res : Tensor = aten::add(%long_0, %long_1)
13+
return (%res))IR";
14+
std::string target_graph = R"IR(
15+
graph(%long_0 : Tensor, %long_1 : Tensor):
16+
%3 : bool = prim::Constant[value=0]()
17+
%4 : Device = prim::Constant[value="cuda:0"]()
18+
%5 : NoneType = prim::Constant()
19+
%6 : int = prim::Constant[value=4]()
20+
%7 : Tensor = aten::to[to_compile=0](%long_0, %4, %6, %3, %3, %5)
21+
%8 : int = prim::Constant[value=4]()
22+
%9 : Tensor = aten::to[to_compile=0](%long_1, %4, %8, %3, %3, %5)
23+
%2 : Tensor = aten::add(%7, %9)
24+
return (%2))IR";
25+
26+
auto sg = std::make_shared<torch::jit::Graph>();
27+
torch::jit::parseIR(source_graph, &*sg);
28+
std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>> type_map;
29+
type_map[sg->inputs()[0]] = at::kLong;
30+
type_map[sg->inputs()[1]] = at::kLong;
31+
torch_tensorrt::core::lowering::AutocastLongInputs(sg, type_map, "cuda:0");
32+
auto tg = std::make_shared<torch::jit::Graph>();
33+
torch::jit::parseIR(target_graph, &*tg);
34+
ASSERT_TRUE(sg->nodes().front()->kind() == torch::jit::prim::Constant); // confirm constants are added before casts
35+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
36+
}

0 commit comments

Comments
 (0)