Skip to content

[mlir][EmitC] Do not inline expressions used by ops with the CExpression trait #93691

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 4, 2024

Conversation

simon-camp
Copy link
Contributor

@simon-camp simon-camp commented May 29, 2024

Currently an expression is inlined without emitting enclosing parentheses regardless of the context of the user. This could led to wrong evaluation order depending on the precedence of both expressions. If the inlining is intended, the user operation should be merged into the expression op.

Fixes #93470.

Copy link
Contributor

@mgehre-amd mgehre-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@aniragil
Copy link
Contributor

I wonder if we could instead add another restriction to shouldBeInlined: Don't inline if the user has the CExpression trait.
The rationale is that the using CExpression could have been fused into its operand ExpressionOp. If it wasn't, then the pass(es) that formed the expressions did not want it to be fused. Inlining the ExpressionOp operand into it goes against that intention.
This would also make sure that we handle expression emission in one place, i.e. ExpressionOps, where cast should already be emitted correctly with parenthesis as needed.

Copy link
Contributor

@cferry-AMD cferry-AMD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks :)

@simon-camp
Copy link
Contributor Author

I wonder if we could instead add another restriction to shouldBeInlined: Don't inline if the user has the CExpression trait. The rationale is that the using CExpression could have been fused into its operand ExpressionOp. If it wasn't, then the pass(es) that formed the expressions did not want it to be fused. Inlining the ExpressionOp operand into it goes against that intention. This would also make sure that we handle expression emission in one place, i.e. ExpressionOps, where cast should already be emitted correctly with parenthesis as needed.

I implemented this in the last commit. I don't know which of these solutions is preferable.

@simon-camp simon-camp marked this pull request as ready for review June 3, 2024 11:57
@llvmbot
Copy link
Member

llvmbot commented Jun 3, 2024

@llvm/pr-subscribers-mlir-emitc

@llvm/pr-subscribers-mlir

Author: Simon Camphausen (simon-camp)

Changes

Inlined expression ops where emitted as is irrespective to the context of the op in which they were used. This is corrected by emitting additional parentheses around the expression. These are omitted if it is safe to so in the context of the user operation.

Fixes #93470.


Full diff: https://github.com/llvm/llvm-project/pull/93691.diff

2 Files Affected:

  • (modified) mlir/lib/Target/Cpp/TranslateToCpp.cpp (+7-2)
  • (modified) mlir/test/Target/Cpp/expressions.mlir (+80)
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index f19e0f8c4c2a4..bbc8aa7c9fe91 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -303,7 +303,11 @@ static bool shouldBeInlined(ExpressionOp expressionOp) {
 
   // Do not inline expressions used by other expressions, as any desired
   // expression folding was taken care of by transformations.
-  return !user->getParentOfType<ExpressionOp>();
+  if (user->getParentOfType<ExpressionOp>())
+    return false;
+  
+  // Do not inline expressions used by ops with the CExpression trait.
+  return !user->hasTrait<OpTrait::emitc::CExpression>();
 }
 
 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
@@ -1338,8 +1342,9 @@ LogicalResult CppEmitter::emitOperand(Value value) {
   }
 
   auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
-  if (expressionOp && shouldBeInlined(expressionOp))
+  if (expressionOp && shouldBeInlined(expressionOp)) {
     return emitExpression(expressionOp);
+  }
 
   auto literalOp = dyn_cast_if_present<LiteralOp>(value.getDefiningOp());
   if (!literalOp && !hasValueInScope(value))
diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir
index aaddd5af874a9..810a629c71533 100644
--- a/mlir/test/Target/Cpp/expressions.mlir
+++ b/mlir/test/Target/Cpp/expressions.mlir
@@ -100,6 +100,86 @@ func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -
   return %e : i32
 }
 
+// CPP-DEFAULT:      int32_t user_with_expression_trait(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
+// CPP-DEFAULT-NEXT:   int32_t v4 = 0;
+// CPP-DEFAULT-NEXT:   int32_t [[EXP_0:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT:   int32_t [[EXP_1:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT:   int32_t [[EXP_2:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT:   int32_t [[EXP_3:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT:   bool v9 = (bool) [[EXP_0]];
+// CPP-DEFAULT-NEXT:   int32_t v10 = [[EXP_1]] + v4;
+// CPP-DEFAULT-NEXT:   int32_t v11 = bar([[EXP_2]], v4);
+// CPP-DEFAULT-NEXT:   int32_t v12 = v9 ? [[EXP_3]] : v4;
+// CPP-DEFAULT-NEXT:   int32_t v13;
+// CPP-DEFAULT-NEXT:   v13 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT:   return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT: }
+
+// CPP-DECLTOP:      int32_t user_with_expression_trait(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
+// CPP-DECLTOP-NEXT:   int32_t v4;
+// CPP-DECLTOP-NEXT:   int32_t v5;
+// CPP-DECLTOP-NEXT:   int32_t v6;
+// CPP-DECLTOP-NEXT:   int32_t v7;
+// CPP-DECLTOP-NEXT:   int32_t v8;
+// CPP-DECLTOP-NEXT:   bool v9;
+// CPP-DECLTOP-NEXT:   int32_t v10;
+// CPP-DECLTOP-NEXT:   int32_t v11;
+// CPP-DECLTOP-NEXT:   int32_t v12;
+// CPP-DECLTOP-NEXT:   int32_t v13;
+// CPP-DECLTOP-NEXT:   v4 = 0;
+// CPP-DECLTOP-NEXT:   v5 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT:   v6 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT:   v7 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT:   v8 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT:   v9 = (bool) v5;
+// CPP-DECLTOP-NEXT:   v10 = v6 + v4;
+// CPP-DECLTOP-NEXT:   v11 = bar(v7, v4);
+// CPP-DECLTOP-NEXT:   v12 = v9 ? v8 : v4;
+// CPP-DECLTOP-NEXT:   ;
+// CPP-DECLTOP-NEXT:   v13 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT:   return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: }
+func.func @user_with_expression_trait(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
+  %c0 = "emitc.constant"() {value = 0 : i32} : () -> i32
+  %e0 = emitc.expression : i32 {
+      %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+      %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+      emitc.yield %1 : i32
+    }
+  %e1 = emitc.expression : i32 {
+      %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+      %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+      emitc.yield %1 : i32
+    }
+  %e2 = emitc.expression : i32 {
+      %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+      %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+      emitc.yield %1 : i32
+    }
+  %e3 = emitc.expression : i32 {
+      %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+      %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+      emitc.yield %1 : i32
+    }
+  %e4 = emitc.expression : i32 {
+      %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+      %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+      emitc.yield %1 : i32
+    }
+  %e5 = emitc.expression : i32 {
+      %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+      %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+      emitc.yield %1 : i32
+    }
+  %cast = emitc.cast %e0 : i32 to i1
+  %add = emitc.add %e1, %c0 : (i32, i32) -> i32
+  %call = emitc.call_opaque "bar" (%e2, %c0) : (i32, i32) -> (i32)
+  %cond = emitc.conditional %cast, %e3, %c0 : i32
+  %var = "emitc.variable"() {value = #emitc.opaque<"">} : () -> i32
+  emitc.assign %e4 : i32 to %var : i32
+  return %e5 : i32
+}
+
 // CPP-DEFAULT:      int32_t multiple_uses(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) {
 // CPP-DEFAULT-NEXT:   bool [[VAL_5:v[0-9]+]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_4]] < [[VAL_2]];
 // CPP-DEFAULT-NEXT:   int32_t [[VAL_6:v[0-9]+]];

Copy link

github-actions bot commented Jun 3, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@simon-camp simon-camp force-pushed the emitc.expression-parenthesis branch from 7b34b21 to c6817f8 Compare June 3, 2024 12:03
Copy link
Contributor

@aniragil aniragil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (with a nit)

Copy link
Contributor

@TinaAMD TinaAMD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking care of this!

@simon-camp simon-camp changed the title [mlir][EmitC] Emit parentheses when expression ops are used as operands [mlir][EmitC] Do not inline expressions used by ops with the CExpression trait Jun 4, 2024
@simon-camp simon-camp requested review from aniragil and TinaAMD June 4, 2024 08:05
@simon-camp simon-camp requested a review from aniragil June 4, 2024 10:07
Copy link
Contributor

@aniragil aniragil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @simon-camp !

Copy link
Contributor

@TinaAMD TinaAMD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@simon-camp simon-camp merged commit a934ddc into llvm:main Jun 4, 2024
7 checks passed
@simon-camp simon-camp deleted the emitc.expression-parenthesis branch June 4, 2024 11:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir][emitc] Parenthesization issue in TranslateToCpp
6 participants