Skip to content

[mlir][ArmSME] Remove ConvertIllegalShapeCastOpsToTransposes #139706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented May 13, 2025

As a follow-up to PR #135841 (see discussion for context), this patch
removes ConvertIllegalShapeCastOpsToTransposes from the SME legalization
pass and unblocks ShapeCastOp::fold for scalable vectors.

AFAIK, ConvertIllegalShapeCastOpsToTransposes was originally needed
because we were generating vector.shape_cast ops that couldn't be
lowered otherwise. To confirm it's no longer required, I tested this
patch locally using end-to-end tests.

Notably, this also removes a special case from ShapeCastOp::fold.

@banach-space banach-space changed the base branch from main to users/banach-space/vector/document_tests May 13, 2025 10:54
@llvmbot
Copy link
Member

llvmbot commented May 13, 2025

@llvm/pr-subscribers-mlir-sme

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes
  • [mlir][vector][nfc] Update comments in vector-transpose.mlir
  • [mlir][ArmSME] Remove ConvertIllegalShapeCastOpsToTransposes

Full diff: https://github.com/llvm/llvm-project/pull/139706.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (-54)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+1-12)
  • (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (-45)
  • (modified) mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir (+37-3)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 95965872f4098..51750f0bb9694 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -724,59 +724,6 @@ struct LiftIllegalVectorTransposeToMemory
   }
 };
 
-/// A rewrite to turn unit dim transpose-like vector.shape_casts into
-/// vector.transposes. The shape_cast has to be from an illegal vector type to a
-/// legal one (as defined by isLegalVectorType).
-///
-/// The reasoning for this is if we've got to this pass and we still have
-/// shape_casts of illegal types, then they likely will not cancel out. Turning
-/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
-/// eliminate them.
-///
-/// Example:
-///
-///  BEFORE:
-///  ```mlir
-///  %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
-///  ```
-///
-///  AFTER:
-///  ```mlir
-///  %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-///  ```
-struct ConvertIllegalShapeCastOpsToTransposes
-    : public OpRewritePattern<vector::ShapeCastOp> {
-  using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
-                                PatternRewriter &rewriter) const override {
-    auto sourceType = shapeCastOp.getSourceVectorType();
-    auto resultType = shapeCastOp.getResultVectorType();
-    if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
-      return rewriter.notifyMatchFailure(shapeCastOp,
-                                         kMatchFailureNotIllegalToLegal);
-
-    // Note: If we know that `sourceType` is an illegal vector type (and 2D)
-    // then dim 0 is scalable and dim 1 is fixed.
-    if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
-      return rewriter.notifyMatchFailure(
-          shapeCastOp, "expected source to be a 2D scalable vector with a "
-                       "trailing unit dim");
-
-    auto loc = shapeCastOp.getLoc();
-    auto transpose = rewriter.create<vector::TransposeOp>(
-        loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
-
-    if (resultType.getRank() == 1)
-      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
-                                                       transpose);
-    else
-      rewriter.replaceOp(shapeCastOp, transpose);
-
-    return success();
-  }
-};
-
 /// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
 /// the ZA state. This workaround rewrite to support these transposes when ZA is
 /// available.
@@ -943,7 +890,6 @@ struct VectorLegalizationPass
     RewritePatternSet rewritePatterns(context);
     rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
                         LiftIllegalVectorTransposeToMemory,
