Skip to content

[mlir][emitc] Lower arith.index_cast, arith.index_castui, arith.shli, arith.shrui, arith.shrsi #95795

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 10, 2024

Conversation

cferry-AMD
Copy link
Contributor

@cferry-AMD cferry-AMD commented Jun 17, 2024

This PR makes use of the newly introduced EmitC types, and it is now possible to lower:

  • ops dealing with index types (index_cast, index_castui),
  • ops where size_t is used as part of the lowering (shli, shrui, shrsi).

For the shli, shrui, shrsi operations, we have to check for overflow, as overflow is UB per C99 specification, and gives a poison value in the MLIR world. Where the bitwidth is not known (i.e. for variables of type index), the check is performed using sizeof. It is then up to the target compiler to optimize it away and perform constant propagation.

Copy link

github-actions bot commented Jun 17, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@cferry-AMD cferry-AMD force-pushed the corentin.upstream_emitc_use_types branch from 7700cb8 to 334a916 Compare June 19, 2024 07:23
@cferry-AMD cferry-AMD marked this pull request as ready for review June 19, 2024 07:24
@llvmbot
Copy link
Member

llvmbot commented Jun 19, 2024

@llvm/pr-subscribers-mlir-emitc

@llvm/pr-subscribers-mlir

Author: Corentin Ferry (cferry-AMD)

Changes

This PR makes use of the newly introduced EmitC types, and it is now possible to lower:

  • ops dealing with index types (index_cast, index_castui),
  • ops where size_t is used as part of the lowering (shli, shrui, shrsi).

For the shli, shrui, shrsi operations, we have to check for overflow, as overflow is UB per C99 specification, and gives a poison value in the MLIR world. Where the bitwidth is not known (i.e. for variables of type index), the check is performed using sizeof. It is then up to the target compiler to optimize it away and perform constant propagation.


Patch is 23.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95795.diff

4 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp (+128-12)
  • (modified) mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt (+1)
  • (modified) mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir (+24)
  • (modified) mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir (+189-8)
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 93717e3b02ef0..b0c9d083ddd88 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -15,8 +15,10 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Region.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -36,8 +38,11 @@ class ArithConstantOpConversionPattern
   matchAndRewrite(arith::ConstantOp arithConst,
                   arith::ConstantOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<emitc::ConstantOp>(
-        arithConst, arithConst.getType(), adaptor.getValue());
+    Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
+    if (!newTy)
+      return rewriter.notifyMatchFailure(arithConst, "type conversion failed");
+    rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy,
+                                                   adaptor.getValue());
     return success();
   }
 };
@@ -52,6 +57,12 @@ Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
       return IntegerType::get(ty.getContext(), ty.getIntOrFloatBitWidth(),
                               signedness);
     }
+  } else if (emitc::isPointerWideType(ty)) {
+    if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
+      if (needsUnsigned)
+        return emitc::SizeTType::get(ty.getContext());
+      return emitc::PtrDiffTType::get(ty.getContext());
+    }
   }
   return ty;
 }
@@ -264,8 +275,9 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
                   ConversionPatternRewriter &rewriter) const override {
 
     Type type = adaptor.getLhs().getType();
-    if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
-      return rewriter.notifyMatchFailure(op, "expected integer or index type");
+    if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
+      return rewriter.notifyMatchFailure(
+          op, "expected integer or size_t/ssize_t/ptrdiff_t type");
     }
 
     bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
