-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][ArmSME] Remove ConvertIllegalShapeCastOpsToTransposes
#139706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][ArmSME] Remove ConvertIllegalShapeCastOpsToTransposes
#139706
Conversation
@llvm/pr-subscribers-mlir-sme @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) Changes
Full diff: https://github.com/llvm/llvm-project/pull/139706.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 95965872f4098..51750f0bb9694 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -724,59 +724,6 @@ struct LiftIllegalVectorTransposeToMemory
}
};
-/// A rewrite to turn unit dim transpose-like vector.shape_casts into
-/// vector.transposes. The shape_cast has to be from an illegal vector type to a
-/// legal one (as defined by isLegalVectorType).
-///
-/// The reasoning for this is if we've got to this pass and we still have
-/// shape_casts of illegal types, then they likely will not cancel out. Turning
-/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
-/// eliminate them.
-///
-/// Example:
-///
-/// BEFORE:
-/// ```mlir
-/// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
-/// ```
-///
-/// AFTER:
-/// ```mlir
-/// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-/// ```
-struct ConvertIllegalShapeCastOpsToTransposes
- : public OpRewritePattern<vector::ShapeCastOp> {
- using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
- PatternRewriter &rewriter) const override {
- auto sourceType = shapeCastOp.getSourceVectorType();
- auto resultType = shapeCastOp.getResultVectorType();
- if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
- return rewriter.notifyMatchFailure(shapeCastOp,
- kMatchFailureNotIllegalToLegal);
-
- // Note: If we know that `sourceType` is an illegal vector type (and 2D)
- // then dim 0 is scalable and dim 1 is fixed.
- if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
- return rewriter.notifyMatchFailure(
- shapeCastOp, "expected source to be a 2D scalable vector with a "
- "trailing unit dim");
-
- auto loc = shapeCastOp.getLoc();
- auto transpose = rewriter.create<vector::TransposeOp>(
- loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
-
- if (resultType.getRank() == 1)
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
- transpose);
- else
- rewriter.replaceOp(shapeCastOp, transpose);
-
- return success();
- }
-};
-
/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
/// the ZA state. This workaround rewrite to support these transposes when ZA is
/// available.
@@ -943,7 +890,6 @@ struct VectorLegalizationPass
RewritePatternSet rewritePatterns(context);
rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
LiftIllegalVectorTransposeToMemory,
- ConvertIllegalShapeCastOpsToTransposes,
LowerIllegalTransposeStoreViaZA>(context);
if (failed(
applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f6c3c6a61afb6..83a287d29d773 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5617,18 +5617,7 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
// shape_cast(transpose(x)) -> shape_cast(x)
if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
- // This folder does
- // shape_cast(transpose) -> shape_cast
- // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
- // shape_cast -> shape_cast(transpose)
- // i.e. the complete opposite. When paired, these 2 patterns can cause
- // infinite cycles in pattern rewriting.
- // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
- // vectors, so by disabling this folder for scalable vectors the
- // cycle is avoided.
- // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
- // still needed. If it's not, then we can fold here.
- if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) {
+ if (isOrderPreserving(transpose)) {
setOperand(transpose.getVector());
return getResult();
}
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index d56df9814f173..6e6615c243d2a 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -491,51 +491,6 @@ func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> v
// -----
-// CHECK-LABEL: @illegal_shape_cast_to_transpose_2d(
-// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>)
-func.func @illegal_shape_cast_to_transpose_2d(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
- // CHECK: vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
- %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<1x[4]xf32>
- return %0 : vector<1x[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @illegal_shape_cast_to_transpose_1d(
-// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>)
-func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector<[4]xf32> {
- // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
- // CHECK: vector.shape_cast %[[TRANSPOSE]] : vector<1x[4]xf32> to vector<[4]xf32>
- %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<[4]xf32>
- return %0 : vector<[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory
-func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> {
- // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
- // CHECK-NOT: vector.shape_cast
- %pad = arith.constant 0.0 : f32
- %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
- %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32>
- return %cast : vector<1x[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory
-func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> {
- // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
- // CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32>
- %pad = arith.constant 0.0 : f32
- %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
- %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
- return %cast : vector<[4]xf32>
-}
-
-// -----
-
// CHECK-LABEL: @multi_tile_splat
func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
{
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index e47578bc80719..625b4a9c53e42 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -161,6 +161,25 @@ func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8
// -----
+// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows:
+// 1 -> 0
+// 2 -> 4
+// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
+// (same as the example above, but one of the dims is scalable)
+// CHECK-LABEL: @transpose_shape_cast_scalable
+// CHECK-SAME: %[[ARG:.*]]: vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> {
+// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+// CHECK-SAME: vector<1x[4]x4x1x1xi8> to vector<[4]x4xi8>
+// CHECK: return %[[SHAPE_CAST]] : vector<[4]x4xi8>
+func.func @transpose_shape_cast_scalable(%arg : vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> {
+ %0 = vector.transpose %arg, [1, 0, 3, 4, 2]
+ : vector<1x[4]x4x1x1xi8> to vector<[4]x1x1x1x4xi8>
+ %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4xi8> to vector<[4]x4xi8>
+ return %1 : vector<[4]x4xi8>
+}
+
+// -----
+
// In this test, the mapping of non-unit dimensions (1 and 2) is as follows:
// 1 -> 2
// 2 -> 1
@@ -225,11 +244,26 @@ func.func @transpose_of_shape_cast(%arg : vector<2x3x1x1xi8>) -> vector<6x1x1xi
// -----
-// Scalable dimensions should be treated as non-unit dimensions.
-// CHECK-LABEL: @transpose_of_shape_cast_scalable
+// CHECK-LABEL: @transpose_shape_cast_scalable
+// CHECK-SAME: %[[ARG:.*]]: vector<[2]x3x1x1xi8>) -> vector<[6]x1x1xi8> {
+// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+// CHECK-SAME: vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8>
+// CHECK: return %[[SHAPE_CAST]] : vector<[6]x1x1xi8>
+func.func @shape_cast_transpose_scalable(%arg : vector<[2]x3x1x1xi8>) -> vector<[6]x1x1xi8> {
+ %0 = vector.shape_cast %arg : vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8>
+ %1 = vector.transpose %0, [0, 2, 1]
+ : vector<[6]x1x1xi8> to vector<[6]x1x1xi8>
+ return %1 : vector<[6]x1x1xi8>
+}
+
+// -----
+
+// Scalable 1 dimensions (i.e. [1]) should be treated as non-unit dimensions
+// (hence no folding).
+// CHECK-LABEL: @negative_shape_cast_transpose_scalable_unit
// CHECK: vector.shape_cast
// CHECK: vector.transpose
-func.func @transpose_of_shape_cast_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
+func.func @negative_shape_cast_transpose_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
%0 = vector.shape_cast %arg : vector<[1]x4x1xi8> to vector<[1]x4xi8>
%1 = vector.transpose %0, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8>
return %1 : vector<4x[1]xi8>
|
@llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) Changes
Full diff: https://github.com/llvm/llvm-project/pull/139706.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 95965872f4098..51750f0bb9694 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -724,59 +724,6 @@ struct LiftIllegalVectorTransposeToMemory
}
};
-/// A rewrite to turn unit dim transpose-like vector.shape_casts into
-/// vector.transposes. The shape_cast has to be from an illegal vector type to a
-/// legal one (as defined by isLegalVectorType).
-///
-/// The reasoning for this is if we've got to this pass and we still have
-/// shape_casts of illegal types, then they likely will not cancel out. Turning
-/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
-/// eliminate them.
-///
-/// Example:
-///
-/// BEFORE:
-/// ```mlir
-/// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
-/// ```
-///
-/// AFTER:
-/// ```mlir
-/// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-/// ```
-struct ConvertIllegalShapeCastOpsToTransposes
- : public OpRewritePattern<vector::ShapeCastOp> {
- using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
- PatternRewriter &rewriter) const override {
- auto sourceType = shapeCastOp.getSourceVectorType();
- auto resultType = shapeCastOp.getResultVectorType();
- if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
- return rewriter.notifyMatchFailure(shapeCastOp,
- kMatchFailureNotIllegalToLegal);
-
- // Note: If we know that `sourceType` is an illegal vector type (and 2D)
- // then dim 0 is scalable and dim 1 is fixed.
- if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
- return rewriter.notifyMatchFailure(
- shapeCastOp, "expected source to be a 2D scalable vector with a "
- "trailing unit dim");
-
- auto loc = shapeCastOp.getLoc();
- auto transpose = rewriter.create<vector::TransposeOp>(
- loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
-
- if (resultType.getRank() == 1)
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
- transpose);
- else
- rewriter.replaceOp(shapeCastOp, transpose);
-
- return success();
- }
-};
-
/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
/// the ZA state. This workaround rewrite to support these transposes when ZA is
/// available.
@@ -943,7 +890,6 @@ struct VectorLegalizationPass
RewritePatternSet rewritePatterns(context);
rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
LiftIllegalVectorTransposeToMemory,
- ConvertIllegalShapeCastOpsToTransposes,
LowerIllegalTransposeStoreViaZA>(context);
if (failed(
applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f6c3c6a61afb6..83a287d29d773 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5617,18 +5617,7 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
// shape_cast(transpose(x)) -> shape_cast(x)
if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
- // This folder does
- // shape_cast(transpose) -> shape_cast
- // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
- // shape_cast -> shape_cast(transpose)
- // i.e. the complete opposite. When paired, these 2 patterns can cause
- // infinite cycles in pattern rewriting.
- // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
- // vectors, so by disabling this folder for scalable vectors the
- // cycle is avoided.
- // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
- // still needed. If it's not, then we can fold here.
- if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) {
+ if (isOrderPreserving(transpose)) {
setOperand(transpose.getVector());
return getResult();
}
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index d56df9814f173..6e6615c243d2a 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -491,51 +491,6 @@ func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> v
// -----
-// CHECK-LABEL: @illegal_shape_cast_to_transpose_2d(
-// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>)
-func.func @illegal_shape_cast_to_transpose_2d(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
- // CHECK: vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
- %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<1x[4]xf32>
- return %0 : vector<1x[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @illegal_shape_cast_to_transpose_1d(
-// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>)
-func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector<[4]xf32> {
- // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
- // CHECK: vector.shape_cast %[[TRANSPOSE]] : vector<1x[4]xf32> to vector<[4]xf32>
- %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<[4]xf32>
- return %0 : vector<[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory
-func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> {
- // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
- // CHECK-NOT: vector.shape_cast
- %pad = arith.constant 0.0 : f32
- %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
- %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32>
- return %cast : vector<1x[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory
-func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> {
- // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
- // CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32>
- %pad = arith.constant 0.0 : f32
- %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
- %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
- return %cast : vector<[4]xf32>
-}
-
-// -----
-
// CHECK-LABEL: @multi_tile_splat
func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
{
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index e47578bc80719..625b4a9c53e42 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -161,6 +161,25 @@ func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8
// -----
+// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows:
+// 1 -> 0
+// 2 -> 4
+// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
+// (same as the example above, but one of the dims is scalable)
+// CHECK-LABEL: @transpose_shape_cast_scalable
+// CHECK-SAME: %[[ARG:.*]]: vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> {
+// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+// CHECK-SAME: vector<1x[4]x4x1x1xi8> to vector<[4]x4xi8>
+// CHECK: return %[[SHAPE_CAST]] : vector<[4]x4xi8>
+func.func @transpose_shape_cast_scalable(%arg : vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> {
+ %0 = vector.transpose %arg, [1, 0, 3, 4, 2]
+ : vector<1x[4]x4x1x1xi8> to vector<[4]x1x1x1x4xi8>
+ %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4xi8> to vector<[4]x4xi8>
+ return %1 : vector<[4]x4xi8>
+}
+
+// -----
+
// In this test, the mapping of non-unit dimensions (1 and 2) is as follows:
// 1 -> 2
// 2 -> 1
@@ -225,11 +244,26 @@ func.func @transpose_of_shape_cast(%arg : vector<2x3x1x1xi8>) -> vector<6x1x1xi
// -----
-// Scalable dimensions should be treated as non-unit dimensions.
-// CHECK-LABEL: @transpose_of_shape_cast_scalable
+// CHECK-LABEL: @transpose_shape_cast_scalable
+// CHECK-SAME: %[[ARG:.*]]: vector<[2]x3x1x1xi8>) -> vector<[6]x1x1xi8> {
+// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+// CHECK-SAME: vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8>
+// CHECK: return %[[SHAPE_CAST]] : vector<[6]x1x1xi8>
+func.func @shape_cast_transpose_scalable(%arg : vector<[2]x3x1x1xi8>) -> vector<[6]x1x1xi8> {
+ %0 = vector.shape_cast %arg : vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8>
+ %1 = vector.transpose %0, [0, 2, 1]
+ : vector<[6]x1x1xi8> to vector<[6]x1x1xi8>
+ return %1 : vector<[6]x1x1xi8>
+}
+
+// -----
+
+// Scalable 1 dimensions (i.e. [1]) should be treated as non-unit dimensions
+// (hence no folding).
+// CHECK-LABEL: @negative_shape_cast_transpose_scalable_unit
// CHECK: vector.shape_cast
// CHECK: vector.transpose
-func.func @transpose_of_shape_cast_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
+func.func @negative_shape_cast_transpose_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
%0 = vector.shape_cast %arg : vector<[1]x4x1xi8> to vector<[1]x4xi8>
%1 = vector.transpose %0, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8>
return %1 : vector<4x[1]xi8>
|
ConvertIllegalShapeCastOpsToTransposes
As a follow-up to PR #135841 (see discussion for context), this patch removes `ConvertIllegalShapeCastOpsToTransposes` from the SME legalization pass and unblocks `ShapeCastOp::fold` for scalable vectors. AFAIK, `ConvertIllegalShapeCastOpsToTransposes` was originally needed because we were generating `vector.shape_cast` ops that couldn't be lowered otherwise. To confirm it's no longer required, I tested this patch locally using end-to-end tests. Notably, this also removes a special case from `ShapeCastOp::fold`.
06ca6e9
to
0bf798d
Compare
// CHECK-NOT: vector.shape_cast | ||
%pad = arith.constant 0.0 : f32 | ||
%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32> | ||
%cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32> | ||
return %cast : vector<1x[4]xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what you've tested, but to know if this rewrite is still needed or not this test case should still be possible to lower to LLVM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Ben!
I'm not sure what you've tested
I used our e2e tests - from what I can tell, we don't generate such code anymore.
this test case should still be possible to lower to LLVM
Indeed. @momchil-velikov , since you are working on a generic pattern for "xfer_read with non-trailing scalable dims", could you make sure that this example lowers with your patch?
%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
I will wait for Momchil to upload his patch before progressing this one.
As a follow-up to PR #135841 (see discussion for context), this patch
removes
ConvertIllegalShapeCastOpsToTransposes
from the SME legalizationpass and unblocks
ShapeCastOp::fold
for scalable vectors.AFAIK,
ConvertIllegalShapeCastOpsToTransposes
was originally neededbecause we were generating
vector.shape_cast
ops that couldn't belowered otherwise. To confirm it's no longer required, I tested this
patch locally using end-to-end tests.
Notably, this also removes a special case from
ShapeCastOp::fold
.