Skip to content

Commit f64e96a

Browse files
committed
Use new types in EmitC, lower index_cast
1 parent 7f0ab5e commit f64e96a

File tree

3 files changed

+120
-20
lines changed

3 files changed

+120
-20
lines changed

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

+40-12
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "mlir/Dialect/Arith/IR/Arith.h"
1717
#include "mlir/Dialect/EmitC/IR/EmitC.h"
18+
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
1819
#include "mlir/IR/BuiltinAttributes.h"
1920
#include "mlir/IR/BuiltinTypes.h"
2021
#include "mlir/Support/LogicalResult.h"
@@ -36,8 +37,11 @@ class ArithConstantOpConversionPattern
3637
matchAndRewrite(arith::ConstantOp arithConst,
3738
arith::ConstantOp::Adaptor adaptor,
3839
ConversionPatternRewriter &rewriter) const override {
39-
rewriter.replaceOpWithNewOp<emitc::ConstantOp>(
40-
arithConst, arithConst.getType(), adaptor.getValue());
40+
Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
41+
if (!newTy)
42+
return rewriter.notifyMatchFailure(arithConst, "type conversion failed");
43+
rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy,
44+
adaptor.getValue());
4145
return success();
4246
}
4347
};
@@ -52,6 +56,12 @@ Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
5256
return IntegerType::get(ty.getContext(), ty.getIntOrFloatBitWidth(),
5357
signedness);
5458
}
59+
} else if (emitc::isPointerWideType(ty)) {
60+
if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
61+
if (needsUnsigned)
62+
return emitc::SizeTType::get(ty.getContext());
63+
return emitc::PtrDiffTType::get(ty.getContext());
64+
}
5565
}
5666
return ty;
5767
}
@@ -264,8 +274,9 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
264274
ConversionPatternRewriter &rewriter) const override {
265275

266276
Type type = adaptor.getLhs().getType();
267-
if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
268-
return rewriter.notifyMatchFailure(op, "expected integer or index type");
277+
if (type && !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
278+
return rewriter.notifyMatchFailure(
279+
op, "expected integer or size_t/ssize_t type");
269280
}
270281

271282
bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
@@ -290,17 +301,21 @@ class CastConversion : public OpConversionPattern<ArithOp> {
290301
ConversionPatternRewriter &rewriter) const override {
291302

292303
Type opReturnType = this->getTypeConverter()->convertType(op.getType());
293-
if (!isa_and_nonnull<IntegerType>(opReturnType))
294-
return rewriter.notifyMatchFailure(op, "expected integer result type");
304+
if (opReturnType && !(isa_and_nonnull<IntegerType>(opReturnType) ||
305+
emitc::isPointerWideType(opReturnType)))
306+
return rewriter.notifyMatchFailure(
307+
op, "expected integer or size_t/ssize_t result type");
295308

296309
if (adaptor.getOperands().size() != 1) {
297310
return rewriter.notifyMatchFailure(
298311
op, "CastConversion only supports unary ops");
299312
}
300313

301314
Type operandType = adaptor.getIn().getType();
302-
if (!isa_and_nonnull<IntegerType>(operandType))
303-
return rewriter.notifyMatchFailure(op, "expected integer operand type");
315+
if (operandType && !(isa_and_nonnull<IntegerType>(operandType) ||
316+
emitc::isPointerWideType(operandType)))
317+
return rewriter.notifyMatchFailure(
318+
op, "expected integer or size_t/ssize_t operand type");
304319

305320
// Signed (sign-extending) casts from i1 are not supported.
306321
if (operandType.isInteger(1) && !castToUnsigned)
@@ -311,8 +326,11 @@ class CastConversion : public OpConversionPattern<ArithOp> {
311326
// equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
312327
// truncation.
313328
if (opReturnType.isInteger(1)) {
329+
Type attrType = (emitc::isPointerWideType(operandType))
330+
? rewriter.getIndexType()
331+
: operandType;
314332
auto constOne = rewriter.create<emitc::ConstantOp>(
315-
op.getLoc(), operandType, rewriter.getIntegerAttr(operandType, 1));
333+
op.getLoc(), operandType, rewriter.getIntegerAttr(attrType, 1));
316334
auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>(
317335
op.getLoc(), operandType, adaptor.getIn(), constOne);
318336
rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
@@ -365,7 +383,11 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
365383
matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
366384
ConversionPatternRewriter &rewriter) const override {
367385

368-
rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, arithOp.getType(),
386+
Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
387+
if (!newTy)
388+
return rewriter.notifyMatchFailure(arithOp,
389+
"converting result type failed");
390+
rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
369391
adaptor.getOperands());
370392

371393
return success();
@@ -382,8 +404,10 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
382404
ConversionPatternRewriter &rewriter) const override {
383405

384406
Type type = this->getTypeConverter()->convertType(op.getType());
385-
if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
386-
return rewriter.notifyMatchFailure(op, "expected integer type");
407+
if (type && !(isa_and_nonnull<IntegerType>(type) ||
408+
emitc::isPointerWideType(type))) {
409+
return rewriter.notifyMatchFailure(
410+
op, "expected integer or size_t/ssize_t/ptrdiff_t type");
387411
}
388412

389413
if (type.isInteger(1)) {
@@ -578,6 +602,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
578602
RewritePatternSet &patterns) {
579603
MLIRContext *ctx = patterns.getContext();
580604

605+
mlir::populateEmitCSizeTTypeConversions(typeConverter);
606+
581607
// clang-format off
582608
patterns.add<
583609
ArithConstantOpConversionPattern,
@@ -600,6 +626,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
600626
UnsignedCastConversion<arith::TruncIOp>,
601627
SignedCastConversion<arith::ExtSIOp>,
602628
UnsignedCastConversion<arith::ExtUIOp>,
629+
SignedCastConversion<arith::IndexCastOp>,
630+
UnsignedCastConversion<arith::IndexCastUIOp>,
603631
ItoFCastOpConversion<arith::SIToFPOp>,
604632
ItoFCastOpConversion<arith::UIToFPOp>,
605633
FtoICastOpConversion<arith::FPToSIOp>,

mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRArithToEmitC
1111
LINK_LIBS PUBLIC
1212
MLIRArithDialect
1313
MLIREmitCDialect
14+
MLIREmitCTransforms
1415
MLIRPass
1516
MLIRTransformUtils
1617
)

mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir

+79-8
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
// CHECK-LABEL: arith_constants
44
func.func @arith_constants() {
55
// CHECK: emitc.constant
6-
// CHECK-SAME: value = 0 : index
6+
// CHECK-SAME: value = 0
7+
// CHECK-SAME: () -> !emitc.size_t
78
%c_index = arith.constant 0 : index
89
// CHECK: emitc.constant
910
// CHECK-SAME: value = 0 : i32
@@ -75,13 +76,18 @@ func.func @arith_integer_ops_signed_nsw(%arg0: i32, %arg1: i32) {
7576
// -----
7677

7778
// CHECK-LABEL: arith_index
78-
func.func @arith_index(%arg0: index, %arg1: index) {
79-
// CHECK: emitc.add %arg0, %arg1 : (index, index) -> index
80-
%0 = arith.addi %arg0, %arg1 : index
81-
// CHECK: emitc.sub %arg0, %arg1 : (index, index) -> index
82-
%1 = arith.subi %arg0, %arg1 : index
83-
// CHECK: emitc.mul %arg0, %arg1 : (index, index) -> index
84-
%2 = arith.muli %arg0, %arg1 : index
79+
func.func @arith_index(%arg0: i32, %arg1: i32) {
80+
// CHECK: %[[CST0:.*]] = emitc.cast %{{.*}} : {{.*}} to !emitc.size_t
81+
%cst0 = arith.index_cast %arg0 : i32 to index
82+
// CHECK: %[[CST1:.*]] = emitc.cast %{{.*}} : {{.*}} to !emitc.size_t
83+
%cst1 = arith.index_cast %arg1 : i32 to index
84+
85+
// CHECK: emitc.add %[[CST0]], %[[CST1]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
86+
%0 = arith.addi %cst0, %cst1 : index
87+
// CHECK: emitc.sub %[[CST0]], %[[CST1]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
88+
%1 = arith.subi %cst0, %cst1 : index
89+
// CHECK: emitc.mul %[[CST0]], %[[CST1]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
90+
%2 = arith.muli %cst0, %cst1 : index
8591

8692
return
8793
}
@@ -420,6 +426,27 @@ func.func @arith_cmpi_predicates(%arg0: i32, %arg1: i32) {
420426
return
421427
}
422428

429+
func.func @arith_cmpi_index(%arg0: i32, %arg1: i32) -> i1 {
430+
// CHECK-LABEL: arith_cmpi_index
431+
432+
// CHECK: %[[Cst0:.*]] = emitc.cast %{{.*}} : {{.*}} to !emitc.size_t
433+
%idx0 = arith.index_cast %arg0 : i32 to index
434+
// CHECK: %[[Cst1:.*]] = emitc.cast %{{.*}} : {{.*}} to !emitc.size_t
435+
%idx1 = arith.index_cast %arg0 : i32 to index
436+
437+
// CHECK-DAG: [[ULT:[^ ]*]] = emitc.cmp lt, %[[Cst0]], %[[Cst1]] : (!emitc.size_t, !emitc.size_t) -> i1
438+
%ult = arith.cmpi ult, %idx0, %idx1 : index
439+
440+
// CHECK-DAG: %[[CastArg0:[^ ]*]] = emitc.cast %[[Cst0]] : !emitc.size_t to !emitc.ptrdiff_t
441+
// CHECK-DAG: %[[CastArg1:[^ ]*]] = emitc.cast %[[Cst1]] : !emitc.size_t to !emitc.ptrdiff_t
442+
// CHECK-DAG: %[[SLT:[^ ]*]] = emitc.cmp lt, %[[CastArg0]], %[[CastArg1]] : (!emitc.ptrdiff_t, !emitc.ptrdiff_t) -> i1
443+
%slt = arith.cmpi slt, %idx0, %idx1 : index
444+
445+
// CHECK: return %[[SLT]]
446+
return %slt: i1
447+
}
448+
449+
423450
// -----
424451

425452
func.func @arith_float_to_int_cast_ops(%arg0: f32, %arg1: f64) {
@@ -525,3 +552,47 @@ func.func @arith_extui_i1_to_i32(%arg0: i1) {
525552
%idx = arith.extui %arg0 : i1 to i32
526553
return
527554
}
555+
556+
// -----
557+
558+
func.func @arith_index_cast(%arg0: i32) -> i32 {
559+
// CHECK-LABEL: arith_index_cast
560+
// CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
561+
// CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to !emitc.ptrdiff_t
562+
// CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : !emitc.ptrdiff_t to !emitc.size_t
563+
%idx = arith.index_cast %arg0 : i32 to index
564+
// CHECK: %[[Conv2:.*]] = emitc.cast %[[Conv1]] : !emitc.size_t to !emitc.ptrdiff_t
565+
// CHECK: %[[Conv3:.*]] = emitc.cast %[[Conv2]] : !emitc.ptrdiff_t to i32
566+
%int = arith.index_cast %idx : index to i32
567+
568+
// CHECK: %[[Const:.*]] = "emitc.constant"
569+
// CHECK-SAME: value = 1
570+
// CHECK-SAME: () -> !emitc.size_t
571+
// CHECK: %[[AndOne:.*]] = emitc.bitwise_and %[[Conv1]], %[[Const]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
572+
// CHECK: %[[Conv4:.*]] = emitc.cast %[[AndOne]] : !emitc.size_t to i1
573+
%bool = arith.index_cast %idx : index to i1
574+
575+
return %int : i32
576+
}
577+
578+
// -----
579+
580+
func.func @arith_index_castui(%arg0: i32) -> i32 {
581+
// CHECK-LABEL: arith_index_castui
582+
// CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
583+
// CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
584+
// CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui32 to !emitc.size_t
585+
%idx = arith.index_castui %arg0 : i32 to index
586+
// CHECK: %[[Conv2:.*]] = emitc.cast %[[Conv1]] : !emitc.size_t to ui32
587+
// CHECK: %[[Conv3:.*]] = emitc.cast %[[Conv2]] : ui32 to i32
588+
%int = arith.index_castui %idx : index to i32
589+
590+
// CHECK: %[[Const:.*]] = "emitc.constant"
591+
// CHECK-SAME: value = 1
592+
// CHECK-SAME: () -> !emitc.size_t
593+
// CHECK: %[[AndOne:.*]] = emitc.bitwise_and %[[Conv1]], %[[Const]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
594+
// CHECK: %[[Conv4:.*]] = emitc.cast %[[AndOne]] : !emitc.size_t to i1
595+
%bool = arith.index_castui %idx : index to i1
596+
597+
return %int : i32
598+
}

0 commit comments

Comments
 (0)