@@ -90,17 +90,18 @@ namespace {
90
90
// / Note that an alternative is to transform it to linalg.transpose +
91
91
// / vector.transfer_read to do the transpose in memory instead.
92
92
struct TransferReadPermutationLowering
93
- : public OpRewritePattern <vector::TransferReadOp> {
94
- using OpRewritePattern::OpRewritePattern ;
93
+ : public MaskableOpRewritePattern <vector::TransferReadOp> {
94
+ using MaskableOpRewritePattern::MaskableOpRewritePattern ;
95
95
96
- LogicalResult matchAndRewrite (vector::TransferReadOp op,
97
- PatternRewriter &rewriter) const override {
96
+ FailureOr<mlir::Value>
97
+ matchAndRewriteMaskableOp (vector::TransferReadOp op,
98
+ MaskingOpInterface maskOp,
99
+ PatternRewriter &rewriter) const override {
98
100
// TODO: support 0-d corner case.
99
101
if (op.getTransferRank () == 0 )
100
102
return rewriter.notifyMatchFailure (op, " 0-d corner case not supported" );
101
- if (isa<vector::MaskOp>(op->getParentOp ()))
102
- return rewriter.notifyMatchFailure (
103
- op, " Cannot expand transfer read inside a Mask Op" );
103
+ if (maskOp)
104
+ return rewriter.notifyMatchFailure (op, " Masked case not supported" );
104
105
105
106
SmallVector<unsigned > permutation;
106
107
AffineMap map = op.getPermutationMap ();
@@ -145,9 +146,9 @@ struct TransferReadPermutationLowering
145
146
146
147
// Transpose result of transfer_read.
147
148
SmallVector<int64_t > transposePerm (permutation.begin (), permutation.end ());
148
- rewriter. replaceOpWithNewOp <vector::TransposeOp>(op, newRead,
149
- transposePerm);
150
- return success ();
149
+ return rewriter
150
+ . create <vector::TransposeOp>(op. getLoc (), newRead, transposePerm)
151
+ . getResult ();
151
152
}
152
153
};
153
154
@@ -168,17 +169,18 @@ struct TransferReadPermutationLowering
168
169
// / %v = vector.transfer_write %tmp ...
169
170
// / permutation_map: (d0, d1, d2, d3) -> (d2, d3)
170
171
struct TransferWritePermutationLowering
171
- : public OpRewritePattern <vector::TransferWriteOp> {
172
- using OpRewritePattern::OpRewritePattern ;
172
+ : public MaskableOpRewritePattern <vector::TransferWriteOp> {
173
+ using MaskableOpRewritePattern::MaskableOpRewritePattern ;
173
174
174
- LogicalResult matchAndRewrite (vector::TransferWriteOp op,
175
- PatternRewriter &rewriter) const override {
175
+ FailureOr<mlir::Value>
176
+ matchAndRewriteMaskableOp (vector::TransferWriteOp op,
177
+ MaskingOpInterface maskOp,
178
+ PatternRewriter &rewriter) const override {
176
179
// TODO: support 0-d corner case.
177
180
if (op.getTransferRank () == 0 )
178
181
return rewriter.notifyMatchFailure (op, " 0-d corner case not supported" );
179
- if (isa<vector::MaskOp>(op->getParentOp ()))
180
- return rewriter.notifyMatchFailure (
181
- op, " Cannot expand transfer write inside a Mask Op" );
182
+ if (maskOp)
183
+ return rewriter.notifyMatchFailure (op, " Masked case not supported" );
182
184
183
185
SmallVector<unsigned > permutation;
184
186
AffineMap map = op.getPermutationMap ();
@@ -213,11 +215,11 @@ struct TransferWritePermutationLowering
213
215
op.getLoc (), op.getVector (), indices);
214
216
auto newMap = AffineMap::getMinorIdentityMap (
215
217
map.getNumDims (), map.getNumResults (), rewriter.getContext ());
216
- rewriter. replaceOpWithNewOp <vector::TransferWriteOp>(
217
- op, newVec, op. getSource (), op. getIndices (), AffineMapAttr::get (newMap),
218
- op.getMask (), newInBoundsAttr);
219
-
220
- return success ();
218
+ return rewriter
219
+ . create <vector::TransferWriteOp>(
220
+ op.getLoc (), newVec, op. getSource (), op. getIndices (),
221
+ AffineMapAttr::get (newMap), op. getMask (), newInBoundsAttr)
222
+ . getResult ();
221
223
}
222
224
};
223
225
@@ -237,17 +239,18 @@ struct TransferWritePermutationLowering
237
239
// / vector<1x8x16xf32>
238
240
// / ```
239
241
struct TransferWriteNonPermutationLowering
240
- : public OpRewritePattern <vector::TransferWriteOp> {
241
- using OpRewritePattern::OpRewritePattern ;
242
+ : public MaskableOpRewritePattern <vector::TransferWriteOp> {
243
+ using MaskableOpRewritePattern::MaskableOpRewritePattern ;
242
244
243
- LogicalResult matchAndRewrite (vector::TransferWriteOp op,
244
- PatternRewriter &rewriter) const override {
245
+ FailureOr<mlir::Value>
246
+ matchAndRewriteMaskableOp (vector::TransferWriteOp op,
247
+ MaskingOpInterface maskOp,
248
+ PatternRewriter &rewriter) const override {
245
249
// TODO: support 0-d corner case.
246
250
if (op.getTransferRank () == 0 )
247
251
return rewriter.notifyMatchFailure (op, " 0-d corner case not supported" );
248
- if (isa<vector::MaskOp>(op->getParentOp ()))
249
- return rewriter.notifyMatchFailure (
250
- op, " Cannot expand transfer write inside a Mask Op" );
252
+ if (maskOp)
253
+ return rewriter.notifyMatchFailure (op, " Masked case not supported" );
251
254
252
255
SmallVector<unsigned > permutation;
253
256
AffineMap map = op.getPermutationMap ();
@@ -294,10 +297,11 @@ struct TransferWriteNonPermutationLowering
294
297
newInBoundsValues.push_back (op.isDimInBounds (i));
295
298
}
296
299
ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr (newInBoundsValues);
297
- rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
298
- op, newVec, op.getSource (), op.getIndices (), AffineMapAttr::get (newMap),
299
- newMask, newInBoundsAttr);
300
- return success ();
300
+ return rewriter
301
+ .create <vector::TransferWriteOp>(
302
+ op.getLoc (), newVec, op.getSource (), op.getIndices (),
303
+ AffineMapAttr::get (newMap), newMask, newInBoundsAttr)
304
+ .getResult ();
301
305
}
302
306
};
303
307
@@ -309,14 +313,19 @@ struct TransferWriteNonPermutationLowering
309
313
// / %v = vector.transfer_read ...
310
314
// / permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
311
315
// / vector.broadcast %v
312
- struct TransferOpReduceRank : public OpRewritePattern <vector::TransferReadOp> {
313
- using OpRewritePattern::OpRewritePattern;
314
-
315
- LogicalResult matchAndRewrite (vector::TransferReadOp op,
316
- PatternRewriter &rewriter) const override {
316
+ struct TransferOpReduceRank
317
+ : public MaskableOpRewritePattern<vector::TransferReadOp> {
318
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
319
+
320
+ FailureOr<mlir::Value>
321
+ matchAndRewriteMaskableOp (vector::TransferReadOp op,
322
+ MaskingOpInterface maskOp,
323
+ PatternRewriter &rewriter) const override {
317
324
// TODO: support 0-d corner case.
318
325
if (op.getTransferRank () == 0 )
319
326
return rewriter.notifyMatchFailure (op, " 0-d corner case not supported" );
327
+ if (maskOp)
328
+ return rewriter.notifyMatchFailure (op, " Masked case not supported" );
320
329
321
330
AffineMap map = op.getPermutationMap ();
322
331
unsigned numLeadingBroadcast = 0 ;
@@ -356,9 +365,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
356
365
op.getLoc (), originalVecType.getElementType (), op.getSource (),
357
366
op.getIndices ());
358
367
}
359
- rewriter. replaceOpWithNewOp <vector::BroadcastOp>(op, originalVecType,
360
- newRead);
361
- return success ();
368
+ return rewriter
369
+ . create <vector::BroadcastOp>(op. getLoc (), originalVecType, newRead)
370
+ . getVector ();
362
371
}
363
372
364
373
SmallVector<int64_t > newShape (
@@ -380,9 +389,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
380
389
op.getLoc (), newReadType, op.getSource (), op.getIndices (),
381
390
AffineMapAttr::get (newMap), op.getPadding (), op.getMask (),
382
391
newInBoundsAttr);
383
- rewriter. replaceOpWithNewOp <vector::BroadcastOp>(op, originalVecType,
384
- newRead);
385
- return success ();
392
+ return rewriter
393
+ . create <vector::BroadcastOp>(op. getLoc (), originalVecType, newRead)
394
+ . getVector ();
386
395
}
387
396
};
388
397
@@ -410,20 +419,23 @@ namespace {
410
419
// / result type.
411
420
// / - The permutation map doesn't perform permutation (broadcasting is allowed).
412
421
struct TransferReadToVectorLoadLowering
413
- : public OpRewritePattern <vector::TransferReadOp> {
422
+ : public MaskableOpRewritePattern <vector::TransferReadOp> {
414
423
TransferReadToVectorLoadLowering (MLIRContext *context,
415
424
std::optional<unsigned > maxRank,
416
425
PatternBenefit benefit = 1 )
417
- : OpRewritePattern <vector::TransferReadOp>(context, benefit),
426
+ : MaskableOpRewritePattern <vector::TransferReadOp>(context, benefit),
418
427
maxTransferRank (maxRank) {}
419
428
420
- LogicalResult matchAndRewrite (vector::TransferReadOp read,
421
- PatternRewriter &rewriter) const override {
429
+ FailureOr<mlir::Value>
430
+ matchAndRewriteMaskableOp (vector::TransferReadOp read,
431
+ MaskingOpInterface maskOp,
432
+ PatternRewriter &rewriter) const override {
422
433
if (maxTransferRank && read .getVectorType ().getRank () > *maxTransferRank) {
423
434
return rewriter.notifyMatchFailure (
424
435
read , " vector type is greater than max transfer rank" );
425
436
}
426
-
437
+ if (maskOp)
438
+ return rewriter.notifyMatchFailure (read , " Masked case not supported" );
427
439
SmallVector<unsigned > broadcastedDims;
428
440
// Permutations are handled by VectorToSCF or
429
441
// populateVectorTransferPermutationMapLoweringPatterns.
@@ -466,7 +478,7 @@ struct TransferReadToVectorLoadLowering
466
478
return rewriter.notifyMatchFailure (read , " out-of-bounds needs mask" );
467
479
468
480
// Create vector load op.
469
- Operation *loadOp ;
481
+ Operation *res ;
470
482
if (read .getMask ()) {
471
483
if (read .getVectorType ().getRank () != 1 )
472
484
// vector.maskedload operates on 1-D vectors.
@@ -476,24 +488,20 @@ struct TransferReadToVectorLoadLowering
476
488
477
489
Value fill = rewriter.create <vector::SplatOp>(
478
490
read .getLoc (), unbroadcastedVectorType, read .getPadding ());
479
- loadOp = rewriter.create <vector::MaskedLoadOp>(
491
+ res = rewriter.create <vector::MaskedLoadOp>(
480
492
read .getLoc (), unbroadcastedVectorType, read .getSource (),
481
493
read .getIndices (), read .getMask (), fill);
482
494
} else {
483
- loadOp = rewriter.create <vector::LoadOp>(
495
+ res = rewriter.create <vector::LoadOp>(
484
496
read .getLoc (), unbroadcastedVectorType, read .getSource (),
485
497
read .getIndices ());
486
498
}
487
499
488
500
// Insert a broadcasting op if required.
489
- if (!broadcastedDims.empty ()) {
490
- rewriter.replaceOpWithNewOp <vector::BroadcastOp>(
491
- read , read .getVectorType (), loadOp->getResult (0 ));
492
- } else {
493
- rewriter.replaceOp (read , loadOp->getResult (0 ));
494
- }
495
-
496
- return success ();
501
+ if (!broadcastedDims.empty ())
502
+ res = rewriter.create <vector::BroadcastOp>(
503
+ read .getLoc (), read .getVectorType (), res->getResult (0 ));
504
+ return res->getResults ()[0 ];
497
505
}
498
506
499
507
std::optional<unsigned > maxTransferRank;
@@ -562,19 +570,23 @@ struct VectorStoreToMemrefStoreLowering
562
570
// / - The permutation map is the minor identity map (neither permutation nor
563
571
// / broadcasting is allowed).
564
572
struct TransferWriteToVectorStoreLowering
565
- : public OpRewritePattern <vector::TransferWriteOp> {
573
+ : public MaskableOpRewritePattern <vector::TransferWriteOp> {
566
574
TransferWriteToVectorStoreLowering (MLIRContext *context,
567
575
std::optional<unsigned > maxRank,
568
576
PatternBenefit benefit = 1 )
569
- : OpRewritePattern <vector::TransferWriteOp>(context, benefit),
577
+ : MaskableOpRewritePattern <vector::TransferWriteOp>(context, benefit),
570
578
maxTransferRank (maxRank) {}
571
579
572
- LogicalResult matchAndRewrite (vector::TransferWriteOp write,
573
- PatternRewriter &rewriter) const override {
580
+ FailureOr<mlir::Value>
581
+ matchAndRewriteMaskableOp (vector::TransferWriteOp write,
582
+ MaskingOpInterface maskOp,
583
+ PatternRewriter &rewriter) const override {
574
584
if (maxTransferRank && write .getVectorType ().getRank () > *maxTransferRank) {
575
585
return rewriter.notifyMatchFailure (
576
586
write , " vector type is greater than max transfer rank" );
577
587
}
588
+ if (maskOp)
589
+ return rewriter.notifyMatchFailure (write , " Masked case not supported" );
578
590
579
591
// Permutations are handled by VectorToSCF or
580
592
// populateVectorTransferPermutationMapLoweringPatterns.
@@ -626,14 +638,17 @@ struct TransferWriteToVectorStoreLowering
626
638
<< write ;
627
639
});
628
640
629
- rewriter.replaceOpWithNewOp <vector::MaskedStoreOp>(
630
- write , write .getSource (), write .getIndices (), write .getMask (),
631
- write .getVector ());
641
+ return rewriter
642
+ .create <vector::MaskedStoreOp>(write .getLoc (), write .getSource (),
643
+ write .getIndices (), write .getMask (),
644
+ write .getVector ())
645
+ .getBase ();
632
646
} else {
633
- rewriter.replaceOpWithNewOp <vector::StoreOp>(
634
- write , write .getVector (), write .getSource (), write .getIndices ());
647
+ return rewriter
648
+ .create <vector::StoreOp>(write .getLoc (), write .getVector (),
649
+ write .getSource (), write .getIndices ())
650
+ .getBase ();
635
651
}
636
- return success ();
637
652
}
638
653
639
654
std::optional<unsigned > maxTransferRank;
0 commit comments