@@ -318,8 +330,10 @@ class CastConversion : public OpConversionPattern<ArithOp> {
                   ConversionPatternRewriter &rewriter) const override {
 
     Type opReturnType = this->getTypeConverter()->convertType(op.getType());
-    if (!isa_and_nonnull<IntegerType>(opReturnType))
-      return rewriter.notifyMatchFailure(op, "expected integer result type");
+    if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
+                           emitc::isPointerWideType(opReturnType)))
+      return rewriter.notifyMatchFailure(
+          op, "expected integer or size_t/ssize_t/ptrdiff_t result type");
 
     if (adaptor.getOperands().size() != 1) {
       return rewriter.notifyMatchFailure(
@@ -327,8 +341,10 @@ class CastConversion : public OpConversionPattern<ArithOp> {
     }
 
     Type operandType = adaptor.getIn().getType();
-    if (!isa_and_nonnull<IntegerType>(operandType))
-      return rewriter.notifyMatchFailure(op, "expected integer operand type");
+    if (!operandType || !(isa<IntegerType>(operandType) ||
+                          emitc::isPointerWideType(operandType)))
+      return rewriter.notifyMatchFailure(
+          op, "expected integer or size_t/ssize_t/ptrdiff_t operand type");
 
     // Signed (sign-extending) casts from i1 are not supported.
     if (operandType.isInteger(1) && !castToUnsigned)
@@ -339,8 +355,11 @@ class CastConversion : public OpConversionPattern<ArithOp> {
     // equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
     // truncation.
     if (opReturnType.isInteger(1)) {
+      Type attrType = (emitc::isPointerWideType(operandType))
+                          ? rewriter.getIndexType()
+                          : operandType;
       auto constOne = rewriter.create<emitc::ConstantOp>(
-          op.getLoc(), operandType, rewriter.getIntegerAttr(operandType, 1));
+          op.getLoc(), operandType, rewriter.getIntegerAttr(attrType, 1));
       auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>(
           op.getLoc(), operandType, adaptor.getIn(), constOne);
       rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
@@ -393,7 +412,11 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
   matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
-    rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, arithOp.getType(),
+    Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
+    if (!newTy)
+      return rewriter.notifyMatchFailure(arithOp,
+                                         "converting result type failed");
+    rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
                                                   adaptor.getOperands());
 
     return success();
@@ -410,8 +433,10 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
                   ConversionPatternRewriter &rewriter) const override {
 
     Type type = this->getTypeConverter()->convertType(op.getType());
-    if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
-      return rewriter.notifyMatchFailure(op, "expected integer type");
+    if (!type || !(isa_and_nonnull<IntegerType>(type) ||
+                   emitc::isPointerWideType(type))) {
+      return rewriter.notifyMatchFailure(
+          op, "expected integer or size_t/ssize_t/ptrdiff_t type");
     }
 
     if (type.isInteger(1)) {
@@ -482,6 +507,90 @@ class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
   }
 };
 
+template <typename ArithOp, typename EmitCOp, bool isUnsignedOp>
+class ShiftOpConversion : public OpConversionPattern<ArithOp> {
+public:
+  using OpConversionPattern<ArithOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    Type type = this->getTypeConverter()->convertType(op.getType());
+    if (!type || !(isa_and_nonnull<IntegerType>(type) ||
+                   emitc::isPointerWideType(type))) {
+      return rewriter.notifyMatchFailure(
+          op, "expected integer or size_t/ssize_t/ptrdiff_t type");
+    }
+
+    if (type.isInteger(1)) {
+      return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
+    }
+
+    Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp);
+
+    Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
+    // Shift amount interpreted as unsigned per Arith dialect spec.
+    Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
+                                               /*needsUnsigned=*/true);
+    Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);
+
+    // Add a runtime check for overflow
+    Value width;
+    if (emitc::isPointerWideType(type)) {
+      Value eight = rewriter.create<emitc::ConstantOp>(
+          op.getLoc(), rhsType, rewriter.getIndexAttr(8));
+      emitc::CallOpaqueOp sizeOfCall = rewriter.create<emitc::CallOpaqueOp>(
+          op.getLoc(), rhsType, "sizeof", SmallVector<Value, 1>({eight}));
+      width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight,
+                                            sizeOfCall.getResult(0));
+    } else {
+      width = rewriter.create<emitc::ConstantOp>(
+          op.getLoc(), rhsType,
+          rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
+    }
+
+    Value excessCheck = rewriter.create<emitc::CmpOp>(
+        op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width);
+
+    // Any concrete value is a valid refinement of poison.
+    Value poison = rewriter.create<emitc::ConstantOp>(
+        op.getLoc(), arithmeticType,
+        (isa<IntegerType>(arithmeticType)
+             ? rewriter.getIntegerAttr(arithmeticType, 0)
+             : rewriter.getIndexAttr(0)));
+
+    emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>(
+        op.getLoc(), arithmeticType, /*do_not_inline=*/false);
+    Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
+    auto currentPoint = rewriter.getInsertionPoint();
+    rewriter.setInsertionPointToStart(&bodyBlock);
+    Value arithmeticResult =
+        rewriter.create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs);
+    Value resultOrPoison = rewriter.create<emitc::ConditionalOp>(
+        op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison);
+    rewriter.create<emitc::YieldOp>(op.getLoc(), resultOrPoison);
+    rewriter.setInsertionPoint(op->getBlock(), currentPoint);
+
+    Value result = adaptValueType(ternary, rewriter, type);
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+template <typename ArithOp, typename EmitCOp>
+class SignedShiftOpConversion final
+    : public ShiftOpConversion<ArithOp, EmitCOp, false> {
+  using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
+};
+
+template <typename ArithOp, typename EmitCOp>
+class UnsignedShiftOpConversion final
+    : public ShiftOpConversion<ArithOp, EmitCOp, true> {
+  using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
+};
+
 class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
 public:
   using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
@@ -606,6 +715,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
                                         RewritePatternSet &patterns) {
   MLIRContext *ctx = patterns.getContext();
 
+  mlir::populateEmitCSizeTTypeConversions(typeConverter);
+
   // clang-format off
   patterns.add<
     ArithConstantOpConversionPattern,
@@ -621,6 +732,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
     BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
     BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
     BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
+    UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
+    SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
+    UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
     CmpFOpConversion,
     CmpIOpConversion,
     NegFOpConversion,
@@ -629,6 +743,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
     UnsignedCastConversion<arith::TruncIOp>,
     SignedCastConversion<arith::ExtSIOp>,
     UnsignedCastConversion<arith::ExtUIOp>,
+    SignedCastConversion<arith::IndexCastOp>,
+    UnsignedCastConversion<arith::IndexCastUIOp>,
     ItoFCastOpConversion<arith::SIToFPOp>,
     ItoFCastOpConversion<arith::UIToFPOp>,
     FtoICastOpConversion<arith::FPToSIOp>,
diff --git a/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
index a3784f47c3bc2..730a4b341673d 100644
--- a/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRArithToEmitC
   LINK_LIBS PUBLIC
   MLIRArithDialect
   MLIREmitCDialect
+  MLIREmitCTransforms
   MLIRPass
   MLIRTransformUtils
   )
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
index caef04052aa8c..766ad4039335e 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -110,3 +110,27 @@ func.func @arith_extsi_i1_to_i32(%arg0: i1) {
   %idx = arith.extsi %arg0 : i1 to i32
   return
 }
