Skip to content

Commit 6fd3c20

Browse files
authored
[MLIR] Add a utility pass to linearize memref (#136797)
To add a transformation that simplifies memory access patterns, this PR adds a memref linearizer which is based on the GPU/DecomposeMemRefs pass, with the following changes: * support vector dialect ops * instead of decompose memrefs to rank-0 memrefs, flatten higher-ranked memrefs to rank-1. Notes: * After the linearization, a MemRef's offset is kept, so a `memref<4x8xf32, strided<[8, 1], offset: 100>>` becomes `memref<32xf32, strided<[1], offset: 100>>`. * It also works with dynamic shapes and strides and offsets (see test cases for details). * The shape of the casted memref is computed as 1d, flattened.
1 parent dc68166 commit 6fd3c20

File tree

5 files changed

+599
-0
lines changed

5 files changed

+599
-0
lines changed

mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,5 +245,15 @@ def ExpandReallocPass : Pass<"expand-realloc"> {
245245
];
246246
}
247247

248+
def FlattenMemrefsPass : Pass<"flatten-memref"> {
249+
let summary = "Flatten a multiple dimensional memref to 1-dimensional";
250+
let description = [{
251+
252+
}];
253+
let dependentDialects = [
254+
"affine::AffineDialect", "memref::MemRefDialect", "vector::VectorDialect"
255+
];
256+
}
257+
248258
#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
249259

mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
144144
/// ```
145145
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);
146146

147+
void populateFlattenMemrefsPatterns(RewritePatternSet &patterns);
148+
147149
/// Build a new memref::AllocaOp whose dynamic sizes are independent of all
148150
/// given independencies. If the op is already independent of all
149151
/// independencies, the same AllocaOp result is returned.

mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
88
EmulateWideInt.cpp
99
EmulateNarrowType.cpp
1010
ExtractAddressComputations.cpp
11+
FlattenMemRefs.cpp
1112
FoldMemRefAliasOps.cpp
1213
IndependenceTransforms.cpp
1314
MultiBuffer.cpp
Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
//===----- FlattenMemRefs.cpp - MemRef ops flattener pass ----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains patterns for flattening an multi-rank memref-related
10+
// ops into 1-d memref ops.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
15+
#include "mlir/Dialect/Arith/IR/Arith.h"
16+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
17+
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
18+
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
19+
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
20+
#include "mlir/Dialect/Utils/IndexingUtils.h"
21+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
22+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
23+
#include "mlir/IR/AffineExpr.h"
24+
#include "mlir/IR/Attributes.h"
25+
#include "mlir/IR/Builders.h"
26+
#include "mlir/IR/BuiltinTypes.h"
27+
#include "mlir/IR/OpDefinition.h"
28+
#include "mlir/IR/PatternMatch.h"
29+
#include "mlir/Pass/Pass.h"
30+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31+
#include "llvm/ADT/SmallVector.h"
32+
#include "llvm/ADT/TypeSwitch.h"
33+
34+
#include <numeric>
35+
36+
namespace mlir {
37+
namespace memref {
38+
#define GEN_PASS_DEF_FLATTENMEMREFSPASS
39+
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
40+
} // namespace memref
41+
} // namespace mlir
42+
43+
using namespace mlir;
44+
45+
static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
46+
OpFoldResult in) {
47+
if (Attribute offsetAttr = dyn_cast<Attribute>(in)) {
48+
return rewriter.create<arith::ConstantIndexOp>(
49+
loc, cast<IntegerAttr>(offsetAttr).getInt());
50+
}
51+
return cast<Value>(in);
52+
}
53+
54+
/// Returns a collapsed memref and the linearized index to access the element
55+
/// at the specified indices.
56+
static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
57+
Location loc,
58+
Value source,
59+
ValueRange indices) {
60+
int64_t sourceOffset;
61+
SmallVector<int64_t, 4> sourceStrides;
62+
auto sourceType = cast<MemRefType>(source.getType());
63+
if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset))) {
64+
assert(false);
65+
}
66+
67+
memref::ExtractStridedMetadataOp stridedMetadata =
68+
rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
69+
70+
auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth();
71+
OpFoldResult linearizedIndices;
72+
memref::LinearizedMemRefInfo linearizedInfo;
73+
std::tie(linearizedInfo, linearizedIndices) =
74+
memref::getLinearizedMemRefOffsetAndSize(
75+
rewriter, loc, typeBit, typeBit,
76+
stridedMetadata.getConstifiedMixedOffset(),
77+
stridedMetadata.getConstifiedMixedSizes(),
78+
stridedMetadata.getConstifiedMixedStrides(),
79+
getAsOpFoldResult(indices));
80+
81+
return std::make_pair(
82+
rewriter.create<memref::ReinterpretCastOp>(
83+
loc, source,
84+
/* offset = */ linearizedInfo.linearizedOffset,
85+
/* shapes = */
86+
ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize},
87+
/* strides = */
88+
ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)}),
89+
getValueFromOpFoldResult(rewriter, loc, linearizedIndices));
90+
}
91+
92+
static bool needFlattening(Value val) {
93+
auto type = cast<MemRefType>(val.getType());
94+
return type.getRank() > 1;
95+
}
96+
97+
static bool checkLayout(Value val) {
98+
auto type = cast<MemRefType>(val.getType());
99+
return type.getLayout().isIdentity() ||
100+
isa<StridedLayoutAttr>(type.getLayout());
101+
}
102+
103+
namespace {
104+
static Value getTargetMemref(Operation *op) {
105+
return llvm::TypeSwitch<Operation *, Value>(op)
106+
.template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp,
107+
memref::AllocOp>([](auto op) { return op.getMemref(); })
108+
.template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
109+
vector::MaskedStoreOp, vector::TransferReadOp,
110+
vector::TransferWriteOp>(
111+
[](auto op) { return op.getBase(); })
112+
.Default([](auto) { return Value{}; });
113+
}
114+
115+
template <typename T>
116+
static void castAllocResult(T oper, T newOper, Location loc,
117+
PatternRewriter &rewriter) {
118+
memref::ExtractStridedMetadataOp stridedMetadata =
119+
rewriter.create<memref::ExtractStridedMetadataOp>(loc, oper);
120+
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
121+
oper, cast<MemRefType>(oper.getType()), newOper,
122+
/*offset=*/rewriter.getIndexAttr(0),
123+
stridedMetadata.getConstifiedMixedSizes(),
124+
stridedMetadata.getConstifiedMixedStrides());
125+
}
126+
127+
template <typename T>
128+
static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
129+
Value offset) {
130+
Location loc = op->getLoc();
131+
llvm::TypeSwitch<Operation *>(op.getOperation())
132+
.template Case<memref::AllocOp>([&](auto oper) {
133+
auto newAlloc = rewriter.create<memref::AllocOp>(
134+
loc, cast<MemRefType>(flatMemref.getType()),
135+
oper.getAlignmentAttr());
136+
castAllocResult(oper, newAlloc, loc, rewriter);
137+
})
138+
.template Case<memref::AllocaOp>([&](auto oper) {
139+
auto newAlloca = rewriter.create<memref::AllocaOp>(
140+
loc, cast<MemRefType>(flatMemref.getType()),
141+
oper.getAlignmentAttr());
142+
castAllocResult(oper, newAlloca, loc, rewriter);
143+
})
144+
.template Case<memref::LoadOp>([&](auto op) {
145+
auto newLoad = rewriter.create<memref::LoadOp>(
146+
loc, op->getResultTypes(), flatMemref, ValueRange{offset});
147+
newLoad->setAttrs(op->getAttrs());
148+
rewriter.replaceOp(op, newLoad.getResult());
149+
})
150+
.template Case<memref::StoreOp>([&](auto op) {
151+
auto newStore = rewriter.create<memref::StoreOp>(
152+
loc, op->getOperands().front(), flatMemref, ValueRange{offset});
153+
newStore->setAttrs(op->getAttrs());
154+
rewriter.replaceOp(op, newStore);
155+
})
156+
.template Case<vector::LoadOp>([&](auto op) {
157+
auto newLoad = rewriter.create<vector::LoadOp>(
158+
loc, op->getResultTypes(), flatMemref, ValueRange{offset});
159+
newLoad->setAttrs(op->getAttrs());
160+
rewriter.replaceOp(op, newLoad.getResult());
161+
})
162+
.template Case<vector::StoreOp>([&](auto op) {
163+
auto newStore = rewriter.create<vector::StoreOp>(
164+
loc, op->getOperands().front(), flatMemref, ValueRange{offset});
165+
newStore->setAttrs(op->getAttrs());
166+
rewriter.replaceOp(op, newStore);
167+
})
168+
.template Case<vector::MaskedLoadOp>([&](auto op) {
169+
auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>(
170+
loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(),
171+
op.getPassThru());
172+
newMaskedLoad->setAttrs(op->getAttrs());
173+
rewriter.replaceOp(op, newMaskedLoad.getResult());
174+
})
175+
.template Case<vector::MaskedStoreOp>([&](auto op) {
176+
auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>(
177+
loc, flatMemref, ValueRange{offset}, op.getMask(),
178+
op.getValueToStore());
179+
newMaskedStore->setAttrs(op->getAttrs());
180+
rewriter.replaceOp(op, newMaskedStore);
181+
})
182+
.template Case<vector::TransferReadOp>([&](auto op) {
183+
auto newTransferRead = rewriter.create<vector::TransferReadOp>(
184+
loc, op.getType(), flatMemref, ValueRange{offset}, op.getPadding());
185+
rewriter.replaceOp(op, newTransferRead.getResult());
186+
})
187+
.template Case<vector::TransferWriteOp>([&](auto op) {
188+
auto newTransferWrite = rewriter.create<vector::TransferWriteOp>(
189+
loc, op.getVector(), flatMemref, ValueRange{offset});
190+
rewriter.replaceOp(op, newTransferWrite);
191+
})
192+
.Default([&](auto op) {
193+
op->emitOpError("unimplemented: do not know how to replace op.");
194+
});
195+
}
196+
197+
template <typename T>
198+
static ValueRange getIndices(T op) {
199+
if constexpr (std::is_same_v<T, memref::AllocaOp> ||
200+
std::is_same_v<T, memref::AllocOp>) {
201+
return ValueRange{};
202+
} else {
203+
return op.getIndices();
204+
}
205+
}
206+
207+
template <typename T>
208+
static LogicalResult canBeFlattened(T op, PatternRewriter &rewriter) {
209+
return llvm::TypeSwitch<Operation *, LogicalResult>(op.getOperation())
210+
.template Case<vector::TransferReadOp, vector::TransferWriteOp>(
211+
[&](auto oper) {
212+
// For vector.transfer_read/write, must make sure:
213+
// 1. all accesses are inbound, and
214+
// 2. has an identity or minor identity permutation map.
215+
auto permutationMap = oper.getPermutationMap();
216+
if (!permutationMap.isIdentity() &&
217+
!permutationMap.isMinorIdentity()) {
218+
return rewriter.notifyMatchFailure(
219+
oper, "only identity permutation map is supported");
220+
}
221+
mlir::ArrayAttr inbounds = oper.getInBounds();
222+
if (llvm::any_of(inbounds, [](Attribute attr) {
223+
return !cast<BoolAttr>(attr).getValue();
224+
})) {
225+
return rewriter.notifyMatchFailure(oper,
226+
"only inbounds are supported");
227+
}
228+
return success();
229+
})
230+
.Default([&](auto op) { return success(); });
231+
}
232+
233+
template <typename T>
234+
struct MemRefRewritePattern : public OpRewritePattern<T> {
235+
using OpRewritePattern<T>::OpRewritePattern;
236+
LogicalResult matchAndRewrite(T op,
237+
PatternRewriter &rewriter) const override {
238+
LogicalResult canFlatten = canBeFlattened(op, rewriter);
239+
if (failed(canFlatten)) {
240+
return canFlatten;
241+
}
242+
243+
Value memref = getTargetMemref(op);
244+
if (!needFlattening(memref) || !checkLayout(memref))
245+
return failure();
246+
auto &&[flatMemref, offset] = getFlattenMemrefAndOffset(
247+
rewriter, op->getLoc(), memref, getIndices<T>(op));
248+
replaceOp<T>(op, rewriter, flatMemref, offset);
249+
return success();
250+
}
251+
};
252+
253+
struct FlattenMemrefsPass
254+
: public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
255+
using Base::Base;
256+
257+
void getDependentDialects(DialectRegistry &registry) const override {
258+
registry.insert<affine::AffineDialect, arith::ArithDialect,
259+
memref::MemRefDialect, vector::VectorDialect>();
260+
}
261+
262+
void runOnOperation() override {
263+
RewritePatternSet patterns(&getContext());
264+
265+
memref::populateFlattenMemrefsPatterns(patterns);
266+
267+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
268+
return signalPassFailure();
269+
}
270+
};
271+
272+
} // namespace
273+
274+
void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
275+
patterns.insert<MemRefRewritePattern<memref::LoadOp>,
276+
MemRefRewritePattern<memref::StoreOp>,
277+
MemRefRewritePattern<memref::AllocOp>,
278+
MemRefRewritePattern<memref::AllocaOp>,
279+
MemRefRewritePattern<vector::LoadOp>,
280+
MemRefRewritePattern<vector::StoreOp>,
281+
MemRefRewritePattern<vector::TransferReadOp>,
282+
MemRefRewritePattern<vector::TransferWriteOp>,
283+
MemRefRewritePattern<vector::MaskedLoadOp>,
284+
MemRefRewritePattern<vector::MaskedStoreOp>>(
285+
patterns.getContext());
286+
}

0 commit comments

Comments
 (0)