Skip to content

[mlir][vector] Patterns to convert to shape_cast, where possible #138777

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ namespace vector {
class ContractionOp;
class TransferReadOp;
class TransferWriteOp;
class TransposeOp;
class VectorDialect;

namespace detail {
Expand Down Expand Up @@ -171,6 +172,12 @@ SmallVector<Value> getAsValues(OpBuilder &builder, Location loc,
/// `std::nullopt`.
std::optional<int64_t> getConstantVscaleMultiplier(Value value);

/// Return true if `transpose` does not permute a pair of non-unit dims.
/// By `order preserving` we mean that the flattened versions of the input and
/// output vectors are (numerically) identical. In other words `transpose` is
/// effectively a shape cast.
bool isOrderPreserving(TransposeOp transpose);

//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,26 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
void populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);

/// Add patterns that convert operations that are semantically equivalent to
/// shape_cast, to shape_cast. Currently this includes patterns for converting
/// transpose, extract and broadcast to shape_cast. Examples that will be
/// converted to shape_cast are:
///
/// ```
/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
/// %1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8>
/// %2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>
/// ```
///
/// Note that there is no pattern for vector.extract_strided_slice, because the
/// only extract_strided_slice that is semantically equivalent to shape_cast is
/// one that has idential input and output shapes, which is already folded.
///
/// These patterns can be useful to expose more folding opportunities by
/// creating pairs of shape_casts that cancel.
void populateConvertToShapeCastPatterns(RewritePatternSet &,
PatternBenefit = 1);

/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
/// This registers (1) which operations are legal and hence should not be
/// linearized, (2) what converted types are (rank-1 vectors) and how to
Expand Down
6 changes: 1 addition & 5 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5574,13 +5574,11 @@ LogicalResult ShapeCastOp::verify() {
return success();
}

namespace {

/// Return true if `transpose` does not permute a pair of non-unit dims.
/// By `order preserving` we mean that the flattened versions of the input and
/// output vectors are (numerically) identical. In other words `transpose` is
/// effectively a shape cast.
bool isOrderPreserving(TransposeOp transpose) {
bool mlir::vector::isOrderPreserving(TransposeOp transpose) {
ArrayRef<int64_t> permutation = transpose.getPermutation();
VectorType sourceType = transpose.getSourceVectorType();
ArrayRef<int64_t> inShape = sourceType.getShape();
Expand All @@ -5600,8 +5598,6 @@ bool isOrderPreserving(TransposeOp transpose) {
return true;
}

} // namespace

OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {

VectorType resultType = getType();
Expand Down
93 changes: 93 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2182,6 +2182,92 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
}
};

/// For example,
/// ```
/// %0 = vector.transpose %arg0, [0, 2, 1] :
/// vector<2x1x2xf32> to vector<2x2x1xf32>
/// ```
/// becomes
/// ```
/// %0 = vector.shape_cast %arg0 :
/// vector<2x1x2xf32> to vector<2x2x1xf32>
/// ```
struct TransposeToShapeCast final
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to clean up other patterns because we have a similar pattern but just for 2D vectors. Please coordinate with @banach-space about scalable vector for the support.

/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
/// to 2D vectors with at least one unit dim. For example:
///
/// Replace:
/// vector.transpose %0, [1, 0] : vector<4x1xi32>> to
/// vector<1x4xi32>
/// with:
/// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
///
/// Source with leading unit dim (inverse) is also replaced. Unit dim must
/// be fixed. Non-unit dim can be scalable.
///
/// TODO: This pattern was introduced specifically to help lower scalable
/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
/// to cancel out) would be preferable:
///
/// BEFORE:
/// %0 = some_op
/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
/// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
/// AFTER:
/// %0 = some_op
/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
///
/// Given the context above, we may want to consider (re-)moving this pattern
/// at some later time. I am leaving it for now in case there are other users
/// that I am not aware of.
class Transpose2DWithUnitDimToShapeCast
: public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;
Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}
LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
Value input = op.getVector();
VectorType resType = op.getResultVectorType();
// Set up convenience transposition table.
ArrayRef<int64_t> transp = op.getPermutation();
if (resType.getRank() == 2 &&
((resType.getShape().front() == 1 &&
!resType.getScalableDims().front()) ||
(resType.getShape().back() == 1 &&
!resType.getScalableDims().back())) &&
transp == ArrayRef<int64_t>({1, 0})) {
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
return success();
}
return failure();
}
};

: public OpRewritePattern<vector::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransposeOp transpose,
PatternRewriter &rewriter) const override {
if (!isOrderPreserving(transpose)) {
return rewriter.notifyMatchFailure(
transpose, "not order preserving, so not semantically a 'copy'");
}
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
transpose, transpose.getType(), transpose.getVector());
return success();
}
};

/// For example,
/// ```
/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
/// ```
/// becomes
/// ```
/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
/// ```
struct BroadcastToShapeCast final
: public OpRewritePattern<vector::BroadcastOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
PatternRewriter &rewriter) const override {
auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
if (!sourceType) {
return rewriter.notifyMatchFailure(
broadcast, "source is a scalar, shape_cast doesn't support scalar");
}

VectorType outType = broadcast.getType();
if (sourceType.getNumElements() != outType.getNumElements())
return failure();

rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
broadcast.getSource());
return success();
}
};

/// For example,
/// ```
/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
/// ```
/// becomes
/// ```
/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
/// ```
struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
VectorType sourceType = extractOp.getSourceVectorType();
VectorType outType = dyn_cast<VectorType>(extractOp.getType());
if (!outType)
return failure();

// Negative values in `position` indicates poison, cannot convert to
// shape_cast
if (llvm::any_of(extractOp.getMixedPosition(),
[](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
return failure();

if (sourceType.getNumElements() != outType.getNumElements())
return failure();

rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
extractOp.getVector());
return success();
}
};

} // namespace

void mlir::vector::populateFoldArithExtensionPatterns(
Expand Down Expand Up @@ -2285,6 +2371,13 @@ void mlir::vector::populateElementwiseToVectorOpsPatterns(
patterns.getContext());
}

void mlir::vector::populateConvertToShapeCastPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns
.insert<TransposeToShapeCast, BroadcastToShapeCast, ExtractToShapeCast>(
patterns.getContext(), benefit);
}

//===----------------------------------------------------------------------===//
// TableGen'd enum attribute definitions
//===----------------------------------------------------------------------===//
Expand Down
65 changes: 65 additions & 0 deletions mlir/test/Dialect/Vector/convert-to-shape-cast.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// RUN: mlir-opt %s -split-input-file -test-convert-to-shape-cast | FileCheck %s


// CHECK-LABEL: @transpose_to_shape_cast
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
// CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32>
func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
%0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
return %0 : vector<2x2x1xf32>
}

// -----

// CHECK-LABEL: @negative_transpose_to_shape_cast
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1]
// CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32>
func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
%0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
return %0 : vector<2x2x1xf32>
}

// -----

// CHECK-LABEL: @broadcast_to_shape_cast
// CHECK-SAME: %[[ARG0:.*]]: vector<4xi8>
// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
// CHECK-NEXT: return %[[SCAST]] : vector<1x1x4xi8>
func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
return %0 : vector<1x1x4xi8>
}

// -----

// CHECK-LABEL: @negative_broadcast_to_shape_cast
// CHECK-NOT: shape_cast
// CHECK: return
func.func @negative_broadcast_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> {
%0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8>
return %0 : vector<2x3x4xi8>
}

// -----

// CHECK-LABEL: @extract_to_shape_cast
// CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32>
// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
// CHECK-NEXT: return %[[SCAST]] : vector<4xf32>
func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> {
%0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
return %0 : vector<4xf32>
}

// -----

// In this example, arg1 might be negative indicating poison.
// CHECK-LABEL: @negative_extract_to_shape_cast
// CHECK-NOT: shape_cast
func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> {
%0 = vector.extract %arg0[%arg1] : vector<4xf32> from vector<1x4xf32>
return %0 : vector<4xf32>
}

22 changes: 22 additions & 0 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,26 @@ struct TestEliminateVectorMasks
VscaleRange{vscaleMin, vscaleMax});
}
};

struct TestConvertToShapeCast
: public PassWrapper<TestConvertToShapeCast, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertToShapeCast)

TestConvertToShapeCast() = default;

StringRef getArgument() const final { return "test-convert-to-shape-cast"; }
StringRef getDescription() const final {
return "Test conversion to shape_cast of semantically equivalent ops";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<vector::VectorDialect>();
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateConvertToShapeCastPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
} // namespace

namespace mlir {
Expand Down Expand Up @@ -1072,6 +1092,8 @@ void registerTestVectorLowerings() {
PassRegistration<vendor::TestVectorBitWidthLinearize>();

PassRegistration<TestEliminateVectorMasks>();

PassRegistration<TestConvertToShapeCast>();
}
} // namespace test
} // namespace mlir