-                        ConvertIllegalShapeCastOpsToTransposes,
                         LowerIllegalTransposeStoreViaZA>(context);
     if (failed(
             applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f6c3c6a61afb6..83a287d29d773 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5617,18 +5617,7 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   // shape_cast(transpose(x)) -> shape_cast(x)
   if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
-    // This folder does
-    //    shape_cast(transpose) -> shape_cast
-    // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
-    //    shape_cast -> shape_cast(transpose)
-    // i.e. the complete opposite. When paired, these 2 patterns can cause
-    // infinite cycles in pattern rewriting.
-    // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
-    // vectors, so by disabling this folder for scalable vectors the
-    // cycle is avoided.
-    // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
-    // still needed. If it's not, then we can fold here.
-    if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) {
+    if (isOrderPreserving(transpose)) {
       setOperand(transpose.getVector());
       return getResult();
     }
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index d56df9814f173..6e6615c243d2a 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -491,51 +491,6 @@ func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> v
 
 // -----
 
-// CHECK-LABEL: @illegal_shape_cast_to_transpose_2d(
-// CHECK-SAME:                                      %[[VEC:.*]]: vector<[4]x1xf32>)
-func.func @illegal_shape_cast_to_transpose_2d(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
-  // CHECK: vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-  %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<1x[4]xf32>
-  return %0 : vector<1x[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @illegal_shape_cast_to_transpose_1d(
-// CHECK-SAME:                                      %[[VEC:.*]]: vector<[4]x1xf32>)
-func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector<[4]xf32> {
-  // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-  // CHECK: vector.shape_cast %[[TRANSPOSE]] : vector<1x[4]xf32> to vector<[4]xf32>
-  %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<[4]xf32>
-  return %0 : vector<[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory
-func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> {
-  // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
-  // CHECK-NOT: vector.shape_cast
-  %pad = arith.constant 0.0 : f32
-  %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
-  %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32>
-  return %cast : vector<1x[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory
-func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> {
-  // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
-  // CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32>
-  %pad = arith.constant 0.0 : f32
-  %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
-  %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
-  return %cast : vector<[4]xf32>
-}
-
-// -----
-
 // CHECK-LABEL: @multi_tile_splat
 func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
 {
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index e47578bc80719..625b4a9c53e42 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -161,6 +161,25 @@ func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8
 
 // -----
 
+// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows:
+// 1 -> 0
+// 2 -> 4
+// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
+// (same as the example above, but one of the dims is scalable)
+// CHECK-LABEL: @transpose_shape_cast_scalable
+//  CHECK-SAME:   %[[ARG:.*]]: vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> {
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+//  CHECK-SAME:   vector<1x[4]x4x1x1xi8> to vector<[4]x4xi8>
+//       CHECK:   return %[[SHAPE_CAST]] : vector<[4]x4xi8>
+func.func @transpose_shape_cast_scalable(%arg : vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> {
+  %0 = vector.transpose %arg, [1, 0, 3, 4, 2]
+     : vector<1x[4]x4x1x1xi8> to vector<[4]x1x1x1x4xi8>
+  %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4xi8> to vector<[4]x4xi8>
+  return %1 : vector<[4]x4xi8>
+}
+
+// -----
+
 // In this test, the mapping of non-unit dimensions (1 and 2) is as follows:
 // 1 -> 2
 // 2 -> 1
@@ -225,11 +244,26 @@ func.func @transpose_of_shape_cast(%arg : vector<2x3x1x1xi8>) ->  vector<6x1x1xi
 
 // -----
 
-// Scalable dimensions should be treated as non-unit dimensions.
-// CHECK-LABEL: @transpose_of_shape_cast_scalable
+// CHECK-LABEL: @transpose_shape_cast_scalable
+//  CHECK-SAME:   %[[ARG:.*]]: vector<[2]x3x1x1xi8>) -> vector<[6]x1x1xi8> {
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+//  CHECK-SAME:   vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8>
+//       CHECK:   return %[[SHAPE_CAST]] : vector<[6]x1x1xi8>
+func.func @shape_cast_transpose_scalable(%arg : vector<[2]x3x1x1xi8>) ->  vector<[6]x1x1xi8> {
+  %0 = vector.shape_cast %arg : vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8>
+  %1 = vector.transpose %0, [0, 2, 1]
+     : vector<[6]x1x1xi8> to vector<[6]x1x1xi8>
+  return %1 : vector<[6]x1x1xi8>
+}
+
+// -----
+
+// Scalable 1 dimensions (i.e. [1]) should be treated as non-unit dimensions
+// (hence no folding).
+// CHECK-LABEL: @negative_shape_cast_transpose_scalable_unit
 //       CHECK: vector.shape_cast
 //       CHECK: vector.transpose
-func.func @transpose_of_shape_cast_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
+func.func @negative_shape_cast_transpose_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
   %0 = vector.shape_cast %arg : vector<[1]x4x1xi8> to vector<[1]x4xi8>
   %1 = vector.transpose %0, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8>
   return %1 : vector<4x[1]xi8>

@llvmbot
Copy link
Member

llvmbot commented May 13, 2025

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes
  • [mlir][vector][nfc] Update comments in vector-transpose.mlir
  • [mlir][ArmSME] Remove ConvertIllegalShapeCastOpsToTransposes

Full diff: https://github.com/llvm/llvm-project/pull/139706.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (-54)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+1-12)
  • (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (-45)
  • (modified) mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir (+37-3)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 95965872f4098..51750f0bb9694 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -724,59 +724,6 @@ struct LiftIllegalVectorTransposeToMemory
   }
 };
 
-/// A rewrite to turn unit dim transpose-like vector.shape_casts into
-/// vector.transposes. The shape_cast has to be from an illegal vector type to a
-/// legal one (as defined by isLegalVectorType).
-///
-/// The reasoning for this is if we've got to this pass and we still have
-/// shape_casts of illegal types, then they likely will not cancel out. Turning
-/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
-/// eliminate them.
-///
-/// Example:
-///
-///  BEFORE:
-///  ```mlir
-///  %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
-///  ```
-///
-///  AFTER:
-///  ```mlir
-///  %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-///  ```
-struct ConvertIllegalShapeCastOpsToTransposes
-    : public OpRewritePattern<vector::ShapeCastOp> {
-  using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
-                                PatternRewriter &rewriter) const override {
-    auto sourceType = shapeCastOp.getSourceVectorType();
-    auto resultType = shapeCastOp.getResultVectorType();
-    if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
-      return rewriter.notifyMatchFailure(shapeCastOp,
-                                         kMatchFailureNotIllegalToLegal);
-
-    // Note: If we know that `sourceType` is an illegal vector type (and 2D)
-    // then dim 0 is scalable and dim 1 is fixed.
-    if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
-      return rewriter.notifyMatchFailure(
-          shapeCastOp, "expected source to be a 2D scalable vector with a "
-                       "trailing unit dim");
-
-    auto loc = shapeCastOp.getLoc();
-    auto transpose = rewriter.create<vector::TransposeOp>(
-        loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
-
-    if (resultType.getRank() == 1)
-      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
-                                                       transpose);
-    else
-      rewriter.replaceOp(shapeCastOp, transpose);
-
-    return success();
-  }
-};
-
 /// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
 /// the ZA state. This workaround rewrite to support these transposes when ZA is
 /// available.
@@ -943,7 +890,6 @@ struct VectorLegalizationPass
     RewritePatternSet rewritePatterns(context);
     rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
                         LiftIllegalVectorTransposeToMemory,
-                        ConvertIllegalShapeCastOpsToTransposes,
                         LowerIllegalTransposeStoreViaZA>(context);
     if (failed(
             applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f6c3c6a61afb6..83a287d29d773 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5617,18 +5617,7 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   // shape_cast(transpose(x)) -> shape_cast(x)
   if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
-    // This folder does
-    //    shape_cast(transpose) -> shape_cast
-    // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
-    //    shape_cast -> shape_cast(transpose)
-    // i.e. the complete opposite. When paired, these 2 patterns can cause
-    // infinite cycles in pattern rewriting.
-    // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
-    // vectors, so by disabling this folder for scalable vectors the
-    // cycle is avoided.
-    // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
-    // still needed. If it's not, then we can fold here.
-    if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) {
+    if (isOrderPreserving(transpose)) {
       setOperand(transpose.getVector());
       return getResult();
     }
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index d56df9814f173..6e6615c243d2a 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -491,51 +491,6 @@ func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> v
 
 // -----
 
-// CHECK-LABEL: @illegal_shape_cast_to_transpose_2d(
-// CHECK-SAME:                                      %[[VEC:.*]]: vector<[4]x1xf32>)
-func.func @illegal_shape_cast_to_transpose_2d(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
-  // CHECK: vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-  %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<1x[4]xf32>
-  return %0 : vector<1x[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @illegal_shape_cast_to_transpose_1d(
-// CHECK-SAME:                                      %[[VEC:.*]]: vector<[4]x1xf32>)
-func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector<[4]xf32> {
-  // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-  // CHECK: vector.shape_cast %[[TRANSPOSE]] : vector<1x[4]xf32> to vector<[4]xf32>
-  %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<[4]xf32>
-  return %0 : vector<[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory
-func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> {
-  // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
-  // CHECK-NOT: vector.shape_cast
-  %pad = arith.constant 0.0 : f32
-  %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
-  %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32>
-  return %cast : vector<1x[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory
-func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> {
-  // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
-  // CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32>
-  %pad = arith.constant 0.0 : f32
-  %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
-  %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
-  return %cast : vector<[4]xf32>
-}
-
-// -----
-
 // CHECK-LABEL: @multi_tile_splat
 func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
 {
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index e47578bc80719..625b4a9c53e42 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -161,6 +161,25 @@ func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8
 
 // -----
 
+// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows:
+// 1 -> 0
+// 2 -> 4
+// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
+// (same as the example above, but one of the dims is scalable)
+// CHECK-LABEL: @transpose_shape_cast_scalable
+//  CHECK-SAME:   %[[ARG:.*]]: vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> {
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+//  CHECK-SAME:   vector<1x[4]x4x1x1xi8> to vector<[4]x4xi8>
+//       CHECK:   return %[[SHAPE_CAST]] : vector<[4]x4xi8>
+func.func @transpose_shape_cast_scalable(%arg : vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> {
+  %0 = vector.transpose %arg, [1, 0, 3, 4, 2]
+     : vector<1x[4]x4x1x1xi8> to vector<[4]x1x1x1x4xi8>
+  %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4xi8> to vector<[4]x4xi8>
+  return %1 : vector<[4]x4xi8>
+}
+
+// -----
+
 // In this test, the mapping of non-unit dimensions (1 and 2) is as follows:
 // 1 -> 2
 // 2 -> 1
@@ -225,11 +244,26 @@ func.func @transpose_of_shape_cast(%arg : vector<2x3x1x1xi8>) ->  vector<6x1x1xi
 
 // -----
 
-// Scalable dimensions should be treated as non-unit dimensions.
-// CHECK-LABEL: @transpose_of_shape_cast_scalable
+// CHECK-LABEL: @transpose_shape_cast_scalable
+//  CHECK-SAME:   %[[ARG:.*]]: vector<[2]x3x1x1xi8>) -> vector<[6]x1x1xi8> {
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+//  CHECK-SAME:   vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8>
+//       CHECK:   return %[[SHAPE_CAST]] : vector<[6]x1x1xi8>
+func.func @shape_cast_transpose_scalable(%arg : vector<[2]x3x1x1xi8>) ->  vector<[6]x1x1xi8> {
+  %0 = vector.shape_cast %arg : vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8>
+  %1 = vector.transpose %0, [0, 2, 1]
+     : vector<[6]x1x1xi8> to vector<[6]x1x1xi8>
+  return %1 : vector<[6]x1x1xi8>
+}
+
+// -----
+
+// Scalable 1 dimensions (i.e. [1]) should be treated as non-unit dimensions
+// (hence no folding).
+// CHECK-LABEL: @negative_shape_cast_transpose_scalable_unit
 //       CHECK: vector.shape_cast
 //       CHECK: vector.transpose
-func.func @transpose_of_shape_cast_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
+func.func @negative_shape_cast_transpose_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
   %0 = vector.shape_cast %arg : vector<[1]x4x1xi8> to vector<[1]x4xi8>
   %1 = vector.transpose %0, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8>
   return %1 : vector<4x[1]xi8>

@banach-space banach-space changed the title users/banach space/sme/remove ConvertIllegalShapeCastOpsToTransposes [mlir][ArmSME] Remove ConvertIllegalShapeCastOpsToTransposes May 13, 2025
@banach-space banach-space requested review from newling and MacDue May 13, 2025 10:54
Base automatically changed from users/banach-space/vector/document_tests to main May 14, 2025 09:13
As a follow-up to PR #135841 (see discussion for context), this patch
removes `ConvertIllegalShapeCastOpsToTransposes` from the SME legalization
pass and unblocks `ShapeCastOp::fold` for scalable vectors.

AFAIK, `ConvertIllegalShapeCastOpsToTransposes` was originally needed
because we were generating `vector.shape_cast` ops that couldn't be
lowered otherwise. To confirm it's no longer required, I tested this
patch locally using end-to-end tests.

Notably, this also removes a special case from `ShapeCastOp::fold`.
@banach-space banach-space force-pushed the users/banach-space/sme/remove_ConvertIllegalShapeCastOpsToTransposes branch from 06ca6e9 to 0bf798d Compare May 14, 2025 09:28
Comment on lines -518 to -522
// CHECK-NOT: vector.shape_cast
%pad = arith.constant 0.0 : f32
%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
%cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32>
return %cast : vector<1x[4]xf32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you've tested, but to know if this rewrite is still needed or not this test case should still be possible to lower to LLVM.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Ben!

I'm not sure what you've tested

I used our e2e tests - from what I can tell, we don't generate such code anymore.

this test case should still be possible to lower to LLVM

Indeed. @momchil-velikov , since you are working on a generic pattern for "xfer_read with non-trailing scalable dims", could you make sure that this example lowers with your patch?

  %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>

I will wait for Momchil to upload his patch before progressing this one.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants