-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR][Vector] Fix transferOps optimization inside maskOp #90835
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -90,14 +90,18 @@ namespace { | |
/// Note that an alternative is to transform it to linalg.transpose + | ||
/// vector.transfer_read to do the transpose in memory instead. | ||
struct TransferReadPermutationLowering | ||
: public OpRewritePattern<vector::TransferReadOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
: public MaskableOpRewritePattern<vector::TransferReadOp> { | ||
using MaskableOpRewritePattern::MaskableOpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(vector::TransferReadOp op, | ||
PatternRewriter &rewriter) const override { | ||
FailureOr<mlir::Value> | ||
matchAndRewriteMaskableOp(vector::TransferReadOp op, | ||
MaskingOpInterface maskOp, | ||
PatternRewriter &rewriter) const override { | ||
// TODO: support 0-d corner case. | ||
if (op.getTransferRank() == 0) | ||
return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); | ||
if (maskOp) | ||
return rewriter.notifyMatchFailure(op, "Masked case not supported"); | ||
|
||
SmallVector<unsigned> permutation; | ||
AffineMap map = op.getPermutationMap(); | ||
|
@@ -142,9 +146,9 @@ struct TransferReadPermutationLowering | |
|
||
// Transpose result of transfer_read. | ||
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end()); | ||
rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead, | ||
transposePerm); | ||
return success(); | ||
return rewriter | ||
.create<vector::TransposeOp>(op.getLoc(), newRead, transposePerm) | ||
.getResult(); | ||
} | ||
}; | ||
|
||
|
@@ -165,14 +169,18 @@ struct TransferReadPermutationLowering | |
/// %v = vector.transfer_write %tmp ... | ||
/// permutation_map: (d0, d1, d2, d3) -> (d2, d3) | ||
struct TransferWritePermutationLowering | ||
: public OpRewritePattern<vector::TransferWriteOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
: public MaskableOpRewritePattern<vector::TransferWriteOp> { | ||
using MaskableOpRewritePattern::MaskableOpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(vector::TransferWriteOp op, | ||
PatternRewriter &rewriter) const override { | ||
FailureOr<mlir::Value> | ||
matchAndRewriteMaskableOp(vector::TransferWriteOp op, | ||
MaskingOpInterface maskOp, | ||
PatternRewriter &rewriter) const override { | ||
// TODO: support 0-d corner case. | ||
if (op.getTransferRank() == 0) | ||
return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); | ||
if (maskOp) | ||
return rewriter.notifyMatchFailure(op, "Masked case not supported"); | ||
Comment on lines
+182
to
+183
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could it be supported? |
||
|
||
SmallVector<unsigned> permutation; | ||
AffineMap map = op.getPermutationMap(); | ||
|
@@ -207,11 +215,11 @@ struct TransferWritePermutationLowering | |
op.getLoc(), op.getVector(), indices); | ||
auto newMap = AffineMap::getMinorIdentityMap( | ||
map.getNumDims(), map.getNumResults(), rewriter.getContext()); | ||
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( | ||
op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap), | ||
op.getMask(), newInBoundsAttr); | ||
|
||
return success(); | ||
return rewriter | ||
.create<vector::TransferWriteOp>( | ||
op.getLoc(), newVec, op.getSource(), op.getIndices(), | ||
AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr) | ||
.getResult(); | ||
} | ||
}; | ||
|
||
|
@@ -231,14 +239,18 @@ struct TransferWritePermutationLowering | |
/// vector<1x8x16xf32> | ||
/// ``` | ||
struct TransferWriteNonPermutationLowering | ||
: public OpRewritePattern<vector::TransferWriteOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
: public MaskableOpRewritePattern<vector::TransferWriteOp> { | ||
using MaskableOpRewritePattern::MaskableOpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(vector::TransferWriteOp op, | ||
PatternRewriter &rewriter) const override { | ||
FailureOr<mlir::Value> | ||
matchAndRewriteMaskableOp(vector::TransferWriteOp op, | ||
MaskingOpInterface maskOp, | ||
PatternRewriter &rewriter) const override { | ||
// TODO: support 0-d corner case. | ||
if (op.getTransferRank() == 0) | ||
return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); | ||
if (maskOp) | ||
return rewriter.notifyMatchFailure(op, "Masked case not supported"); | ||
Comment on lines
+252
to
+253
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could it be suppported? |
||
|
||
SmallVector<unsigned> permutation; | ||
AffineMap map = op.getPermutationMap(); | ||
|
@@ -285,10 +297,11 @@ struct TransferWriteNonPermutationLowering | |
newInBoundsValues.push_back(op.isDimInBounds(i)); | ||
} | ||
ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues); | ||
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( | ||
op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap), | ||
newMask, newInBoundsAttr); | ||
return success(); | ||
return rewriter | ||
.create<vector::TransferWriteOp>( | ||
op.getLoc(), newVec, op.getSource(), op.getIndices(), | ||
AffineMapAttr::get(newMap), newMask, newInBoundsAttr) | ||
.getResult(); | ||
} | ||
}; | ||
|
||
|
@@ -300,14 +313,19 @@ struct TransferWriteNonPermutationLowering | |
/// %v = vector.transfer_read ... | ||
/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3) | ||
/// vector.broadcast %v | ||
struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(vector::TransferReadOp op, | ||
PatternRewriter &rewriter) const override { | ||
struct TransferOpReduceRank | ||
: public MaskableOpRewritePattern<vector::TransferReadOp> { | ||
using MaskableOpRewritePattern::MaskableOpRewritePattern; | ||
|
||
FailureOr<mlir::Value> | ||
matchAndRewriteMaskableOp(vector::TransferReadOp op, | ||
MaskingOpInterface maskOp, | ||
PatternRewriter &rewriter) const override { | ||
// TODO: support 0-d corner case. | ||
if (op.getTransferRank() == 0) | ||
return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); | ||
if (maskOp) | ||
return rewriter.notifyMatchFailure(op, "Masked case not supported"); | ||
Comment on lines
+327
to
+328
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could it be supported? |
||
|
||
AffineMap map = op.getPermutationMap(); | ||
unsigned numLeadingBroadcast = 0; | ||
|
@@ -347,9 +365,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> { | |
op.getLoc(), originalVecType.getElementType(), op.getSource(), | ||
op.getIndices()); | ||
} | ||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType, | ||
newRead); | ||
return success(); | ||
return rewriter | ||
.create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead) | ||
.getVector(); | ||
} | ||
|
||
SmallVector<int64_t> newShape( | ||
|
@@ -371,9 +389,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> { | |
op.getLoc(), newReadType, op.getSource(), op.getIndices(), | ||
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), | ||
newInBoundsAttr); | ||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType, | ||
newRead); | ||
return success(); | ||
return rewriter | ||
.create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead) | ||
.getVector(); | ||
} | ||
}; | ||
|
||
|
@@ -401,20 +419,24 @@ namespace { | |
/// result type. | ||
/// - The permutation map doesn't perform permutation (broadcasting is allowed). | ||
struct TransferReadToVectorLoadLowering | ||
: public OpRewritePattern<vector::TransferReadOp> { | ||
: public MaskableOpRewritePattern<vector::TransferReadOp> { | ||
TransferReadToVectorLoadLowering(MLIRContext *context, | ||
std::optional<unsigned> maxRank, | ||
PatternBenefit benefit = 1) | ||
: OpRewritePattern<vector::TransferReadOp>(context, benefit), | ||
: MaskableOpRewritePattern<vector::TransferReadOp>(context, benefit), | ||
maxTransferRank(maxRank) {} | ||
|
||
LogicalResult matchAndRewrite(vector::TransferReadOp read, | ||
PatternRewriter &rewriter) const override { | ||
FailureOr<mlir::Value> | ||
matchAndRewriteMaskableOp(vector::TransferReadOp read, | ||
MaskingOpInterface maskOp, | ||
PatternRewriter &rewriter) const override { | ||
if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) { | ||
return rewriter.notifyMatchFailure( | ||
read, "vector type is greater than max transfer rank"); | ||
} | ||
|
||
if (maskOp) | ||
return rewriter.notifyMatchFailure(read, "Masked case not supported"); | ||
Comment on lines
+438
to
+439
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could it be supported? |
||
SmallVector<unsigned> broadcastedDims; | ||
// Permutations are handled by VectorToSCF or | ||
// populateVectorTransferPermutationMapLoweringPatterns. | ||
|
@@ -457,7 +479,7 @@ struct TransferReadToVectorLoadLowering | |
return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask"); | ||
|
||
// Create vector load op. | ||
Operation *loadOp; | ||
Operation *res; | ||
if (read.getMask()) { | ||
if (read.getVectorType().getRank() != 1) | ||
// vector.maskedload operates on 1-D vectors. | ||
|
@@ -467,24 +489,20 @@ struct TransferReadToVectorLoadLowering | |
|
||
Value fill = rewriter.create<vector::SplatOp>( | ||
read.getLoc(), unbroadcastedVectorType, read.getPadding()); | ||
loadOp = rewriter.create<vector::MaskedLoadOp>( | ||
res = rewriter.create<vector::MaskedLoadOp>( | ||
read.getLoc(), unbroadcastedVectorType, read.getSource(), | ||
read.getIndices(), read.getMask(), fill); | ||
} else { | ||
loadOp = rewriter.create<vector::LoadOp>( | ||
res = rewriter.create<vector::LoadOp>( | ||
read.getLoc(), unbroadcastedVectorType, read.getSource(), | ||
read.getIndices()); | ||
} | ||
|
||
// Insert a broadcasting op if required. | ||
if (!broadcastedDims.empty()) { | ||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>( | ||
read, read.getVectorType(), loadOp->getResult(0)); | ||
} else { | ||
rewriter.replaceOp(read, loadOp->getResult(0)); | ||
} | ||
|
||
return success(); | ||
if (!broadcastedDims.empty()) | ||
res = rewriter.create<vector::BroadcastOp>( | ||
read.getLoc(), read.getVectorType(), res->getResult(0)); | ||
return res->getResults()[0]; | ||
} | ||
|
||
std::optional<unsigned> maxTransferRank; | ||
|
@@ -553,19 +571,23 @@ struct VectorStoreToMemrefStoreLowering | |
/// - The permutation map is the minor identity map (neither permutation nor | ||
/// broadcasting is allowed). | ||
struct TransferWriteToVectorStoreLowering | ||
: public OpRewritePattern<vector::TransferWriteOp> { | ||
: public MaskableOpRewritePattern<vector::TransferWriteOp> { | ||
TransferWriteToVectorStoreLowering(MLIRContext *context, | ||
std::optional<unsigned> maxRank, | ||
PatternBenefit benefit = 1) | ||
: OpRewritePattern<vector::TransferWriteOp>(context, benefit), | ||
: MaskableOpRewritePattern<vector::TransferWriteOp>(context, benefit), | ||
maxTransferRank(maxRank) {} | ||
|
||
LogicalResult matchAndRewrite(vector::TransferWriteOp write, | ||
PatternRewriter &rewriter) const override { | ||
FailureOr<mlir::Value> | ||
matchAndRewriteMaskableOp(vector::TransferWriteOp write, | ||
MaskingOpInterface maskOp, | ||
PatternRewriter &rewriter) const override { | ||
if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) { | ||
return rewriter.notifyMatchFailure( | ||
write, "vector type is greater than max transfer rank"); | ||
} | ||
if (maskOp) | ||
return rewriter.notifyMatchFailure(write, "Masked case not supported"); | ||
Comment on lines
+589
to
+590
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could it be supported? |
||
|
||
// Permutations are handled by VectorToSCF or | ||
// populateVectorTransferPermutationMapLoweringPatterns. | ||
|
@@ -617,14 +639,17 @@ struct TransferWriteToVectorStoreLowering | |
<< write; | ||
}); | ||
|
||
rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( | ||
write, write.getSource(), write.getIndices(), write.getMask(), | ||
write.getVector()); | ||
return rewriter | ||
.create<vector::MaskedStoreOp>(write.getLoc(), write.getSource(), | ||
write.getIndices(), write.getMask(), | ||
write.getVector()) | ||
.getBase(); | ||
} else { | ||
rewriter.replaceOpWithNewOp<vector::StoreOp>( | ||
write, write.getVector(), write.getSource(), write.getIndices()); | ||
return rewriter | ||
.create<vector::StoreOp>(write.getLoc(), write.getVector(), | ||
write.getSource(), write.getIndices()) | ||
.getBase(); | ||
} | ||
return success(); | ||
} | ||
|
||
std::optional<unsigned> maxTransferRank; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could it be supported?