Skip to content

Commit fa53c9c

Browse files
committed
Use MaskableOp interface for FoldTensorSubsetOps and LowerVectorTransfer
1 parent 57863a4 commit fa53c9c

File tree

3 files changed

+104
-85
lines changed

3 files changed

+104
-85
lines changed

mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,6 @@ add_mlir_dialect_library(MLIRTensorTransforms
3939
MLIRTilingInterface
4040
MLIRTransforms
4141
MLIRVectorDialect
42+
MLIRVectorUtils
4243
MLIRValueBoundsOpInterface
4344
)

mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
1919
#include "mlir/Dialect/Utils/IndexingUtils.h"
2020
#include "mlir/Dialect/Vector/IR/VectorOps.h"
21+
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
2122
#include "mlir/IR/AffineMap.h"
2223
#include "mlir/IR/BuiltinAttributes.h"
2324
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -48,12 +49,14 @@ static Value getTensorOperand(tensor::InsertSliceOp op) {
4849
namespace {
4950
/// Merge extract_slice operation with load/transferRead operation.
5051
class TransferReadOfExtractSliceOpFolder final
51-
: public OpRewritePattern<vector::TransferReadOp> {
52+
: public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
5253
public:
53-
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
54+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
5455

55-
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
56-
PatternRewriter &rewriter) const override;
56+
FailureOr<mlir::Value>
57+
matchAndRewriteMaskableOp(vector::TransferReadOp readOp,
58+
vector::MaskingOpInterface maskOp,
59+
PatternRewriter &rewriter) const override;
5760
};
5861

5962
/// Merge insert_slice operation with store/transferWriteOp operation.
@@ -84,8 +87,10 @@ static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(
8487
return success();
8588
}
8689

87-
LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
88-
vector::TransferReadOp readOp, PatternRewriter &rewriter) const {
90+
FailureOr<mlir::Value>
91+
TransferReadOfExtractSliceOpFolder::matchAndRewriteMaskableOp(
92+
vector::TransferReadOp readOp, vector::MaskingOpInterface maskOp,
93+
PatternRewriter &rewriter) const {
8994
auto extractSliceOp =
9095
getTensorOperand(readOp).getDefiningOp<tensor::ExtractSliceOp>();
9196
if (!extractSliceOp)
@@ -95,31 +100,29 @@ LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
95100
preconditionsFoldExtractOrInsertWithTransferOp(rewriter, readOp,
96101
extractSliceOp);
97102
if (failed(preconditionResult))
98-
return preconditionResult;
103+
return rewriter.notifyMatchFailure(readOp, "Failed preconditions");
99104

100105
SmallVector<Value> indices(readOp.getIndices().begin(),
101106
readOp.getIndices().end());
102107
SmallVector<Value> sourceIndices;
103108
// In case transfer_read is located inside a MaskOp we want to avoid creating
104109
// more ops inside it.
105-
if (isa<vector::MaskOp>(readOp->getParentOp()))
106-
rewriter.setInsertionPoint(readOp->getParentOp());
107110
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
108111
rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
109112
extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
110113
indices, sourceIndices);
111114

112-
// Reset the insertion point.
113-
rewriter.setInsertionPoint(readOp);
114-
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
115-
readOp, readOp.getVectorType(), extractSliceOp.getSource(), sourceIndices,
115+
Operation *newOp = rewriter.create<vector::TransferReadOp>(
116+
readOp.getLoc(), readOp.getVectorType(), extractSliceOp.getSource(),
117+
sourceIndices,
116118
AffineMapAttr::get(expandDimsToRank(
117119
readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(),
118120
extractSliceOp.getDroppedDims())),
119121
readOp.getPadding(),
120122
/*mask=*/Value(), readOp.getInBoundsAttr());
121-
122-
return success();
123+
if (maskOp)
124+
newOp = mlir::vector::maskOperation(rewriter, newOp, maskOp.getMask());
125+
return newOp->getResults()[0];
123126
}
124127

125128
LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(

mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp

Lines changed: 85 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,18 @@ namespace {
9090
/// Note that an alternative is to transform it to linalg.transpose +
9191
/// vector.transfer_read to do the transpose in memory instead.
9292
struct TransferReadPermutationLowering
93-
: public OpRewritePattern<vector::TransferReadOp> {
94-
using OpRewritePattern::OpRewritePattern;
93+
: public MaskableOpRewritePattern<vector::TransferReadOp> {
94+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
9595

96-
LogicalResult matchAndRewrite(vector::TransferReadOp op,
97-
PatternRewriter &rewriter) const override {
96+
FailureOr<mlir::Value>
97+
matchAndRewriteMaskableOp(vector::TransferReadOp op,
98+
MaskingOpInterface maskOp,
99+
PatternRewriter &rewriter) const override {
98100
// TODO: support 0-d corner case.
99101
if (op.getTransferRank() == 0)
100102
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
101-
if (isa<vector::MaskOp>(op->getParentOp()))
102-
return rewriter.notifyMatchFailure(
103-
op, "Cannot expand transfer read inside a Mask Op");
103+
if (maskOp)
104+
return rewriter.notifyMatchFailure(op, "Masked case not supported");
104105

105106
SmallVector<unsigned> permutation;
106107
AffineMap map = op.getPermutationMap();
@@ -145,9 +146,9 @@ struct TransferReadPermutationLowering
145146

146147
// Transpose result of transfer_read.
147148
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
148-
rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
149-
transposePerm);
150-
return success();
149+
return rewriter
150+
.create<vector::TransposeOp>(op.getLoc(), newRead, transposePerm)
151+
.getResult();
151152
}
152153
};
153154

@@ -168,17 +169,18 @@ struct TransferReadPermutationLowering
168169
/// %v = vector.transfer_write %tmp ...
169170
/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
170171
struct TransferWritePermutationLowering
171-
: public OpRewritePattern<vector::TransferWriteOp> {
172-
using OpRewritePattern::OpRewritePattern;
172+
: public MaskableOpRewritePattern<vector::TransferWriteOp> {
173+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
173174

174-
LogicalResult matchAndRewrite(vector::TransferWriteOp op,
175-
PatternRewriter &rewriter) const override {
175+
FailureOr<mlir::Value>
176+
matchAndRewriteMaskableOp(vector::TransferWriteOp op,
177+
MaskingOpInterface maskOp,
178+
PatternRewriter &rewriter) const override {
176179
// TODO: support 0-d corner case.
177180
if (op.getTransferRank() == 0)
178181
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
179-
if (isa<vector::MaskOp>(op->getParentOp()))
180-
return rewriter.notifyMatchFailure(
181-
op, "Cannot expand transfer write inside a Mask Op");
182+
if (maskOp)
183+
return rewriter.notifyMatchFailure(op, "Masked case not supported");
182184

183185
SmallVector<unsigned> permutation;
184186
AffineMap map = op.getPermutationMap();
@@ -213,11 +215,11 @@ struct TransferWritePermutationLowering
213215
op.getLoc(), op.getVector(), indices);
214216
auto newMap = AffineMap::getMinorIdentityMap(
215217
map.getNumDims(), map.getNumResults(), rewriter.getContext());
216-
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
217-
op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
218-
op.getMask(), newInBoundsAttr);
219-
220-
return success();
218+
return rewriter
219+
.create<vector::TransferWriteOp>(
220+
op.getLoc(), newVec, op.getSource(), op.getIndices(),
221+
AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr)
222+
.getResult();
221223
}
222224
};
223225

@@ -237,17 +239,18 @@ struct TransferWritePermutationLowering
237239
/// vector<1x8x16xf32>
238240
/// ```
239241
struct TransferWriteNonPermutationLowering
240-
: public OpRewritePattern<vector::TransferWriteOp> {
241-
using OpRewritePattern::OpRewritePattern;
242+
: public MaskableOpRewritePattern<vector::TransferWriteOp> {
243+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
242244

243-
LogicalResult matchAndRewrite(vector::TransferWriteOp op,
244-
PatternRewriter &rewriter) const override {
245+
FailureOr<mlir::Value>
246+
matchAndRewriteMaskableOp(vector::TransferWriteOp op,
247+
MaskingOpInterface maskOp,
248+
PatternRewriter &rewriter) const override {
245249
// TODO: support 0-d corner case.
246250
if (op.getTransferRank() == 0)
247251
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
248-
if (isa<vector::MaskOp>(op->getParentOp()))
249-
return rewriter.notifyMatchFailure(
250-
op, "Cannot expand transfer write inside a Mask Op");
252+
if (maskOp)
253+
return rewriter.notifyMatchFailure(op, "Masked case not supported");
251254

252255
SmallVector<unsigned> permutation;
253256
AffineMap map = op.getPermutationMap();
@@ -294,10 +297,11 @@ struct TransferWriteNonPermutationLowering
294297
newInBoundsValues.push_back(op.isDimInBounds(i));
295298
}
296299
ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
297-
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
298-
op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
299-
newMask, newInBoundsAttr);
300-
return success();
300+
return rewriter
301+
.create<vector::TransferWriteOp>(
302+
op.getLoc(), newVec, op.getSource(), op.getIndices(),
303+
AffineMapAttr::get(newMap), newMask, newInBoundsAttr)
304+
.getResult();
301305
}
302306
};
303307

@@ -309,14 +313,19 @@ struct TransferWriteNonPermutationLowering
309313
/// %v = vector.transfer_read ...
310314
/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
311315
/// vector.broadcast %v
312-
struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
313-
using OpRewritePattern::OpRewritePattern;
314-
315-
LogicalResult matchAndRewrite(vector::TransferReadOp op,
316-
PatternRewriter &rewriter) const override {
316+
struct TransferOpReduceRank
317+
: public MaskableOpRewritePattern<vector::TransferReadOp> {
318+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
319+
320+
FailureOr<mlir::Value>
321+
matchAndRewriteMaskableOp(vector::TransferReadOp op,
322+
MaskingOpInterface maskOp,
323+
PatternRewriter &rewriter) const override {
317324
// TODO: support 0-d corner case.
318325
if (op.getTransferRank() == 0)
319326
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
327+
if (maskOp)
328+
return rewriter.notifyMatchFailure(op, "Masked case not supported");
320329

321330
AffineMap map = op.getPermutationMap();
322331
unsigned numLeadingBroadcast = 0;
@@ -356,9 +365,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
356365
op.getLoc(), originalVecType.getElementType(), op.getSource(),
357366
op.getIndices());
358367
}
359-
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
360-
newRead);
361-
return success();
368+
return rewriter
369+
.create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
370+
.getVector();
362371
}
363372

364373
SmallVector<int64_t> newShape(
@@ -380,9 +389,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
380389
op.getLoc(), newReadType, op.getSource(), op.getIndices(),
381390
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
382391
newInBoundsAttr);
383-
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
384-
newRead);
385-
return success();
392+
return rewriter
393+
.create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
394+
.getVector();
386395
}
387396
};
388397

@@ -410,20 +419,23 @@ namespace {
410419
/// result type.
411420
/// - The permutation map doesn't perform permutation (broadcasting is allowed).
412421
struct TransferReadToVectorLoadLowering
413-
: public OpRewritePattern<vector::TransferReadOp> {
422+
: public MaskableOpRewritePattern<vector::TransferReadOp> {
414423
TransferReadToVectorLoadLowering(MLIRContext *context,
415424
std::optional<unsigned> maxRank,
416425
PatternBenefit benefit = 1)
417-
: OpRewritePattern<vector::TransferReadOp>(context, benefit),
426+
: MaskableOpRewritePattern<vector::TransferReadOp>(context, benefit),
418427
maxTransferRank(maxRank) {}
419428

420-
LogicalResult matchAndRewrite(vector::TransferReadOp read,
421-
PatternRewriter &rewriter) const override {
429+
FailureOr<mlir::Value>
430+
matchAndRewriteMaskableOp(vector::TransferReadOp read,
431+
MaskingOpInterface maskOp,
432+
PatternRewriter &rewriter) const override {
422433
if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) {
423434
return rewriter.notifyMatchFailure(
424435
read, "vector type is greater than max transfer rank");
425436
}
426-
437+
if (maskOp)
438+
return rewriter.notifyMatchFailure(read, "Masked case not supported");
427439
SmallVector<unsigned> broadcastedDims;
428440
// Permutations are handled by VectorToSCF or
429441
// populateVectorTransferPermutationMapLoweringPatterns.
@@ -466,7 +478,7 @@ struct TransferReadToVectorLoadLowering
466478
return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask");
467479

468480
// Create vector load op.
469-
Operation *loadOp;
481+
Operation *res;
470482
if (read.getMask()) {
471483
if (read.getVectorType().getRank() != 1)
472484
// vector.maskedload operates on 1-D vectors.
@@ -476,24 +488,20 @@ struct TransferReadToVectorLoadLowering
476488

477489
Value fill = rewriter.create<vector::SplatOp>(
478490
read.getLoc(), unbroadcastedVectorType, read.getPadding());
479-
loadOp = rewriter.create<vector::MaskedLoadOp>(
491+
res = rewriter.create<vector::MaskedLoadOp>(
480492
read.getLoc(), unbroadcastedVectorType, read.getSource(),
481493
read.getIndices(), read.getMask(), fill);
482494
} else {
483-
loadOp = rewriter.create<vector::LoadOp>(
495+
res = rewriter.create<vector::LoadOp>(
484496
read.getLoc(), unbroadcastedVectorType, read.getSource(),
485497
read.getIndices());
486498
}
487499

488500
// Insert a broadcasting op if required.
489-
if (!broadcastedDims.empty()) {
490-
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
491-
read, read.getVectorType(), loadOp->getResult(0));
492-
} else {
493-
rewriter.replaceOp(read, loadOp->getResult(0));
494-
}
495-
496-
return success();
501+
if (!broadcastedDims.empty())
502+
res = rewriter.create<vector::BroadcastOp>(
503+
read.getLoc(), read.getVectorType(), res->getResult(0));
504+
return res->getResults()[0];
497505
}
498506

499507
std::optional<unsigned> maxTransferRank;
@@ -562,19 +570,23 @@ struct VectorStoreToMemrefStoreLowering
562570
/// - The permutation map is the minor identity map (neither permutation nor
563571
/// broadcasting is allowed).
564572
struct TransferWriteToVectorStoreLowering
565-
: public OpRewritePattern<vector::TransferWriteOp> {
573+
: public MaskableOpRewritePattern<vector::TransferWriteOp> {
566574
TransferWriteToVectorStoreLowering(MLIRContext *context,
567575
std::optional<unsigned> maxRank,
568576
PatternBenefit benefit = 1)
569-
: OpRewritePattern<vector::TransferWriteOp>(context, benefit),
577+
: MaskableOpRewritePattern<vector::TransferWriteOp>(context, benefit),
570578
maxTransferRank(maxRank) {}
571579

572-
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
573-
PatternRewriter &rewriter) const override {
580+
FailureOr<mlir::Value>
581+
matchAndRewriteMaskableOp(vector::TransferWriteOp write,
582+
MaskingOpInterface maskOp,
583+
PatternRewriter &rewriter) const override {
574584
if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) {
575585
return rewriter.notifyMatchFailure(
576586
write, "vector type is greater than max transfer rank");
577587
}
588+
if (maskOp)
589+
return rewriter.notifyMatchFailure(write, "Masked case not supported");
578590

579591
// Permutations are handled by VectorToSCF or
580592
// populateVectorTransferPermutationMapLoweringPatterns.
@@ -626,14 +638,17 @@ struct TransferWriteToVectorStoreLowering
626638
<< write;
627639
});
628640

629-
rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
630-
write, write.getSource(), write.getIndices(), write.getMask(),
631-
write.getVector());
641+
return rewriter
642+
.create<vector::MaskedStoreOp>(write.getLoc(), write.getSource(),
643+
write.getIndices(), write.getMask(),
644+
write.getVector())
645+
.getBase();
632646
} else {
633-
rewriter.replaceOpWithNewOp<vector::StoreOp>(
634-
write, write.getVector(), write.getSource(), write.getIndices());
647+
return rewriter
648+
.create<vector::StoreOp>(write.getLoc(), write.getVector(),
649+
write.getSource(), write.getIndices())
650+
.getBase();
635651
}
636-
return success();
637652
}
638653

639654
std::optional<unsigned> maxTransferRank;

0 commit comments

Comments
 (0)