Skip to content

Commit f90cb21

Browse files
committed
Merge branch 'master' into resolve_inputs_fix
2 parents d345407 + 79b8392 commit f90cb21

File tree

4 files changed

+40
-28
lines changed

4 files changed

+40
-28
lines changed

core/conversion/evaluators/aten.cpp

+9-2
Original file line numberDiff line numberDiff line change
@@ -342,15 +342,22 @@ auto aten_registrations TORCHTRT_UNUSED =
342342
auto a = args.at(n->input(0)).unwrapToDouble();
343343
auto b = args.at(n->input(1)).unwrapToDouble();
344344
return a + b;
345+
} else if (args.at(n->input(0)).IValue()->isString()) {
346+
auto a = args.at(n->input(0)).unwrapToString();
347+
auto b = args.at(n->input(1)).unwrapToString();
348+
return a + b;
345349
} else {
346350
TORCHTRT_THROW_ERROR(
347351
"Unimplemented data type for aten::add evaluator: "
348352
<< args.at(n->input(0)).IValue()->type()->str());
349353
return {};
350354
}
351355
},
352-
EvalOptions().validSchemas(
353-
{"aten::add.int(int a, int b) -> (int)", "aten::add.float(float a, float b) -> (float)"})})
356+
EvalOptions().validSchemas({
357+
"aten::add.int(int a, int b) -> (int)",
358+
"aten::add.float(float a, float b) -> (float)",
359+
"aten::add.str(str a, str b) -> (str)"
360+
})})
354361
.evaluator({c10::Symbol::fromQualString("aten::add_"),
355362
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
356363
if (args.at(n->input(0)).IValue()->isList()) {

core/lowering/passes/exception_elimination.cpp

+18-19
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ struct ExceptionOrPassPatternElimination {
4444
bool arm1_starts_with_exception = (*arm1_start)->kind() == prim::RaiseException;
4545
bool arm2_starts_with_exception = (*arm2_start)->kind() == prim::RaiseException;
4646

47-
if (!arm1_starts_with_exception && !arm2_starts_with_exception) {
47+
//if (!arm1_starts_with_exception && !arm2_starts_with_exception) {
4848
// Neither arm matches the pattern
49-
return false;
50-
}
49+
// return false;
50+
//}
5151

5252
/// Check if this Node hosts a pattern like so:
5353
/// = prim::If(%5958)
@@ -57,14 +57,12 @@ struct ExceptionOrPassPatternElimination {
5757
/// block1():
5858
/// -> ()
5959
if (arm1_starts_with_exception) {
60-
if ((*(++arm1_start))->kind() != prim::Return) {
60+
if ((*(++arm1_start))->kind() == prim::Return) {
6161
// Make sure that block0 is solely just the exception and the return
62-
return false;
63-
}
64-
65-
if ((*(arm2_start))->kind() != prim::Return) {
66-
// Make sure that block1 is solely the return
67-
return false;
62+
if ((*(arm2_start))->kind() == prim::Return) {
63+
// Make sure that block1 is solely the return
64+
return true;
65+
}
6866
}
6967
}
7068

@@ -76,25 +74,23 @@ struct ExceptionOrPassPatternElimination {
7674
/// = prim::RaiseException(%45)
7775
/// -> ()
7876
if (arm2_starts_with_exception) {
79-
if ((*(++arm2_start))->kind() != prim::Return) {
77+
if ((*(++arm2_start))->kind() == prim::Return) {
8078
// Make sure that block1 is solely just the exception and the return
81-
return false;
82-
}
83-
84-
if ((*(arm1_start))->kind() != prim::Return) {
85-
// Make sure that block0 is solely the return
86-
return false;
79+
if ((*(arm1_start))->kind() == prim::Return) {
80+
// Make sure that block0 is solely the return
81+
return true;
82+
}
8783
}
8884
}
8985

90-
return true;
86+
return false;
9187
}
9288

9389
void findExceptionOrPassNodes(Block* b) {
9490
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
9591
auto n = *it;
9692
if (n->kind() == prim::If && isExceptionOrPassNode(n)) {
97-
LOG_GRAPH("Found that node " << *n << " is an exception or pass node (EliminateChecks)" << std::endl);
93+
LOG_ERROR("Found that node " << *n << " is an exception or pass node (EliminateChecks)" << std::endl);
9894
it.destroyCurrent();
9995
}
10096
}
@@ -107,6 +103,9 @@ struct ExceptionOrPassPatternElimination {
107103
void EliminateExceptionOrPassPattern(std::shared_ptr<Graph> graph) {
108104
ExceptionOrPassPatternElimination eppe(std::move(graph));
109105
eppe.run();
106+
if (graph) {
107+
LOG_ERROR("Post Eliminate Exception or Pass Patterns: " << *graph);
108+
}
110109
}
111110

112111
} // namespace passes

tests/accuracy/test_fp16_accuracy.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) {
2525
}
2626
torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100;
2727

28-
std::vector<std::vector<int64_t>> input_shape = {{32, 3, 32, 32}};
29-
auto compile_spec = torch_tensorrt::ts::CompileSpec({input_shape});
28+
std::vector<int64_t> input_shape = {32, 3, 32, 32};
29+
auto input = torch_tensorrt::Input(input_shape);
30+
input.dtype = torch::kF16;
31+
auto compile_spec = torch_tensorrt::ts::CompileSpec({input});
3032
compile_spec.enabled_precisions.insert(torch::kF16);
3133

3234
auto trt_mod = torch_tensorrt::ts::compile(mod, compile_spec);

tests/core/lowering/test_exception_elimination_pass.cpp

+9-5
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,17 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block0) {
4242
g->insertNode(bool_node);
4343
auto if_node = g->create(torch::jit::prim::If, {bool_node->output()}, 0);
4444
auto if_block0 = if_node->addBlock();
45-
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val}, 0);
45+
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0);
4646
if_block0->appendNode(exception_node);
47-
auto if_block1 = if_node->addBlock();
47+
/*auto if_block1 =*/ if_node->addBlock();
4848
g->insertNode(if_node);
4949
auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val});
5050
g->insertNode(cat_node);
5151
g->registerOutput(cat_node->output());
5252

53+
std::cout << "Source Graph: " << *g << std::endl;
5354
torch_tensorrt::core::lowering::passes::EliminateExceptionOrPassPattern(g);
55+
std::cout << "Modified Graph: " << *g << std::endl;
5456
for (auto node : g->nodes()) {
5557
EXPECT_NE(node, if_node);
5658
}
@@ -95,16 +97,18 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block1) {
9597
bool_node->output()->setType(torch::jit::BoolType::get());
9698
g->insertNode(bool_node);
9799
auto if_node = g->create(torch::jit::prim::If, {bool_node->output()}, 0);
98-
auto if_block0 = if_node->addBlock();
100+
/*auto if_block0 = */if_node->addBlock();
99101
auto if_block1 = if_node->addBlock();
100-
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val}, 0);
102+
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0);
101103
if_block1->appendNode(exception_node);
102104
g->insertNode(if_node);
103105
auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val});
104106
g->insertNode(cat_node);
105107
g->registerOutput(cat_node->output());
106108

109+
std::cout << "Source Graph: " << *g << std::endl;
107110
torch_tensorrt::core::lowering::passes::EliminateExceptionOrPassPattern(g);
111+
std::cout << "Modified Graph: " << *g << std::endl;
108112
for (auto node : g->nodes()) {
109113
EXPECT_NE(node, if_node);
110114
}
@@ -150,7 +154,7 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Negative) {
150154
auto if_block0 = if_node->addBlock();
151155
auto append_node = g->create(torch::jit::aten::append, {list_node->output(), y});
152156
if_block0->appendNode(append_node);
153-
auto if_block1 = if_node->addBlock();
157+
/*auto if_block1 = */if_node->addBlock();
154158
g->insertNode(if_node);
155159
auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val});
156160
g->insertNode(cat_node);

0 commit comments

Comments
 (0)