-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][emitc] Lower arith.index_cast, arith.index_castui, arith.shli, arith.shrui, arith.shrsi #95795
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][emitc] Lower arith.index_cast, arith.index_castui, arith.shli, arith.shrui, arith.shrsi #95795
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
7700cb8
to
334a916
Compare
@llvm/pr-subscribers-mlir-emitc @llvm/pr-subscribers-mlir Author: Corentin Ferry (cferry-AMD) ChangesThis PR makes use of the newly introduced EmitC types, and it is now possible to lower:
For the Patch is 23.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95795.diff 4 Files Affected:
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 93717e3b02ef0..b0c9d083ddd88 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -15,8 +15,10 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Region.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -36,8 +38,11 @@ class ArithConstantOpConversionPattern
matchAndRewrite(arith::ConstantOp arithConst,
arith::ConstantOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<emitc::ConstantOp>(
- arithConst, arithConst.getType(), adaptor.getValue());
+ Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
+ if (!newTy)
+ return rewriter.notifyMatchFailure(arithConst, "type conversion failed");
+ rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy,
+ adaptor.getValue());
return success();
}
};
@@ -52,6 +57,12 @@ Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
return IntegerType::get(ty.getContext(), ty.getIntOrFloatBitWidth(),
signedness);
}
+ } else if (emitc::isPointerWideType(ty)) {
+ if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
+ if (needsUnsigned)
+ return emitc::SizeTType::get(ty.getContext());
+ return emitc::PtrDiffTType::get(ty.getContext());
+ }
}
return ty;
}
@@ -264,8 +275,9 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
ConversionPatternRewriter &rewriter) const override {
Type type = adaptor.getLhs().getType();
- if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
- return rewriter.notifyMatchFailure(op, "expected integer or index type");
+ if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
+ return rewriter.notifyMatchFailure(
+ op, "expected integer or size_t/ssize_t/ptrdiff_t type");
}
bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
@@ -318,8 +330,10 @@ class CastConversion : public OpConversionPattern<ArithOp> {
ConversionPatternRewriter &rewriter) const override {
Type opReturnType = this->getTypeConverter()->convertType(op.getType());
- if (!isa_and_nonnull<IntegerType>(opReturnType))
- return rewriter.notifyMatchFailure(op, "expected integer result type");
+ if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
+ emitc::isPointerWideType(opReturnType)))
+ return rewriter.notifyMatchFailure(
+ op, "expected integer or size_t/ssize_t/ptrdiff_t result type");
if (adaptor.getOperands().size() != 1) {
return rewriter.notifyMatchFailure(
@@ -327,8 +341,10 @@ class CastConversion : public OpConversionPattern<ArithOp> {
}
Type operandType = adaptor.getIn().getType();
- if (!isa_and_nonnull<IntegerType>(operandType))
- return rewriter.notifyMatchFailure(op, "expected integer operand type");
+ if (!operandType || !(isa<IntegerType>(operandType) ||
+ emitc::isPointerWideType(operandType)))
+ return rewriter.notifyMatchFailure(
+ op, "expected integer or size_t/ssize_t/ptrdiff_t operand type");
// Signed (sign-extending) casts from i1 are not supported.
if (operandType.isInteger(1) && !castToUnsigned)
@@ -339,8 +355,11 @@ class CastConversion : public OpConversionPattern<ArithOp> {
// equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
// truncation.
if (opReturnType.isInteger(1)) {
+ Type attrType = (emitc::isPointerWideType(operandType))
+ ? rewriter.getIndexType()
+ : operandType;
auto constOne = rewriter.create<emitc::ConstantOp>(
- op.getLoc(), operandType, rewriter.getIntegerAttr(operandType, 1));
+ op.getLoc(), operandType, rewriter.getIntegerAttr(attrType, 1));
auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>(
op.getLoc(), operandType, adaptor.getIn(), constOne);
rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
@@ -393,7 +412,11 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, arithOp.getType(),
+ Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
+ if (!newTy)
+ return rewriter.notifyMatchFailure(arithOp,
+ "converting result type failed");
+ rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
adaptor.getOperands());
return success();
@@ -410,8 +433,10 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
ConversionPatternRewriter &rewriter) const override {
Type type = this->getTypeConverter()->convertType(op.getType());
- if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
- return rewriter.notifyMatchFailure(op, "expected integer type");
+ if (!type || !(isa_and_nonnull<IntegerType>(type) ||
+ emitc::isPointerWideType(type))) {
+ return rewriter.notifyMatchFailure(
+ op, "expected integer or size_t/ssize_t/ptrdiff_t type");
}
if (type.isInteger(1)) {
@@ -482,6 +507,90 @@ class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
}
};
+template <typename ArithOp, typename EmitCOp, bool isUnsignedOp>
+class ShiftOpConversion : public OpConversionPattern<ArithOp> {
+public:
+ using OpConversionPattern<ArithOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ Type type = this->getTypeConverter()->convertType(op.getType());
+ if (!type || !(isa_and_nonnull<IntegerType>(type) ||
+ emitc::isPointerWideType(type))) {
+ return rewriter.notifyMatchFailure(
+ op, "expected integer or size_t/ssize_t/ptrdiff_t type");
+ }
+
+ if (type.isInteger(1)) {
+ return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
+ }
+
+ Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp);
+
+ Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
+ // Shift amount interpreted as unsigned per Arith dialect spec.
+ Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
+ /*needsUnsigned=*/true);
+ Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);
+
+ // Add a runtime check for overflow
+ Value width;
+ if (emitc::isPointerWideType(type)) {
+ Value eight = rewriter.create<emitc::ConstantOp>(
+ op.getLoc(), rhsType, rewriter.getIndexAttr(8));
+ emitc::CallOpaqueOp sizeOfCall = rewriter.create<emitc::CallOpaqueOp>(
+ op.getLoc(), rhsType, "sizeof", SmallVector<Value, 1>({eight}));
+ width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight,
+ sizeOfCall.getResult(0));
+ } else {
+ width = rewriter.create<emitc::ConstantOp>(
+ op.getLoc(), rhsType,
+ rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
+ }
+
+ Value excessCheck = rewriter.create<emitc::CmpOp>(
+ op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width);
+
+ // Any concrete value is a valid refinement of poison.
+ Value poison = rewriter.create<emitc::ConstantOp>(
+ op.getLoc(), arithmeticType,
+ (isa<IntegerType>(arithmeticType)
+ ? rewriter.getIntegerAttr(arithmeticType, 0)
+ : rewriter.getIndexAttr(0)));
+
+ emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>(
+ op.getLoc(), arithmeticType, /*do_not_inline=*/false);
+ Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
+ auto currentPoint = rewriter.getInsertionPoint();
+ rewriter.setInsertionPointToStart(&bodyBlock);
+ Value arithmeticResult =
+ rewriter.create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs);
+ Value resultOrPoison = rewriter.create<emitc::ConditionalOp>(
+ op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison);
+ rewriter.create<emitc::YieldOp>(op.getLoc(), resultOrPoison);
+ rewriter.setInsertionPoint(op->getBlock(), currentPoint);
+
+ Value result = adaptValueType(ternary, rewriter, type);
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+template <typename ArithOp, typename EmitCOp>
+class SignedShiftOpConversion final
+ : public ShiftOpConversion<ArithOp, EmitCOp, false> {
+ using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
+};
+
+template <typename ArithOp, typename EmitCOp>
+class UnsignedShiftOpConversion final
+ : public ShiftOpConversion<ArithOp, EmitCOp, true> {
+ using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
+};
+
class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
public:
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
@@ -606,6 +715,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();
+ mlir::populateEmitCSizeTTypeConversions(typeConverter);
+
// clang-format off
patterns.add<
ArithConstantOpConversionPattern,
@@ -621,6 +732,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
+ UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
+ SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
+ UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
CmpFOpConversion,
CmpIOpConversion,
NegFOpConversion,
@@ -629,6 +743,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
UnsignedCastConversion<arith::TruncIOp>,
SignedCastConversion<arith::ExtSIOp>,
UnsignedCastConversion<arith::ExtUIOp>,
+ SignedCastConversion<arith::IndexCastOp>,
+ UnsignedCastConversion<arith::IndexCastUIOp>,
ItoFCastOpConversion<arith::SIToFPOp>,
ItoFCastOpConversion<arith::UIToFPOp>,
FtoICastOpConversion<arith::FPToSIOp>,
diff --git a/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
index a3784f47c3bc2..730a4b341673d 100644
--- a/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRArithToEmitC
LINK_LIBS PUBLIC
MLIRArithDialect
MLIREmitCDialect
+ MLIREmitCTransforms
MLIRPass
MLIRTransformUtils
)
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
index caef04052aa8c..766ad4039335e 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -110,3 +110,27 @@ func.func @arith_extsi_i1_to_i32(%arg0: i1) {
%idx = arith.extsi %arg0 : i1 to i32
return
}
+
+// -----
+
+func.func @arith_shli_i1(%arg0: i1, %arg1: i1) {
+ // expected-error @+1 {{failed to legalize operation 'arith.shli'}}
+ %shli = arith.shli %arg0, %arg1 : i1
+ return
+}
+
+// -----
+
+func.func @arith_shrsi_i1(%arg0: i1, %arg1: i1) {
+ // expected-error @+1 {{failed to legalize operation 'arith.shrsi'}}
+ %shrsi = arith.shrsi %arg0, %arg1 : i1
+ return
+}
+
+// -----
+
+func.func @arith_shrui_i1(%arg0: i1, %arg1: i1) {
+ // expected-error @+1 {{failed to legalize operation 'arith.shrui'}}
+ %shrui = arith.shrui %arg0, %arg1 : i1
+ return
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 0289b7dc0728f..858ccd1171445 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -3,7 +3,8 @@
// CHECK-LABEL: arith_constants
func.func @arith_constants() {
// CHECK: emitc.constant
- // CHECK-SAME: value = 0 : index
+ // CHECK-SAME: value = 0
+ // CHECK-SAME: () -> !emitc.size_t
%c_index = arith.constant 0 : index
// CHECK: emitc.constant
// CHECK-SAME: value = 0 : i32
@@ -75,13 +76,18 @@ func.func @arith_integer_ops_signed_nsw(%arg0: i32, %arg1: i32) {
// -----
// CHECK-LABEL: arith_index
-func.func @arith_index(%arg0: index, %arg1: index) {
- // CHECK: emitc.add %arg0, %arg1 : (index, index) -> index
- %0 = arith.addi %arg0, %arg1 : index
- // CHECK: emitc.sub %arg0, %arg1 : (index, index) -> index
- %1 = arith.subi %arg0, %arg1 : index
- // CHECK: emitc.mul %arg0, %arg1 : (index, index) -> index
- %2 = arith.muli %arg0, %arg1 : index
+func.func @arith_index(%arg0: i32, %arg1: i32) {
+ // CHECK: %[[CST0:.*]] = emitc.cast %{{.*}} : {{.*}} to !emitc.size_t
+ %cst0 = arith.index_cast %arg0 : i32 to index
+ // CHECK: %[[CST1:.*]] = emitc.cast %{{.*}} : {{.*}} to !emitc.size_t
+ %cst1 = arith.index_cast %arg1 : i32 to index
+
+ // CHECK: emitc.add %[[CST0]], %[[CST1]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+ %0 = arith.addi %cst0, %cst1 : index
+ // CHECK: emitc.sub %[[CST0]], %[[CST1]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+ %1 = arith.subi %cst0, %cst1 : index
+ // CHECK: emitc.mul %[[CST0]], %[[CST1]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+ %2 = arith.muli %cst0, %cst1 : index
return
}
@@ -138,6 +144,116 @@ func.func @arith_signed_integer_div_rem(%arg0: i32, %arg1: i32) {
// -----
+// CHECK-LABEL: arith_shift_left
+// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
+func.func @arith_shift_left(%arg0: i32, %arg1: i32) {
+ // CHECK-DAG: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
+ // CHECK-DAG: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
+ // CHECK-DAG: %[[SizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32
+ // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[C2]], %[[SizeConstant]] : (ui32, ui32) -> i1
+ // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0
+ // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32
+ // CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
+ // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHL]], %[[Zero]] : ui32
+ // CHECK: emitc.yield %[[Ternary]] : ui32
+ // CHECK: emitc.cast %[[ShiftRes]] : ui32 to i32
+ %1 = arith.shli %arg0, %arg1 : i32
+ return
+}
+
+// -----
+
+// CHECK-LABEL: arith_shift_right
+// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
+func.func @arith_shift_right(%arg0: i32, %arg1: i32) {
+ // CHECK-DAG: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
+ // CHECK-DAG: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
+ // CHECK-DAG: %[[SizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}ui32
+ // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[C2]], %[[SizeConstant]] : (ui32, ui32) -> i1
+ // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}ui32
+ // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32
+ // CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
+ // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHR]], %[[Zero]] : ui32
+ // CHECK: emitc.yield %[[Ternary]] : ui32
+ // CHECK: emitc.cast %[[ShiftRes]] : ui32 to i32
+ %2 = arith.shrui %arg0, %arg1 : i32
+
+ // CHECK-DAG: %[[SC2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
+ // CHECK-DAG: %[[SSizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}ui32
+ // CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[SC2]], %[[SSizeConstant]] : (ui32, ui32) -> i1
+ // CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}i32
+ // CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : i32
+ // CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[ARG0]], %[[SC2]] : (i32, ui32) -> i32
+ // CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : i32
+ // CHECK: emitc.yield %[[STernary]] : i32
+ %3 = arith.shrsi %arg0, %arg1 : i32
+
+ return
+}
+
+// -----
+
+// CHECK-LABEL: arith_shift_left_index
+// CHECK-SAME: %[[AMOUNT:.*]]: i32
+func.func @arith_shift_left_index(%amount: i32) {
+ %cst0 = "arith.constant"() {value = 42 : index} : () -> (index)
+ %cast1 = arith.index_cast %amount : i32 to index
+ // CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t
+ // CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ptrdiff_t
+ // CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ptrdiff_t to !emitc.size_t
+ // CHECK-DAG: %[[Byte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index
+ // CHECK-DAG: %[[SizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[Byte]]) : (!emitc.size_t) -> !emitc.size_t
+ // CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+ // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
+ // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0
+ // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t
+ // CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[AmountIdx]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+ // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHL]], %[[Zero]] : !emitc.size_t
+ // CHECK: emitc.yield %[[Ternary]] : !emitc.size_t
+ %1 = arith.shli %cst0, %cast1 : index
+ return
+}
+
+// -----
+
+// CHECK-LABEL: arith_shift_right_index
+// CHECK-SAME: %[[AMOUNT:.*]]: i32
+func.func @arith_shift_right_index(%amount: i32) {
+ // CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t
+ // CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ptrdiff_t
+ // CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ptrdiff_t to !emitc.size_t
+ %arg0 = "arith.constant"() {value = 42 : index} : () -> (index)
+ %arg1 = arith.index_cast %amount : i32 to index
+
+ // CHECK-DAG: %[[Byte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index
+ // CHECK-DAG: %[[SizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[Byte]]) : (!emitc.size_t) -> !emitc.size_t
+ // CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+ // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
+ // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.size_t
+ // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t
+ // CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[AmountIdx]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+ // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHR]], %[[Zero]] : !emitc.size_t
+ // CHECK: emitc.yield %[[Ternary]] : !emitc.size_t
+ %2 = arith.shrui %arg0, %arg1 : index
+
+ // CHECK-DAG: %[[SC1:[^ ]*]] = emitc.cast %[[C1]] : !emitc.size_t to !emitc.ptrdiff_t
+ // CHECK-DAG: %[[SByte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index{{.*}}!emitc.size_t
+ // CHECK-DAG: %[[SSizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[SByte]]) : (!emitc.size_t) -> !emitc.size_t
+ // CHECK-DAG: %[[SSizeConstant:[^ ]*]] = emitc.mul %[[SByte]], %[[SSizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+ // CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emi...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look good overall
… arith.shrui, arith.shrsi (llvm#95795) This PR makes use of the newly introduced EmitC types, and lowers: * ops dealing with index types (index_cast, index_castui), * ops where `size_t` is used as part of the lowering (shli, shrui, shrsi, to check for overflow and avoid UB in this case).
This PR makes use of the newly introduced EmitC types, and it is now possible to lower:
index_cast
,index_castui
),size_t
is used as part of the lowering (shli
,shrui
,shrsi
).For the
shli
,shrui
,shrsi
operations, we have to check for overflow, as overflow is UB per C99 specification, and gives a poison value in the MLIR world. Where the bitwidth is not known (i.e. for variables of typeindex
), the check is performed usingsizeof
. It is then up to the target compiler to optimize it away and perform constant propagation.