-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][vector] Patterns to convert to shape_cast, where possible #138777
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
@llvm/pr-subscribers-mlir-vector Author: James Newling (newling) ChangesThese are all semantically just copies, and can be rewritten as shape_casts: %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> Currently the vector dialect has no strict specification of which of 2 equivalent forms is canonical, the unwritten rule seems to be that if it is not 'obvious' that a transformation results in something more canonical, it shouldn't be on a op's canonicalization method. So it's probably not worthwhile discussing here if these conversions to shape_cast should be part op canonicalizers! Nonetheless I've found these particular patterns useful in my work, so maybe they're a good addition upstream? Full diff: https://github.com/llvm/llvm-project/pull/138777.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 98fb6075cbf32..be9839ce26339 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -50,6 +50,7 @@ namespace vector {
class ContractionOp;
class TransferReadOp;
class TransferWriteOp;
+class TransposeOp;
class VectorDialect;
namespace detail {
@@ -171,6 +172,12 @@ SmallVector<Value> getAsValues(OpBuilder &builder, Location loc,
/// `std::nullopt`.
std::optional<int64_t> getConstantVscaleMultiplier(Value value);
+/// Return true if `transpose` does not permute a pair of non-unit dims.
+/// By `order preserving` we mean that the flattened versions of the input and
+/// output vectors are (numerically) identical. In other words `transpose` is
+/// effectively a shape cast.
+bool isOrderPreserving(TransposeOp transpose);
+
//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f1100d5cf8b68..3344765f4818a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -406,6 +406,26 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
void populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
+/// Add patterns that convert operations that are semantically equivalent to
+/// shape_cast, to shape_cast. Currently this includes patterns for converting
+/// transpose, extract and broadcast to shape_cast. Examples that will be
+/// converted to shape_cast are:
+///
+/// ```
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// %1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8>
+/// %2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>
+/// ```
+///
+/// Note that there is no pattern for vector.extract_strided_slice, because the
+/// only extract_strided_slice that is semantically equivalent to shape_cast is
+/// one that has idential input and output shapes, which is already folded.
+///
+/// These patterns can be useful to expose more folding opportunities by
+/// creating pairs of shape_casts that cancel.
+void populateConvertToShapeCastPatterns(RewritePatternSet &,
+ PatternBenefit = 1);
+
/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
/// This registers (1) which operations are legal and hence should not be
/// linearized, (2) what converted types are (rank-1 vectors) and how to
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f9c7fb7799eb0..562fc7d6ca110 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5574,13 +5574,11 @@ LogicalResult ShapeCastOp::verify() {
return success();
}
-namespace {
-
/// Return true if `transpose` does not permute a pair of non-unit dims.
/// By `order preserving` we mean that the flattened versions of the input and
/// output vectors are (numerically) identical. In other words `transpose` is
/// effectively a shape cast.
-bool isOrderPreserving(TransposeOp transpose) {
+bool mlir::vector::isOrderPreserving(TransposeOp transpose) {
ArrayRef<int64_t> permutation = transpose.getPermutation();
VectorType sourceType = transpose.getSourceVectorType();
ArrayRef<int64_t> inShape = sourceType.getShape();
@@ -5600,8 +5598,6 @@ bool isOrderPreserving(TransposeOp transpose) {
return true;
}
-} // namespace
-
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
VectorType resultType = getType();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b94c5fce64f83..efcde8e97c0cd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -2182,6 +2182,92 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
}
};
+/// For example,
+/// ```
+/// %0 = vector.transpose %arg0, [0, 2, 1] :
+/// vector<2x1x2xf32> to vector<2x2x1xf32>
+/// ```
+/// becomes
+/// ```
+/// %0 = vector.shape_cast %arg0 :
+/// vector<2x1x2xf32> to vector<2x2x1xf32>
+/// ```
+struct TransposeToShapeCast final
+ : public OpRewritePattern<vector::TransposeOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+ PatternRewriter &rewriter) const override {
+ if (!isOrderPreserving(transpose)) {
+ return rewriter.notifyMatchFailure(
+ transpose, "not order preserving, so not semantically a 'copy'");
+ }
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+ transpose, transpose.getType(), transpose.getVector());
+ return success();
+ }
+};
+
+/// For example,
+/// ```
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// ```
+/// becomes
+/// ```
+/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// ```
+struct BroadcastToShapeCast final
+ : public OpRewritePattern<vector::BroadcastOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
+ PatternRewriter &rewriter) const override {
+ auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
+ if (!sourceType) {
+ return rewriter.notifyMatchFailure(
+ broadcast, "source is a scalar, shape_cast doesn't support scalar");
+ }
+
+ VectorType outType = broadcast.getType();
+ if (sourceType.getNumElements() != outType.getNumElements())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
+ broadcast.getSource());
+ return success();
+ }
+};
+
+/// For example,
+/// ```
+/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+/// ```
+/// becomes
+/// ```
+/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+/// ```
+struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ VectorType sourceType = extractOp.getSourceVectorType();
+ VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+ if (!outType)
+ return failure();
+
+ // Negative values in `position` indicates poison, cannot convert to
+ // shape_cast
+ if (llvm::any_of(extractOp.getMixedPosition(),
+ [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
+ return failure();
+
+ if (sourceType.getNumElements() != outType.getNumElements())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
+ extractOp.getVector());
+ return success();
+ }
+};
+
} // namespace
void mlir::vector::populateFoldArithExtensionPatterns(
@@ -2285,6 +2371,13 @@ void mlir::vector::populateElementwiseToVectorOpsPatterns(
patterns.getContext());
}
+void mlir::vector::populateConvertToShapeCastPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns
+ .insert<TransposeToShapeCast, BroadcastToShapeCast, ExtractToShapeCast>(
+ patterns.getContext(), benefit);
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd enum attribute definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir b/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir
new file mode 100644
index 0000000000000..0ad6b3ff7d541
--- /dev/null
+++ b/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt %s -split-input-file -test-convert-to-shape-cast | FileCheck %s
+
+
+// CHECK-LABEL: @transpose_to_shape_cast
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
+// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+// CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32>
+func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
+ %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
+ return %0 : vector<2x2x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_transpose_to_shape_cast
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
+// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1]
+// CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32>
+func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
+ %0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
+ return %0 : vector<2x2x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_to_shape_cast
+// CHECK-SAME: %[[ARG0:.*]]: vector<4xi8>
+// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+// CHECK-NEXT: return %[[SCAST]] : vector<1x1x4xi8>
+func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
+ %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+ return %0 : vector<1x1x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_broadcast_to_shape_cast
+// CHECK-NOT: shape_cast
+// CHECK: return
+func.func @negative_broadcast_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> {
+ %0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8>
+ return %0 : vector<2x3x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @extract_to_shape_cast
+// CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32>
+// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+// CHECK-NEXT: return %[[SCAST]] : vector<4xf32>
+func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> {
+ %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+ return %0 : vector<4xf32>
+}
+
+// -----
+
+// In this example, arg1 might be negative indicating poison.
+// CHECK-LABEL: @negative_extract_to_shape_cast
+// CHECK-NOT: shape_cast
+func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> {
+ %0 = vector.extract %arg0[%arg1] : vector<4xf32> from vector<1x4xf32>
+ return %0 : vector<4xf32>
+}
+
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index b73c40adcffa7..aa97d6fc5dc69 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -1022,6 +1022,26 @@ struct TestEliminateVectorMasks
VscaleRange{vscaleMin, vscaleMax});
}
};
+
+struct TestConvertToShapeCast
+ : public PassWrapper<TestConvertToShapeCast, OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertToShapeCast)
+
+ TestConvertToShapeCast() = default;
+
+ StringRef getArgument() const final { return "test-convert-to-shape-cast"; }
+ StringRef getDescription() const final {
+ return "Test conversion to shape_cast of semantically equivalent ops";
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<vector::VectorDialect>();
+ }
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateConvertToShapeCastPatterns(patterns);
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+};
} // namespace
namespace mlir {
@@ -1072,6 +1092,8 @@ void registerTestVectorLowerings() {
PassRegistration<vendor::TestVectorBitWidthLinearize>();
PassRegistration<TestEliminateVectorMasks>();
+
+ PassRegistration<TestConvertToShapeCast>();
}
} // namespace test
} // namespace mlir
|
@llvm/pr-subscribers-mlir Author: James Newling (newling) ChangesThese are all semantically just copies, and can be rewritten as shape_casts: %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> Currently the vector dialect has no strict specification of which of 2 equivalent forms is canonical, the unwritten rule seems to be that if it is not 'obvious' that a transformation results in something more canonical, it shouldn't be on a op's canonicalization method. So it's probably not worthwhile discussing here if these conversions to shape_cast should be part op canonicalizers! Nonetheless I've found these particular patterns useful in my work, so maybe they're a good addition upstream? Full diff: https://github.com/llvm/llvm-project/pull/138777.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 98fb6075cbf32..be9839ce26339 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -50,6 +50,7 @@ namespace vector {
class ContractionOp;
class TransferReadOp;
class TransferWriteOp;
+class TransposeOp;
class VectorDialect;
namespace detail {
@@ -171,6 +172,12 @@ SmallVector<Value> getAsValues(OpBuilder &builder, Location loc,
/// `std::nullopt`.
std::optional<int64_t> getConstantVscaleMultiplier(Value value);
+/// Return true if `transpose` does not permute a pair of non-unit dims.
+/// By `order preserving` we mean that the flattened versions of the input and
+/// output vectors are (numerically) identical. In other words `transpose` is
+/// effectively a shape cast.
+bool isOrderPreserving(TransposeOp transpose);
+
//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f1100d5cf8b68..3344765f4818a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -406,6 +406,26 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
void populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
+/// Add patterns that convert operations that are semantically equivalent to
+/// shape_cast, to shape_cast. Currently this includes patterns for converting
+/// transpose, extract and broadcast to shape_cast. Examples that will be
+/// converted to shape_cast are:
+///
+/// ```
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// %1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8>
+/// %2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>
+/// ```
+///
+/// Note that there is no pattern for vector.extract_strided_slice, because the
+/// only extract_strided_slice that is semantically equivalent to shape_cast is
+/// one that has idential input and output shapes, which is already folded.
+///
+/// These patterns can be useful to expose more folding opportunities by
+/// creating pairs of shape_casts that cancel.
+void populateConvertToShapeCastPatterns(RewritePatternSet &,
+ PatternBenefit = 1);
+
/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
/// This registers (1) which operations are legal and hence should not be
/// linearized, (2) what converted types are (rank-1 vectors) and how to
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f9c7fb7799eb0..562fc7d6ca110 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5574,13 +5574,11 @@ LogicalResult ShapeCastOp::verify() {
return success();
}
-namespace {
-
/// Return true if `transpose` does not permute a pair of non-unit dims.
/// By `order preserving` we mean that the flattened versions of the input and
/// output vectors are (numerically) identical. In other words `transpose` is
/// effectively a shape cast.
-bool isOrderPreserving(TransposeOp transpose) {
+bool mlir::vector::isOrderPreserving(TransposeOp transpose) {
ArrayRef<int64_t> permutation = transpose.getPermutation();
VectorType sourceType = transpose.getSourceVectorType();
ArrayRef<int64_t> inShape = sourceType.getShape();
@@ -5600,8 +5598,6 @@ bool isOrderPreserving(TransposeOp transpose) {
return true;
}
-} // namespace
-
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
VectorType resultType = getType();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b94c5fce64f83..efcde8e97c0cd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -2182,6 +2182,92 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
}
};
+/// For example,
+/// ```
+/// %0 = vector.transpose %arg0, [0, 2, 1] :
+/// vector<2x1x2xf32> to vector<2x2x1xf32>
+/// ```
+/// becomes
+/// ```
+/// %0 = vector.shape_cast %arg0 :
+/// vector<2x1x2xf32> to vector<2x2x1xf32>
+/// ```
+struct TransposeToShapeCast final
+ : public OpRewritePattern<vector::TransposeOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+ PatternRewriter &rewriter) const override {
+ if (!isOrderPreserving(transpose)) {
+ return rewriter.notifyMatchFailure(
+ transpose, "not order preserving, so not semantically a 'copy'");
+ }
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+ transpose, transpose.getType(), transpose.getVector());
+ return success();
+ }
+};
+
+/// For example,
+/// ```
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// ```
+/// becomes
+/// ```
+/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// ```
+struct BroadcastToShapeCast final
+ : public OpRewritePattern<vector::BroadcastOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
+ PatternRewriter &rewriter) const override {
+ auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
+ if (!sourceType) {
+ return rewriter.notifyMatchFailure(
+ broadcast, "source is a scalar, shape_cast doesn't support scalar");
+ }
+
+ VectorType outType = broadcast.getType();
+ if (sourceType.getNumElements() != outType.getNumElements())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
+ broadcast.getSource());
+ return success();
+ }
+};
+
+/// For example,
+/// ```
+/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+/// ```
+/// becomes
+/// ```
+/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+/// ```
+struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ VectorType sourceType = extractOp.getSourceVectorType();
+ VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+ if (!outType)
+ return failure();
+
+ // Negative values in `position` indicates poison, cannot convert to
+ // shape_cast
+ if (llvm::any_of(extractOp.getMixedPosition(),
+ [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
+ return failure();
+
+ if (sourceType.getNumElements() != outType.getNumElements())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
+ extractOp.getVector());
+ return success();
+ }
+};
+
} // namespace
void mlir::vector::populateFoldArithExtensionPatterns(
@@ -2285,6 +2371,13 @@ void mlir::vector::populateElementwiseToVectorOpsPatterns(
patterns.getContext());
}
+void mlir::vector::populateConvertToShapeCastPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns
+ .insert<TransposeToShapeCast, BroadcastToShapeCast, ExtractToShapeCast>(
+ patterns.getContext(), benefit);
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd enum attribute definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir b/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir
new file mode 100644
index 0000000000000..0ad6b3ff7d541
--- /dev/null
+++ b/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt %s -split-input-file -test-convert-to-shape-cast | FileCheck %s
+
+
+// CHECK-LABEL: @transpose_to_shape_cast
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
+// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+// CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32>
+func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
+ %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
+ return %0 : vector<2x2x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_transpose_to_shape_cast
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
+// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1]
+// CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32>
+func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
+ %0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
+ return %0 : vector<2x2x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_to_shape_cast
+// CHECK-SAME: %[[ARG0:.*]]: vector<4xi8>
+// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+// CHECK-NEXT: return %[[SCAST]] : vector<1x1x4xi8>
+func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
+ %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+ return %0 : vector<1x1x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_broadcast_to_shape_cast
+// CHECK-NOT: shape_cast
+// CHECK: return
+func.func @negative_broadcast_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> {
+ %0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8>
+ return %0 : vector<2x3x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @extract_to_shape_cast
+// CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32>
+// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+// CHECK-NEXT: return %[[SCAST]] : vector<4xf32>
+func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> {
+ %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+ return %0 : vector<4xf32>
+}
+
+// -----
+
+// In this example, arg1 might be negative indicating poison.
+// CHECK-LABEL: @negative_extract_to_shape_cast
+// CHECK-NOT: shape_cast
+func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> {
+ %0 = vector.extract %arg0[%arg1] : vector<4xf32> from vector<1x4xf32>
+ return %0 : vector<4xf32>
+}
+
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index b73c40adcffa7..aa97d6fc5dc69 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -1022,6 +1022,26 @@ struct TestEliminateVectorMasks
VscaleRange{vscaleMin, vscaleMax});
}
};
+
+struct TestConvertToShapeCast
+ : public PassWrapper<TestConvertToShapeCast, OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertToShapeCast)
+
+ TestConvertToShapeCast() = default;
+
+ StringRef getArgument() const final { return "test-convert-to-shape-cast"; }
+ StringRef getDescription() const final {
+ return "Test conversion to shape_cast of semantically equivalent ops";
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<vector::VectorDialect>();
+ }
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateConvertToShapeCastPatterns(patterns);
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+};
} // namespace
namespace mlir {
@@ -1072,6 +1092,8 @@ void registerTestVectorLowerings() {
PassRegistration<vendor::TestVectorBitWidthLinearize>();
PassRegistration<TestEliminateVectorMasks>();
+
+ PassRegistration<TestConvertToShapeCast>();
}
} // namespace test
} // namespace mlir
|
/// %0 = vector.shape_cast %arg0 : | ||
/// vector<2x1x2xf32> to vector<2x2x1xf32> | ||
/// ``` | ||
struct TransposeToShapeCast final |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might want to clean up other patterns because we have a similar pattern but just for 2D vectors. Please coordinate with @banach-space about scalable vector for the support.
llvm-project/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
Lines 385 to 441 in 8810595
/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied | |
/// to 2D vectors with at least one unit dim. For example: | |
/// | |
/// Replace: | |
/// vector.transpose %0, [1, 0] : vector<4x1xi32>> to | |
/// vector<1x4xi32> | |
/// with: | |
/// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32> | |
/// | |
/// Source with leading unit dim (inverse) is also replaced. Unit dim must | |
/// be fixed. Non-unit dim can be scalable. | |
/// | |
/// TODO: This pattern was introduced specifically to help lower scalable | |
/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's | |
/// to cancel out) would be preferable: | |
/// | |
/// BEFORE: | |
/// %0 = some_op | |
/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32> | |
/// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> | |
/// AFTER: | |
/// %0 = some_op | |
/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32> | |
/// | |
/// Given the context above, we may want to consider (re-)moving this pattern | |
/// at some later time. I am leaving it for now in case there are other users | |
/// that I am not aware of. | |
class Transpose2DWithUnitDimToShapeCast | |
: public OpRewritePattern<vector::TransposeOp> { | |
public: | |
using OpRewritePattern::OpRewritePattern; | |
Transpose2DWithUnitDimToShapeCast(MLIRContext *context, | |
PatternBenefit benefit = 1) | |
: OpRewritePattern<vector::TransposeOp>(context, benefit) {} | |
LogicalResult matchAndRewrite(vector::TransposeOp op, | |
PatternRewriter &rewriter) const override { | |
Value input = op.getVector(); | |
VectorType resType = op.getResultVectorType(); | |
// Set up convenience transposition table. | |
ArrayRef<int64_t> transp = op.getPermutation(); | |
if (resType.getRank() == 2 && | |
((resType.getShape().front() == 1 && | |
!resType.getScalableDims().front()) || | |
(resType.getShape().back() == 1 && | |
!resType.getScalableDims().back())) && | |
transp == ArrayRef<int64_t>({1, 0})) { | |
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input); | |
return success(); | |
} | |
return failure(); | |
} | |
}; |
Hey, thanks for bringing this up!
I understand by "copies" you mean "data movements"? These ops are mostly changing the "view" without any actual data movement.
I think it's important to have the canonicalization discussions, even if they are difficult. Otherwise, dowstream projects adopting these transformations broadly and contributing to other vector passes upstream would create an "unseen" dependency with these transformations, which would impact other users that are not using or are not aware (most likely the latter) of them. I think there's a general consensus that |
I found 06dbb28 which removes patterns exactly like these, on the caoincalizers. They were removed because there was a plan to restrict shape_cast to be for collapsing (expanding) to (from) rank-1 only |
Instead of using shape_cast op in the pattern removing leading unit dimensions we use extract/broadcast ops. This is part of the effort to restrict ShapeCastOp fuirther in the future and only allow them to convert to or from 1D vector. This also adds extra canonicalization to fill the gaps in simplifying broadcast/extract ops. Differential Revision: https://reviews.llvm.org/D114205
I think some people might think that having %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> to %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8> is controversial and maybe harmful in some pipelines. Just formalizing this a bit, let
There's no obvious asymmetry here, because neither one is a subset of the other. So the only difference is our intuition that the lowering of %0 = vector.broadcast %arg0 : vector<2x3xf32> to vector<1x2x3xf32> is %0 = ub.poison : vector<1x2x3xf32>
%1 = vector.insert %arg0, %0 [0] : vector<2x3xf32> into vector<1x2x3xf32> while %0 = vector.shape_cast %arg0 : vector<2x3xf32> to vector<1x2x3xf32> gets lowered to %0 = ub.poison : vector<1x2x3xf32>
%1 = vector.extract %arg0[0, 0] : f32 from vector<2x3xf32>
%2 = vector.insert %1, %0 [0, 0, 0] : f32 into vector<1x2x3xf32>
%3 = vector.extract %arg0[0, 1] : f32 from vector<2x3xf32>
%4 = vector.insert %3, %2 [0, 0, 1] : f32 into vector<1x2x3xf32>
%5 = vector.extract %arg0[0, 2] : f32 from vector<2x3xf32>
%6 = vector.insert %5, %4 [0, 0, 2] : f32 into vector<1x2x3xf32>
%7 = vector.extract %arg0[1, 0] : f32 from vector<2x3xf32>
%8 = vector.insert %7, %6 [0, 1, 0] : f32 into vector<1x2x3xf32>
%9 = vector.extract %arg0[1, 1] : f32 from vector<2x3xf32>
%10 = vector.insert %9, %8 [0, 1, 1] : f32 into vector<1x2x3xf32>
%11 = vector.extract %arg0[1, 2] : f32 from vector<2x3xf32>
%12 = vector.insert %11, %10 [0, 1, 2] : f32 into vector<1x2x3xf32> While both will hopefully have vanished in the final assembly (just reuse whatever register the data was in), I can sympathize with a person who prefers the lowering of broadcast, and doesn't want their broadcast to be changed to a shape_cast! |
Thanks for finding and sharing that thread :) I wasn't aware of those earlier design discussions about restricting That’s a fairly old discussion (given MLIR’s age), and likely dates back to when there was an assumption we'd be using
However, these matrix intrinsics are not widely used, and - as far as I know - the corresponding LLVM work has been on hold for quite some time (a few years?). From what I can tell, those requirements are now outdated. In the last 2–3 years that I’ve been working in this space, the consistent takeaway has been:
In general, the assumption is that Regarding the example you mentioned: %0 = vector.shape_cast %arg0 : vector<2x3xf32> to vector<1x2x3xf32> ...and its lowering to: %0 = ub.poison : vector<1x2x3xf32>
%1 = vector.extract %arg0[0, 0] : f32 from vector<2x3xf32>
%2 = vector.insert %1, %0 [0, 0, 0] : f32 into vector<1x2x3xf32>
%3 = vector.extract %arg0[0, 1] : f32 from vector<2x3xf32>
%4 = vector.insert %3, %2 [0, 0, 1] : f32 into vector<1x2x3xf32>
%5 = vector.extract %arg0[0, 2] : f32 from vector<2x3xf32>
%6 = vector.insert %5, %4 [0, 0, 2] : f32 into vector<1x2x3xf32>
%7 = vector.extract %arg0[1, 0] : f32 from vector<2x3xf32>
%8 = vector.insert %7, %6 [0, 1, 0] : f32 into vector<1x2x3xf32>
%9 = vector.extract %arg0[1, 1] : f32 from vector<2x3xf32>
%10 = vector.insert %9, %8 [0, 1, 1] : f32 into vector<1x2x3xf32>
%11 = vector.extract %arg0[1, 2] : f32 from vector<2x3xf32>
%12 = vector.insert %11, %10 [0, 1, 2] : f32 into vector<1x2x3xf32> This lowering clearly involves data movement, and to me violates the design intent of So in principle: yes, I’m in favor of turning these patterns into canonicalizations. We should just be thoughtful about which constraints we want to enforce. |
Ok i might give that a try, starting with transpose->shape_cast on the transpose canonicalizer. I'm not entirely sure what a formal definition of 'no data movement' in SSA might be. The shape_cast still has to lower to something in LLVM, it's only during register allocation that it will vanish. I guess ideally all shape_casts will have been canonicalized away, especially when using linearization/flattening. I mentioned this previously and it's probably over-the-top at the moment but at some point it'd be nice to add something to the docs describing the canonicalization rules, like
|
Thanks for sharing your thoughts! Replying to this and this threads here. I think part of the challenge with canonicalization stems from the expectations we’ve built around it. We often assume the IR after canonicalization will be ideal for all patterns, analyses and, ultimately, lead to the most efficient lowering/ASM/runtime performance and whatnot. But realistically, canonicalization should only focus on eliminating redundancy from the IR and offering a unique and consistent representation to make passes consuming such a canonical form simpler, regardless of everything else. For me, a pass running on canonical IR that must deal with both:
and
indicates that we have a gap in the canonical form because both operations are representing the same thing. IMO, using |
These are all semantically just copies, and can be rewritten as shape_casts:
%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
%1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8>
%2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>
Currently the vector dialect has no strict specification of which of 2 equivalent forms is more canonical, the unwritten rule seems to be that if it is not 'obvious' that a transformation results in something more canonical, it shouldn't be on a op's canonicalization method. So it's probably not worthwhile discussing here if these conversions to shape_cast should be part op canonicalizers! Nonetheless I've found these particular patterns useful in my work, so maybe they're a good addition upstream?