-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR][Vector] Implement XferOp To {Load|Store}Lowering as MaskableOpRewritePattern #92892
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Hugo Trachino (nujaa) ChangesImplements Split of #90835 Full diff: https://github.com/llvm/llvm-project/pull/92892.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index c59012266ceb3..9418a087c4367 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -423,20 +423,24 @@ namespace {
/// result type.
/// - The permutation map doesn't perform permutation (broadcasting is allowed).
struct TransferReadToVectorLoadLowering
- : public OpRewritePattern<vector::TransferReadOp> {
+ : public MaskableOpRewritePattern<vector::TransferReadOp> {
TransferReadToVectorLoadLowering(MLIRContext *context,
std::optional<unsigned> maxRank,
PatternBenefit benefit = 1)
- : OpRewritePattern<vector::TransferReadOp>(context, benefit),
+ : MaskableOpRewritePattern<vector::TransferReadOp>(context, benefit),
maxTransferRank(maxRank) {}
- LogicalResult matchAndRewrite(vector::TransferReadOp read,
- PatternRewriter &rewriter) const override {
+ FailureOr<mlir::Value>
+ matchAndRewriteMaskableOp(vector::TransferReadOp read,
+ MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override {
if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) {
return rewriter.notifyMatchFailure(
read, "vector type is greater than max transfer rank");
}
+ if (maskOp)
+ return rewriter.notifyMatchFailure(read, "Masked case not supported");
SmallVector<unsigned> broadcastedDims;
// Permutations are handled by VectorToSCF or
// populateVectorTransferPermutationMapLoweringPatterns.
@@ -479,7 +483,7 @@ struct TransferReadToVectorLoadLowering
return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask");
// Create vector load op.
- Operation *loadOp;
+ Operation *res;
if (read.getMask()) {
if (read.getVectorType().getRank() != 1)
// vector.maskedload operates on 1-D vectors.
@@ -489,24 +493,20 @@ struct TransferReadToVectorLoadLowering
Value fill = rewriter.create<vector::SplatOp>(
read.getLoc(), unbroadcastedVectorType, read.getPadding());
- loadOp = rewriter.create<vector::MaskedLoadOp>(
+ res = rewriter.create<vector::MaskedLoadOp>(
read.getLoc(), unbroadcastedVectorType, read.getSource(),
read.getIndices(), read.getMask(), fill);
} else {
- loadOp = rewriter.create<vector::LoadOp>(
+ res = rewriter.create<vector::LoadOp>(
read.getLoc(), unbroadcastedVectorType, read.getSource(),
read.getIndices());
}
// Insert a broadcasting op if required.
- if (!broadcastedDims.empty()) {
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
- read, read.getVectorType(), loadOp->getResult(0));
- } else {
- rewriter.replaceOp(read, loadOp->getResult(0));
- }
-
- return success();
+ if (!broadcastedDims.empty())
+ res = rewriter.create<vector::BroadcastOp>(
+ read.getLoc(), read.getVectorType(), res->getResult(0));
+ return res->getResults()[0];
}
std::optional<unsigned> maxTransferRank;
@@ -575,19 +575,23 @@ struct VectorStoreToMemrefStoreLowering
/// - The permutation map is the minor identity map (neither permutation nor
/// broadcasting is allowed).
struct TransferWriteToVectorStoreLowering
- : public OpRewritePattern<vector::TransferWriteOp> {
+ : public MaskableOpRewritePattern<vector::TransferWriteOp> {
TransferWriteToVectorStoreLowering(MLIRContext *context,
std::optional<unsigned> maxRank,
PatternBenefit benefit = 1)
- : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
+ : MaskableOpRewritePattern<vector::TransferWriteOp>(context, benefit),
maxTransferRank(maxRank) {}
- LogicalResult matchAndRewrite(vector::TransferWriteOp write,
- PatternRewriter &rewriter) const override {
+ FailureOr<mlir::Value>
+ matchAndRewriteMaskableOp(vector::TransferWriteOp write,
+ MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override {
if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) {
return rewriter.notifyMatchFailure(
write, "vector type is greater than max transfer rank");
}
+ if (maskOp)
+ return rewriter.notifyMatchFailure(write, "Masked case not supported");
// Permutations are handled by VectorToSCF or
// populateVectorTransferPermutationMapLoweringPatterns.
@@ -639,14 +643,19 @@ struct TransferWriteToVectorStoreLowering
<< write;
});
- rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
- write, write.getSource(), write.getIndices(), write.getMask(),
- write.getVector());
+ rewriter
+ .create<vector::MaskedStoreOp>(write.getLoc(), write.getSource(),
+ write.getIndices(), write.getMask(),
+ write.getVector())
+ .getBase();
+ return Value();
} else {
- rewriter.replaceOpWithNewOp<vector::StoreOp>(
- write, write.getVector(), write.getSource(), write.getIndices());
+ rewriter
+ .create<vector::StoreOp>(write.getLoc(), write.getVector(),
+ write.getSource(), write.getIndices())
+ .getBase();
+ return Value();
}
- return success();
}
std::optional<unsigned> maxTransferRank;
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 2f2bdcaab5b3e..a789aac717dab 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -392,6 +392,41 @@ func.func @transfer_2D_masked(%mem : memref<?x?xf32>, %mask : vector<2x4xi1>) ->
return %res : vector<2x4xf32>
}
+// transfer_read/write are lowered to vector.load/store
+// CHECK-LABEL: func @masked_transfer_to_load(
+// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index,
+// CHECK-SAME: %[[MASK:.*]]: vector<4xi1>) -> memref<8x8xf32>
+// CHECK-NOT: vector.load
+// CHECK-NOT: vector.store
+// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %arg0[%[[IDX]], %[[IDX]]]{{.*}} : memref<8x8xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[MEM]][%[[IDX]], %[[IDX]]]{{.*}} : vector<4xf32>, memref<8x8xf32> } : vector<4xi1>
+
+
+func.func @masked_transfer_to_load(%mem : memref<8x8xf32>, %i : index, %mask : vector<4xi1>) -> memref<8x8xf32> {
+ %cf0 = arith.constant 0.0 : f32
+ %read = vector.mask %mask { vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>} : vector<4xi1> -> vector<4xf32>
+ vector.mask %mask {vector.transfer_write %read, %mem[%i, %i] {in_bounds = [true]} : vector<4xf32>, memref<8x8xf32> } : vector<4xi1>
+ return %mem : memref<8x8xf32>
+}
+
+// n-D results are also supported.
+// CHECK-LABEL: func @masked_transfer_2D(
+// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index,
+// CHECK-SAME: %[[MASK:.*]]: vector<2x4xi1>) -> memref<8x8xf32>
+// CHECK-NOT: vector.load
+// CHECK-NOT: vector.store
+// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]]{{.*}} : memref<8x8xf32>, vector<2x4xf32> } : vector<2x4xi1> -> vector<2x4xf32>
+// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[MEM]][%[[IDX]], %[[IDX]]]{{.*}}: vector<2x4xf32>, memref<8x8xf32> } : vector<2x4xi1>
+
+func.func @masked_transfer_2D(%mem : memref<8x8xf32>, %i : index, %mask : vector<2x4xi1>) -> memref<8x8xf32> {
+ %cf0 = arith.constant 0.0 : f32
+ %read = vector.mask %mask { vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true, true]} : memref<8x8xf32>, vector<2x4xf32> } : vector<2x4xi1> -> vector<2x4xf32>
+ vector.mask %mask {vector.transfer_write %read, %mem[%i, %i] {in_bounds = [true, true]} : vector<2x4xf32>, memref<8x8xf32> } : vector<2x4xi1>
+ return %mem : memref<8x8xf32>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
|
Final split of this original MR ! Thanks @banach-space and @MacDue. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few comments/nits:
@@ -479,7 +483,7 @@ struct TransferReadToVectorLoadLowering | |||
return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask"); | |||
|
|||
// Create vector load op. | |||
Operation *loadOp; | |||
Operation *res; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why rename this from loadOp?
Operation *res; | |
Operation *loadOp; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because it can be a BroadcastOp. from l.507
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Hugo! I see 2 new tests, but it’s not obvious which corresponds to which pattern.
mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
Outdated
Show resolved
Hide resolved
Ping 🏓 |
mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same question here: is it an NFC or adding the support to scalable vectors? Can you add the information to PR description?
rewriter.create<vector::MaskedStoreOp>( | ||
write.getLoc(), write.getSource(), write.getIndices(), | ||
write.getMask(), write.getVector()); | ||
return Value(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is confusing when it returns Value()
. I think it is better to return a FailureOr<Operation *>
in matchAndRewriteMaskableOp
method. Because we eventually will just replace the op with the new operation.
And perhaps we don't need the check anymore.
llvm-project/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
Lines 160 to 167 in d4ff961
// Rewriting succeeded but there are no values to replace. | |
if (rootOp->getNumResults() == 0) { | |
rewriter.eraseOp(rootOp); | |
} else { | |
assert(*newOp != Value() && | |
"Cannot replace an op's use with an empty value."); | |
rewriter.replaceOp(rootOp, *newOp); | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to your comment @hanhanW , but this is a bit more nuanced 😅
I implemented it this because there's quite a few places that simply return Value
, e.g.:
llvm-project/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
Lines 432 to 455 in 77db8b0
FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, | |
VectorType lhsType, int reductionSize, | |
std::optional<Value> maybeMask = std::nullopt) { | |
// Incremental support for masking. | |
if (mask && !maybeMask.has_value()) | |
return failure(); | |
Type resElementType = cast<VectorType>(res.getType()).getElementType(); | |
for (int64_t k = 0; k < reductionSize; ++k) { | |
Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k); | |
Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k); | |
extractA = promote(extractA, resElementType); | |
extractB = promote(extractB, resElementType); | |
Value extractMask; | |
if (maybeMask.has_value() && maybeMask.value()) | |
extractMask = | |
rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k); | |
Operation *outerProdOp = rewriter.create<vector::OuterProductOp>( | |
loc, res.getType(), extractA, extractB, res, kind); | |
res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0); | |
} | |
return res; | |
} |
Lambdas like that get used a lot :) I recall trying to update that and other similar examples, but never got round to it. Sounds like a worthwhile TODO!
Comments updated and second test removed as duplicate.
c63692c
to
eb9f46c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No objections from me for landing this :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
…RewritePattern (llvm#92892) Implements `TransferReadToVectorLoadLowering` and `TransferWriteToVectorStoreLowering` as a `MaskableOpRewritePattern`. Allowing to exit gracefully when run on an xferOp located inside a `vector::MaskOp` instead of breaking because the pattern generated multiple ops in the MaskOp with `error: 'vector.mask' op expects only one operation to mask`. Split of llvm#90835
Implements
TransferReadToVectorLoadLowering
andTransferWriteToVectorStoreLowering
as aMaskableOpRewritePattern
. Allowing to exit gracefully when run on the xferOp is located inside avector::MaskOp
instead of breaking because the pattern generated multiple ops in the MaskOp witherror: 'vector.mask' op expects only one operation to mask
.Split of #90835