@@ -90,14 +90,19 @@ namespace {
90
90
// / Note that an alternative is to transform it to linalg.transpose +
91
91
// / vector.transfer_read to do the transpose in memory instead.
92
92
struct TransferReadPermutationLowering
93
- : public OpRewritePattern <vector::TransferReadOp> {
94
- using OpRewritePattern::OpRewritePattern ;
93
+ : public MaskableOpRewritePattern <vector::TransferReadOp> {
94
+ using MaskableOpRewritePattern::MaskableOpRewritePattern ;
95
95
96
- LogicalResult matchAndRewrite (vector::TransferReadOp op,
97
- PatternRewriter &rewriter) const override {
96
+ FailureOr<mlir::Value>
97
+ matchAndRewriteMaskableOp (vector::TransferReadOp op,
98
+ MaskingOpInterface maskOp,
99
+ PatternRewriter &rewriter) const override {
98
100
// TODO: support 0-d corner case.
99
101
if (op.getTransferRank () == 0 )
100
102
return rewriter.notifyMatchFailure (op, " 0-d corner case not supported" );
103
+ // TODO: Support transfer_read inside MaskOp case.
104
+ if (maskOp)
105
+ return rewriter.notifyMatchFailure (op, " Masked case not supported" );
101
106
102
107
SmallVector<unsigned > permutation;
103
108
AffineMap map = op.getPermutationMap ();
@@ -142,9 +147,9 @@ struct TransferReadPermutationLowering
142
147
143
148
// Transpose result of transfer_read.
144
149
SmallVector<int64_t > transposePerm (permutation.begin (), permutation.end ());
145
- rewriter. replaceOpWithNewOp <vector::TransposeOp>(op, newRead,
146
- transposePerm);
147
- return success ();
150
+ return rewriter
151
+ . create <vector::TransposeOp>(op. getLoc (), newRead, transposePerm)
152
+ . getResult ();
148
153
}
149
154
};
150
155
@@ -165,14 +170,19 @@ struct TransferReadPermutationLowering
165
170
// / %v = vector.transfer_write %tmp ...
166
171
// / permutation_map: (d0, d1, d2, d3) -> (d2, d3)
167
172
struct TransferWritePermutationLowering
168
- : public OpRewritePattern <vector::TransferWriteOp> {
169
- using OpRewritePattern::OpRewritePattern ;
173
+ : public MaskableOpRewritePattern <vector::TransferWriteOp> {
174
+ using MaskableOpRewritePattern::MaskableOpRewritePattern ;
170
175
171
- LogicalResult matchAndRewrite (vector::TransferWriteOp op,
172
- PatternRewriter &rewriter) const override {
176
+ FailureOr<mlir::Value>
177
+ matchAndRewriteMaskableOp (vector::TransferWriteOp op,
178
+ MaskingOpInterface maskOp,
179
+ PatternRewriter &rewriter) const override {
173
180
// TODO: support 0-d corner case.
174
181
if (op.getTransferRank () == 0 )
175
182
return rewriter.notifyMatchFailure (op, " 0-d corner case not supported" );
183
+ // TODO: Support transfer_write inside MaskOp case.
184
+ if (maskOp)
185
+ return rewriter.notifyMatchFailure (op, " Masked case not supported" );
176
186
177
187
SmallVector<unsigned > permutation;
178
188
AffineMap map = op.getPermutationMap ();
@@ -207,11 +217,14 @@ struct TransferWritePermutationLowering
207
217
op.getLoc (), op.getVector (), indices);
208
218
auto newMap = AffineMap::getMinorIdentityMap (
209
219
map.getNumDims (), map.getNumResults (), rewriter.getContext ());
210
- rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
211
- op, newVec, op.getSource (), op.getIndices (), AffineMapAttr::get (newMap),
212
- op.getMask (), newInBoundsAttr);
213
-
214
- return success ();
220
+ auto newWrite = rewriter.create <vector::TransferWriteOp>(
221
+ op.getLoc (), newVec, op.getSource (), op.getIndices (),
222
+ AffineMapAttr::get (newMap), op.getMask (), newInBoundsAttr);
223
+ if (newWrite.hasPureTensorSemantics ())
224
+ return newWrite.getResult ();
225
+ // In the memref case there's no return value. Use empty value to signal
226
+ // success.
227
+ return Value ();
215
228
}
216
229
};
217
230
@@ -231,14 +244,19 @@ struct TransferWritePermutationLowering
231
244
// / vector<1x8x16xf32>
232
245
// / ```
233
246
struct TransferWriteNonPermutationLowering
234
- : public OpRewritePattern <vector::TransferWriteOp> {
235
- using OpRewritePattern::OpRewritePattern ;
247
+ : public MaskableOpRewritePattern <vector::TransferWriteOp> {
248
+ using MaskableOpRewritePattern::MaskableOpRewritePattern ;
236
249
237
- LogicalResult matchAndRewrite (vector::TransferWriteOp op,
238
- PatternRewriter &rewriter) const override {
250
+ FailureOr<mlir::Value>
251
+ matchAndRewriteMaskableOp (vector::TransferWriteOp op,
252
+ MaskingOpInterface maskOp,
253
+ PatternRewriter &rewriter) const override {
239
254
// TODO: support 0-d corner case.
240
255
if (op.getTransferRank () == 0 )
241
256
return rewriter.notifyMatchFailure (op, " 0-d corner case not supported" );
257
+ // TODO: Support transfer_write inside MaskOp case.
258
+ if (maskOp)
259
+ return rewriter.notifyMatchFailure (op, " Masked case not supported" );
242
260
243
261
SmallVector<unsigned > permutation;
244
262
AffineMap map = op.getPermutationMap ();
@@ -285,10 +303,14 @@ struct TransferWriteNonPermutationLowering
285
303
newInBoundsValues.push_back (op.isDimInBounds (i));
286
304
}
287
305
ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr (newInBoundsValues);
288
- rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
289
- op, newVec, op.getSource (), op.getIndices (), AffineMapAttr::get (newMap),
290
- newMask, newInBoundsAttr);
291
- return success ();
306
+ auto newWrite = rewriter.create <vector::TransferWriteOp>(
307
+ op.getLoc (), newVec, op.getSource (), op.getIndices (),
308
+ AffineMapAttr::get (newMap), newMask, newInBoundsAttr);
309
+ if (newWrite.hasPureTensorSemantics ())
310
+ return newWrite.getResult ();
311
+ // In the memref case there's no return value. Use empty value to signal
312
+ // success.
313
+ return Value ();
292
314
}
293
315
};
294
316
0 commit comments