Skip to content

Commit c6817f8

Browse files
author
Simon Camphausen
committed
Do not inline expressions into ops with the CExpression trait
1 parent aa4b1bf commit c6817f8

File tree

2 files changed

+35
-28
lines changed

2 files changed

+35
-28
lines changed

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,12 @@ static bool shouldBeInlined(ExpressionOp expressionOp) {
303303

304304
// Do not inline expressions used by other expressions, as any desired
305305
// expression folding was taken care of by transformations.
306-
return !user->getParentOfType<ExpressionOp>();
306+
if (user->getParentOfType<ExpressionOp>())
307+
return false;
308+
309+
// Do not inline expressions used by ops with the CExpression trait. If this
310+
// was intended, the user could have been merged into the expression op.
311+
return !user->hasTrait<OpTrait::emitc::CExpression>();
307312
}
308313

309314
static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
@@ -1339,17 +1344,7 @@ LogicalResult CppEmitter::emitOperand(Value value) {
13391344

13401345
auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
13411346
if (expressionOp && shouldBeInlined(expressionOp)) {
1342-
Operation *user = *expressionOp->getUsers().begin();
1343-
const bool safeToSkipParentheses =
1344-
isa<emitc::AssignOp, emitc::CallOp, emitc::CallOpaqueOp, emitc::ForOp,
1345-
emitc::IfOp, emitc::ReturnOp, func::CallOp, func::ReturnOp>(user);
1346-
if (!safeToSkipParentheses)
1347-
os << "(";
1348-
if (failed(emitExpression(expressionOp)))
1349-
return failure();
1350-
if (!safeToSkipParentheses)
1351-
os << ")";
1352-
return success();
1347+
return emitExpression(expressionOp);
13531348
}
13541349

13551350
auto literalOp = dyn_cast_if_present<LiteralOp>(value.getDefiningOp());

mlir/test/Target/Cpp/expressions.mlir

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -100,34 +100,46 @@ func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -
100100
return %e : i32
101101
}
102102

103-
// CPP-DEFAULT: int32_t parentheses_for_expression_users(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
103+
// CPP-DEFAULT: int32_t user_with_expression_trait(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
104104
// CPP-DEFAULT-NEXT: int32_t v4 = 0;
105-
// CPP-DEFAULT-NEXT: bool v5 = (bool) ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
106-
// CPP-DEFAULT-NEXT: int32_t v6 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) + v4;
107-
// CPP-DEFAULT-NEXT: int32_t v7 = bar([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]), v4);
108-
// CPP-DEFAULT-NEXT: int32_t v8 = v5 ? ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) : v4;
109-
// CPP-DEFAULT-NEXT: int32_t v9;
110-
// CPP-DEFAULT-NEXT: v9 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
105+
// CPP-DEFAULT-NEXT: int32_t [[EXP_0:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
106+
// CPP-DEFAULT-NEXT: int32_t [[EXP_1:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
107+
// CPP-DEFAULT-NEXT: int32_t [[EXP_2:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
108+
// CPP-DEFAULT-NEXT: int32_t [[EXP_3:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
109+
// CPP-DEFAULT-NEXT: bool v9 = (bool) [[EXP_0]];
110+
// CPP-DEFAULT-NEXT: int32_t v10 = [[EXP_1]] + v4;
111+
// CPP-DEFAULT-NEXT: int32_t v11 = bar([[EXP_2]], v4);
112+
// CPP-DEFAULT-NEXT: int32_t v12 = v9 ? [[EXP_3]] : v4;
113+
// CPP-DEFAULT-NEXT: int32_t v13;
114+
// CPP-DEFAULT-NEXT: v13 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
111115
// CPP-DEFAULT-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
112116
// CPP-DEFAULT-NEXT: }
113117

114-
// CPP-DECLTOP: int32_t parentheses_for_expression_users(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
118+
// CPP-DECLTOP: int32_t user_with_expression_trait(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
115119
// CPP-DECLTOP-NEXT: int32_t v4;
116-
// CPP-DECLTOP-NEXT: bool v5;
120+
// CPP-DECLTOP-NEXT: int32_t v5;
117121
// CPP-DECLTOP-NEXT: int32_t v6;
118122
// CPP-DECLTOP-NEXT: int32_t v7;
119123
// CPP-DECLTOP-NEXT: int32_t v8;
120-
// CPP-DECLTOP-NEXT: int32_t v9;
124+
// CPP-DECLTOP-NEXT: bool v9;
125+
// CPP-DECLTOP-NEXT: int32_t v10;
126+
// CPP-DECLTOP-NEXT: int32_t v11;
127+
// CPP-DECLTOP-NEXT: int32_t v12;
128+
// CPP-DECLTOP-NEXT: int32_t v13;
121129
// CPP-DECLTOP-NEXT: v4 = 0;
122-
// CPP-DECLTOP-NEXT: v5 = (bool) ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
123-
// CPP-DECLTOP-NEXT: v6 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) + v4;
124-
// CPP-DECLTOP-NEXT: v7 = bar([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]), v4);
125-
// CPP-DECLTOP-NEXT: v8 = v5 ? ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) : v4;
130+
// CPP-DECLTOP-NEXT: v5 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
131+
// CPP-DECLTOP-NEXT: v6 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
132+
// CPP-DECLTOP-NEXT: v7 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
133+
// CPP-DECLTOP-NEXT: v8 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
134+
// CPP-DECLTOP-NEXT: v9 = (bool) v5;
135+
// CPP-DECLTOP-NEXT: v10 = v6 + v4;
136+
// CPP-DECLTOP-NEXT: v11 = bar(v7, v4);
137+
// CPP-DECLTOP-NEXT: v12 = v9 ? v8 : v4;
126138
// CPP-DECLTOP-NEXT: ;
127-
// CPP-DECLTOP-NEXT: v9 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
139+
// CPP-DECLTOP-NEXT: v13 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
128140
// CPP-DECLTOP-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
129141
// CPP-DECLTOP-NEXT: }
130-
func.func @parentheses_for_expression_users(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
142+
func.func @user_with_expression_trait(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
131143
%c0 = "emitc.constant"() {value = 0 : i32} : () -> i32
132144
%e0 = emitc.expression : i32 {
133145
%0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32

0 commit comments

Comments
 (0)