Skip to content

Commit 06ca6e9

Browse files
committed
[mlir][ArmSME] Remove ConvertIllegalShapeCastOpsToTransposes
As a follow-up to PR #135841 (see discussion for context), this patch removes `ConvertIllegalShapeCastOpsToTransposes` from the SME legalization pass and unblocks `ShapeCastOp::fold` for scalable vectors. AFAIK, `ConvertIllegalShapeCastOpsToTransposes` was originally needed because we were generating `vector.shape_cast` ops that couldn't be lowered otherwise. To confirm it's no longer required, I tested this patch locally using end-to-end tests. Notably, this also removes a special case from `ShapeCastOp::fold`.
1 parent ff87b87 commit 06ca6e9

File tree

4 files changed

+38
-114
lines changed

4 files changed

+38
-114
lines changed

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -724,59 +724,6 @@ struct LiftIllegalVectorTransposeToMemory
724724
}
725725
};
726726

727-
/// A rewrite to turn unit dim transpose-like vector.shape_casts into
728-
/// vector.transposes. The shape_cast has to be from an illegal vector type to a
729-
/// legal one (as defined by isLegalVectorType).
730-
///
731-
/// The reasoning for this is if we've got to this pass and we still have
732-
/// shape_casts of illegal types, then they likely will not cancel out. Turning
733-
/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
734-
/// eliminate them.
735-
///
736-
/// Example:
737-
///
738-
/// BEFORE:
739-
/// ```mlir
740-
/// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
741-
/// ```
742-
///
743-
/// AFTER:
744-
/// ```mlir
745-
/// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
746-
/// ```
747-
struct ConvertIllegalShapeCastOpsToTransposes
748-
: public OpRewritePattern<vector::ShapeCastOp> {
749-
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
750-
751-
LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
752-
PatternRewriter &rewriter) const override {
753-
auto sourceType = shapeCastOp.getSourceVectorType();
754-
auto resultType = shapeCastOp.getResultVectorType();
755-
if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
756-
return rewriter.notifyMatchFailure(shapeCastOp,
757-
kMatchFailureNotIllegalToLegal);
758-
759-
// Note: If we know that `sourceType` is an illegal vector type (and 2D)
760-
// then dim 0 is scalable and dim 1 is fixed.
761-
if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
762-
return rewriter.notifyMatchFailure(
763-
shapeCastOp, "expected source to be a 2D scalable vector with a "
764-
"trailing unit dim");
765-
766-
auto loc = shapeCastOp.getLoc();
767-
auto transpose = rewriter.create<vector::TransposeOp>(
768-
loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
769-
770-
if (resultType.getRank() == 1)
771-
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
772-
transpose);
773-
else
774-
rewriter.replaceOp(shapeCastOp, transpose);
775-
776-
return success();
777-
}
778-
};
779-
780727
/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
781728
/// the ZA state. This workaround rewrite to support these transposes when ZA is
782729
/// available.
@@ -943,7 +890,6 @@ struct VectorLegalizationPass
943890
RewritePatternSet rewritePatterns(context);
944891
rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
945892
LiftIllegalVectorTransposeToMemory,
946-
ConvertIllegalShapeCastOpsToTransposes,
947893
LowerIllegalTransposeStoreViaZA>(context);
948894
if (failed(
949895
applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5617,18 +5617,7 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56175617

56185618
// shape_cast(transpose(x)) -> shape_cast(x)
56195619
if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
5620-
// This folder does
5621-
// shape_cast(transpose) -> shape_cast
5622-
// But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
5623-
// shape_cast -> shape_cast(transpose)
5624-
// i.e. the complete opposite. When paired, these 2 patterns can cause
5625-
// infinite cycles in pattern rewriting.
5626-
// ConvertIllegalShapeCastOpsToTransposes only matches on scalable
5627-
// vectors, so by disabling this folder for scalable vectors the
5628-
// cycle is avoided.
5629-
// TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
5630-
// still needed. If it's not, then we can fold here.
5631-
if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) {
5620+
if (isOrderPreserving(transpose)) {
56325621
setOperand(transpose.getVector());
56335622
return getResult();
56345623
}

mlir/test/Dialect/ArmSME/vector-legalization.mlir

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -491,51 +491,6 @@ func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> v
491491

492492
// -----
493493

494-
// CHECK-LABEL: @illegal_shape_cast_to_transpose_2d(
495-
// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>)
496-
func.func @illegal_shape_cast_to_transpose_2d(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
497-
// CHECK: vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
498-
%0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<1x[4]xf32>
499-
return %0 : vector<1x[4]xf32>
500-
}
501-
502-
// -----
503-
504-
// CHECK-LABEL: @illegal_shape_cast_to_transpose_1d(
505-
// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>)
506-
func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector<[4]xf32> {
507-
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
508-
// CHECK: vector.shape_cast %[[TRANSPOSE]] : vector<1x[4]xf32> to vector<[4]xf32>
509-
%0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<[4]xf32>
510-
return %0 : vector<[4]xf32>
511-
}
512-
513-
// -----
514-
515-
// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory
516-
func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> {
517-
// CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
518-
// CHECK-NOT: vector.shape_cast
519-
%pad = arith.constant 0.0 : f32
520-
%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
521-
%cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32>
522-
return %cast : vector<1x[4]xf32>
523-
}
524-
525-
// -----
526-
527-
// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory
528-
func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> {
529-
// CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
530-
// CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32>
531-
%pad = arith.constant 0.0 : f32
532-
%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
533-
%cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
534-
return %cast : vector<[4]xf32>
535-
}
536-
537-
// -----
538-
539494
// CHECK-LABEL: @multi_tile_splat
540495
func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
541496
{

mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,25 @@ func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8
161161

162162
// -----
163163

164+
// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows:
165+
// 1 -> 0
166+
// 2 -> 4
167+
// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
168+
// (same as the example above, but one of the dims is scalable)
169+
// CHECK-LABEL: @transpose_shape_cast_scalable
170+
// CHECK-SAME: %[[ARG:.*]]: vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> {
171+
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
172+
// CHECK-SAME: vector<1x[4]x4x1x1xi8> to vector<[4]x4xi8>
173+
// CHECK: return %[[SHAPE_CAST]] : vector<[4]x4xi8>
174+
func.func @transpose_shape_cast_scalable(%arg : vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> {
175+
%0 = vector.transpose %arg, [1, 0, 3, 4, 2]
176+
: vector<1x[4]x4x1x1xi8> to vector<[4]x1x1x1x4xi8>
177+
%1 = vector.shape_cast %0 : vector<[4]x1x1x1x4xi8> to vector<[4]x4xi8>
178+
return %1 : vector<[4]x4xi8>
179+
}
180+
181+
// -----
182+
164183
// In this test, the mapping of non-unit dimensions (1 and 2) is as follows:
165184
// 1 -> 2
166185
// 2 -> 1
@@ -225,11 +244,26 @@ func.func @transpose_of_shape_cast(%arg : vector<2x3x1x1xi8>) -> vector<6x1x1xi
225244

226245
// -----
227246

228-
// Scalable dimensions should be treated as non-unit dimensions.
229-
// CHECK-LABEL: @transpose_of_shape_cast_scalable
247+
// CHECK-LABEL: @transpose_shape_cast_scalable
248+
// CHECK-SAME: %[[ARG:.*]]: vector<[2]x3x1x1xi8>) -> vector<[6]x1x1xi8> {
249+
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
250+
// CHECK-SAME: vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8>
251+
// CHECK: return %[[SHAPE_CAST]] : vector<[6]x1x1xi8>
252+
func.func @shape_cast_transpose_scalable(%arg : vector<[2]x3x1x1xi8>) -> vector<[6]x1x1xi8> {
253+
%0 = vector.shape_cast %arg : vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8>
254+
%1 = vector.transpose %0, [0, 2, 1]
255+
: vector<[6]x1x1xi8> to vector<[6]x1x1xi8>
256+
return %1 : vector<[6]x1x1xi8>
257+
}
258+
259+
// -----
260+
261+
// Scalable 1 dimensions (i.e. [1]) should be treated as non-unit dimensions
262+
// (hence no folding).
263+
// CHECK-LABEL: @negative_shape_cast_transpose_scalable_unit
230264
// CHECK: vector.shape_cast
231265
// CHECK: vector.transpose
232-
func.func @transpose_of_shape_cast_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
266+
func.func @negative_shape_cast_transpose_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
233267
%0 = vector.shape_cast %arg : vector<[1]x4x1xi8> to vector<[1]x4xi8>
234268
%1 = vector.transpose %0, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8>
235269
return %1 : vector<4x[1]xi8>

0 commit comments

Comments
 (0)