Skip to content

[mlir][vector] Patterns to convert to shape_cast, where possible #138777

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

newling
Copy link
Contributor

@newling newling commented May 6, 2025

These are all semantically just copies, and can be rewritten as shape_casts:

%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
%1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8>
%2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>

Currently the vector dialect has no strict specification of which of 2 equivalent forms is more canonical, the unwritten rule seems to be that if it is not 'obvious' that a transformation results in something more canonical, it shouldn't be on a op's canonicalization method. So it's probably not worthwhile discussing here if these conversions to shape_cast should be part op canonicalizers! Nonetheless I've found these particular patterns useful in my work, so maybe they're a good addition upstream?

Copy link

github-actions bot commented May 6, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@llvmbot
Copy link
Member

llvmbot commented May 7, 2025

@llvm/pr-subscribers-mlir-vector

Author: James Newling (newling)

Changes

These are all semantically just copies, and can be rewritten as shape_casts:

%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
%1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8>
%2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>

Currently the vector dialect has no strict specification of which of 2 equivalent forms is canonical, the unwritten rule seems to be that if it is not 'obvious' that a transformation results in something more canonical, it shouldn't be on a op's canonicalization method. So it's probably not worthwhile discussing here if these conversions to shape_cast should be part op canonicalizers! Nonetheless I've found these particular patterns useful in my work, so maybe they're a good addition upstream?


