13
13
#include " mlir/Dialect/Arith/Transforms/Passes.h"
14
14
#include " mlir/Dialect/Arith/Utils/Utils.h"
15
15
#include " mlir/Dialect/MemRef/IR/MemRef.h"
16
- #include " mlir/Dialect/MemRef/Transforms/Passes.h"
17
16
#include " mlir/Dialect/MemRef/Transforms/Transforms.h"
18
17
#include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
19
18
#include " mlir/Dialect/Vector/IR/VectorOps.h"
24
23
#include " mlir/Support/MathExtras.h"
25
24
#include " mlir/Transforms/DialectConversion.h"
26
25
#include " llvm/Support/FormatVariadic.h"
27
- #include " llvm/Support/MathExtras.h"
28
26
#include < cassert>
29
27
#include < type_traits>
30
28
@@ -430,6 +428,33 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
430
428
}
431
429
};
432
430
431
+ // ===----------------------------------------------------------------------===//
432
+ // ConvertMemRefCollapseShape
433
+ // ===----------------------------------------------------------------------===//
434
+
435
+ // / Emulating a `memref.collapse_shape` becomes a no-op after emulation given
436
+ // / that we flatten memrefs to a single dimension as part of the emulation and
437
+ // / there is no dimension to collapse any further.
438
+ struct ConvertMemRefCollapseShape final
439
+ : OpConversionPattern<memref::CollapseShapeOp> {
440
+ using OpConversionPattern::OpConversionPattern;
441
+
442
+ LogicalResult
443
+ matchAndRewrite (memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
444
+ ConversionPatternRewriter &rewriter) const override {
445
+ Value srcVal = adaptor.getSrc ();
446
+ auto newTy = dyn_cast<MemRefType>(srcVal.getType ());
447
+ if (!newTy)
448
+ return failure ();
449
+
450
+ if (newTy.getRank () != 1 )
451
+ return failure ();
452
+
453
+ rewriter.replaceOp (collapseShapeOp, srcVal);
454
+ return success ();
455
+ }
456
+ };
457
+
433
458
} // end anonymous namespace
434
459
435
460
// ===----------------------------------------------------------------------===//
@@ -442,7 +467,8 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
442
467
443
468
// Populate `memref.*` conversion patterns.
444
469
patterns.add <ConvertMemRefAllocation<memref::AllocOp>,
445
- ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad,
470
+ ConvertMemRefAllocation<memref::AllocaOp>,
471
+ ConvertMemRefCollapseShape, ConvertMemRefLoad,
446
472
ConvertMemrefStore, ConvertMemRefAssumeAlignment,
447
473
ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
448
474
typeConverter, patterns.getContext ());
0 commit comments