15
15
16
16
#include " mlir/Dialect/Arith/IR/Arith.h"
17
17
#include " mlir/Dialect/EmitC/IR/EmitC.h"
18
+ #include " mlir/Dialect/EmitC/Transforms/TypeConversions.h"
18
19
#include " mlir/IR/BuiltinAttributes.h"
19
20
#include " mlir/IR/BuiltinTypes.h"
20
21
#include " mlir/Support/LogicalResult.h"
@@ -36,8 +37,11 @@ class ArithConstantOpConversionPattern
36
37
matchAndRewrite (arith::ConstantOp arithConst,
37
38
arith::ConstantOp::Adaptor adaptor,
38
39
ConversionPatternRewriter &rewriter) const override {
39
- rewriter.replaceOpWithNewOp <emitc::ConstantOp>(
40
- arithConst, arithConst.getType (), adaptor.getValue ());
40
+ Type newTy = this ->getTypeConverter ()->convertType (arithConst.getType ());
41
+ if (!newTy)
42
+ return rewriter.notifyMatchFailure (arithConst, " type conversion failed" );
43
+ rewriter.replaceOpWithNewOp <emitc::ConstantOp>(arithConst, newTy,
44
+ adaptor.getValue ());
41
45
return success ();
42
46
}
43
47
};
@@ -52,6 +56,12 @@ Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
52
56
return IntegerType::get (ty.getContext (), ty.getIntOrFloatBitWidth (),
53
57
signedness);
54
58
}
59
+ } else if (emitc::isPointerWideType (ty)) {
60
+ if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
61
+ if (needsUnsigned)
62
+ return emitc::SizeTType::get (ty.getContext ());
63
+ return emitc::PtrDiffTType::get (ty.getContext ());
64
+ }
55
65
}
56
66
return ty;
57
67
}
@@ -264,8 +274,9 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
264
274
ConversionPatternRewriter &rewriter) const override {
265
275
266
276
Type type = adaptor.getLhs ().getType ();
267
- if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
268
- return rewriter.notifyMatchFailure (op, " expected integer or index type" );
277
+ if (type && !(isa<IntegerType>(type) || emitc::isPointerWideType (type))) {
278
+ return rewriter.notifyMatchFailure (
279
+ op, " expected integer or size_t/ssize_t type" );
269
280
}
270
281
271
282
bool needsUnsigned = needsUnsignedCmp (op.getPredicate ());
@@ -290,17 +301,21 @@ class CastConversion : public OpConversionPattern<ArithOp> {
290
301
ConversionPatternRewriter &rewriter) const override {
291
302
292
303
Type opReturnType = this ->getTypeConverter ()->convertType (op.getType ());
293
- if (!isa_and_nonnull<IntegerType>(opReturnType))
294
- return rewriter.notifyMatchFailure (op, " expected integer result type" );
304
+ if (opReturnType && !(isa_and_nonnull<IntegerType>(opReturnType) ||
305
+ emitc::isPointerWideType (opReturnType)))
306
+ return rewriter.notifyMatchFailure (
307
+ op, " expected integer or size_t/ssize_t result type" );
295
308
296
309
if (adaptor.getOperands ().size () != 1 ) {
297
310
return rewriter.notifyMatchFailure (
298
311
op, " CastConversion only supports unary ops" );
299
312
}
300
313
301
314
Type operandType = adaptor.getIn ().getType ();
302
- if (!isa_and_nonnull<IntegerType>(operandType))
303
- return rewriter.notifyMatchFailure (op, " expected integer operand type" );
315
+ if (operandType && !(isa_and_nonnull<IntegerType>(operandType) ||
316
+ emitc::isPointerWideType (operandType)))
317
+ return rewriter.notifyMatchFailure (
318
+ op, " expected integer or size_t/ssize_t operand type" );
304
319
305
320
// Signed (sign-extending) casts from i1 are not supported.
306
321
if (operandType.isInteger (1 ) && !castToUnsigned)
@@ -311,8 +326,11 @@ class CastConversion : public OpConversionPattern<ArithOp> {
311
326
// equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
312
327
// truncation.
313
328
if (opReturnType.isInteger (1 )) {
329
+ Type attrType = (emitc::isPointerWideType (operandType))
330
+ ? rewriter.getIndexType ()
331
+ : operandType;
314
332
auto constOne = rewriter.create <emitc::ConstantOp>(
315
- op.getLoc (), operandType, rewriter.getIntegerAttr (operandType , 1 ));
333
+ op.getLoc (), operandType, rewriter.getIntegerAttr (attrType , 1 ));
316
334
auto oneAndOperand = rewriter.create <emitc::BitwiseAndOp>(
317
335
op.getLoc (), operandType, adaptor.getIn (), constOne);
318
336
rewriter.replaceOpWithNewOp <emitc::CastOp>(op, opReturnType,
@@ -365,7 +383,11 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
365
383
matchAndRewrite (ArithOp arithOp, typename ArithOp::Adaptor adaptor,
366
384
ConversionPatternRewriter &rewriter) const override {
367
385
368
- rewriter.template replaceOpWithNewOp <EmitCOp>(arithOp, arithOp.getType (),
386
+ Type newTy = this ->getTypeConverter ()->convertType (arithOp.getType ());
387
+ if (!newTy)
388
+ return rewriter.notifyMatchFailure (arithOp,
389
+ " converting result type failed" );
390
+ rewriter.template replaceOpWithNewOp <EmitCOp>(arithOp, newTy,
369
391
adaptor.getOperands ());
370
392
371
393
return success ();
@@ -382,8 +404,10 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
382
404
ConversionPatternRewriter &rewriter) const override {
383
405
384
406
Type type = this ->getTypeConverter ()->convertType (op.getType ());
385
- if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
386
- return rewriter.notifyMatchFailure (op, " expected integer type" );
407
+ if (type && !(isa_and_nonnull<IntegerType>(type) ||
408
+ emitc::isPointerWideType (type))) {
409
+ return rewriter.notifyMatchFailure (
410
+ op, " expected integer or size_t/ssize_t/ptrdiff_t type" );
387
411
}
388
412
389
413
if (type.isInteger (1 )) {
@@ -578,6 +602,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
578
602
RewritePatternSet &patterns) {
579
603
MLIRContext *ctx = patterns.getContext ();
580
604
605
+ mlir::populateEmitCSizeTTypeConversions (typeConverter);
606
+
581
607
// clang-format off
582
608
patterns.add <
583
609
ArithConstantOpConversionPattern,
@@ -600,6 +626,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
600
626
UnsignedCastConversion<arith::TruncIOp>,
601
627
SignedCastConversion<arith::ExtSIOp>,
602
628
UnsignedCastConversion<arith::ExtUIOp>,
629
+ SignedCastConversion<arith::IndexCastOp>,
630
+ UnsignedCastConversion<arith::IndexCastUIOp>,
603
631
ItoFCastOpConversion<arith::SIToFPOp>,
604
632
ItoFCastOpConversion<arith::UIToFPOp>,
605
633
FtoICastOpConversion<arith::FPToSIOp>,
0 commit comments