@@ -136,13 +136,14 @@ MatmulConfig getDefaultMatmulConfig(linalg::LinalgOp &linalgOp) {
136
136
cfg.KBlock = 64 ;
137
137
cfg.MThreads = 2 ;
138
138
cfg.NThreads = 2 ;
139
- cfg.KThreads = 2 ;
139
+ cfg.KThreads = 1 ;
140
140
return cfg;
141
141
}
142
142
143
- static Value tensorViewRankedTensor (RewriterBase &rewriter,
144
- RankedTensorType outTensorType,
145
- Value value) {
143
+ static Value
144
+ tensorViewRankedTensor (RewriterBase &rewriter, RankedTensorType outTensorType,
145
+ Value value,
146
+ ArrayRef<int64_t > permutation = SmallVector<int64_t >{}) {
146
147
// TODO: add support for plain layout transpose
147
148
Value result, currentValue = value;
148
149
auto loc = currentValue.getLoc ();
@@ -175,33 +176,57 @@ static Value tensorViewRankedTensor(RewriterBase &rewriter,
175
176
176
177
if (outShape.size () < inShape.size ()) {
177
178
SmallVector<ReassociationIndices> reassocIndices;
178
- ReassociationIndices firstEntry;
179
- for (auto i = 0UL ; i < inShape.size () - outShape.size () + 1 ; i++) {
180
- firstEntry.push_back (i);
181
- }
182
- reassocIndices.push_back (firstEntry);
183
- for (auto i = inShape.size () - outShape.size () + 1UL ; i < inShape.size ();
184
- i++) {
185
- reassocIndices.push_back ({(int )i});
179
+ uint64_t outIdx = 0UL , inIdx = 0UL ;
180
+ while (inIdx < inShape.size () && outIdx < outShape.size ()) {
181
+ ReassociationIndices firstEntry;
182
+ auto remaining = outShape[outIdx++];
183
+ if (remaining == 1 ) {
184
+ firstEntry.push_back (inIdx++);
185
+ reassocIndices.push_back (firstEntry);
186
+ continue ;
187
+ }
188
+ while (remaining > 1 ) {
189
+ remaining /= inShape[inIdx];
190
+ firstEntry.push_back (inIdx++);
191
+ }
192
+ reassocIndices.push_back (firstEntry);
186
193
}
187
194
result = rewriter.create <tensor::CollapseShapeOp>(
188
195
loc, outTensorType, currentValue, reassocIndices);
189
196
} else if (outShape.size () > inShape.size ()) {
190
197
SmallVector<ReassociationIndices> reassocIndices;
191
- ReassociationIndices firstEntry;
192
- for (auto i = 0UL ; i < outShape.size () - inShape.size () + 1 ; i++) {
193
- firstEntry.push_back ((int )i);
194
- }
195
- reassocIndices.push_back (firstEntry);
196
- for (auto i = outShape.size () - inShape.size () + 1UL ; i < outShape.size ();
197
- i++) {
198
- reassocIndices.push_back ({(int )i});
198
+ uint64_t outIdx = 0UL , inIdx = 0UL ;
199
+ while (outIdx < outShape.size () && inIdx < inShape.size ()) {
200
+ ReassociationIndices firstEntry;
201
+ auto remaining = inShape[inIdx++];
202
+ if (remaining == 1 ) {
203
+ firstEntry.push_back (outIdx++);
204
+ reassocIndices.push_back (firstEntry);
205
+ continue ;
206
+ }
207
+ while (remaining > 1 ) {
208
+ remaining /= outShape[outIdx];
209
+ firstEntry.push_back (outIdx++);
210
+ }
211
+ reassocIndices.push_back (firstEntry);
199
212
}
200
213
result = rewriter.create <tensor::ExpandShapeOp>(
201
214
loc, outTensorType, currentValue, reassocIndices);
202
215
} else {
203
216
result = rewriter.create <tensor::CastOp>(loc, outTensorType, currentValue);
204
217
}
218
+
219
+ if (!permutation.empty ()) {
220
+ SmallVector<int64_t > transposeShape;
221
+ for (auto idx : permutation) {
222
+ transposeShape.push_back (outShape[idx]);
223
+ }
224
+ auto initOp = rewriter.create <tensor::EmptyOp>(loc, transposeShape,
225
+ tensorElementType);
226
+ auto transposeOp = rewriter.create <linalg::TransposeOp>(
227
+ loc, result, initOp->getResult (0 ), permutation);
228
+ result = transposeOp->getResult (0 );
229
+ }
205
230
return result;
206
231
}
207
232
@@ -345,6 +370,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
345
370
return b.notifyMatchFailure (
346
371
linalgOp, " currentOp should not has pure buffer semantics" );
347
372
linalg::LinalgOp currentOp = linalgOp;
373
+
348
374
for (auto loopTypeIter : llvm::enumerate (loopType)) {
349
375
auto [i, loopType] = loopTypeIter;
350
376
auto currentDim = loopDim[i];
@@ -486,6 +512,8 @@ static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
486
512
bool isExtract,
487
513
SmallVector<int64_t > size,
488
514
int shrinDimNum = 0 ) {
515
+ OpBuilder::InsertionGuard guard (rewriter);
516
+ rewriter.setInsertionPoint (op);
489
517
if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(op)) {
490
518
SmallVector<OpFoldResult> mixedOffsets = extractSlice.getMixedOffsets ();
491
519
SmallVector<OpFoldResult> mixedSizes = extractSlice.getMixedSizes ();
@@ -514,6 +542,8 @@ static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
514
542
static LogicalResult setStaticSizeForInsertSliceOp (RewriterBase &rewriter,
515
543
Operation *op, Value source,
516
544
SmallVector<int64_t > size) {
545
+ OpBuilder::InsertionGuard guard (rewriter);
546
+ rewriter.setInsertionPoint (op);
517
547
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(op)) {
518
548
SmallVector<OpFoldResult> mixedOffsets = insertSlice.getMixedOffsets ();
519
549
SmallVector<OpFoldResult> mixedSizes = insertSlice.getMixedSizes ();
@@ -575,35 +605,34 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
575
605
linalgOp.getReductionDims (KDimPos);
576
606
getMatmulParallelDims (linalgOp, 0 , MDimPos);
577
607
getMatmulParallelDims (linalgOp, 1 , NDimPos);
578
- bool useBlockedLayout = KDimPos.size () > 1 ;
579
608
580
609
OuterLoopGenerationOption option;
581
610
auto iteratorTypes = linalgOp.getIteratorTypesArray ();
582
611
auto KFirstDim = (int )getOprandDim (linalgOp, KDimPos[0 ], 1 );
583
612
auto MFirstDim = (int )getOprandDim (linalgOp, MDimPos[0 ], 0 );
584
613
auto NFirstDim = (int )getOprandDim (linalgOp, NDimPos[0 ], 1 );
585
614
auto KParallelBlockSize =
586
- useBlockedLayout
615
+ KDimPos. size () > 1
587
616
? divAndCeil (KFirstDim, cfg.KThreads )
588
617
: divAndCeil (divAndCeil (KFirstDim, cfg.KBlock ), cfg.KThreads ) *
589
618
cfg.KBlock ;
590
619
auto MParallelBlockSize =
591
- useBlockedLayout
620
+ MDimPos. size () > 1
592
621
? divAndCeil (MFirstDim, cfg.MThreads )
593
622
: divAndCeil (divAndCeil (MFirstDim, cfg.MBlock ), cfg.MThreads ) *
594
623
cfg.MBlock ;
595
624
auto NParallelBlockSize =
596
- useBlockedLayout
625
+ NDimPos. size () > 1
597
626
? divAndCeil (NFirstDim, cfg.NThreads )
598
627
: divAndCeil (divAndCeil (NFirstDim, cfg.NBlock ), cfg.NThreads ) *
599
628
cfg.NBlock ;
600
- auto KOuterBlockSize = useBlockedLayout
629
+ auto KOuterBlockSize = KDimPos. size () > 1
601
630
? (cfg.KBlock - 1 ) / cfg.innerMostKBlock + 1
602
631
: cfg.KBlock ;
603
- auto MOuterBlockSize = useBlockedLayout
632
+ auto MOuterBlockSize = MDimPos. size () > 1
604
633
? (cfg.MBlock - 1 ) / cfg.innerMostMBlock + 1
605
634
: cfg.MBlock ;
606
- auto NOuterBlockSize = useBlockedLayout
635
+ auto NOuterBlockSize = NDimPos. size () > 1
607
636
? (cfg.NBlock - 1 ) / cfg.innerMostNBlock + 1
608
637
: cfg.NBlock ;
609
638
// Outer
@@ -631,11 +660,23 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
631
660
option.loopDim .emplace_back (SmallVector<int >{dim});
632
661
}
633
662
// Inner
634
- if (!useBlockedLayout ) {
663
+ if (KDimPos. size () == 1 ) {
635
664
option.nestedTileSizes .emplace_back (SmallVector<int >{cfg.KBlock });
636
665
option.loopType .emplace_back (OuterLoopGenerationOption::LoopType::ForOp);
637
666
option.loopDim .emplace_back (SmallVector<int >{(int )KDimPos.back ()});
638
667
}
668
+ if (MDimPos.size () == 1 ) {
669
+ option.nestedTileSizes .emplace_back (
670
+ SmallVector<int >{cfg.innerMostMBlock });
671
+ option.loopType .emplace_back (OuterLoopGenerationOption::LoopType::ForOp);
672
+ option.loopDim .emplace_back (SmallVector<int >{(int )MDimPos.back ()});
673
+ }
674
+ if (NDimPos.size () == 1 ) {
675
+ option.nestedTileSizes .emplace_back (
676
+ SmallVector<int >{cfg.innerMostNBlock });
677
+ option.loopType .emplace_back (OuterLoopGenerationOption::LoopType::ForOp);
678
+ option.loopDim .emplace_back (SmallVector<int >{(int )NDimPos.back ()});
679
+ }
639
680
for (auto dim = 0UL ; dim < linalgOp.getNumLoops (); dim++) {
640
681
if (dim != MDimPos.back () && dim != NDimPos.back () &&
641
682
iteratorTypes[dim] != mlir::utils::IteratorType::reduction) {
@@ -658,17 +699,24 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
658
699
linalg::LinalgOp originOp,
659
700
linalg::LinalgOp currentOp,
660
701
innerBodyGenerationOption &option) const {
702
+
661
703
mlir::easybuild::EasyBuilder eb{rewriter, originOp.getLoc ()};
662
704
auto operandDimTypes = getOprandDimType (originOp);
663
705
MatmulConfig cfg = getDefaultMatmulConfig (originOp);
664
706
auto AShape = originOp.getShape (originOp.getDpsInputOperand (0 ));
665
707
auto BShape = originOp.getShape (originOp.getDpsInputOperand (1 ));
666
708
auto CShape = originOp.getShape (originOp.getDpsInitOperand (0 ));
667
- bool useBlockedLayout = BShape.size () > 2 ;
709
+
710
+ auto MDimNum = std::count_if ((*operandDimTypes)[0 ].begin (),
711
+ (*operandDimTypes)[0 ].end (),
712
+ [](DimType d) { return d == DimType::M; });
713
+ auto NDimNum = std::count_if ((*operandDimTypes)[1 ].begin (),
714
+ (*operandDimTypes)[1 ].end (),
715
+ [](DimType d) { return d == DimType::N; });
668
716
// TODO: support plain in/block out format
669
717
SmallVector<int64_t > AInnermostDims, BInnermostDims, CInnermostDims;
670
- if (useBlockedLayout) {
671
- bool firstM = true , firstK = true , firstN = true ;
718
+ bool firstM = true , firstK = true , firstN = true ;
719
+ if (MDimNum > 1 ) {
672
720
for (auto [idx, iter] : llvm::enumerate ((*operandDimTypes)[0 ])) {
673
721
if (iter == DimType::M && firstM) {
674
722
AInnermostDims.push_back (1 );
@@ -682,21 +730,6 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
682
730
AInnermostDims.push_back (AShape[idx]);
683
731
}
684
732
}
685
- firstN = true ;
686
- firstK = true ;
687
- for (auto [idx, iter] : llvm::enumerate ((*operandDimTypes)[1 ])) {
688
- if (iter == DimType::N && firstN) {
689
- BInnermostDims.push_back (1 );
690
- firstN = false ;
691
- } else if (iter == DimType::Batch) {
692
- BInnermostDims.push_back (1 );
693
- } else if (iter == DimType::K && firstK) {
694
- BInnermostDims.push_back (cfg.KBlock / cfg.innerMostKBlock );
695
- firstK = false ;
696
- } else {
697
- BInnermostDims.push_back (BShape[idx]);
698
- }
699
- }
700
733
firstM = true ;
701
734
firstN = true ;
702
735
for (auto [idx, iter] : llvm::enumerate ((*operandDimTypes)[2 ])) {
@@ -716,74 +749,94 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
716
749
AInnermostDims = SmallVector<int64_t >{cfg.innerMostMBlock ,
717
750
cfg.KBlock / cfg.innerMostKBlock *
718
751
cfg.innerMostKBlock };
752
+ CInnermostDims =
753
+ SmallVector<int64_t >{cfg.innerMostMBlock , cfg.innerMostNBlock };
754
+ }
755
+ if (NDimNum > 1 ) {
756
+ firstN = true ;
757
+ firstK = true ;
758
+ for (auto [idx, iter] : llvm::enumerate ((*operandDimTypes)[1 ])) {
759
+ if (iter == DimType::N && firstN) {
760
+ BInnermostDims.push_back (1 );
761
+ firstN = false ;
762
+ } else if (iter == DimType::Batch) {
763
+ BInnermostDims.push_back (1 );
764
+ } else if (iter == DimType::K && firstK) {
765
+ BInnermostDims.push_back (cfg.KBlock / cfg.innerMostKBlock );
766
+ firstK = false ;
767
+ } else {
768
+ BInnermostDims.push_back (BShape[idx]);
769
+ }
770
+ }
771
+ } else {
719
772
BInnermostDims = SmallVector<int64_t >{cfg.KBlock / cfg.innerMostKBlock *
720
773
cfg.innerMostKBlock ,
721
774
cfg.innerMostNBlock };
722
- CInnermostDims =
723
- SmallVector<int64_t >{cfg.innerMostMBlock , cfg.innerMostNBlock };
724
775
}
725
776
726
777
OpBuilder::InsertionGuard guard (rewriter);
727
778
rewriter.setInsertionPoint (currentOp);
728
779
auto dataType =
729
- dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInputs ()[0 ].getType ());
780
+ dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInputs ()[0 ].getType ())
781
+ .getElementType ();
730
782
auto weightType =
731
- dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInputs ()[1 ].getType ());
783
+ dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInputs ()[1 ].getType ())
784
+ .getElementType ();
732
785
auto resultType =
733
- dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInits ()[0 ].getType ());
734
- // use shrink layout when it is able to be converted to brgemm
735
- bool useShrinkedLayout = (BInnermostDims.size () == 4 );
786
+ dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInits ()[0 ].getType ())
787
+ .getElementType ();
736
788
737
789
// update the extractSlice to static size, replace it with
738
790
// useBlockedLayout when
739
791
if (failed (setStaticSizeForExtractSliceOp (
740
792
rewriter, currentOp.getDpsInits ()[0 ].getDefiningOp (), true ,
741
- CInnermostDims, useShrinkedLayout ? 2 : 0 )) ||
793
+ CInnermostDims, MDimNum > 1 ? 2 : 0 )) ||
742
794
failed (setStaticSizeForExtractSliceOp (
743
795
rewriter, currentOp.getDpsInputs ()[1 ].getDefiningOp (), true ,
744
- BInnermostDims, useShrinkedLayout )) ||
796
+ BInnermostDims, NDimNum > 1 )) ||
745
797
failed (setStaticSizeForExtractSliceOp (
746
798
rewriter, currentOp.getDpsInputs ()[0 ].getDefiningOp (), true ,
747
- AInnermostDims, useShrinkedLayout ))) {
799
+ AInnermostDims, MDimNum > 1 ))) {
748
800
return failure ();
749
801
}
750
-
751
802
// View the tensor to brgemm required format
752
803
Value dataOprand = tensorViewRankedTensor (
753
804
rewriter,
754
805
mlir::RankedTensorType::get (
755
- useBlockedLayout
756
- ? SmallVector<int64_t >(AInnermostDims.begin () + 1 ,
757
- AInnermostDims.end ())
758
- : SmallVector<int64_t >{1 , AInnermostDims[0 ], AInnermostDims[1 ]},
759
- dataType.getElementType ()),
760
- currentOp.getDpsInputs ()[0 ]);
806
+ MDimNum > 1 ? SmallVector<int64_t >(AInnermostDims.begin () + 1 ,
807
+ AInnermostDims.end ())
808
+ : SmallVector<int64_t >{cfg.innerMostMBlock ,
809
+ cfg.KBlock / cfg.innerMostKBlock ,
810
+ cfg.innerMostKBlock },
811
+ dataType),
812
+ currentOp.getDpsInputs ()[0 ],
813
+ MDimNum == 1 ? SmallVector<int64_t >{1 , 0 , 2 } : SmallVector<int64_t >{});
761
814
Value weightOprand = tensorViewRankedTensor (
762
815
rewriter,
763
816
mlir::RankedTensorType::get (
764
- useBlockedLayout
765
- ? SmallVector<int64_t >(BInnermostDims.begin () + 1 ,
766
- BInnermostDims.end ())
767
- : SmallVector<int64_t >{1 , BInnermostDims[0 ], BInnermostDims[1 ]},
768
- weightType.getElementType ()),
817
+ NDimNum > 1 ? SmallVector<int64_t >(BInnermostDims.begin () + 1 ,
818
+ BInnermostDims.end ())
819
+ : SmallVector<int64_t >{cfg.KBlock / cfg.innerMostKBlock ,
820
+ cfg.innerMostKBlock ,
821
+ cfg.innerMostNBlock },
822
+ weightType),
769
823
currentOp.getDpsInputs ()[1 ]);
770
824
Value resultOprand = tensorViewRankedTensor (
771
825
rewriter,
772
826
mlir::RankedTensorType::get (
773
- SmallVector<int64_t >(CInnermostDims.begin () +
774
- (useBlockedLayout ? 2 : 0 ),
827
+ SmallVector<int64_t >(CInnermostDims.begin () + (MDimNum > 1 ? 2 : 0 ),
775
828
CInnermostDims.end ()),
776
- resultType. getElementType () ),
829
+ resultType),
777
830
currentOp.getDpsInits ()[0 ]);
778
-
779
831
// Create the brgemm op and replace the origin linalg op
780
832
linalg::LinalgOp matmul;
781
- if (BInnermostDims.size () == 4 || BInnermostDims.size () == 2 ) {
833
+ if (dyn_cast<mlir::RankedTensorType>(weightOprand.getType ())
834
+ .getShape ()
835
+ .size () == 3 ) {
782
836
matmul = rewriter.create <linalg::BatchReduceMatmulOp>(
783
837
resultOprand.getLoc (), resultOprand.getType (),
784
838
ValueRange{dataOprand, weightOprand}, resultOprand);
785
839
} else {
786
- IRMapping mapping;
787
840
matmul = rewriter.create <linalgx::BatchReduceMatmulVnniOp>(
788
841
resultOprand.getLoc (), resultOprand.getType (),
789
842
ValueRange{dataOprand, weightOprand}, resultOprand);
0 commit comments