Skip to content

Commit 79b8392

Browse files
committed
refactor(//core/lowering): Make logic a bit clearer in EE pass
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 65dbf90 commit 79b8392

File tree

2 files changed

+21
-22
lines changed

2 files changed

+21
-22
lines changed

core/lowering/passes/exception_elimination.cpp

+18-19
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ struct ExceptionOrPassPatternElimination {
4444
bool arm1_starts_with_exception = (*arm1_start)->kind() == prim::RaiseException;
4545
bool arm2_starts_with_exception = (*arm2_start)->kind() == prim::RaiseException;
4646

47-
if (!arm1_starts_with_exception && !arm2_starts_with_exception) {
47+
//if (!arm1_starts_with_exception && !arm2_starts_with_exception) {
4848
// Neither arm matches the pattern
49-
return false;
50-
}
49+
// return false;
50+
//}
5151

5252
/// Check if this Node hosts a pattern like so:
5353
/// = prim::If(%5958)
@@ -57,14 +57,12 @@ struct ExceptionOrPassPatternElimination {
5757
/// block1():
5858
/// -> ()
5959
if (arm1_starts_with_exception) {
60-
if ((*(++arm1_start))->kind() != prim::Return) {
60+
if ((*(++arm1_start))->kind() == prim::Return) {
6161
// Make sure that block0 is solely just the exception and the return
62-
return false;
63-
}
64-
65-
if ((*(arm2_start))->kind() != prim::Return) {
66-
// Make sure that block1 is solely the return
67-
return false;
62+
if ((*(arm2_start))->kind() == prim::Return) {
63+
// Make sure that block1 is solely the return
64+
return true;
65+
}
6866
}
6967
}
7068

@@ -76,25 +74,23 @@ struct ExceptionOrPassPatternElimination {
7674
/// = prim::RaiseException(%45)
7775
/// -> ()
7876
if (arm2_starts_with_exception) {
79-
if ((*(++arm2_start))->kind() != prim::Return) {
77+
if ((*(++arm2_start))->kind() == prim::Return) {
8078
// Make sure that block1 is solely just the exception and the return
81-
return false;
82-
}
83-
84-
if ((*(arm1_start))->kind() != prim::Return) {
85-
// Make sure that block0 is solely the return
86-
return false;
79+
if ((*(arm1_start))->kind() == prim::Return) {
80+
// Make sure that block0 is solely the return
81+
return true;
82+
}
8783
}
8884
}
8985

90-
return true;
86+
return false;
9187
}
9288

9389
void findExceptionOrPassNodes(Block* b) {
9490
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
9591
auto n = *it;
9692
if (n->kind() == prim::If && isExceptionOrPassNode(n)) {
97-
LOG_GRAPH("Found that node " << *n << " is an exception or pass node (EliminateChecks)" << std::endl);
93+
LOG_ERROR("Found that node " << *n << " is an exception or pass node (EliminateChecks)" << std::endl);
9894
it.destroyCurrent();
9995
}
10096
}
@@ -107,6 +103,9 @@ struct ExceptionOrPassPatternElimination {
107103
void EliminateExceptionOrPassPattern(std::shared_ptr<Graph> graph) {
108104
ExceptionOrPassPatternElimination eppe(std::move(graph));
109105
eppe.run();
106+
if (graph) {
107+
LOG_ERROR("Post Eliminate Exception or Pass Patterns: " << *graph);
108+
}
110109
}
111110

112111
} // namespace passes

tests/core/lowering/test_exception_elimination_pass.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block0) {
4444
auto if_block0 = if_node->addBlock();
4545
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0);
4646
if_block0->appendNode(exception_node);
47-
auto if_block1 = if_node->addBlock();
47+
/*auto if_block1 =*/ if_node->addBlock();
4848
g->insertNode(if_node);
4949
auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val});
5050
g->insertNode(cat_node);
@@ -97,7 +97,7 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block1) {
9797
bool_node->output()->setType(torch::jit::BoolType::get());
9898
g->insertNode(bool_node);
9999
auto if_node = g->create(torch::jit::prim::If, {bool_node->output()}, 0);
100-
auto if_block0 = if_node->addBlock();
100+
/*auto if_block0 = */if_node->addBlock();
101101
auto if_block1 = if_node->addBlock();
102102
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0);
103103
if_block1->appendNode(exception_node);
@@ -154,7 +154,7 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Negative) {
154154
auto if_block0 = if_node->addBlock();
155155
auto append_node = g->create(torch::jit::aten::append, {list_node->output(), y});
156156
if_block0->appendNode(append_node);
157-
auto if_block1 = if_node->addBlock();
157+
/*auto if_block1 = */if_node->addBlock();
158158
g->insertNode(if_node);
159159
auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val});
160160
g->insertNode(cat_node);

0 commit comments

Comments
 (0)