Skip to content

Commit 10c3872

Browse files
committed
update comments
1 parent 8e62bf5 commit 10c3872

File tree

2 files changed

+72
-42
lines changed

2 files changed

+72
-42
lines changed

mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -51,31 +51,6 @@ static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
5151
return cast<Value>(in);
5252
}
5353

54-
/// Given dimension size [d1, d2, ...] and strides [s1, s2, ...], compute the
55-
/// span of the memref.
56-
static OpFoldResult computeSize(OpBuilder &builder, Location loc,
57-
ArrayRef<OpFoldResult> dims,
58-
ArrayRef<OpFoldResult> strides) {
59-
assert(dims.size() == strides.size() &&
60-
"number of dimensions and strides should be equal");
61-
SmallVector<AffineExpr> symbols(2 * dims.size());
62-
bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
63-
SmallVector<AffineExpr> productExpressions;
64-
SmallVector<OpFoldResult> values;
65-
size_t symbolIndex = 0;
66-
for (auto &&[dim, stride] : llvm::zip(dims, strides)) {
67-
AffineExpr dimExpr = symbols[symbolIndex++];
68-
AffineExpr strideExpr = symbols[symbolIndex++];
69-
productExpressions.push_back(dimExpr * strideExpr);
70-
values.push_back(dim);
71-
values.push_back(stride);
72-
}
73-
74-
AffineMap maxMap = AffineMap::get(0, symbols.size(), productExpressions,
75-
builder.getContext());
76-
return affine::makeComposedFoldedAffineMax(builder, loc, maxMap, values);
77-
}
78-
7954
/// Returns a collapsed memref and the linearized index to access the element
8055
/// at the specified indices.
8156
static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
@@ -108,9 +83,7 @@ static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
10883
loc, source,
10984
/* offset = */ linearizedInfo.linearizedOffset,
11085
/* shapes = */
111-
ArrayRef<OpFoldResult>{computeSize(
112-
rewriter, loc, stridedMetadata.getConstifiedMixedSizes(),
113-
stridedMetadata.getConstifiedMixedStrides())},
86+
ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize},
11487
/* strides = */
11588
ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)}),
11689
getValueFromOpFoldResult(rewriter, loc, linearizedIndices));
@@ -133,16 +106,15 @@ static Value getTargetMemref(Operation *op) {
133106
.template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp,
134107
memref::AllocOp>([](auto op) { return op.getMemref(); })
135108
.template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
136-
vector::MaskedStoreOp>(
109+
vector::MaskedStoreOp, vector::TransferReadOp,
110+
vector::TransferWriteOp>(
137111
[](auto op) { return op.getBase(); })
138-
.template Case<vector::TransferReadOp, vector::TransferWriteOp>(
139-
[](auto op) { return op.getSource(); })
140112
.Default([](auto) { return Value{}; });
141113
}
142114

143115
template <typename T>
144-
static void castResult(T oper, T newOper, Location loc,
145-
PatternRewriter &rewriter) {
116+
static void castAllocResult(T oper, T newOper, Location loc,
117+
PatternRewriter &rewriter) {
146118
memref::ExtractStridedMetadataOp stridedMetadata =
147119
rewriter.create<memref::ExtractStridedMetadataOp>(loc, oper);
148120
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
@@ -155,19 +127,19 @@ static void castResult(T oper, T newOper, Location loc,
155127
template <typename T>
156128
static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
157129
Value offset) {
158-
auto loc = op->getLoc();
130+
Location loc = op->getLoc();
159131
llvm::TypeSwitch<Operation *>(op.getOperation())
160132
.template Case<memref::AllocOp>([&](auto oper) {
161133
auto newAlloc = rewriter.create<memref::AllocOp>(
162134
loc, cast<MemRefType>(flatMemref.getType()),
163135
oper.getAlignmentAttr());
164-
castResult(oper, newAlloc, loc, rewriter);
136+
castAllocResult(oper, newAlloc, loc, rewriter);
165137
})
166138
.template Case<memref::AllocaOp>([&](auto oper) {
167139
auto newAlloca = rewriter.create<memref::AllocaOp>(
168140
loc, cast<MemRefType>(flatMemref.getType()),
169141
oper.getAlignmentAttr());
170-
castResult(oper, newAlloca, loc, rewriter);
142+
castAllocResult(oper, newAlloca, loc, rewriter);
171143
})
172144
.template Case<memref::LoadOp>([&](auto op) {
173145
auto newLoad = rewriter.create<memref::LoadOp>(
@@ -232,11 +204,42 @@ static ValueRange getIndices(T op) {
232204
}
233205
}
234206

207+
template <typename T>
208+
static LogicalResult canBeFlattened(T op, PatternRewriter &rewriter) {
209+
return llvm::TypeSwitch<Operation *, LogicalResult>(op.getOperation())
210+
.template Case<vector::TransferReadOp, vector::TransferWriteOp>(
211+
[&](auto oper) {
212+
// For vector.transfer_read/write, must make sure:
213+
// 1. all accesses are inbound, and
214+
// 2. has an identity or minor identity permutation map.
215+
auto permutationMap = oper.getPermutationMap();
216+
if (!permutationMap.isIdentity() &&
217+
!permutationMap.isMinorIdentity()) {
218+
return rewriter.notifyMatchFailure(
219+
oper, "only identity permutation map is supported");
220+
}
221+
mlir::ArrayAttr inbounds = oper.getInBounds();
222+
if (llvm::any_of(inbounds, [](Attribute attr) {
223+
return !cast<BoolAttr>(attr).getValue();
224+
})) {
225+
return rewriter.notifyMatchFailure(oper,
226+
"only inbounds are supported");
227+
}
228+
return success();
229+
})
230+
.Default([&](auto op) { return success(); });
231+
}
232+
235233
template <typename T>
236234
struct MemRefRewritePattern : public OpRewritePattern<T> {
237235
using OpRewritePattern<T>::OpRewritePattern;
238236
LogicalResult matchAndRewrite(T op,
239237
PatternRewriter &rewriter) const override {
238+
LogicalResult canFlatten = canBeFlattened(op, rewriter);
239+
if (failed(canFlatten)) {
240+
return canFlatten;
241+
}
242+
240243
Value memref = getTargetMemref(op);
241244
if (!needFlattening(memref) || !checkLayout(memref))
242245
return failure();

mlir/test/Dialect/MemRef/flatten_memref.mlir

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func.func @load_scalar_from_memref_dynamic_dim(%input: memref<?x?xf32, strided<[
2626
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32, strided<[?, ?], offset: ?>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
2727
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
2828
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
29-
// CHECK: %[[SIZE:.*]] = affine.max #[[MAP1]]()[%[[SIZES]]#0, %[[STRIDES]]#0, %[[SIZES]]#1, %[[STRIDES]]#1]
29+
// CHECK: %[[SIZE:.*]] = affine.max #[[MAP1]]()[%[[STRIDES]]#0, %[[SIZES]]#0, %[[STRIDES]]#1, %[[SIZES]]#1]
3030
// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
3131
// CHECK: memref.load %[[REINT]][%[[IDX]]]
3232

@@ -70,7 +70,7 @@ func.func @store_scalar_from_memref_dynamic_dim(%input: memref<?x?xf32, strided<
7070
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32, strided<[?, ?], offset: ?>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: f32)
7171
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
7272
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
73-
// CHECK: %[[SIZE:.*]] = affine.max #[[MAP1]]()[%[[SIZES]]#0, %[[STRIDES]]#0, %[[SIZES]]#1, %[[STRIDES]]#1]
73+
// CHECK: %[[SIZE:.*]] = affine.max #[[MAP1]]()[%[[STRIDES]]#0, %[[SIZES]]#0, %[[STRIDES]]#1, %[[SIZES]]#1]
7474
// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [1]
7575
// CHECK: memref.store %[[ARG3]], %[[REINT]][%[[IDX]]]
7676

@@ -196,22 +196,49 @@ func.func @mask_load_vector_from_memref_dynamic(%input: memref<3x7xi2>, %row: in
196196

197197
func.func @transfer_read_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) -> vector<8xi2> {
198198
%c0 = arith.constant 0 : i2
199-
%0 = vector.transfer_read %input[%col, %row], %c0 : memref<4x8xi2>, vector<8xi2>
199+
%0 = vector.transfer_read %input[%col, %row], %c0 {in_bounds = [true]} : memref<4x8xi2>, vector<8xi2>
200200
return %0 : vector<8xi2>
201201
}
202-
// CHECK-LABEL: func @transfer_read_memref
202+
203+
// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 8 + s1)>
204+
// CHECK: func @transfer_read_memref
203205
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xi2>, %[[ARG1:.*]]: vector<8xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
204206
// CHECK: %[[C0:.*]] = arith.constant 0 : i2
205-
// CHECK: %[[IDX:.*]] = affine.apply #map()[%[[ARG3]], %[[ARG2]]]
207+
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG2]]]
206208
// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]]
207209
// CHECK-NEXT: vector.transfer_read %[[REINT]][%[[IDX]]], %[[C0]]
208210

209211
// -----
210212

213+
func.func @transfer_read_memref_not_inbound(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) -> vector<8xi2> {
214+
%c0 = arith.constant 0 : i2
215+
%0 = vector.transfer_read %input[%col, %row], %c0 {in_bounds = [false]} : memref<4x8xi2>, vector<8xi2>
216+
return %0 : vector<8xi2>
217+
}
218+
219+
// CHECK-LABEL: func @transfer_read_memref_not_inbound
220+
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xi2>, %[[ARG1:.*]]: vector<8xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
221+
// CHECK: vector.transfer_read %[[ARG0]][%[[ARG3]], %[[ARG2]]]
222+
223+
// -----
224+
225+
func.func @transfer_read_memref_non_id(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) -> vector<8xi2> {
226+
%c0 = arith.constant 0 : i2
227+
%0 = vector.transfer_read %input[%col, %row], %c0 {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref<4x8xi2>, vector<8xi2>
228+
return %0 : vector<8xi2>
229+
}
230+
231+
// CHECK-LABEL: func @transfer_read_memref_non_id
232+
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xi2>, %[[ARG1:.*]]: vector<8xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
233+
// CHECK: vector.transfer_read %[[ARG0]][%[[ARG3]], %[[ARG2]]]
234+
235+
// -----
236+
211237
func.func @transfer_write_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) {
212-
vector.transfer_write %value, %input[%col, %row] : vector<8xi2>, memref<4x8xi2>
238+
vector.transfer_write %value, %input[%col, %row] {in_bounds = [true]} : vector<8xi2>, memref<4x8xi2>
213239
return
214240
}
241+
215242
// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 8 + s1)>
216243
// CHECK: func @transfer_write_memref
217244
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xi2>, %[[ARG1:.*]]: vector<8xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)

0 commit comments

Comments
 (0)