18
18
#include " mlir/Dialect/ArmSME/Utils/Utils.h"
19
19
#include " mlir/Dialect/Func/IR/FuncOps.h"
20
20
#include " mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
21
+ #include " mlir/Dialect/Index/IR/IndexDialect.h"
22
+ #include " mlir/Dialect/Index/IR/IndexOps.h"
21
23
#include " mlir/Dialect/MemRef/IR/MemRef.h"
22
24
#include " mlir/Dialect/SCF/IR/SCF.h"
23
25
#include " mlir/Dialect/SCF/Transforms/Patterns.h"
24
26
#include " mlir/Dialect/Utils/IndexingUtils.h"
27
+ #include " mlir/Dialect/Vector/Utils/VectorUtils.h"
25
28
#include " mlir/Transforms/OneToNTypeConversion.h"
26
29
27
30
#define DEBUG_TYPE " arm-sme-vector-legalization"
@@ -140,11 +143,11 @@ Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
140
143
auto decomposeToSMETiles (OpBuilder &builder, VectorType type,
141
144
VectorType smeTileType,
142
145
bool transposeIndices = false ) {
143
- assert (isMultipleOfSMETileVectorType (type) &&
144
- " `type` not multiple of SME tiles" );
145
146
return llvm::map_range (
146
- StaticTileOffsetRange (type.getShape (), {smeTileType.getDimSize (0 ),
147
- smeTileType.getDimSize (1 )}),
147
+ StaticTileOffsetRange (
148
+ type.getShape (),
149
+ {std::min (type.getDimSize (0 ), smeTileType.getDimSize (0 )),
150
+ std::min (type.getDimSize (1 ), smeTileType.getDimSize (1 ))}),
148
151
[=](auto indices) {
149
152
int row = int (indices[0 ]);
150
153
int col = int (indices[1 ]);
@@ -440,12 +443,8 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
440
443
kMatchFailureUnsupportedMaskOp );
441
444
442
445
auto loc = writeOp.getLoc ();
443
- auto vscale = rewriter.create <vector::VectorScaleOp>(loc);
444
- auto createVscaleMultiple = [&](int64_t multiplier) {
445
- return rewriter.create <arith::MulIOp>(
446
- loc, vscale,
447
- rewriter.create <arith::ConstantIndexOp>(loc, multiplier));
448
- };
446
+ auto createVscaleMultiple =
447
+ vector::makeVscaleConstantBuilder (rewriter, loc);
449
448
450
449
// Get SME tile and slice types.
451
450
auto smeTileType = getSMETileTypeForElement (vectorType.getElementType ());
@@ -775,6 +774,149 @@ struct ConvertIllegalShapeCastOpsToTransposes
775
774
}
776
775
};
777
776
777
+ // / Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
778
+ // / the ZA state. This workaround rewrite to support these transposes when ZA is
779
+ // / available.
780
+ // /
781
+ // / Example:
782
+ // /
783
+ // / BEFORE:
784
+ // / ```mlir
785
+ // / %transpose = vector.transpose %vec, [1, 0]
786
+ // / : vector<2x[4]xf32> to vector<[4]x2xf32>
787
+ // / vector.transfer_write %transpose, %dest[%y, %x]
788
+ // / : vector<[4]x2xf32>, memref<?x?xf32>
789
+ // / ```
790
+ // /
791
+ // / AFTER:
792
+ // / ```mlir
793
+ // / %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
794
+ // / %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
795
+ // / %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
796
+ // / %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
797
+ // / %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
798
+ // / %c4_vscale = arith.muli %vscale, %c4 : index
799
+ // / %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
800
+ // / vector.transfer_write %4, %dest[%y, %x], %mask
801
+ // / {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
802
+ // / : vector<[4]x[4]xf32>, memref<?x?xf32>
803
+ // / ```
804
+ // /
805
+ // / Values larger than a single tile are supported via decomposition.
806
+ struct LowerIllegalTransposeStoreViaZA
807
+ : public OpRewritePattern<vector::TransferWriteOp> {
808
+ using OpRewritePattern::OpRewritePattern;
809
+
810
+ LogicalResult matchAndRewrite (vector::TransferWriteOp writeOp,
811
+ PatternRewriter &rewriter) const override {
812
+ if (!isSupportedMaskOp (writeOp.getMask ()))
813
+ return rewriter.notifyMatchFailure (writeOp,
814
+ kMatchFailureUnsupportedMaskOp );
815
+
816
+ auto permutationMap = writeOp.getPermutationMap ();
817
+ if (!permutationMap.isIdentity ())
818
+ return rewriter.notifyMatchFailure (writeOp,
819
+ kMatchFailureNonPermutationMap );
820
+
821
+ auto transposeOp = writeOp.getVector ().getDefiningOp <vector::TransposeOp>();
822
+ if (!transposeOp)
823
+ return failure ();
824
+
825
+ auto sourceType = transposeOp.getSourceVectorType ();
826
+ auto resultType = transposeOp.getResultVectorType ();
827
+
828
+ if (resultType.getRank () != 2 )
829
+ return rewriter.notifyMatchFailure (transposeOp, " TransposeOp not rank 2" );
830
+
831
+ if (!isLegalVectorType (sourceType) || isLegalVectorType (resultType))
832
+ return rewriter.notifyMatchFailure (
833
+ transposeOp, " not illegal/unsupported SVE transpose" );
834
+
835
+ auto smeTileType = getSMETileTypeForElement (resultType.getElementType ());
836
+ VectorType smeSliceType = VectorType::Builder (smeTileType).dropDim (0 );
837
+
838
+ if (sourceType.getDimSize (0 ) <= 1 ||
839
+ sourceType.getDimSize (1 ) % smeSliceType.getDimSize (0 ) != 0 )
840
+ return rewriter.notifyMatchFailure (writeOp, " unsupported source shape" );
841
+
842
+ auto loc = writeOp.getLoc ();
843
+ auto createVscaleMultiple =
844
+ vector::makeVscaleConstantBuilder (rewriter, loc);
845
+
846
+ auto transposeMap = AffineMapAttr::get (
847
+ AffineMap::getPermutationMap (ArrayRef<int64_t >{1 , 0 }, getContext ()));
848
+
849
+ // Note: We need to use `get_tile` as there's no vector-level `undef`.
850
+ Value undefTile = rewriter.create <arm_sme::GetTileOp>(loc, smeTileType);
851
+ Value destTensorOrMemref = writeOp.getSource ();
852
+ auto numSlicesPerTile =
853
+ std::min (sourceType.getDimSize (0 ), smeTileType.getDimSize (0 ));
854
+ auto numSlices =
855
+ rewriter.create <arith::ConstantIndexOp>(loc, numSlicesPerTile);
856
+ for (auto [index , smeTile] : llvm::enumerate (
857
+ decomposeToSMETiles (rewriter, sourceType, smeTileType))) {
858
+ // 1. _Deliberately_ drop a scalable dimension and insert a fixed number
859
+ // of slices from the source type into the SME tile. Without checking
860
+ // vscale (and emitting multiple implementations) we can't make use of the
861
+ // rows of the tile after 1*vscale rows.
862
+ Value tile = undefTile;
863
+ for (int d = 0 ; d < numSlicesPerTile; ++d) {
864
+ Value vector = rewriter.create <vector::ExtractOp>(
865
+ loc, transposeOp.getVector (),
866
+ rewriter.getIndexAttr (d + smeTile.row ));
867
+ if (vector.getType () != smeSliceType) {
868
+ vector = rewriter.create <vector::ScalableExtractOp>(
869
+ loc, smeSliceType, vector, smeTile.col );
870
+ }
871
+ tile = rewriter.create <vector::InsertOp>(loc, vector, tile, d);
872
+ }
873
+
874
+ // 2. Transpose the tile position.
875
+ auto transposedRow = createVscaleMultiple (smeTile.col );
876
+ auto transposedCol =
877
+ rewriter.create <arith::ConstantIndexOp>(loc, smeTile.row );
878
+
879
+ // 3. Compute mask for tile store.
880
+ Value maskRows;
881
+ Value maskCols;
882
+ if (auto mask = writeOp.getMask ()) {
883
+ auto createMask = mask.getDefiningOp <vector::CreateMaskOp>();
884
+ maskRows = rewriter.create <arith::SubIOp>(loc, createMask.getOperand (0 ),
885
+ transposedRow);
886
+ maskCols = rewriter.create <arith::SubIOp>(loc, createMask.getOperand (1 ),
887
+ transposedCol);
888
+ maskCols = rewriter.create <index ::MinSOp>(loc, maskCols, numSlices);
889
+ } else {
890
+ maskRows = createVscaleMultiple (smeTileType.getDimSize (0 ));
891
+ maskCols = numSlices;
892
+ }
893
+ auto subMask = rewriter.create <vector::CreateMaskOp>(
894
+ loc, smeTileType.clone (rewriter.getI1Type ()),
895
+ ValueRange{maskRows, maskCols});
896
+
897
+ // 4. Emit a transposed tile write.
898
+ auto writeIndices = writeOp.getIndices ();
899
+ Value destRow =
900
+ rewriter.create <arith::AddIOp>(loc, transposedRow, writeIndices[0 ]);
901
+ Value destCol =
902
+ rewriter.create <arith::AddIOp>(loc, transposedCol, writeIndices[1 ]);
903
+ auto smeWrite = rewriter.create <vector::TransferWriteOp>(
904
+ loc, tile, destTensorOrMemref, ValueRange{destRow, destCol},
905
+ transposeMap, subMask, writeOp.getInBounds ());
906
+
907
+ if (writeOp.hasPureTensorSemantics ())
908
+ destTensorOrMemref = smeWrite.getResult ();
909
+ }
910
+
911
+ if (writeOp.hasPureTensorSemantics ())
912
+ rewriter.replaceOp (writeOp, destTensorOrMemref);
913
+ else
914
+ rewriter.eraseOp (writeOp);
915
+
916
+ return success ();
917
+ }
918
+ };
919
+
778
920
struct VectorLegalizationPass
779
921
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
780
922
void runOnOperation () override {
@@ -796,7 +938,8 @@ struct VectorLegalizationPass
796
938
797
939
patterns.add <FoldExtractFromVectorOfSMELikeCreateMasks,
798
940
LiftIllegalVectorTransposeToMemory,
799
- ConvertIllegalShapeCastOpsToTransposes>(context);
941
+ ConvertIllegalShapeCastOpsToTransposes,
942
+ LowerIllegalTransposeStoreViaZA>(context);
800
943
// Note: These two patterns are added with a high benefit to ensure:
801
944
// - Masked outer products are handled before unmasked ones
802
945
// - Multi-tile writes are lowered as a store loop (if possible)
0 commit comments