-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR][Vector] Implement TransferReadOfExtractSliceOp as MaskableOpRewritePattern #91960
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: Hugo Trachino (nujaa) ChangesSplit of #90835 Full diff: https://github.com/llvm/llvm-project/pull/91960.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index c6ef6ed86e0d9..0aabdaf667b9d 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -39,5 +39,6 @@ add_mlir_dialect_library(MLIRTensorTransforms
MLIRTilingInterface
MLIRTransforms
MLIRVectorDialect
+ MLIRVectorUtils
MLIRValueBoundsOpInterface
)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index 3b8d3708bb731..5396531922aab 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -48,12 +49,14 @@ static Value getTensorOperand(tensor::InsertSliceOp op) {
namespace {
/// Merge extract_slice operation with load/transferRead operation.
class TransferReadOfExtractSliceOpFolder final
- : public OpRewritePattern<vector::TransferReadOp> {
+ : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
public:
- using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
- LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
- PatternRewriter &rewriter) const override;
+ FailureOr<mlir::Value>
+ matchAndRewriteMaskableOp(vector::TransferReadOp readOp,
+ vector::MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override;
};
/// Merge insert_slice operation with store/transferWriteOp operation.
@@ -84,8 +87,10 @@ static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(
return success();
}
-LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
- vector::TransferReadOp readOp, PatternRewriter &rewriter) const {
+FailureOr<mlir::Value>
+TransferReadOfExtractSliceOpFolder::matchAndRewriteMaskableOp(
+ vector::TransferReadOp readOp, vector::MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const {
auto extractSliceOp =
getTensorOperand(readOp).getDefiningOp<tensor::ExtractSliceOp>();
if (!extractSliceOp)
@@ -95,7 +100,7 @@ LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
preconditionsFoldExtractOrInsertWithTransferOp(rewriter, readOp,
extractSliceOp);
if (failed(preconditionResult))
- return preconditionResult;
+ return rewriter.notifyMatchFailure(readOp, "Failed preconditions");
SmallVector<Value> indices(readOp.getIndices().begin(),
readOp.getIndices().end());
@@ -105,15 +110,17 @@ LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
indices, sourceIndices);
- rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
- readOp, readOp.getVectorType(), extractSliceOp.getSource(), sourceIndices,
+ Operation *newOp = rewriter.create<vector::TransferReadOp>(
+ readOp.getLoc(), readOp.getVectorType(), extractSliceOp.getSource(),
+ sourceIndices,
AffineMapAttr::get(expandDimsToRank(
readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(),
extractSliceOp.getDroppedDims())),
readOp.getPadding(),
/*mask=*/Value(), readOp.getInBoundsAttr());
-
- return success();
+ if (maskOp)
+ newOp = mlir::vector::maskOperation(rewriter, newOp, maskOp.getMask());
+ return newOp->getResults()[0];
}
LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
diff --git a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir
index 6213db3956f9a..214b41461b98f 100644
--- a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir
+++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir
@@ -111,3 +111,18 @@ func.func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32
%1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
return %1 : tensor<?x?x12xf32>
}
+
+// CHECK-LABEL: func @masked_transfer_read_of_extract_slice
+// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+// CHECK-DAG: %[[m:.*]] = vector.create_mask{{.*}} : vector<5x6xi1>
+// CHECK-DAG: %[[a:.*]] = affine.apply {{.*}}[[s1]]
+// CHECK: vector.mask %[[m]] { vector.transfer_read %[[t]]{{.*}}: tensor<?x?xf32>, vector<5x6xf32> } : vector<5x6xi1> -> vector<5x6xf32>
+func.func @masked_transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+ %cst = arith.constant 0.0 : f32
+ %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
+ %mask = vector.create_mask %c3, %c4 : vector<5x6xi1>
+ %1 = vector.mask %mask {vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<10x?xf32>, vector<5x6xf32>} : vector<5x6xi1> -> vector<5x6xf32>
+ return %1 : vector<5x6xf32>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! LGTM 🥳 (left one nit)
mlir/test/Dialect/Tensor/fold-tensor-subset-ops-into-vector-transfers.mlir
Outdated
Show resolved
Hide resolved
…leOpRewritePattern
…ansfers.mlir Co-authored-by: Andrzej Warzyński <[email protected]>
Split of #90835
Adds support for
TransferReadOfExtractSliceOpFolder
when theTransferReadOp
is inside aMaskOp
.