Skip to content

[MLIR][Vector] Fix transferOps optimization inside maskOp #90835

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,6 @@ add_mlir_dialect_library(MLIRTensorTransforms
MLIRTilingInterface
MLIRTransforms
MLIRVectorDialect
MLIRVectorUtils
MLIRValueBoundsOpInterface
)
29 changes: 18 additions & 11 deletions mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Expand Down Expand Up @@ -48,12 +49,14 @@ static Value getTensorOperand(tensor::InsertSliceOp op) {
namespace {
/// Merge extract_slice operation with load/transferRead operation.
class TransferReadOfExtractSliceOpFolder final
: public OpRewritePattern<vector::TransferReadOp> {
: public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
public:
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
using MaskableOpRewritePattern::MaskableOpRewritePattern;

LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const override;
FailureOr<mlir::Value>
matchAndRewriteMaskableOp(vector::TransferReadOp readOp,
vector::MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override;
};

/// Merge insert_slice operation with store/transferWriteOp operation.
Expand Down Expand Up @@ -84,8 +87,10 @@ static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(
return success();
}

LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
vector::TransferReadOp readOp, PatternRewriter &rewriter) const {
FailureOr<mlir::Value>
TransferReadOfExtractSliceOpFolder::matchAndRewriteMaskableOp(
vector::TransferReadOp readOp, vector::MaskingOpInterface maskOp,
PatternRewriter &rewriter) const {
auto extractSliceOp =
getTensorOperand(readOp).getDefiningOp<tensor::ExtractSliceOp>();
if (!extractSliceOp)
Expand All @@ -95,7 +100,7 @@ LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
preconditionsFoldExtractOrInsertWithTransferOp(rewriter, readOp,
extractSliceOp);
if (failed(preconditionResult))
return preconditionResult;
return rewriter.notifyMatchFailure(readOp, "Failed preconditions");

SmallVector<Value> indices(readOp.getIndices().begin(),
readOp.getIndices().end());
Expand All @@ -105,15 +110,17 @@ LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
indices, sourceIndices);

rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
readOp, readOp.getVectorType(), extractSliceOp.getSource(), sourceIndices,
Operation *newOp = rewriter.create<vector::TransferReadOp>(
readOp.getLoc(), readOp.getVectorType(), extractSliceOp.getSource(),
sourceIndices,
AffineMapAttr::get(expandDimsToRank(
readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(),
extractSliceOp.getDroppedDims())),
readOp.getPadding(),
/*mask=*/Value(), readOp.getInBoundsAttr());

return success();
if (maskOp)
newOp = mlir::vector::maskOperation(rewriter, newOp, maskOp.getMask());
return newOp->getResults()[0];
}

LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
Expand Down
145 changes: 85 additions & 60 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,18 @@ namespace {
/// Note that an alternative is to transform it to linalg.transpose +
/// vector.transfer_read to do the transpose in memory instead.
struct TransferReadPermutationLowering
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern::OpRewritePattern;
: public MaskableOpRewritePattern<vector::TransferReadOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;

LogicalResult matchAndRewrite(vector::TransferReadOp op,
PatternRewriter &rewriter) const override {
FailureOr<mlir::Value>
matchAndRewriteMaskableOp(vector::TransferReadOp op,
MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
if (maskOp)
return rewriter.notifyMatchFailure(op, "Masked case not supported");
Comment on lines +103 to +104
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be supported?


SmallVector<unsigned> permutation;
AffineMap map = op.getPermutationMap();
Expand Down Expand Up @@ -142,9 +146,9 @@ struct TransferReadPermutationLowering

// Transpose result of transfer_read.
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
transposePerm);
return success();
return rewriter
.create<vector::TransposeOp>(op.getLoc(), newRead, transposePerm)
.getResult();
}
};

Expand All @@ -165,14 +169,18 @@ struct TransferReadPermutationLowering
/// %v = vector.transfer_write %tmp ...
/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
struct TransferWritePermutationLowering
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern::OpRewritePattern;
: public MaskableOpRewritePattern<vector::TransferWriteOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;

LogicalResult matchAndRewrite(vector::TransferWriteOp op,
PatternRewriter &rewriter) const override {
FailureOr<mlir::Value>
matchAndRewriteMaskableOp(vector::TransferWriteOp op,
MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
if (maskOp)
return rewriter.notifyMatchFailure(op, "Masked case not supported");
Comment on lines +182 to +183
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be supported?


SmallVector<unsigned> permutation;
AffineMap map = op.getPermutationMap();
Expand Down Expand Up @@ -207,11 +215,11 @@ struct TransferWritePermutationLowering
op.getLoc(), op.getVector(), indices);
auto newMap = AffineMap::getMinorIdentityMap(
map.getNumDims(), map.getNumResults(), rewriter.getContext());
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
op.getMask(), newInBoundsAttr);

return success();
return rewriter
.create<vector::TransferWriteOp>(
op.getLoc(), newVec, op.getSource(), op.getIndices(),
AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr)
.getResult();
}
};

