-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][emitc] Refactor ArithToEmitC: perform sign adaptation, type conversions / cast insertion in a single place #95789
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][emitc] Refactor ArithToEmitC: perform sign adaptation, type conversions / cast insertion in a single place #95789
Conversation
@llvm/pr-subscribers-mlir-emitc @llvm/pr-subscribers-mlir Author: Corentin Ferry (cferry-AMD) ChangesThis PR lays the ground for a next PR that will use the newly introduced EmitC types. When certain values have to be interpreted as signed/unsigned (e.g. for ops like While doing this refactoring, emerged the need for one more Full diff: https://github.com/llvm/llvm-project/pull/95789.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 452302c565139..25d1983ec583b 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -288,6 +288,7 @@ def EmitC_CastOp : EmitC_Op<"cast",
let arguments = (ins EmitCType:$source);
let results = (outs EmitCType:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
+ let hasFolder = 1;
}
def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> {
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 74f0f61d04a1a..9214bc5b2c13e 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -270,19 +270,11 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
- Type arithmeticType = type;
- if (type.isUnsignedInteger() != needsUnsigned) {
- arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
- /*isSigned=*/!needsUnsigned);
- }
- Value lhs = adaptor.getLhs();
- Value rhs = adaptor.getRhs();
- if (arithmeticType != type) {
- lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
- lhs);
- rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
- rhs);
- }
+
+ Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
+ Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
+ Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
+
rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
return success();
}
@@ -328,37 +320,26 @@ class CastConversion : public OpConversionPattern<ArithOp> {
return success();
}
- bool isTruncation = operandType.getIntOrFloatBitWidth() >
- opReturnType.getIntOrFloatBitWidth();
+ bool isTruncation =
+ (isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
+ operandType.getIntOrFloatBitWidth() >
+ opReturnType.getIntOrFloatBitWidth());
bool doUnsigned = castToUnsigned || isTruncation;
- Type castType = opReturnType;
- // If the op is a ui variant and the type wanted as
- // return type isn't unsigned, we need to issue an unsigned type to do
- // the conversion.
- if (castType.isUnsignedInteger() != doUnsigned) {
- castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(),
- /*isSigned=*/!doUnsigned);
- }
+ // Adapt the signedness of the result (bitwidth-preserving cast)
+ // This is needed e.g., if the return type is signless.
+ Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);
- Value actualOp = adaptor.getIn();
- // Adapt the signedness of the operand if necessary
- if (operandType.isUnsignedInteger() != doUnsigned) {
- Type correctSignednessType =
- rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
- /*isSigned=*/!doUnsigned);
- actualOp = rewriter.template create<emitc::CastOp>(
- op.getLoc(), correctSignednessType, actualOp);
- }
+ // Adapt the signedness of the operand (bitwidth-preserving cast)
+ Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
+ Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
- auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType,
- actualOp);
+ // Actual cast (may change bitwidth)
+ auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
+ castDestType, actualOp);
// Cast to the expected output type
- if (castType != opReturnType) {
- result = rewriter.template create<emitc::CastOp>(op.getLoc(),
- opReturnType, result);
- }
+ auto result = adaptValueType(cast, rewriter, opReturnType);
rewriter.replaceOp(op, result);
return success();
@@ -410,8 +391,6 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
}
- Value lhs = adaptor.getLhs();
- Value rhs = adaptor.getRhs();
Type arithmeticType = type;
if ((type.isSignlessInteger() || type.isSignedInteger()) &&
!bitEnumContainsAll(op.getOverflowFlags(),
@@ -421,20 +400,15 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
/*isSigned=*/false);
}
- if (arithmeticType != type) {
- lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
- lhs);
- rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
- rhs);
- }
- Value result = rewriter.template create<EmitCOp>(op.getLoc(),
- arithmeticType, lhs, rhs);
+ Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
+ Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
+
+ Value arithmeticResult = rewriter.template create<EmitCOp>(
+ op.getLoc(), arithmeticType, lhs, rhs);
+
+ Value result = adaptValueType(arithmeticResult, rewriter, type);
- if (arithmeticType != type) {
- result =
- rewriter.template create<emitc::CastOp>(op.getLoc(), type, result);
- }
rewriter.replaceOp(op, result);
return success();
}
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index b2556bb6065d8..c3c9b4e6a1d3e 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -241,6 +241,12 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
emitc::isSupportedFloatType(output) || isa<emitc::PointerType>(output)));
}
+OpFoldResult emitc::CastOp::fold(FoldAdaptor adaptor) {
+ if (getOperand().getType() == getResult().getType())
+ return getOperand();
+ return nullptr;
+}
+
//===----------------------------------------------------------------------===//
// CallOpaqueOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 71f1a6abd913b..607e5bf9b1a3b 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -466,6 +466,13 @@ func.func @arith_trunci(%arg0: i32) -> i8 {
// CHECK: emitc.cast %[[Trunc]] : ui8 to i8
%truncd = arith.trunci %arg0 : i32 to i8
+ // CHECK: %[[Const:.*]] = "emitc.constant"
+ // CHECK-SAME: value = 1
+ // CHECK-SAME: () -> i32
+ // CHECK: %[[AndOne:.*]] = emitc.bitwise_and %[[Arg0]], %[[Const]] : (i32, i32) -> i32
+ // CHECK: %[[Conv:.*]] = emitc.cast %[[AndOne]] : i32 to i1
+ %bool = arith.trunci %arg0 : i32 to i1
+
return %truncd : i8
}
|
Thanks @simon-camp. Let's merge this refactor with a single review, we'll need to spend more time discussing #95795 so further updates can land in there as well. |
…nversions / cast insertion in a single place (llvm#95789) Factor EmitC type signedness adaptation and cast operations in ArithToEmitC using adaptValueType and adaptIntegralTypeSignedness.
Factor EmitC type signedness adaptation and cast operations in ArithToEmitC using adaptValueType and adaptIntegralTypeSignedness.