Skip to content

[mlir][emitc] Refactor ArithToEmitC: perform sign adaptation, type conversions / cast insertion in a single place #95789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 26 additions & 52 deletions mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,19 +270,11 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {

bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
Type arithmeticType = type;
if (type.isUnsignedInteger() != needsUnsigned) {
arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
/*isSigned=*/!needsUnsigned);
}
Value lhs = adaptor.getLhs();
Value rhs = adaptor.getRhs();
if (arithmeticType != type) {
lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
lhs);
rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
rhs);
}

Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);

rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
return success();
}
Expand Down Expand Up @@ -328,37 +320,26 @@ class CastConversion : public OpConversionPattern<ArithOp> {
return success();
}

bool isTruncation = operandType.getIntOrFloatBitWidth() >
opReturnType.getIntOrFloatBitWidth();
bool isTruncation =
(isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
operandType.getIntOrFloatBitWidth() >
opReturnType.getIntOrFloatBitWidth());
bool doUnsigned = castToUnsigned || isTruncation;

Type castType = opReturnType;
// If the op is a ui variant and the type wanted as
// return type isn't unsigned, we need to issue an unsigned type to do
// the conversion.
if (castType.isUnsignedInteger() != doUnsigned) {
castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(),
/*isSigned=*/!doUnsigned);
}
// Adapt the signedness of the result (bitwidth-preserving cast)
// This is needed e.g., if the return type is signless.
Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);

Value actualOp = adaptor.getIn();
// Adapt the signedness of the operand if necessary
if (operandType.isUnsignedInteger() != doUnsigned) {
Type correctSignednessType =
rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
/*isSigned=*/!doUnsigned);
actualOp = rewriter.template create<emitc::CastOp>(
op.getLoc(), correctSignednessType, actualOp);
}
// Adapt the signedness of the operand (bitwidth-preserving cast)
Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);

auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType,
actualOp);
// Actual cast (may change bitwidth)
auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
castDestType, actualOp);

// Cast to the expected output type
if (castType != opReturnType) {
result = rewriter.template create<emitc::CastOp>(op.getLoc(),
opReturnType, result);
}
auto result = adaptValueType(cast, rewriter, opReturnType);

rewriter.replaceOp(op, result);
return success();
Expand Down Expand Up @@ -410,8 +391,6 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
}

Value lhs = adaptor.getLhs();
Value rhs = adaptor.getRhs();
Type arithmeticType = type;
if ((type.isSignlessInteger() || type.isSignedInteger()) &&
!bitEnumContainsAll(op.getOverflowFlags(),
Expand All @@ -421,20 +400,15 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
/*isSigned=*/false);
}
if (arithmeticType != type) {
lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
lhs);
rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
rhs);
}

Value result = rewriter.template create<EmitCOp>(op.getLoc(),
arithmeticType, lhs, rhs);
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);

Value arithmeticResult = rewriter.template create<EmitCOp>(
op.getLoc(), arithmeticType, lhs, rhs);

Value result = adaptValueType(arithmeticResult, rewriter, type);

if (arithmeticType != type) {
result =
rewriter.template create<emitc::CastOp>(op.getLoc(), type, result);
}
rewriter.replaceOp(op, result);
return success();
}
Expand Down
7 changes: 7 additions & 0 deletions mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,13 @@ func.func @arith_trunci(%arg0: i32) -> i8 {
// CHECK: emitc.cast %[[Trunc]] : ui8 to i8
%truncd = arith.trunci %arg0 : i32 to i8

// CHECK: %[[Const:.*]] = "emitc.constant"
// CHECK-SAME: value = 1
// CHECK-SAME: () -> i32
// CHECK: %[[AndOne:.*]] = emitc.bitwise_and %[[Arg0]], %[[Const]] : (i32, i32) -> i32
// CHECK: %[[Conv:.*]] = emitc.cast %[[AndOne]] : i32 to i1
%bool = arith.trunci %arg0 : i32 to i1

return %truncd : i8
}

Expand Down
Loading