-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][EmitC] Do not inline expressions used by ops with the CExpression trait #93691
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
I wonder if we could instead add another restriction to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks :)
I implemented this in the last commit. I don't know which of these solutions is preferable. |
@llvm/pr-subscribers-mlir-emitc @llvm/pr-subscribers-mlir Author: Simon Camphausen (simon-camp) ChangesInlined expression ops where emitted as is irrespective to the context of the op in which they were used. This is corrected by emitting additional parentheses around the expression. These are omitted if it is safe to so in the context of the user operation. Fixes #93470. Full diff: https://github.com/llvm/llvm-project/pull/93691.diff 2 Files Affected:
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index f19e0f8c4c2a4..bbc8aa7c9fe91 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -303,7 +303,11 @@ static bool shouldBeInlined(ExpressionOp expressionOp) {
// Do not inline expressions used by other expressions, as any desired
// expression folding was taken care of by transformations.
- return !user->getParentOfType<ExpressionOp>();
+ if (user->getParentOfType<ExpressionOp>())
+ return false;
+
+ // Do not inline expressions used by ops with the CExpression trait.
+ return !user->hasTrait<OpTrait::emitc::CExpression>();
}
static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
@@ -1338,8 +1342,9 @@ LogicalResult CppEmitter::emitOperand(Value value) {
}
auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
- if (expressionOp && shouldBeInlined(expressionOp))
+ if (expressionOp && shouldBeInlined(expressionOp)) {
return emitExpression(expressionOp);
+ }
auto literalOp = dyn_cast_if_present<LiteralOp>(value.getDefiningOp());
if (!literalOp && !hasValueInScope(value))
diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir
index aaddd5af874a9..810a629c71533 100644
--- a/mlir/test/Target/Cpp/expressions.mlir
+++ b/mlir/test/Target/Cpp/expressions.mlir
@@ -100,6 +100,86 @@ func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -
return %e : i32
}
+// CPP-DEFAULT: int32_t user_with_expression_trait(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
+// CPP-DEFAULT-NEXT: int32_t v4 = 0;
+// CPP-DEFAULT-NEXT: int32_t [[EXP_0:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT: int32_t [[EXP_1:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT: int32_t [[EXP_2:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT: int32_t [[EXP_3:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT: bool v9 = (bool) [[EXP_0]];
+// CPP-DEFAULT-NEXT: int32_t v10 = [[EXP_1]] + v4;
+// CPP-DEFAULT-NEXT: int32_t v11 = bar([[EXP_2]], v4);
+// CPP-DEFAULT-NEXT: int32_t v12 = v9 ? [[EXP_3]] : v4;
+// CPP-DEFAULT-NEXT: int32_t v13;
+// CPP-DEFAULT-NEXT: v13 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT: }
+
+// CPP-DECLTOP: int32_t user_with_expression_trait(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
+// CPP-DECLTOP-NEXT: int32_t v4;
+// CPP-DECLTOP-NEXT: int32_t v5;
+// CPP-DECLTOP-NEXT: int32_t v6;
+// CPP-DECLTOP-NEXT: int32_t v7;
+// CPP-DECLTOP-NEXT: int32_t v8;
+// CPP-DECLTOP-NEXT: bool v9;
+// CPP-DECLTOP-NEXT: int32_t v10;
+// CPP-DECLTOP-NEXT: int32_t v11;
+// CPP-DECLTOP-NEXT: int32_t v12;
+// CPP-DECLTOP-NEXT: int32_t v13;
+// CPP-DECLTOP-NEXT: v4 = 0;
+// CPP-DECLTOP-NEXT: v5 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: v6 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: v7 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: v8 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: v9 = (bool) v5;
+// CPP-DECLTOP-NEXT: v10 = v6 + v4;
+// CPP-DECLTOP-NEXT: v11 = bar(v7, v4);
+// CPP-DECLTOP-NEXT: v12 = v9 ? v8 : v4;
+// CPP-DECLTOP-NEXT: ;
+// CPP-DECLTOP-NEXT: v13 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: }
+func.func @user_with_expression_trait(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
+ %c0 = "emitc.constant"() {value = 0 : i32} : () -> i32
+ %e0 = emitc.expression : i32 {
+ %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+ emitc.yield %1 : i32
+ }
+ %e1 = emitc.expression : i32 {
+ %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+ emitc.yield %1 : i32
+ }
+ %e2 = emitc.expression : i32 {
+ %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+ emitc.yield %1 : i32
+ }
+ %e3 = emitc.expression : i32 {
+ %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+ emitc.yield %1 : i32
+ }
+ %e4 = emitc.expression : i32 {
+ %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+ emitc.yield %1 : i32
+ }
+ %e5 = emitc.expression : i32 {
+ %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+ emitc.yield %1 : i32
+ }
+ %cast = emitc.cast %e0 : i32 to i1
+ %add = emitc.add %e1, %c0 : (i32, i32) -> i32
+ %call = emitc.call_opaque "bar" (%e2, %c0) : (i32, i32) -> (i32)
+ %cond = emitc.conditional %cast, %e3, %c0 : i32
+ %var = "emitc.variable"() {value = #emitc.opaque<"">} : () -> i32
+ emitc.assign %e4 : i32 to %var : i32
+ return %e5 : i32
+}
+
// CPP-DEFAULT: int32_t multiple_uses(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) {
// CPP-DEFAULT-NEXT: bool [[VAL_5:v[0-9]+]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_4]] < [[VAL_2]];
// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]];
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
7b34b21
to
c6817f8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM (with a nit)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking care of this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @simon-camp !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Currently an expression is inlined without emitting enclosing parentheses regardless of the context of the user. This could led to wrong evaluation order depending on the precedence of both expressions. If the inlining is intended, the user operation should be merged into the expression op.
Fixes #93470.