Skip to content

[MLIR][Vector] Implement TransferReadOfExtractSliceOp as MaskableOpRewritePattern #91960

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 16, 2024

Conversation

nujaa
Copy link
Contributor

@nujaa nujaa commented May 13, 2024

Split of #90835
Adds support for TransferReadOfExtractSliceOpFolder when the TransferReadOp is inside a MaskOp.

@nujaa
Copy link
Contributor Author

nujaa commented May 13, 2024

CC @banach-space

@llvmbot
Copy link
Member

llvmbot commented May 13, 2024

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Hugo Trachino (nujaa)

Changes

Split of #90835
Adds support for TransferReadOfExtractSliceOpFolder when the TransferReadOp is inside a MaskOp.


Full diff: https://github.com/llvm/llvm-project/pull/91960.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp (+18-11)
  • (modified) mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir (+15)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index c6ef6ed86e0d9..0aabdaf667b9d 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -39,5 +39,6 @@ add_mlir_dialect_library(MLIRTensorTransforms
   MLIRTilingInterface
   MLIRTransforms
   MLIRVectorDialect
+  MLIRVectorUtils
   MLIRValueBoundsOpInterface
 )
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index 3b8d3708bb731..5396531922aab 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -48,12 +49,14 @@ static Value getTensorOperand(tensor::InsertSliceOp op) {
 namespace {
 /// Merge extract_slice operation with load/transferRead operation.
 class TransferReadOfExtractSliceOpFolder final
-    : public OpRewritePattern<vector::TransferReadOp> {
+    : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
 public:
-  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
-                                PatternRewriter &rewriter) const override;
+  FailureOr<mlir::Value>
+  matchAndRewriteMaskableOp(vector::TransferReadOp readOp,
+                            vector::MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override;
 };
 
 /// Merge insert_slice operation with store/transferWriteOp operation.
@@ -84,8 +87,10 @@ static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(
   return success();
 }
 
-LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
-    vector::TransferReadOp readOp, PatternRewriter &rewriter) const {
+FailureOr<mlir::Value>
+TransferReadOfExtractSliceOpFolder::matchAndRewriteMaskableOp(
+    vector::TransferReadOp readOp, vector::MaskingOpInterface maskOp,
+    PatternRewriter &rewriter) const {
   auto extractSliceOp =
       getTensorOperand(readOp).getDefiningOp<tensor::ExtractSliceOp>();
   if (!extractSliceOp)
@@ -95,7 +100,7 @@ LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
       preconditionsFoldExtractOrInsertWithTransferOp(rewriter, readOp,
                                                      extractSliceOp);
   if (failed(preconditionResult))
-    return preconditionResult;
+    return rewriter.notifyMatchFailure(readOp, "Failed preconditions");
 
   SmallVector<Value> indices(readOp.getIndices().begin(),
                              readOp.getIndices().end());
@@ -105,15 +110,17 @@ LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
       extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
       indices, sourceIndices);
 
-  rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
-      readOp, readOp.getVectorType(), extractSliceOp.getSource(), sourceIndices,
+  Operation *newOp = rewriter.create<vector::TransferReadOp>(
+      readOp.getLoc(), readOp.getVectorType(), extractSliceOp.getSource(),
+      sourceIndices,
       AffineMapAttr::get(expandDimsToRank(
           readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(),
           extractSliceOp.getDroppedDims())),
       readOp.getPadding(),
       /*mask=*/Value(), readOp.getInBoundsAttr());
-
-  return success();
+  if (maskOp)
+    newOp = mlir::vector::maskOperation(rewriter, newOp, maskOp.getMask());
+  return newOp->getResults()[0];
 }
 
 LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
diff --git a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir
index 6213db3956f9a..214b41461b98f 100644
--- a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir
+++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir
@@ -111,3 +111,18 @@ func.func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32
   %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
   return %1 : tensor<?x?x12xf32>
 }
+
+// CHECK-LABEL: func @masked_transfer_read_of_extract_slice
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+//       CHECK-DAG: %[[m:.*]] = vector.create_mask{{.*}} : vector<5x6xi1>
+//       CHECK-DAG: %[[a:.*]] = affine.apply {{.*}}[[s1]]
+//       CHECK: vector.mask %[[m]] { vector.transfer_read %[[t]]{{.*}}: tensor<?x?xf32>, vector<5x6xf32> } : vector<5x6xi1> -> vector<5x6xf32>
+func.func @masked_transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : index
+  %cst = arith.constant 0.0 : f32
+  %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
+  %mask = vector.create_mask %c3, %c4 : vector<5x6xi1>
+  %1 = vector.mask %mask {vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<10x?xf32>, vector<5x6xf32>} : vector<5x6xi1> -> vector<5x6xf32>
+  return %1 : vector<5x6xf32>
+}

@nujaa nujaa force-pushed the hugo.TfOfExtract branch from d1e339f to 0464fb1 Compare May 13, 2024 13:32
@banach-space banach-space requested a review from dcaballe May 14, 2024 09:11
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! LGTM 🥳 (left one nit)

@nujaa nujaa force-pushed the hugo.TfOfExtract branch from 7f48b64 to 37b7f95 Compare May 14, 2024 13:51
@banach-space banach-space merged commit 1ede503 into llvm:main May 16, 2024
4 checks passed
keith added a commit that referenced this pull request May 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants