Skip to content

[mlir][emitc] Lower arith.index_cast, arith.index_castui, arith.shli, arith.shrui, arith.shrsi #95795

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 125 additions & 12 deletions mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
Expand All @@ -36,8 +37,11 @@ class ArithConstantOpConversionPattern
matchAndRewrite(arith::ConstantOp arithConst,
arith::ConstantOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<emitc::ConstantOp>(
arithConst, arithConst.getType(), adaptor.getValue());
Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
if (!newTy)
return rewriter.notifyMatchFailure(arithConst, "type conversion failed");
rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy,
adaptor.getValue());
return success();
}
};
Expand All @@ -52,6 +56,12 @@ Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
return IntegerType::get(ty.getContext(), ty.getIntOrFloatBitWidth(),
signedness);
}
} else if (emitc::isPointerWideType(ty)) {
if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
if (needsUnsigned)
return emitc::SizeTType::get(ty.getContext());
return emitc::PtrDiffTType::get(ty.getContext());
}
}
return ty;
}
Expand Down Expand Up @@ -264,8 +274,9 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
ConversionPatternRewriter &rewriter) const override {

Type type = adaptor.getLhs().getType();
if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
return rewriter.notifyMatchFailure(op, "expected integer or index type");
if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t type");
}

bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
Expand Down Expand Up @@ -318,17 +329,21 @@ class CastConversion : public OpConversionPattern<ArithOp> {
ConversionPatternRewriter &rewriter) const override {

Type opReturnType = this->getTypeConverter()->convertType(op.getType());
if (!isa_and_nonnull<IntegerType>(opReturnType))
return rewriter.notifyMatchFailure(op, "expected integer result type");
if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
emitc::isPointerWideType(opReturnType)))
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t result type");

if (adaptor.getOperands().size() != 1) {
return rewriter.notifyMatchFailure(
op, "CastConversion only supports unary ops");
}

Type operandType = adaptor.getIn().getType();
if (!isa_and_nonnull<IntegerType>(operandType))
return rewriter.notifyMatchFailure(op, "expected integer operand type");
if (!operandType || !(isa<IntegerType>(operandType) ||
emitc::isPointerWideType(operandType)))
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t operand type");

// Signed (sign-extending) casts from i1 are not supported.
if (operandType.isInteger(1) && !castToUnsigned)
Expand All @@ -339,8 +354,11 @@ class CastConversion : public OpConversionPattern<ArithOp> {
// equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
// truncation.
if (opReturnType.isInteger(1)) {
Type attrType = (emitc::isPointerWideType(operandType))
? rewriter.getIndexType()
: operandType;
auto constOne = rewriter.create<emitc::ConstantOp>(
op.getLoc(), operandType, rewriter.getIntegerAttr(operandType, 1));
op.getLoc(), operandType, rewriter.getOneAttr(attrType));
auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>(
op.getLoc(), operandType, adaptor.getIn(), constOne);
rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
Expand Down Expand Up @@ -393,7 +411,11 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, arithOp.getType(),
Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
if (!newTy)
return rewriter.notifyMatchFailure(arithOp,
"converting result type failed");
rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
adaptor.getOperands());

return success();
Expand All @@ -410,8 +432,9 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
ConversionPatternRewriter &rewriter) const override {

Type type = this->getTypeConverter()->convertType(op.getType());
if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
return rewriter.notifyMatchFailure(op, "expected integer type");
if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t type");
}

if (type.isInteger(1)) {
Expand Down Expand Up @@ -482,6 +505,89 @@ class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
}
};

template <typename ArithOp, typename EmitCOp, bool isUnsignedOp>
class ShiftOpConversion : public OpConversionPattern<ArithOp> {
public:
using OpConversionPattern<ArithOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

Type type = this->getTypeConverter()->convertType(op.getType());
if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t type");
}

if (type.isInteger(1)) {
return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
}

Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp);

Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
// Shift amount interpreted as unsigned per Arith dialect spec.
Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
/*needsUnsigned=*/true);
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);

// Add a runtime check for overflow
Value width;
if (emitc::isPointerWideType(type)) {
Value eight = rewriter.create<emitc::ConstantOp>(
op.getLoc(), rhsType, rewriter.getIndexAttr(8));
emitc::CallOpaqueOp sizeOfCall = rewriter.create<emitc::CallOpaqueOp>(
op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight});
width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight,
sizeOfCall.getResult(0));
} else {
width = rewriter.create<emitc::ConstantOp>(
op.getLoc(), rhsType,
rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
}

Value excessCheck = rewriter.create<emitc::CmpOp>(
op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width);

// Any concrete value is a valid refinement of poison.
Value poison = rewriter.create<emitc::ConstantOp>(
op.getLoc(), arithmeticType,
(isa<IntegerType>(arithmeticType)
? rewriter.getIntegerAttr(arithmeticType, 0)
: rewriter.getIndexAttr(0)));

emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>(
op.getLoc(), arithmeticType, /*do_not_inline=*/false);
Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
auto currentPoint = rewriter.getInsertionPoint();
rewriter.setInsertionPointToStart(&bodyBlock);
Value arithmeticResult =
rewriter.create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs);
Value resultOrPoison = rewriter.create<emitc::ConditionalOp>(
op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison);
rewriter.create<emitc::YieldOp>(op.getLoc(), resultOrPoison);
rewriter.setInsertionPoint(op->getBlock(), currentPoint);

Value result = adaptValueType(ternary, rewriter, type);

rewriter.replaceOp(op, result);
return success();
}
};

template <typename ArithOp, typename EmitCOp>
class SignedShiftOpConversion final
: public ShiftOpConversion<ArithOp, EmitCOp, false> {
using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
};

template <typename ArithOp, typename EmitCOp>
class UnsignedShiftOpConversion final
: public ShiftOpConversion<ArithOp, EmitCOp, true> {
using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
};

class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
public:
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
Expand Down Expand Up @@ -606,6 +712,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();

mlir::populateEmitCSizeTTypeConversions(typeConverter);

// clang-format off
patterns.add<
ArithConstantOpConversionPattern,
Expand All @@ -621,6 +729,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
CmpFOpConversion,
CmpIOpConversion,
NegFOpConversion,
Expand All @@ -629,6 +740,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
UnsignedCastConversion<arith::TruncIOp>,
SignedCastConversion<arith::ExtSIOp>,
UnsignedCastConversion<arith::ExtUIOp>,
SignedCastConversion<arith::IndexCastOp>,
UnsignedCastConversion<arith::IndexCastUIOp>,
ItoFCastOpConversion<arith::SIToFPOp>,
ItoFCastOpConversion<arith::UIToFPOp>,
FtoICastOpConversion<arith::FPToSIOp>,
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRArithToEmitC
LINK_LIBS PUBLIC
MLIRArithDialect
MLIREmitCDialect
MLIREmitCTransforms
MLIRPass
MLIRTransformUtils
)
24 changes: 24 additions & 0 deletions mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,27 @@ func.func @arith_extsi_i1_to_i32(%arg0: i1) {
%idx = arith.extsi %arg0 : i1 to i32
return
}

// -----

func.func @arith_shli_i1(%arg0: i1, %arg1: i1) {
// expected-error @+1 {{failed to legalize operation 'arith.shli'}}
%shli = arith.shli %arg0, %arg1 : i1
return
}

// -----

func.func @arith_shrsi_i1(%arg0: i1, %arg1: i1) {
// expected-error @+1 {{failed to legalize operation 'arith.shrsi'}}
%shrsi = arith.shrsi %arg0, %arg1 : i1
return
}

// -----

func.func @arith_shrui_i1(%arg0: i1, %arg1: i1) {
// expected-error @+1 {{failed to legalize operation 'arith.shrui'}}
%shrui = arith.shrui %arg0, %arg1 : i1
return
}
Loading
Loading