18
18
#include " mlir/Dialect/Tensor/Transforms/Transforms.h"
19
19
#include " mlir/Dialect/Utils/IndexingUtils.h"
20
20
#include " mlir/Dialect/Vector/IR/VectorOps.h"
21
+ #include " mlir/Dialect/Vector/Utils/VectorUtils.h"
21
22
#include " mlir/IR/AffineMap.h"
22
23
#include " mlir/IR/BuiltinAttributes.h"
23
24
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -48,12 +49,14 @@ static Value getTensorOperand(tensor::InsertSliceOp op) {
48
49
namespace {
49
50
// / Merge extract_slice operation with load/transferRead operation.
50
51
class TransferReadOfExtractSliceOpFolder final
51
- : public OpRewritePattern <vector::TransferReadOp> {
52
+ : public vector::MaskableOpRewritePattern <vector::TransferReadOp> {
52
53
public:
53
- using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern ;
54
+ using MaskableOpRewritePattern::MaskableOpRewritePattern ;
54
55
55
- LogicalResult matchAndRewrite (vector::TransferReadOp readOp,
56
- PatternRewriter &rewriter) const override ;
56
+ FailureOr<mlir::Value>
57
+ matchAndRewriteMaskableOp (vector::TransferReadOp readOp,
58
+ vector::MaskingOpInterface maskOp,
59
+ PatternRewriter &rewriter) const override ;
57
60
};
58
61
59
62
// / Merge insert_slice operation with store/transferWriteOp operation.
@@ -84,8 +87,10 @@ static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(
84
87
return success ();
85
88
}
86
89
87
- LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite (
88
- vector::TransferReadOp readOp, PatternRewriter &rewriter) const {
90
+ FailureOr<mlir::Value>
91
+ TransferReadOfExtractSliceOpFolder::matchAndRewriteMaskableOp (
92
+ vector::TransferReadOp readOp, vector::MaskingOpInterface maskOp,
93
+ PatternRewriter &rewriter) const {
89
94
auto extractSliceOp =
90
95
getTensorOperand (readOp).getDefiningOp <tensor::ExtractSliceOp>();
91
96
if (!extractSliceOp)
@@ -95,7 +100,7 @@ LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
95
100
preconditionsFoldExtractOrInsertWithTransferOp (rewriter, readOp,
96
101
extractSliceOp);
97
102
if (failed (preconditionResult))
98
- return preconditionResult ;
103
+ return rewriter. notifyMatchFailure (readOp, " Failed preconditions " ) ;
99
104
100
105
SmallVector<Value> indices (readOp.getIndices ().begin (),
101
106
readOp.getIndices ().end ());
@@ -105,15 +110,17 @@ LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
105
110
extractSliceOp.getMixedStrides (), extractSliceOp.getDroppedDims (),
106
111
indices, sourceIndices);
107
112
108
- rewriter.replaceOpWithNewOp <vector::TransferReadOp>(
109
- readOp, readOp.getVectorType (), extractSliceOp.getSource (), sourceIndices,
113
+ Operation *newOp = rewriter.create <vector::TransferReadOp>(
114
+ readOp.getLoc (), readOp.getVectorType (), extractSliceOp.getSource (),
115
+ sourceIndices,
110
116
AffineMapAttr::get (expandDimsToRank (
111
117
readOp.getPermutationMap (), extractSliceOp.getSourceType ().getRank (),
112
118
extractSliceOp.getDroppedDims ())),
113
119
readOp.getPadding (),
114
120
/* mask=*/ Value (), readOp.getInBoundsAttr ());
115
-
116
- return success ();
121
+ if (maskOp)
122
+ newOp = mlir::vector::maskOperation (rewriter, newOp, maskOp.getMask ());
123
+ return newOp->getResults ()[0 ];
117
124
}
118
125
119
126
LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite (
0 commit comments