Full diff: https://github.com/llvm/llvm-project/pull/138777.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.h (+7)
  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+20)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+1-5)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+93)
  • (added) mlir/test/Dialect/Vector/convert-to-shape-cast.mlir (+65)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+22)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 98fb6075cbf32..be9839ce26339 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -50,6 +50,7 @@ namespace vector {
 class ContractionOp;
 class TransferReadOp;
 class TransferWriteOp;
+class TransposeOp;
 class VectorDialect;
 
 namespace detail {
@@ -171,6 +172,12 @@ SmallVector<Value> getAsValues(OpBuilder &builder, Location loc,
 /// `std::nullopt`.
 std::optional<int64_t> getConstantVscaleMultiplier(Value value);
 
+/// Return true if `transpose` does not permute a pair of non-unit dims.
+/// By `order preserving` we mean that the flattened versions of the input and
+/// output vectors are (numerically) identical. In other words `transpose` is
+/// effectively a shape cast.
+bool isOrderPreserving(TransposeOp transpose);
+
 //===----------------------------------------------------------------------===//
 // Vector Masking Utilities
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f1100d5cf8b68..3344765f4818a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -406,6 +406,26 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
 void populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
+/// Add patterns that convert operations that are semantically equivalent to
+/// shape_cast, to shape_cast. Currently this includes patterns for converting
+/// transpose, extract and broadcast to shape_cast. Examples that will be
+/// converted to shape_cast are:
+///
+/// ```
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// %1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8>
+/// %2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>
+/// ```
+///
+/// Note that there is no pattern for vector.extract_strided_slice, because the
+/// only extract_strided_slice that is semantically equivalent to shape_cast is
+/// one that has idential input and output shapes, which is already folded.
+///
+/// These patterns can be useful to expose more folding opportunities by
+/// creating pairs of shape_casts that cancel.
+void populateConvertToShapeCastPatterns(RewritePatternSet &,
+                                        PatternBenefit = 1);
+
 /// Initialize `typeConverter` and `conversionTarget` for vector linearization.
 /// This registers (1) which operations are legal and hence should not be
 /// linearized, (2) what converted types are (rank-1 vectors) and how to
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f9c7fb7799eb0..562fc7d6ca110 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5574,13 +5574,11 @@ LogicalResult ShapeCastOp::verify() {
   return success();
 }
 
-namespace {
-
 /// Return true if `transpose` does not permute a pair of non-unit dims.
 /// By `order preserving` we mean that the flattened versions of the input and
 /// output vectors are (numerically) identical. In other words `transpose` is
 /// effectively a shape cast.
-bool isOrderPreserving(TransposeOp transpose) {
+bool mlir::vector::isOrderPreserving(TransposeOp transpose) {
   ArrayRef<int64_t> permutation = transpose.getPermutation();
   VectorType sourceType = transpose.getSourceVectorType();
   ArrayRef<int64_t> inShape = sourceType.getShape();
@@ -5600,8 +5598,6 @@ bool isOrderPreserving(TransposeOp transpose) {
   return true;
 }
 
-} // namespace
-
 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   VectorType resultType = getType();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b94c5fce64f83..efcde8e97c0cd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -2182,6 +2182,92 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
   }
 };
 
+/// For example,
+/// ```
+/// %0 = vector.transpose %arg0, [0, 2, 1] :
+///                   vector<2x1x2xf32> to vector<2x2x1xf32>
+/// ```
+/// becomes
+/// ```
+/// %0 = vector.shape_cast %arg0 :
+///                   vector<2x1x2xf32> to vector<2x2x1xf32>
+/// ```
+struct TransposeToShapeCast final
+    : public OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+                                PatternRewriter &rewriter) const override {
+    if (!isOrderPreserving(transpose)) {
+      return rewriter.notifyMatchFailure(
+          transpose, "not order preserving, so not semantically a 'copy'");
+    }
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+        transpose, transpose.getType(), transpose.getVector());
+    return success();
+  }
+};
+
+/// For example,
+/// ```
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// ```
+/// becomes
+/// ```
+/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// ```
+struct BroadcastToShapeCast final
+    : public OpRewritePattern<vector::BroadcastOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
+                                PatternRewriter &rewriter) const override {
+    auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
+    if (!sourceType) {
+      return rewriter.notifyMatchFailure(
+          broadcast, "source is a scalar, shape_cast doesn't support scalar");
+    }
+
+    VectorType outType = broadcast.getType();
+    if (sourceType.getNumElements() != outType.getNumElements())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
+                                                     broadcast.getSource());
+    return success();
+  }
+};
+
+/// For example,
+/// ```
+/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+/// ```
+/// becomes
+/// ```
+/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+/// ```
+struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType sourceType = extractOp.getSourceVectorType();
+    VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+    if (!outType)
+      return failure();
+
+    // Negative values in `position` indicates poison, cannot convert to
+    // shape_cast
+    if (llvm::any_of(extractOp.getMixedPosition(),
+                     [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
+      return failure();
+
+    if (sourceType.getNumElements() != outType.getNumElements())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
+                                                     extractOp.getVector());
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateFoldArithExtensionPatterns(
@@ -2285,6 +2371,13 @@ void mlir::vector::populateElementwiseToVectorOpsPatterns(
       patterns.getContext());
 }
 
+void mlir::vector::populateConvertToShapeCastPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns
+      .insert<TransposeToShapeCast, BroadcastToShapeCast, ExtractToShapeCast>(
+          patterns.getContext(), benefit);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd enum attribute definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir b/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir
new file mode 100644
index 0000000000000..0ad6b3ff7d541
--- /dev/null
+++ b/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt %s -split-input-file -test-convert-to-shape-cast |  FileCheck %s
+
+
+// CHECK-LABEL: @transpose_to_shape_cast
+//  CHECK-SAME:  %[[ARG0:.*]]: vector<2x1x2xf32>
+//  CHECK-NEXT:  %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+//  CHECK-NEXT:  return %[[SCAST]] : vector<2x2x1xf32>
+func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
+  %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
+  return %0 : vector<2x2x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_transpose_to_shape_cast
+//  CHECK-SAME:  %[[ARG0:.*]]: vector<2x1x2xf32>
+//  CHECK-NEXT:  %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1]
+//  CHECK-NEXT:  return %[[TRANSPOSE]] : vector<2x2x1xf32>
+func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
+  %0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
+  return %0 : vector<2x2x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_to_shape_cast
+//  CHECK-SAME:  %[[ARG0:.*]]: vector<4xi8>
+//  CHECK-NEXT:  %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+//  CHECK-NEXT:  return %[[SCAST]] : vector<1x1x4xi8>
+func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+  return %0 : vector<1x1x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_broadcast_to_shape_cast
+//   CHECK-NOT: shape_cast
+//       CHECK: return
+func.func @negative_broadcast_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8>
+  return %0 : vector<2x3x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @extract_to_shape_cast
+//  CHECK-SAME:  %[[ARG0:.*]]: vector<1x4xf32>
+//  CHECK-NEXT:  %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+//  CHECK-NEXT:  return %[[SCAST]] : vector<4xf32>
+func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> {
+  %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+  return %0 : vector<4xf32>
+}
+
+// -----
+
+// In this example, arg1 might be negative indicating poison.
+// CHECK-LABEL: @negative_extract_to_shape_cast
+//   CHECK-NOT: shape_cast
+func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> {
+  %0 = vector.extract %arg0[%arg1] : vector<4xf32> from vector<1x4xf32>
+  return %0 : vector<4xf32>
+}
+
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index b73c40adcffa7..aa97d6fc5dc69 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -1022,6 +1022,26 @@ struct TestEliminateVectorMasks
                          VscaleRange{vscaleMin, vscaleMax});
   }
 };
+
+struct TestConvertToShapeCast
+    : public PassWrapper<TestConvertToShapeCast, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertToShapeCast)
+
+  TestConvertToShapeCast() = default;
+
+  StringRef getArgument() const final { return "test-convert-to-shape-cast"; }
+  StringRef getDescription() const final {
+    return "Test conversion to shape_cast of semantically equivalent ops";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<vector::VectorDialect>();
+  }
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateConvertToShapeCastPatterns(patterns);
+    (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+  }
+};
 } // namespace
 
 namespace mlir {
@@ -1072,6 +1092,8 @@ void registerTestVectorLowerings() {
   PassRegistration<vendor::TestVectorBitWidthLinearize>();
 
   PassRegistration<TestEliminateVectorMasks>();
+
+  PassRegistration<TestConvertToShapeCast>();
 }
 } // namespace test
 } // namespace mlir

@llvmbot
Copy link
Member

llvmbot commented May 7, 2025

@llvm/pr-subscribers-mlir

Author: James Newling (newling)

Changes

These are all semantically just copies, and can be rewritten as shape_casts:

%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
%1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8>
%2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>

Currently the vector dialect has no strict specification of which of 2 equivalent forms is canonical, the unwritten rule seems to be that if it is not 'obvious' that a transformation results in something more canonical, it shouldn't be on a op's canonicalization method. So it's probably not worthwhile discussing here if these conversions to shape_cast should be part op canonicalizers! Nonetheless I've found these particular patterns useful in my work, so maybe they're a good addition upstream?