+
+// -----
+
+func.func @arith_shli_i1(%arg0: i1, %arg1: i1) {
+  // expected-error @+1 {{failed to legalize operation 'arith.shli'}}
+  %shli = arith.shli %arg0, %arg1 : i1
+  return
+}
+
+// -----
+
+func.func @arith_shrsi_i1(%arg0: i1, %arg1: i1) {
+  // expected-error @+1 {{failed to legalize operation 'arith.shrsi'}}
+  %shrsi = arith.shrsi %arg0, %arg1 : i1
+  return
+}
+
+// -----
+
+func.func @arith_shrui_i1(%arg0: i1, %arg1: i1) {
+  // expected-error @+1 {{failed to legalize operation 'arith.shrui'}}
+  %shrui = arith.shrui %arg0, %arg1 : i1
+  return
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 0289b7dc0728f..858ccd1171445 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -3,7 +3,8 @@
 // CHECK-LABEL: arith_constants
 func.func @arith_constants() {
   // CHECK: emitc.constant
-  // CHECK-SAME: value = 0 : index
+  // CHECK-SAME: value = 0
+  // CHECK-SAME: () -> !emitc.size_t
   %c_index = arith.constant 0 : index
   // CHECK: emitc.constant
   // CHECK-SAME: value = 0 : i32
@@ -75,13 +76,18 @@ func.func @arith_integer_ops_signed_nsw(%arg0: i32, %arg1: i32) {
 // -----
 
 // CHECK-LABEL: arith_index
-func.func @arith_index(%arg0: index, %arg1: index) {
-  // CHECK: emitc.add %arg0, %arg1 : (index, index) -> index
-  %0 = arith.addi %arg0, %arg1 : index
-  // CHECK: emitc.sub %arg0, %arg1 : (index, index) -> index
-  %1 = arith.subi %arg0, %arg1 : index
-  // CHECK: emitc.mul %arg0, %arg1 : (index, index) -> index
-  %2 = arith.muli %arg0, %arg1 : index
+func.func @arith_index(%arg0: i32, %arg1: i32) {
+  // CHECK: %[[CST0:.*]] = emitc.cast %{{.*}} : {{.*}} to !emitc.size_t
+  %cst0 = arith.index_cast %arg0 : i32 to index
+  // CHECK: %[[CST1:.*]] = emitc.cast %{{.*}} : {{.*}} to !emitc.size_t
+  %cst1 = arith.index_cast %arg1 : i32 to index
+
+  // CHECK: emitc.add %[[CST0]], %[[CST1]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+  %0 = arith.addi %cst0, %cst1 : index
+  // CHECK: emitc.sub %[[CST0]], %[[CST1]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+  %1 = arith.subi %cst0, %cst1 : index
+  // CHECK: emitc.mul %[[CST0]], %[[CST1]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+  %2 = arith.muli %cst0, %cst1 : index
 
   return
 }
@@ -138,6 +144,116 @@ func.func @arith_signed_integer_div_rem(%arg0: i32, %arg1: i32) {
 
 // -----
 
+// CHECK-LABEL: arith_shift_left
+// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
+func.func @arith_shift_left(%arg0: i32, %arg1: i32) {
+  // CHECK-DAG: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
+  // CHECK-DAG: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
+  // CHECK-DAG: %[[SizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32
+  // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[C2]], %[[SizeConstant]] : (ui32, ui32) -> i1
+  // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0
+  // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32
+  // CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
+  // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHL]], %[[Zero]] : ui32
+  // CHECK: emitc.yield %[[Ternary]] : ui32
+  // CHECK: emitc.cast %[[ShiftRes]] : ui32 to i32
+  %1 = arith.shli %arg0, %arg1 : i32
+  return
+}
+
+// -----
+
+// CHECK-LABEL: arith_shift_right
+// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
+func.func @arith_shift_right(%arg0: i32, %arg1: i32) {
+  // CHECK-DAG: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
+  // CHECK-DAG: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
+  // CHECK-DAG: %[[SizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}ui32
+  // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[C2]], %[[SizeConstant]] : (ui32, ui32) -> i1
+  // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}ui32
+  // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32
+  // CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
+  // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHR]], %[[Zero]] : ui32
+  // CHECK: emitc.yield %[[Ternary]] : ui32
+  // CHECK: emitc.cast %[[ShiftRes]] : ui32 to i32
+  %2 = arith.shrui %arg0, %arg1 : i32
+
+  // CHECK-DAG: %[[SC2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
+  // CHECK-DAG: %[[SSizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}ui32
+  // CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[SC2]], %[[SSizeConstant]] : (ui32, ui32) -> i1
+  // CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}i32
+  // CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : i32
+  // CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[ARG0]], %[[SC2]] : (i32, ui32) -> i32
+  // CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : i32
+  // CHECK: emitc.yield %[[STernary]] : i32
+  %3 = arith.shrsi %arg0, %arg1 : i32
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: arith_shift_left_index
+// CHECK-SAME: %[[AMOUNT:.*]]: i32
+func.func @arith_shift_left_index(%amount: i32) {
+  %cst0 = "arith.constant"() {value = 42 : index} : () -> (index)
+  %cast1 = arith.index_cast %amount : i32 to index
+  // CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t
+  // CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ptrdiff_t
+  // CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ptrdiff_t to !emitc.size_t
+  // CHECK-DAG: %[[Byte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index
+  // CHECK-DAG: %[[SizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[Byte]]) : (!emitc.size_t) -> !emitc.size_t
+  // CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+  // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
+  // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0
+  // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t
+  // CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[AmountIdx]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+  // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHL]], %[[Zero]] : !emitc.size_t
+  // CHECK: emitc.yield %[[Ternary]] : !emitc.size_t
+  %1 = arith.shli %cst0, %cast1 : index
+  return
+}
+
+// -----
+
+// CHECK-LABEL: arith_shift_right_index
+// CHECK-SAME: %[[AMOUNT:.*]]: i32
+func.func @arith_shift_right_index(%amount: i32) {
+  // CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t
+  // CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ptrdiff_t
+  // CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ptrdiff_t to !emitc.size_t
+  %arg0 = "arith.constant"() {value = 42 : index} : () -> (index)
+  %arg1 = arith.index_cast %amount : i32 to index
+
+  // CHECK-DAG: %[[Byte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index
+  // CHECK-DAG: %[[SizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[Byte]]) : (!emitc.size_t) -> !emitc.size_t
+  // CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+  // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
+  // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.size_t
+  // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t
+  // CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[AmountIdx]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+  // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHR]], %[[Zero]] : !emitc.size_t
+  // CHECK: emitc.yield %[[Ternary]] : !emitc.size_t
+  %2 = arith.shrui %arg0, %arg1 : index
+
+  // CHECK-DAG: %[[SC1:[^ ]*]] = emitc.cast %[[C1]] : !emitc.size_t to !emitc.ptrdiff_t
+  // CHECK-DAG: %[[SByte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index{{.*}}!emitc.size_t
+  // CHECK-DAG: %[[SSizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[SByte]]) : (!emitc.size_t) -> !emitc.size_t
+  // CHECK-DAG: %[[SSizeConstant:[^ ]*]] = emitc.mul %[[SByte]], %[[SSizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+  // CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emi...
[truncated]

Copy link
Contributor

@simon-camp simon-camp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look good overall

@cferry-AMD cferry-AMD requested a review from simon-camp July 8, 2024 12:35
@cferry-AMD cferry-AMD merged commit 5c09dda into llvm:main Jul 10, 2024
7 checks passed
@cferry-AMD cferry-AMD deleted the corentin.upstream_emitc_use_types branch July 10, 2024 09:32
aaryanshukla pushed a commit to aaryanshukla/llvm-project that referenced this pull request Jul 14, 2024
… arith.shrui, arith.shrsi (llvm#95795)

This PR makes use of the newly introduced EmitC types, and lowers:
* ops dealing with index types (index_cast, index_castui),
* ops where `size_t` is used as part of the lowering (shli, shrui,
shrsi, to check for overflow and avoid UB in this case).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants