Skip to content

Commit 571831a

Browse files
authored
[mlir] Add sub-byte type emulation support for memref.collapse_shape (#89962)
This PR adds support for `memref.collapse_shape` to sub-byte type emulation. The `memref.collapse_shape` becomes a no-opt given that we are flattening the memref as part of the emulation (i.e., we are collapsing all the dimensions).
1 parent 64d514a commit 571831a

File tree

2 files changed

+49
-3
lines changed

2 files changed

+49
-3
lines changed

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include "mlir/Dialect/Arith/Transforms/Passes.h"
1414
#include "mlir/Dialect/Arith/Utils/Utils.h"
1515
#include "mlir/Dialect/MemRef/IR/MemRef.h"
16-
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
1716
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
1817
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1918
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -24,7 +23,6 @@
2423
#include "mlir/Support/MathExtras.h"
2524
#include "mlir/Transforms/DialectConversion.h"
2625
#include "llvm/Support/FormatVariadic.h"
27-
#include "llvm/Support/MathExtras.h"
2826
#include <cassert>
2927
#include <type_traits>
3028

@@ -430,6 +428,33 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
430428
}
431429
};
432430

431+
//===----------------------------------------------------------------------===//
432+
// ConvertMemRefCollapseShape
433+
//===----------------------------------------------------------------------===//
434+
435+
/// Emulating a `memref.collapse_shape` becomes a no-op after emulation given
436+
/// that we flatten memrefs to a single dimension as part of the emulation and
437+
/// there is no dimension to collapse any further.
438+
struct ConvertMemRefCollapseShape final
439+
: OpConversionPattern<memref::CollapseShapeOp> {
440+
using OpConversionPattern::OpConversionPattern;
441+
442+
LogicalResult
443+
matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
444+
ConversionPatternRewriter &rewriter) const override {
445+
Value srcVal = adaptor.getSrc();
446+
auto newTy = dyn_cast<MemRefType>(srcVal.getType());
447+
if (!newTy)
448+
return failure();
449+
450+
if (newTy.getRank() != 1)
451+
return failure();
452+
453+
rewriter.replaceOp(collapseShapeOp, srcVal);
454+
return success();
455+
}
456+
};
457+
433458
} // end anonymous namespace
434459

435460
//===----------------------------------------------------------------------===//
@@ -442,7 +467,8 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
442467

443468
// Populate `memref.*` conversion patterns.
444469
patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
445-
ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad,
470+
ConvertMemRefAllocation<memref::AllocaOp>,
471+
ConvertMemRefCollapseShape, ConvertMemRefLoad,
446472
ConvertMemrefStore, ConvertMemRefAssumeAlignment,
447473
ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
448474
typeConverter, patterns.getContext());

mlir/test/Dialect/MemRef/emulate-narrow-type.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,3 +430,23 @@ func.func @rank_zero_memref_store(%arg0: i4) -> () {
430430
// CHECK32: %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i32
431431
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
432432
// CHECK32: return
433+
434+
// -----
435+
436+
func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 {
437+
%arr = memref.alloc() : memref<32x8x128xi4>
438+
%collapse = memref.collapse_shape %arr[[0, 1], [2]] : memref<32x8x128xi4> into memref<256x128xi4>
439+
%1 = memref.load %collapse[%idx0, %idx1] : memref<256x128xi4>
440+
return %1 : i4
441+
}
442+
443+
// CHECK-LABEL: func.func @memref_collapse_shape_i4(
444+
// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<16384xi8>
445+
// CHECK-NOT: memref.collapse_shape
446+
// CHECK: memref.load %[[ALLOC]][%{{.*}}] : memref<16384xi8>
447+
448+
// CHECK32-LABEL: func.func @memref_collapse_shape_i4(
449+
// CHECK32: %[[ALLOC:.*]] = memref.alloc() : memref<4096xi32>
450+
// CHECK32-NOT: memref.collapse_shape
451+
// CHECK32: memref.load %[[ALLOC]][%{{.*}}] : memref<4096xi32>
452+

0 commit comments

Comments
 (0)