@@ -774,6 +774,94 @@ struct ConvertIllegalShapeCastOpsToTransposes
774
774
}
775
775
};
776
776
777
+ // / Returns an iterator over the dims (inc scalability) of a VectorType.
778
+ static auto getDims (VectorType vType) {
779
+ return llvm::zip_equal (vType.getShape (), vType.getScalableDims ());
780
+ }
781
+
782
+ // / Helper to drop (fixed-size) unit dims from a VectorType.
783
+ static VectorType dropUnitDims (VectorType vType) {
784
+ SmallVector<bool > scalableFlags;
785
+ SmallVector<int64_t > dimSizes;
786
+ for (auto dim : getDims (vType)) {
787
+ if (dim == std::make_tuple (1 , false ))
788
+ continue ;
789
+ auto [size, scalableFlag] = dim;
790
+ dimSizes.push_back (size);
791
+ scalableFlags.push_back (scalableFlag);
792
+ }
793
+ return VectorType::get (dimSizes, vType.getElementType (), scalableFlags);
794
+ }
795
+
796
+ // / A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the
797
+ // / shape_cast only drops unit dimensions.
798
+ // /
799
+ // / This simplifies the transpose making it possible for other legalization
800
+ // / rewrites to handle it.
801
+ // /
802
+ // / Example:
803
+ // /
804
+ // / BEFORE:
805
+ // / ```mlir
806
+ // / %0 = vector.transpose %vector, [3, 0, 1, 2]
807
+ // / : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
808
+ // / %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
809
+ // / ```
810
+ // /
811
+ // / AFTER:
812
+ // / ```mlir
813
+ // / %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
814
+ // / %1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
815
+ // / ```
816
+ struct SwapShapeCastOfTranspose : public OpRewritePattern <vector::ShapeCastOp> {
817
+ using OpRewritePattern::OpRewritePattern;
818
+
819
+ LogicalResult matchAndRewrite (vector::ShapeCastOp shapeCastOp,
820
+ PatternRewriter &rewriter) const override {
821
+ auto transposeOp =
822
+ shapeCastOp.getSource ().getDefiningOp <vector::TransposeOp>();
823
+ if (!transposeOp)
824
+ return rewriter.notifyMatchFailure (shapeCastOp, " not TransposeOp" );
825
+
826
+ auto resultType = shapeCastOp.getResultVectorType ();
827
+ if (resultType.getRank () <= 1 )
828
+ return rewriter.notifyMatchFailure (shapeCastOp, " result rank too low" );
829
+
830
+ if (resultType != dropUnitDims (shapeCastOp.getSourceVectorType ()))
831
+ return rewriter.notifyMatchFailure (
832
+ shapeCastOp, " ShapeCastOp changes non-unit dimension(s)" );
833
+
834
+ auto transposeSourceVectorType = transposeOp.getSourceVectorType ();
835
+ auto transposeSourceDims =
836
+ llvm::to_vector (getDims (transposeSourceVectorType));
837
+
838
+ // Construct a map from dimIdx -> number of dims dropped before dimIdx.
839
+ SmallVector<int64_t > droppedDimsBefore (transposeSourceVectorType.getRank ());
840
+ int64_t droppedDims = 0 ;
841
+ for (auto [i, dim] : llvm::enumerate (transposeSourceDims)) {
842
+ droppedDimsBefore[i] = droppedDims;
843
+ if (dim == std::make_tuple (1 , false ))
844
+ ++droppedDims;
845
+ }
846
+
847
+ // Drop unit dims from transpose permutation.
848
+ auto perm = transposeOp.getPermutation ();
849
+ SmallVector<int64_t > newPerm;
850
+ for (int64_t idx : perm) {
851
+ if (transposeSourceDims[idx] == std::make_tuple (1 , false ))
852
+ continue ;
853
+ newPerm.push_back (idx - droppedDimsBefore[idx]);
854
+ }
855
+
856
+ auto loc = shapeCastOp.getLoc ();
857
+ auto newShapeCastOp = rewriter.create <vector::ShapeCastOp>(
858
+ loc, dropUnitDims (transposeSourceVectorType), transposeOp.getVector ());
859
+ rewriter.replaceOpWithNewOp <vector::TransposeOp>(shapeCastOp,
860
+ newShapeCastOp, newPerm);
861
+ return success ();
862
+ }
863
+ };
864
+
777
865
// / Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
778
866
// / the ZA state. This workaround rewrite to support these transposes when ZA is
779
867
// / available.
@@ -939,7 +1027,8 @@ struct VectorLegalizationPass
939
1027
patterns.add <FoldExtractFromVectorOfSMELikeCreateMasks,
940
1028
LiftIllegalVectorTransposeToMemory,
941
1029
ConvertIllegalShapeCastOpsToTransposes,
942
- LowerIllegalTransposeStoreViaZA>(context);
1030
+ SwapShapeCastOfTranspose, LowerIllegalTransposeStoreViaZA>(
1031
+ context);
943
1032
// Note: These two patterns are added with a high benefit to ensure:
944
1033
// - Masked outer products are handled before unmasked ones
945
1034
// - Multi-tile writes are lowered as a store loop (if possible)
0 commit comments