Skip to content

[MLIR][Vector] Implement XferOp To {Load|Store}Lowering as MaskableOpRewritePattern #92892

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 18, 2024

Conversation

nujaa
Copy link
Contributor

@nujaa nujaa commented May 21, 2024

Implements TransferReadToVectorLoadLowering and TransferWriteToVectorStoreLowering as a MaskableOpRewritePattern. Allowing to exit gracefully when run on the xferOp is located inside a vector::MaskOp instead of breaking because the pattern generated multiple ops in the MaskOp with error: 'vector.mask' op expects only one operation to mask.

Split of #90835

@llvmbot
Copy link
Member

llvmbot commented May 21, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Hugo Trachino (nujaa)

Changes

Implements TransferReadToVectorLoadLowering and TransferWriteToVectorStoreLowering as a MaskableOpRewritePattern. Allowing to exit gracefully when such use of a xferOp is inside a vector::MaskOp

Split of #90835


Full diff: https://github.com/llvm/llvm-project/pull/92892.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp (+34-25)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir (+35)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index c59012266ceb3..9418a087c4367 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -423,20 +423,24 @@ namespace {
 ///   result type.
 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
 struct TransferReadToVectorLoadLowering
-    : public OpRewritePattern<vector::TransferReadOp> {
+    : public MaskableOpRewritePattern<vector::TransferReadOp> {
   TransferReadToVectorLoadLowering(MLIRContext *context,
                                    std::optional<unsigned> maxRank,
                                    PatternBenefit benefit = 1)
-      : OpRewritePattern<vector::TransferReadOp>(context, benefit),
+      : MaskableOpRewritePattern<vector::TransferReadOp>(context, benefit),
         maxTransferRank(maxRank) {}
 
-  LogicalResult matchAndRewrite(vector::TransferReadOp read,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<mlir::Value>
+  matchAndRewriteMaskableOp(vector::TransferReadOp read,
+                            MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override {
     if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) {
       return rewriter.notifyMatchFailure(
           read, "vector type is greater than max transfer rank");
     }
 
+    if (maskOp)
+      return rewriter.notifyMatchFailure(read, "Masked case not supported");
     SmallVector<unsigned> broadcastedDims;
     // Permutations are handled by VectorToSCF or
     // populateVectorTransferPermutationMapLoweringPatterns.
@@ -479,7 +483,7 @@ struct TransferReadToVectorLoadLowering
       return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask");
 
     // Create vector load op.
-    Operation *loadOp;
+    Operation *res;
     if (read.getMask()) {
       if (read.getVectorType().getRank() != 1)
         // vector.maskedload operates on 1-D vectors.
@@ -489,24 +493,20 @@ struct TransferReadToVectorLoadLowering
 
       Value fill = rewriter.create<vector::SplatOp>(
           read.getLoc(), unbroadcastedVectorType, read.getPadding());
-      loadOp = rewriter.create<vector::MaskedLoadOp>(
+      res = rewriter.create<vector::MaskedLoadOp>(
           read.getLoc(), unbroadcastedVectorType, read.getSource(),
           read.getIndices(), read.getMask(), fill);
     } else {
-      loadOp = rewriter.create<vector::LoadOp>(
+      res = rewriter.create<vector::LoadOp>(
           read.getLoc(), unbroadcastedVectorType, read.getSource(),
           read.getIndices());
     }
 
     // Insert a broadcasting op if required.
-    if (!broadcastedDims.empty()) {
-      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-          read, read.getVectorType(), loadOp->getResult(0));
-    } else {
-      rewriter.replaceOp(read, loadOp->getResult(0));
-    }
-
-    return success();
+    if (!broadcastedDims.empty())
+      res = rewriter.create<vector::BroadcastOp>(
+          read.getLoc(), read.getVectorType(), res->getResult(0));
+    return res->getResults()[0];
   }
 
   std::optional<unsigned> maxTransferRank;
@@ -575,19 +575,23 @@ struct VectorStoreToMemrefStoreLowering
 /// - The permutation map is the minor identity map (neither permutation nor
 ///   broadcasting is allowed).
 struct TransferWriteToVectorStoreLowering
-    : public OpRewritePattern<vector::TransferWriteOp> {
+    : public MaskableOpRewritePattern<vector::TransferWriteOp> {
   TransferWriteToVectorStoreLowering(MLIRContext *context,
                                      std::optional<unsigned> maxRank,
                                      PatternBenefit benefit = 1)
-      : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
+      : MaskableOpRewritePattern<vector::TransferWriteOp>(context, benefit),
         maxTransferRank(maxRank) {}
 
-  LogicalResult matchAndRewrite(vector::TransferWriteOp write,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<mlir::Value>
+  matchAndRewriteMaskableOp(vector::TransferWriteOp write,
+                            MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override {
     if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) {
       return rewriter.notifyMatchFailure(
           write, "vector type is greater than max transfer rank");
     }
+    if (maskOp)
+      return rewriter.notifyMatchFailure(write, "Masked case not supported");
 
     // Permutations are handled by VectorToSCF or
     // populateVectorTransferPermutationMapLoweringPatterns.
@@ -639,14 +643,19 @@ struct TransferWriteToVectorStoreLowering
                    << write;
             });
 
-      rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
-          write, write.getSource(), write.getIndices(), write.getMask(),
-          write.getVector());
+      rewriter
+          .create<vector::MaskedStoreOp>(write.getLoc(), write.getSource(),
+                                         write.getIndices(), write.getMask(),
+                                         write.getVector())
+          .getBase();
+      return Value();
     } else {
-      rewriter.replaceOpWithNewOp<vector::StoreOp>(
-          write, write.getVector(), write.getSource(), write.getIndices());
+      rewriter
+          .create<vector::StoreOp>(write.getLoc(), write.getVector(),
+                                   write.getSource(), write.getIndices())
+          .getBase();
+      return Value();
     }
-    return success();
   }
 
   std::optional<unsigned> maxTransferRank;
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 2f2bdcaab5b3e..a789aac717dab 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -392,6 +392,41 @@ func.func @transfer_2D_masked(%mem : memref<?x?xf32>, %mask : vector<2x4xi1>) ->
   return %res : vector<2x4xf32>
 }
 
+// transfer_read/write are lowered to vector.load/store
+// CHECK-LABEL:   func @masked_transfer_to_load(
+//  CHECK-SAME:                                %[[MEM:.*]]: memref<8x8xf32>,
+//  CHECK-SAME:                                %[[IDX:.*]]: index,
+//  CHECK-SAME:                                %[[MASK:.*]]: vector<4xi1>) -> memref<8x8xf32>
+//   CHECK-NOT:      vector.load 
+//   CHECK-NOT:      vector.store
+//       CHECK:      %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %arg0[%[[IDX]], %[[IDX]]]{{.*}} : memref<8x8xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+//       CHECK:      vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[MEM]][%[[IDX]], %[[IDX]]]{{.*}} : vector<4xf32>, memref<8x8xf32> } : vector<4xi1>
+
+
+func.func @masked_transfer_to_load(%mem : memref<8x8xf32>, %i : index, %mask : vector<4xi1>) -> memref<8x8xf32> {
+  %cf0 = arith.constant 0.0 : f32
+  %read = vector.mask %mask { vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>} : vector<4xi1> -> vector<4xf32>
+  vector.mask %mask {vector.transfer_write %read, %mem[%i, %i] {in_bounds = [true]} : vector<4xf32>, memref<8x8xf32> } : vector<4xi1> 
+  return %mem : memref<8x8xf32>
+}
+
+// n-D results are also supported.
+// CHECK-LABEL:   func @masked_transfer_2D(
+//  CHECK-SAME:                           %[[MEM:.*]]: memref<8x8xf32>,
+//  CHECK-SAME:                           %[[IDX:.*]]: index,
+//  CHECK-SAME:                           %[[MASK:.*]]: vector<2x4xi1>) -> memref<8x8xf32>
+//   CHECK-NOT:      vector.load 
+//   CHECK-NOT:      vector.store
+//       CHECK:      %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]]{{.*}} : memref<8x8xf32>, vector<2x4xf32> } : vector<2x4xi1> -> vector<2x4xf32>
+//       CHECK:      vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[MEM]][%[[IDX]], %[[IDX]]]{{.*}}: vector<2x4xf32>, memref<8x8xf32> } : vector<2x4xi1>
+
+func.func @masked_transfer_2D(%mem : memref<8x8xf32>, %i : index, %mask : vector<2x4xi1>) -> memref<8x8xf32> {
+  %cf0 = arith.constant 0.0 : f32
+  %read = vector.mask %mask { vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true, true]} : memref<8x8xf32>, vector<2x4xf32> } : vector<2x4xi1> -> vector<2x4xf32>
+  vector.mask %mask {vector.transfer_write %read, %mem[%i, %i] {in_bounds = [true, true]} : vector<2x4xf32>, memref<8x8xf32> } : vector<2x4xi1> 
+  return %mem : memref<8x8xf32>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
     %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">

@nujaa
Copy link
Contributor Author

nujaa commented May 21, 2024

Final split of this original MR ! Thanks @banach-space and @MacDue.

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments/nits:

@@ -479,7 +483,7 @@ struct TransferReadToVectorLoadLowering
return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask");

// Create vector load op.
Operation *loadOp;
Operation *res;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why rename this from loadOp?

Suggested change
Operation *res;
Operation *loadOp;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it can be a BroadcastOp. from l.507

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Hugo! I see 2 new tests, but it’s not obvious which corresponds to which pattern.

@nujaa
Copy link
Contributor Author

nujaa commented May 30, 2024

Ping 🏓

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question here: is it an NFC or adding the support to scalable vectors? Can you add the information to PR description?

rewriter.create<vector::MaskedStoreOp>(
write.getLoc(), write.getSource(), write.getIndices(),
write.getMask(), write.getVector());
return Value();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is confusing when it returns Value(). I think it is better to return a FailureOr<Operation *> in matchAndRewriteMaskableOp method. Because we eventually will just replace the op with the new operation.

And perhaps we don't need the check anymore.

// Rewriting succeeded but there are no values to replace.
if (rootOp->getNumResults() == 0) {
rewriter.eraseOp(rootOp);
} else {
assert(*newOp != Value() &&
"Cannot replace an op's use with an empty value.");
rewriter.replaceOp(rootOp, *newOp);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to your comment @hanhanW , but this is a bit more nuanced 😅

I implemented it this because there's quite a few places that simply return Value, e.g.:

FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
VectorType lhsType, int reductionSize,
std::optional<Value> maybeMask = std::nullopt) {
// Incremental support for masking.
if (mask && !maybeMask.has_value())
return failure();
Type resElementType = cast<VectorType>(res.getType()).getElementType();
for (int64_t k = 0; k < reductionSize; ++k) {
Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
extractA = promote(extractA, resElementType);
extractB = promote(extractB, resElementType);
Value extractMask;
if (maybeMask.has_value() && maybeMask.value())
extractMask =
rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k);
Operation *outerProdOp = rewriter.create<vector::OuterProductOp>(
loc, res.getType(), extractA, extractB, res, kind);
res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
}
return res;
}

Lambdas like that get used a lot :) I recall trying to update that and other similar examples, but never got round to it. Sounds like a worthwhile TODO!

@nujaa nujaa force-pushed the hugo.fixVectorization branch from c63692c to eb9f46c Compare May 31, 2024 09:35
@nujaa
Copy link
Contributor Author

nujaa commented Jun 17, 2024

Hi, are you satisfied with the changes or answers to your suggestions @hanhanW @MacDue ?

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No objections from me for landing this :)

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@nujaa nujaa merged commit 74941d0 into llvm:main Jun 18, 2024
7 checks passed
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
…RewritePattern (llvm#92892)

Implements `TransferReadToVectorLoadLowering` and
`TransferWriteToVectorStoreLowering` as a `MaskableOpRewritePattern`.
Allowing to exit gracefully when run on an xferOp located inside a
`vector::MaskOp` instead of breaking because the pattern generated
multiple ops in the MaskOp with `error: 'vector.mask' op expects only
one operation to mask`.

Split of llvm#90835
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants