Skip to content

[vector][mlir] Canonicalize to shape_cast where possible #140583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 129 additions & 118 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2344,11 +2344,45 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
return success();
}

/// For example,
/// ```
/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
/// ```
/// becomes
/// ```
/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
/// ```
struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
VectorType sourceType = extractOp.getSourceVectorType();
VectorType outType = dyn_cast<VectorType>(extractOp.getType());
if (!outType)
return failure();

// Negative values in `position` indicates poison, cannot convert to
// shape_cast
if (llvm::any_of(extractOp.getMixedPosition(),
[](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
return failure();

if (sourceType.getNumElements() != outType.getNumElements())
return failure();

rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
extractOp.getVector());
return success();
}
};

} // namespace

void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
results
.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
}
Expand Down Expand Up @@ -2651,13 +2685,40 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
return success();
}
};

/// For example,
/// ```
/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
/// ```
/// becomes
/// ```
/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
/// ```
struct BroadcastToShapeCast final
: public OpRewritePattern<vector::BroadcastOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
PatternRewriter &rewriter) const override {
auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
if (!sourceType) {
return rewriter.notifyMatchFailure(
broadcast, "source is a scalar, shape_cast doesn't support scalar");
}

VectorType outType = broadcast.getType();
if (sourceType.getNumElements() != outType.getNumElements())
return failure();

rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
broadcast.getSource());
return success();
}
};
} // namespace

void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by
// calling `populateCastAwayVectorLeadingOneDimPatterns`
results.add<BroadcastFolder>(context);
results.add<BroadcastFolder, BroadcastToShapeCast>(context);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -5573,30 +5634,6 @@ LogicalResult ShapeCastOp::verify() {
return success();
}

/// Return true if `transpose` does not permute a pair of non-unit dims.
/// By `order preserving` we mean that the flattened versions of the input and
/// output vectors are (numerically) identical. In other words `transpose` is
/// effectively a shape cast.
static bool isOrderPreserving(TransposeOp transpose) {
ArrayRef<int64_t> permutation = transpose.getPermutation();
VectorType sourceType = transpose.getSourceVectorType();
ArrayRef<int64_t> inShape = sourceType.getShape();
ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
auto isNonScalableUnitDim = [&](int64_t dim) {
return inShape[dim] == 1 && !inDimIsScalable[dim];
};
int64_t current = 0;
for (auto p : permutation) {
if (!isNonScalableUnitDim(p)) {
if (p < current) {
return false;
}
current = p;
}
}
return true;
}

OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {

VectorType resultType = getType();
Expand All @@ -5611,33 +5648,6 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
return getResult();
}

// shape_cast(transpose(x)) -> shape_cast(x)
if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
// This folder does
// shape_cast(transpose) -> shape_cast
// But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
// shape_cast -> shape_cast(transpose)
// i.e. the complete opposite. When paired, these 2 patterns can cause
// infinite cycles in pattern rewriting.
// ConvertIllegalShapeCastOpsToTransposes only matches on scalable
// vectors, so by disabling this folder for scalable vectors the
// cycle is avoided.
// TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
// still needed. If it's not, then we can fold here.
if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) {
setOperand(transpose.getVector());
return getResult();
}
return {};
}

// Y = shape_cast(broadcast(X))
// -> X, if X and Y have same type
if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
if (bcastOp.getSourceType() == resultType)
return bcastOp.getSource();
}

// shape_cast(constant) -> constant
if (auto splatAttr =
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
Expand Down Expand Up @@ -5759,10 +5769,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
}
};

/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
/// i) Y = ShapeCast(X), or
/// ii) Y = Broadcast(X)
/// If both (i) and (ii) are possible, (i) is chosen.
/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
Expand All @@ -5777,22 +5784,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
bool srcIsScalar = !srcVectorType;

// Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
// Example:
// %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
// %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
// to
// %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
if (srcVectorType) {
if (srcVectorType.getNumElements() ==
shapeCastOp.getResultVectorType().getNumElements()) {
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
shapeCastOp, shapeCastOp.getResultVectorType(),
broadcastOp.getSource());
return success();
}
}

// Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
// Example
// %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
Expand Down Expand Up @@ -5993,21 +5984,6 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
return ub::PoisonAttr::get(getContext());

// Eliminate identity transposes, and more generally any transposes that
// preserves the shape without permuting elements.
//
// Examples of what to fold:
// %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
// %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
// %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
//
// Example of what NOT to fold:
// %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
//
if (getSourceVectorType() == getResultVectorType() &&
isOrderPreserving(*this))
return getVector();

return {};
}

Expand Down Expand Up @@ -6127,32 +6103,6 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};

/// Folds transpose(shape_cast) into a new shape_cast.
class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
auto shapeCastOp =
transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
if (!shapeCastOp)
return failure();
if (!isOrderPreserving(transposeOp))
return failure();

VectorType resultType = transposeOp.getType();

// We don't need to check isValidShapeCast at this point, because it is
// guaranteed that merging the transpose into the the shape_cast is a valid
// shape_cast, because the transpose just inserts/removes ones.

rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
shapeCastOp.getSource());
return success();
}
};

/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
/// 'order preserving', where 'order preserving' means the flattened
/// inputs and outputs of the transpose have identical (numerical) values.
Expand Down Expand Up @@ -6248,12 +6198,73 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
}
};

/// Return true if `transpose` does not permute a pair of non-unit dims.
/// By `order preserving` we mean that the flattened versions of the input and
/// output vectors are (numerically) identical. In other words `transpose` is
/// effectively a shape cast.
static bool isOrderPreserving(TransposeOp transpose) {
ArrayRef<int64_t> permutation = transpose.getPermutation();
VectorType sourceType = transpose.getSourceVectorType();
ArrayRef<int64_t> inShape = sourceType.getShape();
ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
auto isNonScalableUnitDim = [&](int64_t dim) {
return inShape[dim] == 1 && !inDimIsScalable[dim];
};
int64_t current = 0;
for (auto p : permutation) {
if (!isNonScalableUnitDim(p)) {
if (p < current) {
return false;
}
current = p;
}
}
return true;
}

/// For example,
/// ```
/// %0 = vector.transpose %arg0, [0, 2, 1] :
/// vector<2x1x2xf32> to vector<2x2x1xf32>
/// ```
/// becomes
/// ```
/// %0 = vector.shape_cast %arg0 :
/// vector<2x1x2xf32> to vector<2x2x1xf32>
/// ```
struct TransposeToShapeCast final
: public OpRewritePattern<vector::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransposeOp transpose,
PatternRewriter &rewriter) const override {

// This folder does
// shape_cast(transpose) -> shape_cast
// But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
// shape_cast -> shape_cast(transpose)
// i.e. the complete opposite. When paired, these 2 patterns can cause
// infinite cycles in pattern rewriting.
// ConvertIllegalShapeCastOpsToTransposes only matches on scalable
// vectors, so by disabling this folder for scalable vectors the
// cycle is avoided.
// TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
// still needed. If it's not, then we can fold here.
if (!isOrderPreserving(transpose) || transpose.getType().isScalable()) {
return rewriter.notifyMatchFailure(
transpose, "not order preserving, so not semantically a 'copy'");
}
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
transpose, transpose.getType(), transpose.getVector());
return success();
}
};

} // namespace

void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
FoldTransposeSplat, FoldTransposeBroadcast>(context);
results.add<FoldTransposeCreateMask, TransposeFolder, FoldTransposeSplat,
FoldTransposeBroadcast, TransposeToShapeCast>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading