@@ -51,31 +51,6 @@ static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
51
51
return cast<Value>(in);
52
52
}
53
53
54
- // / Given dimension size [d1, d2, ...] and strides [s1, s2, ...], compute the
55
- // / span of the memref.
56
- static OpFoldResult computeSize (OpBuilder &builder, Location loc,
57
- ArrayRef<OpFoldResult> dims,
58
- ArrayRef<OpFoldResult> strides) {
59
- assert (dims.size () == strides.size () &&
60
- " number of dimensions and strides should be equal" );
61
- SmallVector<AffineExpr> symbols (2 * dims.size ());
62
- bindSymbolsList (builder.getContext (), MutableArrayRef{symbols});
63
- SmallVector<AffineExpr> productExpressions;
64
- SmallVector<OpFoldResult> values;
65
- size_t symbolIndex = 0 ;
66
- for (auto &&[dim, stride] : llvm::zip (dims, strides)) {
67
- AffineExpr dimExpr = symbols[symbolIndex++];
68
- AffineExpr strideExpr = symbols[symbolIndex++];
69
- productExpressions.push_back (dimExpr * strideExpr);
70
- values.push_back (dim);
71
- values.push_back (stride);
72
- }
73
-
74
- AffineMap maxMap = AffineMap::get (0 , symbols.size (), productExpressions,
75
- builder.getContext ());
76
- return affine::makeComposedFoldedAffineMax (builder, loc, maxMap, values);
77
- }
78
-
79
54
// / Returns a collapsed memref and the linearized index to access the element
80
55
// / at the specified indices.
81
56
static std::pair<Value, Value> getFlattenMemrefAndOffset (OpBuilder &rewriter,
@@ -108,9 +83,7 @@ static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
108
83
loc, source,
109
84
/* offset = */ linearizedInfo.linearizedOffset ,
110
85
/* shapes = */
111
- ArrayRef<OpFoldResult>{computeSize (
112
- rewriter, loc, stridedMetadata.getConstifiedMixedSizes (),
113
- stridedMetadata.getConstifiedMixedStrides ())},
86
+ ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize },
114
87
/* strides = */
115
88
ArrayRef<OpFoldResult>{rewriter.getIndexAttr (1 )}),
116
89
getValueFromOpFoldResult (rewriter, loc, linearizedIndices));
@@ -133,16 +106,15 @@ static Value getTargetMemref(Operation *op) {
133
106
.template Case <memref::LoadOp, memref::StoreOp, memref::AllocaOp,
134
107
memref::AllocOp>([](auto op) { return op.getMemref (); })
135
108
.template Case <vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
136
- vector::MaskedStoreOp>(
109
+ vector::MaskedStoreOp, vector::TransferReadOp,
110
+ vector::TransferWriteOp>(
137
111
[](auto op) { return op.getBase (); })
138
- .template Case <vector::TransferReadOp, vector::TransferWriteOp>(
139
- [](auto op) { return op.getSource (); })
140
112
.Default ([](auto ) { return Value{}; });
141
113
}
142
114
143
115
template <typename T>
144
- static void castResult (T oper, T newOper, Location loc,
145
- PatternRewriter &rewriter) {
116
+ static void castAllocResult (T oper, T newOper, Location loc,
117
+ PatternRewriter &rewriter) {
146
118
memref::ExtractStridedMetadataOp stridedMetadata =
147
119
rewriter.create <memref::ExtractStridedMetadataOp>(loc, oper);
148
120
rewriter.replaceOpWithNewOp <memref::ReinterpretCastOp>(
@@ -155,19 +127,19 @@ static void castResult(T oper, T newOper, Location loc,
155
127
template <typename T>
156
128
static void replaceOp (T op, PatternRewriter &rewriter, Value flatMemref,
157
129
Value offset) {
158
- auto loc = op->getLoc ();
130
+ Location loc = op->getLoc ();
159
131
llvm::TypeSwitch<Operation *>(op.getOperation ())
160
132
.template Case <memref::AllocOp>([&](auto oper) {
161
133
auto newAlloc = rewriter.create <memref::AllocOp>(
162
134
loc, cast<MemRefType>(flatMemref.getType ()),
163
135
oper.getAlignmentAttr ());
164
- castResult (oper, newAlloc, loc, rewriter);
136
+ castAllocResult (oper, newAlloc, loc, rewriter);
165
137
})
166
138
.template Case <memref::AllocaOp>([&](auto oper) {
167
139
auto newAlloca = rewriter.create <memref::AllocaOp>(
168
140
loc, cast<MemRefType>(flatMemref.getType ()),
169
141
oper.getAlignmentAttr ());
170
- castResult (oper, newAlloca, loc, rewriter);
142
+ castAllocResult (oper, newAlloca, loc, rewriter);
171
143
})
172
144
.template Case <memref::LoadOp>([&](auto op) {
173
145
auto newLoad = rewriter.create <memref::LoadOp>(
@@ -232,11 +204,42 @@ static ValueRange getIndices(T op) {
232
204
}
233
205
}
234
206
207
+ template <typename T>
208
+ static LogicalResult canBeFlattened (T op, PatternRewriter &rewriter) {
209
+ return llvm::TypeSwitch<Operation *, LogicalResult>(op.getOperation ())
210
+ .template Case <vector::TransferReadOp, vector::TransferWriteOp>(
211
+ [&](auto oper) {
212
+ // For vector.transfer_read/write, must make sure:
213
+ // 1. all accesses are inbound, and
214
+ // 2. has an identity or minor identity permutation map.
215
+ auto permutationMap = oper.getPermutationMap ();
216
+ if (!permutationMap.isIdentity () &&
217
+ !permutationMap.isMinorIdentity ()) {
218
+ return rewriter.notifyMatchFailure (
219
+ oper, " only identity permutation map is supported" );
220
+ }
221
+ mlir::ArrayAttr inbounds = oper.getInBounds ();
222
+ if (llvm::any_of (inbounds, [](Attribute attr) {
223
+ return !cast<BoolAttr>(attr).getValue ();
224
+ })) {
225
+ return rewriter.notifyMatchFailure (oper,
226
+ " only inbounds are supported" );
227
+ }
228
+ return success ();
229
+ })
230
+ .Default ([&](auto op) { return success (); });
231
+ }
232
+
235
233
template <typename T>
236
234
struct MemRefRewritePattern : public OpRewritePattern <T> {
237
235
using OpRewritePattern<T>::OpRewritePattern;
238
236
LogicalResult matchAndRewrite (T op,
239
237
PatternRewriter &rewriter) const override {
238
+ LogicalResult canFlatten = canBeFlattened (op, rewriter);
239
+ if (failed (canFlatten)) {
240
+ return canFlatten;
241
+ }
242
+
240
243
Value memref = getTargetMemref (op);
241
244
if (!needFlattening (memref) || !checkLayout (memref))
242
245
return failure ();
0 commit comments