Expand All @@ -231,14 +239,18 @@ struct TransferWritePermutationLowering
/// vector<1x8x16xf32>
/// ```
struct TransferWriteNonPermutationLowering
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern::OpRewritePattern;
: public MaskableOpRewritePattern<vector::TransferWriteOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;

LogicalResult matchAndRewrite(vector::TransferWriteOp op,
PatternRewriter &rewriter) const override {
FailureOr<mlir::Value>
matchAndRewriteMaskableOp(vector::TransferWriteOp op,
MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
if (maskOp)
return rewriter.notifyMatchFailure(op, "Masked case not supported");
Comment on lines +252 to +253
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be suppported?


SmallVector<unsigned> permutation;
AffineMap map = op.getPermutationMap();
Expand Down Expand Up @@ -285,10 +297,11 @@ struct TransferWriteNonPermutationLowering
newInBoundsValues.push_back(op.isDimInBounds(i));
}
ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
newMask, newInBoundsAttr);
return success();
return rewriter
.create<vector::TransferWriteOp>(
op.getLoc(), newVec, op.getSource(), op.getIndices(),
AffineMapAttr::get(newMap), newMask, newInBoundsAttr)
.getResult();
}
};

Expand All @@ -300,14 +313,19 @@ struct TransferWriteNonPermutationLowering
/// %v = vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
/// vector.broadcast %v
struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::TransferReadOp op,
PatternRewriter &rewriter) const override {
struct TransferOpReduceRank
: public MaskableOpRewritePattern<vector::TransferReadOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;

FailureOr<mlir::Value>
matchAndRewriteMaskableOp(vector::TransferReadOp op,
MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
if (maskOp)
return rewriter.notifyMatchFailure(op, "Masked case not supported");
Comment on lines +327 to +328
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be supported?


AffineMap map = op.getPermutationMap();
unsigned numLeadingBroadcast = 0;
Expand Down Expand Up @@ -347,9 +365,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
op.getLoc(), originalVecType.getElementType(), op.getSource(),
op.getIndices());
}
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
newRead);
return success();
return rewriter
.create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
.getVector();
}

SmallVector<int64_t> newShape(
Expand All @@ -371,9 +389,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
op.getLoc(), newReadType, op.getSource(), op.getIndices(),
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
newInBoundsAttr);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
newRead);
return success();
return rewriter
.create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
.getVector();
}
};

Expand Down Expand Up @@ -401,20 +419,24 @@ namespace {
/// result type.
/// - The permutation map doesn't perform permutation (broadcasting is allowed).
struct TransferReadToVectorLoadLowering
: public OpRewritePattern<vector::TransferReadOp> {
: public MaskableOpRewritePattern<vector::TransferReadOp> {
TransferReadToVectorLoadLowering(MLIRContext *context,
std::optional<unsigned> maxRank,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransferReadOp>(context, benefit),
: MaskableOpRewritePattern<vector::TransferReadOp>(context, benefit),
maxTransferRank(maxRank) {}

LogicalResult matchAndRewrite(vector::TransferReadOp read,
PatternRewriter &rewriter) const override {
FailureOr<mlir::Value>
matchAndRewriteMaskableOp(vector::TransferReadOp read,
MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override {
if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) {
return rewriter.notifyMatchFailure(
read, "vector type is greater than max transfer rank");
}

if (maskOp)
return rewriter.notifyMatchFailure(read, "Masked case not supported");
Comment on lines +438 to +439
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be supported?

SmallVector<unsigned> broadcastedDims;
// Permutations are handled by VectorToSCF or
// populateVectorTransferPermutationMapLoweringPatterns.
Expand Down Expand Up @@ -457,7 +479,7 @@ struct TransferReadToVectorLoadLowering
return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask");

// Create vector load op.
Operation *loadOp;
Operation *res;
if (read.getMask()) {
if (read.getVectorType().getRank() != 1)
// vector.maskedload operates on 1-D vectors.
Expand All @@ -467,24 +489,20 @@ struct TransferReadToVectorLoadLowering

Value fill = rewriter.create<vector::SplatOp>(
read.getLoc(), unbroadcastedVectorType, read.getPadding());
loadOp = rewriter.create<vector::MaskedLoadOp>(
res = rewriter.create<vector::MaskedLoadOp>(
read.getLoc(), unbroadcastedVectorType, read.getSource(),
read.getIndices(), read.getMask(), fill);
} else {
loadOp = rewriter.create<vector::LoadOp>(
res = rewriter.create<vector::LoadOp>(
read.getLoc(), unbroadcastedVectorType, read.getSource(),
read.getIndices());
}

// Insert a broadcasting op if required.
if (!broadcastedDims.empty()) {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
read, read.getVectorType(), loadOp->getResult(0));
} else {
rewriter.replaceOp(read, loadOp->getResult(0));
}

return success();
if (!broadcastedDims.empty())
res = rewriter.create<vector::BroadcastOp>(
read.getLoc(), read.getVectorType(), res->getResult(0));
return res->getResults()[0];
}

std::optional<unsigned> maxTransferRank;
Expand Down Expand Up @@ -553,19 +571,23 @@ struct VectorStoreToMemrefStoreLowering
/// - The permutation map is the minor identity map (neither permutation nor
/// broadcasting is allowed).
struct TransferWriteToVectorStoreLowering
: public OpRewritePattern<vector::TransferWriteOp> {
: public MaskableOpRewritePattern<vector::TransferWriteOp> {
TransferWriteToVectorStoreLowering(MLIRContext *context,
std::optional<unsigned> maxRank,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransferWriteOp>(context, benefit),
: MaskableOpRewritePattern<vector::TransferWriteOp>(context, benefit),
maxTransferRank(maxRank) {}

LogicalResult matchAndRewrite(vector::TransferWriteOp write,
PatternRewriter &rewriter) const override {
FailureOr<mlir::Value>
matchAndRewriteMaskableOp(vector::TransferWriteOp write,
MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override {
if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) {
return rewriter.notifyMatchFailure(
write, "vector type is greater than max transfer rank");
}
if (maskOp)
return rewriter.notifyMatchFailure(write, "Masked case not supported");
Comment on lines +589 to +590
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be supported?


// Permutations are handled by VectorToSCF or
// populateVectorTransferPermutationMapLoweringPatterns.
Expand Down Expand Up @@ -617,14 +639,17 @@ struct TransferWriteToVectorStoreLowering
<< write;
});

rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
write, write.getSource(), write.getIndices(), write.getMask(),
write.getVector());
return rewriter
.create<vector::MaskedStoreOp>(write.getLoc(), write.getSource(),
write.getIndices(), write.getMask(),
write.getVector())
.getBase();
} else {
rewriter.replaceOpWithNewOp<vector::StoreOp>(
write, write.getVector(), write.getSource(), write.getIndices());
return rewriter
.create<vector::StoreOp>(write.getLoc(), write.getVector(),
write.getSource(), write.getIndices())
.getBase();
}
return success();
}

std::optional<unsigned> maxTransferRank;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,18 @@ func.func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32
%1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
return %1 : tensor<?x?x12xf32>
}

// CHECK-LABEL: func @masked_transfer_read_of_extract_slice
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
// CHECK-DAG: %[[m:.*]] = vector.create_mask{{.*}} : vector<5x6xi1>
// CHECK-DAG: %[[a:.*]] = affine.apply {{.*}}[[s1]]
// CHECK: vector.mask %[[m]] { vector.transfer_read %[[t]]{{.*}}: tensor<?x?xf32>, vector<5x6xf32> } : vector<5x6xi1> -> vector<5x6xf32>
func.func @masked_transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%cst = arith.constant 0.0 : f32
%0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
%mask = vector.create_mask %c3, %c4 : vector<5x6xi1>
%1 = vector.mask %mask {vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<10x?xf32>, vector<5x6xf32>} : vector<5x6xi1> -> vector<5x6xf32>
return %1 : vector<5x6xf32>
}
Loading
Loading