Skip to content

Commit 88accd9

Browse files
authored
[mlir][ArmSME] Pattern to swap shape_cast(tranpose) with transpose(shape_cast) (#100731)
This applies when the shape_cast is simply for dropping unit dims, and the result rank is >= 2. This simplifies the transpose making it possible for other ArmSME legalization patterns to handle it. Example: ```mlir %0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32> %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32> ``` ```mlir %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32> %1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32> ```
1 parent 49cb170 commit 88accd9

File tree

2 files changed

+116
-1
lines changed

2 files changed

+116
-1
lines changed

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,94 @@ struct ConvertIllegalShapeCastOpsToTransposes
774774
}
775775
};
776776

777+
/// Returns an iterator over the dims (inc scalability) of a VectorType.
778+
static auto getDims(VectorType vType) {
779+
return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
780+
}
781+
782+
/// Helper to drop (fixed-size) unit dims from a VectorType.
783+
static VectorType dropUnitDims(VectorType vType) {
784+
SmallVector<bool> scalableFlags;
785+
SmallVector<int64_t> dimSizes;
786+
for (auto dim : getDims(vType)) {
787+
if (dim == std::make_tuple(1, false))
788+
continue;
789+
auto [size, scalableFlag] = dim;
790+
dimSizes.push_back(size);
791+
scalableFlags.push_back(scalableFlag);
792+
}
793+
return VectorType::get(dimSizes, vType.getElementType(), scalableFlags);
794+
}
795+
796+
/// A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the
797+
/// shape_cast only drops unit dimensions.
798+
///
799+
/// This simplifies the transpose making it possible for other legalization
800+
/// rewrites to handle it.
801+
///
802+
/// Example:
803+
///
804+
/// BEFORE:
805+
/// ```mlir
806+
/// %0 = vector.transpose %vector, [3, 0, 1, 2]
807+
/// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
808+
/// %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
809+
/// ```
810+
///
811+
/// AFTER:
812+
/// ```mlir
813+
/// %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
814+
/// %1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
815+
/// ```
816+
struct SwapShapeCastOfTranspose : public OpRewritePattern<vector::ShapeCastOp> {
817+
using OpRewritePattern::OpRewritePattern;
818+
819+
LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
820+
PatternRewriter &rewriter) const override {
821+
auto transposeOp =
822+
shapeCastOp.getSource().getDefiningOp<vector::TransposeOp>();
823+
if (!transposeOp)
824+
return rewriter.notifyMatchFailure(shapeCastOp, "not TransposeOp");
825+
826+
auto resultType = shapeCastOp.getResultVectorType();
827+
if (resultType.getRank() <= 1)
828+
return rewriter.notifyMatchFailure(shapeCastOp, "result rank too low");
829+
830+
if (resultType != dropUnitDims(shapeCastOp.getSourceVectorType()))
831+
return rewriter.notifyMatchFailure(
832+
shapeCastOp, "ShapeCastOp changes non-unit dimension(s)");
833+
834+
auto transposeSourceVectorType = transposeOp.getSourceVectorType();
835+
auto transposeSourceDims =
836+
llvm::to_vector(getDims(transposeSourceVectorType));
837+
838+
// Construct a map from dimIdx -> number of dims dropped before dimIdx.
839+
SmallVector<int64_t> droppedDimsBefore(transposeSourceVectorType.getRank());
840+
int64_t droppedDims = 0;
841+
for (auto [i, dim] : llvm::enumerate(transposeSourceDims)) {
842+
droppedDimsBefore[i] = droppedDims;
843+
if (dim == std::make_tuple(1, false))
844+
++droppedDims;
845+
}
846+
847+
// Drop unit dims from transpose permutation.
848+
auto perm = transposeOp.getPermutation();
849+
SmallVector<int64_t> newPerm;
850+
for (int64_t idx : perm) {
851+
if (transposeSourceDims[idx] == std::make_tuple(1, false))
852+
continue;
853+
newPerm.push_back(idx - droppedDimsBefore[idx]);
854+
}
855+
856+
auto loc = shapeCastOp.getLoc();
857+
auto newShapeCastOp = rewriter.create<vector::ShapeCastOp>(
858+
loc, dropUnitDims(transposeSourceVectorType), transposeOp.getVector());
859+
rewriter.replaceOpWithNewOp<vector::TransposeOp>(shapeCastOp,
860+
newShapeCastOp, newPerm);
861+
return success();
862+
}
863+
};
864+
777865
/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
778866
/// the ZA state. This workaround rewrite to support these transposes when ZA is
779867
/// available.
@@ -939,7 +1027,8 @@ struct VectorLegalizationPass
9391027
patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
9401028
LiftIllegalVectorTransposeToMemory,
9411029
ConvertIllegalShapeCastOpsToTransposes,
942-
LowerIllegalTransposeStoreViaZA>(context);
1030+
SwapShapeCastOfTranspose, LowerIllegalTransposeStoreViaZA>(
1031+
context);
9431032
// Note: These two patterns are added with a high benefit to ensure:
9441033
// - Masked outer products are handled before unmasked ones
9451034
// - Multi-tile writes are lowered as a store loop (if possible)

mlir/test/Dialect/ArmSME/vector-legalization.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,3 +646,29 @@ func.func @negative_transpose_store_scalable_via_za__bad_source_shape(%vec: vect
646646
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[7]x2xf32>, memref<?x?xf32>
647647
return
648648
}
649+
650+
// -----
651+
652+
// CHECK-LABEL: @swap_shape_cast_of_transpose(
653+
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x4x[4]xf32>)
654+
func.func @swap_shape_cast_of_transpose(%vector: vector<1x1x4x[4]xf32>) -> vector<[4]x4xf32> {
655+
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
656+
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
657+
// CHECK: return %[[TRANSPOSE]]
658+
%0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
659+
%1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
660+
return %1 : vector<[4]x4xf32>
661+
}
662+
663+
// -----
664+
665+
// CHECK-LABEL: @swap_shape_cast_of_transpose_units_dims_before_and_after(
666+
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x1x4x[4]x1xf32>)
667+
func.func @swap_shape_cast_of_transpose_units_dims_before_and_after(%vector: vector<1x1x1x4x[4]x1xf32>) -> vector<[4]x4xf32> {
668+
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x1x4x[4]x1xf32> to vector<4x[4]xf32>
669+
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
670+
// CHECK: return %[[TRANSPOSE]]
671+
%0 = vector.transpose %vector, [4, 1, 0, 2, 3, 5] : vector<1x1x1x4x[4]x1xf32> to vector<[4]x1x1x1x4x1xf32>
672+
%1 = vector.shape_cast %0 : vector<[4]x1x1x1x4x1xf32> to vector<[4]x4xf32>
673+
return %1 : vector<[4]x4xf32>
674+
}

0 commit comments

Comments
 (0)