Full diff: https://github.com/llvm/llvm-project/pull/138777.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.h (+7)
  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+20)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+1-5)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+93)
  • (added) mlir/test/Dialect/Vector/convert-to-shape-cast.mlir (+65)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+22)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 98fb6075cbf32..be9839ce26339 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -50,6 +50,7 @@ namespace vector {
 class ContractionOp;
 class TransferReadOp;
 class TransferWriteOp;
+class TransposeOp;
 class VectorDialect;
 
 namespace detail {
@@ -171,6 +172,12 @@ SmallVector<Value> getAsValues(OpBuilder &builder, Location loc,
 /// `std::nullopt`.
 std::optional<int64_t> getConstantVscaleMultiplier(Value value);
 
+/// Return true if `transpose` does not permute a pair of non-unit dims.
+/// By `order preserving` we mean that the flattened versions of the input and
+/// output vectors are (numerically) identical. In other words `transpose` is
+/// effectively a shape cast.
+bool isOrderPreserving(TransposeOp transpose);
+
 //===----------------------------------------------------------------------===//
 // Vector Masking Utilities
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f1100d5cf8b68..3344765f4818a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -406,6 +406,26 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
 void populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
+/// Add patterns that convert operations that are semantically equivalent to
+/// shape_cast, to shape_cast. Currently this includes patterns for converting
+/// transpose, extract and broadcast to shape_cast. Examples that will be
+/// converted to shape_cast are:
+///
+/// ```
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// %1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8>
+/// %2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>
+/// ```
+///
+/// Note that there is no pattern for vector.extract_strided_slice, because the
+/// only extract_strided_slice that is semantically equivalent to shape_cast is
+/// one that has idential input and output shapes, which is already folded.
+///
+/// These patterns can be useful to expose more folding opportunities by
+/// creating pairs of shape_casts that cancel.
+void populateConvertToShapeCastPatterns(RewritePatternSet &,
+                                        PatternBenefit = 1);
+
 /// Initialize `typeConverter` and `conversionTarget` for vector linearization.
 /// This registers (1) which operations are legal and hence should not be
 /// linearized, (2) what converted types are (rank-1 vectors) and how to
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f9c7fb7799eb0..562fc7d6ca110 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5574,13 +5574,11 @@ LogicalResult ShapeCastOp::verify() {
   return success();
 }
 
-namespace {
-
 /// Return true if `transpose` does not permute a pair of non-unit dims.
 /// By `order preserving` we mean that the flattened versions of the input and
 /// output vectors are (numerically) identical. In other words `transpose` is
 /// effectively a shape cast.
-bool isOrderPreserving(TransposeOp transpose) {
+bool mlir::vector::isOrderPreserving(TransposeOp transpose) {
   ArrayRef<int64_t> permutation = transpose.getPermutation();
   VectorType sourceType = transpose.getSourceVectorType();
   ArrayRef<int64_t> inShape = sourceType.getShape();
@@ -5600,8 +5598,6 @@ bool isOrderPreserving(TransposeOp transpose) {
   return true;
 }
 
-} // namespace
-
 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   VectorType resultType = getType();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b94c5fce64f83..efcde8e97c0cd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -2182,6 +2182,92 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
   }
 };
 
+/// For example,
+/// ```
+/// %0 = vector.transpose %arg0, [0, 2, 1] :
+///                   vector<2x1x2xf32> to vector<2x2x1xf32>
+/// ```
+/// becomes
+/// ```
+/// %0 = vector.shape_cast %arg0 :
+///                   vector<2x1x2xf32> to vector<2x2x1xf32>
+/// ```
+struct TransposeToShapeCast final
+    : public OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+                                PatternRewriter &rewriter) const override {
+    if (!isOrderPreserving(transpose)) {
+      return rewriter.notifyMatchFailure(
+          transpose, "not order preserving, so not semantically a 'copy'");
+    }
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+        transpose, transpose.getType(), transpose.getVector());
+    return success();
+  }
+};
+
+/// For example,
+/// ```
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// ```
+/// becomes
+/// ```
+/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// ```
+struct BroadcastToShapeCast final
+    : public OpRewritePattern<vector::BroadcastOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
+                                PatternRewriter &rewriter) const override {
+    auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
+    if (!sourceType) {
+      return rewriter.notifyMatchFailure(
+          broadcast, "source is a scalar, shape_cast doesn't support scalar");
+    }
+
+    VectorType outType = broadcast.getType();
+    if (sourceType.getNumElements() != outType.getNumElements())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
+                                                     broadcast.getSource());
+    return success();
+  }
+};
+
+/// For example,
+/// ```
+/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+/// ```
+/// becomes
+/// ```
+/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+/// ```
+struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType sourceType = extractOp.getSourceVectorType();
+    VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+    if (!outType)
+      return failure();
+
+    // Negative values in `position` indicates poison, cannot convert to
+    // shape_cast
+    if (llvm::any_of(extractOp.getMixedPosition(),
+                     [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
+      return failure();
+
+    if (sourceType.getNumElements() != outType.getNumElements())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
+                                                     extractOp.getVector());
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateFoldArithExtensionPatterns(
@@ -2285,6 +2371,13 @@ void mlir::vector::populateElementwiseToVectorOpsPatterns(
       patterns.getContext());
 }
 
+void mlir::vector::populateConvertToShapeCastPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns
+      .insert<TransposeToShapeCast, BroadcastToShapeCast, ExtractToShapeCast>(
+          patterns.getContext(), benefit);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd enum attribute definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir b/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir
new file mode 100644
index 0000000000000..0ad6b3ff7d541
--- /dev/null
+++ b/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt %s -split-input-file -test-convert-to-shape-cast |  FileCheck %s
+
+
+// CHECK-LABEL: @transpose_to_shape_cast
+//  CHECK-SAME:  %[[ARG0:.*]]: vector<2x1x2xf32>
+//  CHECK-NEXT:  %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+//  CHECK-NEXT:  return %[[SCAST]] : vector<2x2x1xf32>
+func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
+  %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
+  return %0 : vector<2x2x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_transpose_to_shape_cast
+//  CHECK-SAME:  %[[ARG0:.*]]: vector<2x1x2xf32>
+//  CHECK-NEXT:  %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1]
+//  CHECK-NEXT:  return %[[TRANSPOSE]] : vector<2x2x1xf32>
+func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
+  %0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
+  return %0 : vector<2x2x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_to_shape_cast
+//  CHECK-SAME:  %[[ARG0:.*]]: vector<4xi8>
+//  CHECK-NEXT:  %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+//  CHECK-NEXT:  return %[[SCAST]] : vector<1x1x4xi8>
+func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+  return %0 : vector<1x1x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_broadcast_to_shape_cast
+//   CHECK-NOT: shape_cast
+//       CHECK: return
+func.func @negative_broadcast_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8>
+  return %0 : vector<2x3x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @extract_to_shape_cast
+//  CHECK-SAME:  %[[ARG0:.*]]: vector<1x4xf32>
+//  CHECK-NEXT:  %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+//  CHECK-NEXT:  return %[[SCAST]] : vector<4xf32>
+func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> {
+  %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+  return %0 : vector<4xf32>
+}
+
+// -----
+
+// In this example, arg1 might be negative indicating poison.
+// CHECK-LABEL: @negative_extract_to_shape_cast
+//   CHECK-NOT: shape_cast
+func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> {
+  %0 = vector.extract %arg0[%arg1] : vector<4xf32> from vector<1x4xf32>
+  return %0 : vector<4xf32>
+}
+
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index b73c40adcffa7..aa97d6fc5dc69 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -1022,6 +1022,26 @@ struct TestEliminateVectorMasks
                          VscaleRange{vscaleMin, vscaleMax});
   }
 };
+
+struct TestConvertToShapeCast
+    : public PassWrapper<TestConvertToShapeCast, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertToShapeCast)
+
+  TestConvertToShapeCast() = default;
+
+  StringRef getArgument() const final { return "test-convert-to-shape-cast"; }
+  StringRef getDescription() const final {
+    return "Test conversion to shape_cast of semantically equivalent ops";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<vector::VectorDialect>();
+  }
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateConvertToShapeCastPatterns(patterns);
+    (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+  }
+};
 } // namespace
 
 namespace mlir {
@@ -1072,6 +1092,8 @@ void registerTestVectorLowerings() {
   PassRegistration<vendor::TestVectorBitWidthLinearize>();
 
   PassRegistration<TestEliminateVectorMasks>();
+
+  PassRegistration<TestConvertToShapeCast>();
 }
 } // namespace test
 } // namespace mlir

/// %0 = vector.shape_cast %arg0 :
/// vector<2x1x2xf32> to vector<2x2x1xf32>
/// ```
struct TransposeToShapeCast final
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to clean up other patterns because we have a similar pattern but just for 2D vectors. Please coordinate with @banach-space about scalable vector for the support.

/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
/// to 2D vectors with at least one unit dim. For example:
///
/// Replace:
/// vector.transpose %0, [1, 0] : vector<4x1xi32>> to
/// vector<1x4xi32>
/// with:
/// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
///
/// Source with leading unit dim (inverse) is also replaced. Unit dim must
/// be fixed. Non-unit dim can be scalable.
///
/// TODO: This pattern was introduced specifically to help lower scalable
/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
/// to cancel out) would be preferable:
///
/// BEFORE:
/// %0 = some_op
/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
/// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
/// AFTER:
/// %0 = some_op
/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
///
/// Given the context above, we may want to consider (re-)moving this pattern
/// at some later time. I am leaving it for now in case there are other users
/// that I am not aware of.
class Transpose2DWithUnitDimToShapeCast
: public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;
Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}
LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
Value input = op.getVector();
VectorType resType = op.getResultVectorType();
// Set up convenience transposition table.
ArrayRef<int64_t> transp = op.getPermutation();
if (resType.getRank() == 2 &&
((resType.getShape().front() == 1 &&
!resType.getScalableDims().front()) ||
(resType.getShape().back() == 1 &&
!resType.getScalableDims().back())) &&
transp == ArrayRef<int64_t>({1, 0})) {
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
return success();
}
return failure();
}
};

@dcaballe
Copy link
Contributor

dcaballe commented May 8, 2025

Hey, thanks for bringing this up!

These are all semantically just copies, and can be rewritten as shape_casts:

I understand by "copies" you mean "data movements"? These ops are mostly changing the "view" without any actual data movement.

Currently the vector dialect has no strict specification of which of 2 equivalent forms is more canonical, the unwritten rule seems to be that if it is not 'obvious' that a transformation results in something more canonical, it shouldn't be on a op's canonicalization method. So it's probably not worthwhile discussing here if these conversions to shape_cast should be part op canonicalizers! Nonetheless I've found these particular patterns useful in my work, so maybe they're a good addition upstream?

I think it's important to have the canonicalization discussions, even if they are difficult. Otherwise, dowstream projects adopting these transformations broadly and contributing to other vector passes upstream would create an "unseen" dependency with these transformations, which would impact other users that are not using or are not aware (most likely the latter) of them.

I think there's a general consensus that vector.shape_cast is the canonical form for reshapes that imply no data movement, with some exceptions. Do people thing that canonicalizing vector.broadcast and vector.extract cases are controversial? Perhaps that's a good first step :)

@newling
Copy link
Contributor Author

newling commented May 9, 2025

I found 06dbb28 which removes patterns exactly like these, on the caoincalizers. They were removed because there was a plan to restrict shape_cast to be for collapsing (expanding) to (from) rank-1 only

newling referenced this pull request May 9, 2025
Instead of using shape_cast op in the pattern removing leading unit
dimensions we use extract/broadcast ops. This is part of the effort to
restrict ShapeCastOp fuirther in the future and only allow them to
convert to or from 1D vector.

This also adds extra canonicalization to fill the gaps in simplifying
broadcast/extract ops.

Differential Revision: https://reviews.llvm.org/D114205
@newling
Copy link
Contributor Author

newling commented May 12, 2025

Do people thing that canonicalizing vector.broadcast and vector.extract cases are controversial?

I think some people might think that having mlir-opt --canonicalize change

%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>

to

%0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>

is controversial and maybe harmful in some pipelines. Just formalizing this a bit, let + be set union and \ be set difference, so

broadcast + shape_cast contains vector<4xi8> to <1x4xi8>
broadcast \ shape_cast contains vector<4xi8> to <2x4xi8>
shape_cast \ broadcast contains vector<4xi8> to <2x2xi8>

There's no obvious asymmetry here, because neither one is a subset of the other. So the only difference is our intuition that shape_cast is simpler because it should always be a no-op when lowered. However, there might be situations where the user prefers it to remain as broadcast rather than become shape_cast an example is lowering (Transforms/LowerVectorShapeCast.cpp and Transforms/LowerVectorBroadcast.cpp). Consider,

the lowering of

%0 = vector.broadcast %arg0 : vector<2x3xf32> to vector<1x2x3xf32>

is

%0 = ub.poison : vector<1x2x3xf32>
%1 = vector.insert %arg0, %0 [0] : vector<2x3xf32> into vector<1x2x3xf32>

while

%0 = vector.shape_cast %arg0 : vector<2x3xf32> to vector<1x2x3xf32>

gets lowered to

   %0 = ub.poison : vector<1x2x3xf32>
    %1 = vector.extract %arg0[0, 0] : f32 from vector<2x3xf32>
    %2 = vector.insert %1, %0 [0, 0, 0] : f32 into vector<1x2x3xf32>
    %3 = vector.extract %arg0[0, 1] : f32 from vector<2x3xf32>
    %4 = vector.insert %3, %2 [0, 0, 1] : f32 into vector<1x2x3xf32>
    %5 = vector.extract %arg0[0, 2] : f32 from vector<2x3xf32>
    %6 = vector.insert %5, %4 [0, 0, 2] : f32 into vector<1x2x3xf32>
    %7 = vector.extract %arg0[1, 0] : f32 from vector<2x3xf32>
    %8 = vector.insert %7, %6 [0, 1, 0] : f32 into vector<1x2x3xf32>
    %9 = vector.extract %arg0[1, 1] : f32 from vector<2x3xf32>
    %10 = vector.insert %9, %8 [0, 1, 1] : f32 into vector<1x2x3xf32>
    %11 = vector.extract %arg0[1, 2] : f32 from vector<2x3xf32>
    %12 = vector.insert %11, %10 [0, 1, 2] : f32 into vector<1x2x3xf32>

While both will hopefully have vanished in the final assembly (just reuse whatever register the data was in), I can sympathize with a person who prefers the lowering of broadcast, and doesn't want their broadcast to be changed to a shape_cast!

@banach-space
Copy link
Contributor

I found 06dbb28 which removes patterns exactly like these, on the caoincalizers.

Thanks for finding and sharing that thread :) I wasn't aware of those earlier design discussions about restricting vector.shape_cast.

That’s a fairly old discussion (given MLIR’s age), and likely dates back to when there was an assumption we'd be using llvm.intr.matrix intrinsics. From the vector.shape_cast documentation:

There is an exception to the folding expectation when targeting llvm.intr.matrix operations. We need a type conversion back and forth from a 2-D MLIR vector to a 1-D flattened LLVM vector.shape_cast lowering to LLVM is supported in that particular case, for now.

However, these matrix intrinsics are not widely used, and - as far as I know - the corresponding LLVM work has been on hold for quite some time (a few years?). From what I can tell, those requirements are now outdated.

In the last 2–3 years that I’ve been working in this space, the consistent takeaway has been:

vector.shape_cast is the canonical form for reshapes that imply no data movement.

In general, the assumption is that vector.shape_cast encodes a reshape with no underlying value movement.

Regarding the example you mentioned:

%0 = vector.shape_cast %arg0 : vector<2x3xf32> to vector<1x2x3xf32>

...and its lowering to:

 %0 = ub.poison : vector<1x2x3xf32>
  %1 = vector.extract %arg0[0, 0] : f32 from vector<2x3xf32>
  %2 = vector.insert %1, %0 [0, 0, 0] : f32 into vector<1x2x3xf32>
  %3 = vector.extract %arg0[0, 1] : f32 from vector<2x3xf32>
  %4 = vector.insert %3, %2 [0, 0, 1] : f32 into vector<1x2x3xf32>
  %5 = vector.extract %arg0[0, 2] : f32 from vector<2x3xf32>
  %6 = vector.insert %5, %4 [0, 0, 2] : f32 into vector<1x2x3xf32>
  %7 = vector.extract %arg0[1, 0] : f32 from vector<2x3xf32>
  %8 = vector.insert %7, %6 [0, 1, 0] : f32 into vector<1x2x3xf32>
  %9 = vector.extract %arg0[1, 1] : f32 from vector<2x3xf32>
  %10 = vector.insert %9, %8 [0, 1, 1] : f32 into vector<1x2x3xf32>
  %11 = vector.extract %arg0[1, 2] : f32 from vector<2x3xf32>
  %12 = vector.insert %11, %10 [0, 1, 2] : f32 into vector<1x2x3xf32>

This lowering clearly involves data movement, and to me violates the design intent of vector.shape_cast. I wouldn’t use it as a guiding example - in fact, this may be a good time to revisit that lowering and consider disabling it or making it conditional.

So in principle: yes, I’m in favor of turning these patterns into canonicalizations. We should just be thoughtful about which constraints we want to enforce.

@newling
Copy link
Contributor Author

newling commented May 13, 2025

Ok i might give that a try, starting with transpose->shape_cast on the transpose canonicalizer. I'm not entirely sure what a formal definition of 'no data movement' in SSA might be. The shape_cast still has to lower to something in LLVM, it's only during register allocation that it will vanish. I guess ideally all shape_casts will have been canonicalized away, especially when using linearization/flattening.

I mentioned this previously and it's probably over-the-top at the moment but at some point it'd be nice to add something to the docs describing the canonicalization rules, like

## Canonicalization rules

To ensure that canonicalization converges to a fixed point, we define a preference 
list on vector dialect operations. Operations higher in the list are preferred. 
- vector.shape_cast
- vector.broadcast
- vector.transpose

Patterns that substitute one operation with another must satisfy the above list 

@dcaballe
Copy link
Contributor

Thanks for sharing your thoughts! Replying to this and this threads here.

I think part of the challenge with canonicalization stems from the expectations we’ve built around it. We often assume the IR after canonicalization will be ideal for all patterns, analyses and, ultimately, lead to the most efficient lowering/ASM/runtime performance and whatnot. But realistically, canonicalization should only focus on eliminating redundancy from the IR and offering a unique and consistent representation to make passes consuming such a canonical form simpler, regardless of everything else. For me, a pass running on canonical IR that must deal with both:

%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> 

and

%0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8> 

indicates that we have a gap in the canonical form because both operations are representing the same thing.

IMO, using vector.shape_cast to handle reshape- or view-like operations that don’t involve data rearrangement/movement (at least at this specific virtual vector level) simplifies things significantly, especially compared to heavier operations like transposes. Later, during legalization, a vector.shape_cast may have to be handled differently depending on the specific target, backend compiler or even lowering approach (e.g., legalizing to multi-dim vectors, linearizing multi-dim vectors, unrolling multi-dim vectors...). That is expected but it shouldn’t prevent us from using it as part of the canonical form of the Vector dialect if it brings value at that level. Improving its lowering to LLVM for cases where it’s really needed should be within the expected work that we have to do.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants