Skip to content

[mlir][emitc] Refactor ArithToEmitC: perform sign adaptation, type conversions / cast insertion in a single place #95789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 19, 2024

Conversation

cferry-AMD
Copy link
Contributor

@cferry-AMD cferry-AMD commented Jun 17, 2024

Factor EmitC type signedness adaptation and cast operations in ArithToEmitC using adaptValueType and adaptIntegralTypeSignedness.

@cferry-AMD cferry-AMD marked this pull request as ready for review June 17, 2024 14:05
@llvmbot
Copy link
Member

llvmbot commented Jun 17, 2024

@llvm/pr-subscribers-mlir-emitc

@llvm/pr-subscribers-mlir

Author: Corentin Ferry (cferry-AMD)

Changes

This PR lays the ground for a next PR that will use the newly introduced EmitC types.

When certain values have to be interpreted as signed/unsigned (e.g. for ops like extsi and extui), type conversions in EmitC may be required (from int to unsigned int for instance), and such conversions are not needed in arith ops that only take signless values as operands. Lowering of such ops now uses adaptIntegralTypeSignedness (to give out the type with the right signedness) and adaptValueType (to insert an explicit cast of a value to a given EmitC type). A folder is added to emitc::CastOp so that any cast from self to self will fold away, so the source and destination types don't need to be checked for equality.

While doing this refactoring, emerged the need for one more trunci case (to i1, which is well-defined and uses a special case of CastOpConversion).


Full diff: https://github.com/llvm/llvm-project/pull/95789.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/EmitC/IR/EmitC.td (+1)
  • (modified) mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp (+26-52)
  • (modified) mlir/lib/Dialect/EmitC/IR/EmitC.cpp (+6)
  • (modified) mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir (+7)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 452302c565139..25d1983ec583b 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -288,6 +288,7 @@ def EmitC_CastOp : EmitC_Op<"cast",
   let arguments = (ins EmitCType:$source);
   let results = (outs EmitCType:$dest);
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
+  let hasFolder = 1;
 }
 
 def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> {
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 74f0f61d04a1a..9214bc5b2c13e 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -270,19 +270,11 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
 
     bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
     emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
-    Type arithmeticType = type;
-    if (type.isUnsignedInteger() != needsUnsigned) {
-      arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
-                                               /*isSigned=*/!needsUnsigned);
-    }
-    Value lhs = adaptor.getLhs();
-    Value rhs = adaptor.getRhs();
-    if (arithmeticType != type) {
-      lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
-                                                    lhs);
-      rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
-                                                    rhs);
-    }
+
+    Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
+    Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
+    Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
+
     rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
     return success();
   }
@@ -328,37 +320,26 @@ class CastConversion : public OpConversionPattern<ArithOp> {
       return success();
     }
 
-    bool isTruncation = operandType.getIntOrFloatBitWidth() >
-                        opReturnType.getIntOrFloatBitWidth();
+    bool isTruncation =
+        (isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
+         operandType.getIntOrFloatBitWidth() >
+             opReturnType.getIntOrFloatBitWidth());
     bool doUnsigned = castToUnsigned || isTruncation;
 
-    Type castType = opReturnType;
-    // If the op is a ui variant and the type wanted as
-    // return type isn't unsigned, we need to issue an unsigned type to do
-    // the conversion.
-    if (castType.isUnsignedInteger() != doUnsigned) {
-      castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(),
-                                         /*isSigned=*/!doUnsigned);
-    }
+    // Adapt the signedness of the result (bitwidth-preserving cast)
+    // This is needed e.g., if the return type is signless.
+    Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);
 
-    Value actualOp = adaptor.getIn();
-    // Adapt the signedness of the operand if necessary
-    if (operandType.isUnsignedInteger() != doUnsigned) {
-      Type correctSignednessType =
-          rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
-                                  /*isSigned=*/!doUnsigned);
-      actualOp = rewriter.template create<emitc::CastOp>(
-          op.getLoc(), correctSignednessType, actualOp);
-    }
+    // Adapt the signedness of the operand (bitwidth-preserving cast)
+    Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
+    Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
 
-    auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType,
-                                                          actualOp);
+    // Actual cast (may change bitwidth)
+    auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
+                                                        castDestType, actualOp);
 
     // Cast to the expected output type
-    if (castType != opReturnType) {
-      result = rewriter.template create<emitc::CastOp>(op.getLoc(),
-                                                       opReturnType, result);
-    }
+    auto result = adaptValueType(cast, rewriter, opReturnType);
 
     rewriter.replaceOp(op, result);
     return success();
@@ -410,8 +391,6 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
       return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
     }
 
-    Value lhs = adaptor.getLhs();
-    Value rhs = adaptor.getRhs();
     Type arithmeticType = type;
     if ((type.isSignlessInteger() || type.isSignedInteger()) &&
         !bitEnumContainsAll(op.getOverflowFlags(),
@@ -421,20 +400,15 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
       arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
                                                /*isSigned=*/false);
     }
-    if (arithmeticType != type) {
-      lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
-                                                    lhs);
-      rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
-                                                    rhs);
-    }
 
-    Value result = rewriter.template create<EmitCOp>(op.getLoc(),
-                                                     arithmeticType, lhs, rhs);
+    Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
+    Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
+
+    Value arithmeticResult = rewriter.template create<EmitCOp>(
+        op.getLoc(), arithmeticType, lhs, rhs);
+
+    Value result = adaptValueType(arithmeticResult, rewriter, type);
 
-    if (arithmeticType != type) {
-      result =
-          rewriter.template create<emitc::CastOp>(op.getLoc(), type, result);
-    }
     rewriter.replaceOp(op, result);
     return success();
   }
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index b2556bb6065d8..c3c9b4e6a1d3e 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -241,6 +241,12 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
        emitc::isSupportedFloatType(output) || isa<emitc::PointerType>(output)));
 }
 
+OpFoldResult emitc::CastOp::fold(FoldAdaptor adaptor) {
+  if (getOperand().getType() == getResult().getType())
+    return getOperand();
+  return nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // CallOpaqueOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 71f1a6abd913b..607e5bf9b1a3b 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -466,6 +466,13 @@ func.func @arith_trunci(%arg0: i32) -> i8 {
   // CHECK: emitc.cast %[[Trunc]] : ui8 to i8
   %truncd = arith.trunci %arg0 : i32 to i8
 
+  // CHECK: %[[Const:.*]] = "emitc.constant"
+  // CHECK-SAME: value = 1
+  // CHECK-SAME: () -> i32
+  // CHECK: %[[AndOne:.*]] = emitc.bitwise_and %[[Arg0]], %[[Const]] : (i32, i32) -> i32
+  // CHECK: %[[Conv:.*]] = emitc.cast %[[AndOne]] : i32 to i1
+  %bool = arith.trunci %arg0 : i32 to i1
+
   return %truncd : i8
 }
 

@cferry-AMD
Copy link
Contributor Author

Thanks @simon-camp. Let's merge this refactor with a single review, we'll need to spend more time discussing #95795 so further updates can land in there as well.

@cferry-AMD cferry-AMD merged commit 519175c into llvm:main Jun 19, 2024
7 checks passed
@cferry-AMD cferry-AMD deleted the corentin.upstream_emitc_refactor branch June 19, 2024 07:19
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
…nversions / cast insertion in a single place (llvm#95789)

Factor EmitC type signedness adaptation and cast operations in ArithToEmitC using adaptValueType and adaptIntegralTypeSignedness.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants