Skip to content

Commit 14f4918

Browse files
committed
support 2Dx4D/5D case
1 parent 9af3f96 commit 14f4918

File tree

3 files changed

+159
-111
lines changed

3 files changed

+159
-111
lines changed

lib/gc/Transforms/DeepTileContractionNamedOp.cpp

Lines changed: 127 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,14 @@ MatmulConfig getDefaultMatmulConfig(linalg::LinalgOp &linalgOp) {
136136
cfg.KBlock = 64;
137137
cfg.MThreads = 2;
138138
cfg.NThreads = 2;
139-
cfg.KThreads = 2;
139+
cfg.KThreads = 1;
140140
return cfg;
141141
}
142142

143-
static Value tensorViewRankedTensor(RewriterBase &rewriter,
144-
RankedTensorType outTensorType,
145-
Value value) {
143+
static Value
144+
tensorViewRankedTensor(RewriterBase &rewriter, RankedTensorType outTensorType,
145+
Value value,
146+
ArrayRef<int64_t> permutation = SmallVector<int64_t>{}) {
146147
// TODO: add support for plain layout transpose
147148
Value result, currentValue = value;
148149
auto loc = currentValue.getLoc();
@@ -175,33 +176,57 @@ static Value tensorViewRankedTensor(RewriterBase &rewriter,
175176

176177
if (outShape.size() < inShape.size()) {
177178
SmallVector<ReassociationIndices> reassocIndices;
178-
ReassociationIndices firstEntry;
179-
for (auto i = 0UL; i < inShape.size() - outShape.size() + 1; i++) {
180-
firstEntry.push_back(i);
181-
}
182-
reassocIndices.push_back(firstEntry);
183-
for (auto i = inShape.size() - outShape.size() + 1UL; i < inShape.size();
184-
i++) {
185-
reassocIndices.push_back({(int)i});
179+
uint64_t outIdx = 0UL, inIdx = 0UL;
180+
while (inIdx < inShape.size() && outIdx < outShape.size()) {
181+
ReassociationIndices firstEntry;
182+
auto remaining = outShape[outIdx++];
183+
if (remaining == 1) {
184+
firstEntry.push_back(inIdx++);
185+
reassocIndices.push_back(firstEntry);
186+
continue;
187+
}
188+
while (remaining > 1) {
189+
remaining /= inShape[inIdx];
190+
firstEntry.push_back(inIdx++);
191+
}
192+
reassocIndices.push_back(firstEntry);
186193
}
187194
result = rewriter.create<tensor::CollapseShapeOp>(
188195
loc, outTensorType, currentValue, reassocIndices);
189196
} else if (outShape.size() > inShape.size()) {
190197
SmallVector<ReassociationIndices> reassocIndices;
191-
ReassociationIndices firstEntry;
192-
for (auto i = 0UL; i < outShape.size() - inShape.size() + 1; i++) {
193-
firstEntry.push_back((int)i);
194-
}
195-
reassocIndices.push_back(firstEntry);
196-
for (auto i = outShape.size() - inShape.size() + 1UL; i < outShape.size();
197-
i++) {
198-
reassocIndices.push_back({(int)i});
198+
uint64_t outIdx = 0UL, inIdx = 0UL;
199+
while (outIdx < outShape.size() && inIdx < inShape.size()) {
200+
ReassociationIndices firstEntry;
201+
auto remaining = inShape[inIdx++];
202+
if (remaining == 1) {
203+
firstEntry.push_back(outIdx++);
204+
reassocIndices.push_back(firstEntry);
205+
continue;
206+
}
207+
while (remaining > 1) {
208+
remaining /= outShape[outIdx];
209+
firstEntry.push_back(outIdx++);
210+
}
211+
reassocIndices.push_back(firstEntry);
199212
}
200213
result = rewriter.create<tensor::ExpandShapeOp>(
201214
loc, outTensorType, currentValue, reassocIndices);
202215
} else {
203216
result = rewriter.create<tensor::CastOp>(loc, outTensorType, currentValue);
204217
}
218+
219+
if (!permutation.empty()) {
220+
SmallVector<int64_t> transposeShape;
221+
for (auto idx : permutation) {
222+
transposeShape.push_back(outShape[idx]);
223+
}
224+
auto initOp = rewriter.create<tensor::EmptyOp>(loc, transposeShape,
225+
tensorElementType);
226+
auto transposeOp = rewriter.create<linalg::TransposeOp>(
227+
loc, result, initOp->getResult(0), permutation);
228+
result = transposeOp->getResult(0);
229+
}
205230
return result;
206231
}
207232

@@ -345,6 +370,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
345370
return b.notifyMatchFailure(
346371
linalgOp, "currentOp should not has pure buffer semantics");
347372
linalg::LinalgOp currentOp = linalgOp;
373+
348374
for (auto loopTypeIter : llvm::enumerate(loopType)) {
349375
auto [i, loopType] = loopTypeIter;
350376
auto currentDim = loopDim[i];
@@ -486,6 +512,8 @@ static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
486512
bool isExtract,
487513
SmallVector<int64_t> size,
488514
int shrinDimNum = 0) {
515+
OpBuilder::InsertionGuard guard(rewriter);
516+
rewriter.setInsertionPoint(op);
489517
if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(op)) {
490518
SmallVector<OpFoldResult> mixedOffsets = extractSlice.getMixedOffsets();
491519
SmallVector<OpFoldResult> mixedSizes = extractSlice.getMixedSizes();
@@ -514,6 +542,8 @@ static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
514542
static LogicalResult setStaticSizeForInsertSliceOp(RewriterBase &rewriter,
515543
Operation *op, Value source,
516544
SmallVector<int64_t> size) {
545+
OpBuilder::InsertionGuard guard(rewriter);
546+
rewriter.setInsertionPoint(op);
517547
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(op)) {
518548
SmallVector<OpFoldResult> mixedOffsets = insertSlice.getMixedOffsets();
519549
SmallVector<OpFoldResult> mixedSizes = insertSlice.getMixedSizes();
@@ -575,35 +605,34 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
575605
linalgOp.getReductionDims(KDimPos);
576606
getMatmulParallelDims(linalgOp, 0, MDimPos);
577607
getMatmulParallelDims(linalgOp, 1, NDimPos);
578-
bool useBlockedLayout = KDimPos.size() > 1;
579608

580609
OuterLoopGenerationOption option;
581610
auto iteratorTypes = linalgOp.getIteratorTypesArray();
582611
auto KFirstDim = (int)getOprandDim(linalgOp, KDimPos[0], 1);
583612
auto MFirstDim = (int)getOprandDim(linalgOp, MDimPos[0], 0);
584613
auto NFirstDim = (int)getOprandDim(linalgOp, NDimPos[0], 1);
585614
auto KParallelBlockSize =
586-
useBlockedLayout
615+
KDimPos.size() > 1
587616
? divAndCeil(KFirstDim, cfg.KThreads)
588617
: divAndCeil(divAndCeil(KFirstDim, cfg.KBlock), cfg.KThreads) *
589618
cfg.KBlock;
590619
auto MParallelBlockSize =
591-
useBlockedLayout
620+
MDimPos.size() > 1
592621
? divAndCeil(MFirstDim, cfg.MThreads)
593622
: divAndCeil(divAndCeil(MFirstDim, cfg.MBlock), cfg.MThreads) *
594623
cfg.MBlock;
595624
auto NParallelBlockSize =
596-
useBlockedLayout
625+
NDimPos.size() > 1
597626
? divAndCeil(NFirstDim, cfg.NThreads)
598627
: divAndCeil(divAndCeil(NFirstDim, cfg.NBlock), cfg.NThreads) *
599628
cfg.NBlock;
600-
auto KOuterBlockSize = useBlockedLayout
629+
auto KOuterBlockSize = KDimPos.size() > 1
601630
? (cfg.KBlock - 1) / cfg.innerMostKBlock + 1
602631
: cfg.KBlock;
603-
auto MOuterBlockSize = useBlockedLayout
632+
auto MOuterBlockSize = MDimPos.size() > 1
604633
? (cfg.MBlock - 1) / cfg.innerMostMBlock + 1
605634
: cfg.MBlock;
606-
auto NOuterBlockSize = useBlockedLayout
635+
auto NOuterBlockSize = NDimPos.size() > 1
607636
? (cfg.NBlock - 1) / cfg.innerMostNBlock + 1
608637
: cfg.NBlock;
609638
// Outer
@@ -631,11 +660,23 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
631660
option.loopDim.emplace_back(SmallVector<int>{dim});
632661
}
633662
// Inner
634-
if (!useBlockedLayout) {
663+
if (KDimPos.size() == 1) {
635664
option.nestedTileSizes.emplace_back(SmallVector<int>{cfg.KBlock});
636665
option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
637666
option.loopDim.emplace_back(SmallVector<int>{(int)KDimPos.back()});
638667
}
668+
if (MDimPos.size() == 1) {
669+
option.nestedTileSizes.emplace_back(
670+
SmallVector<int>{cfg.innerMostMBlock});
671+
option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
672+
option.loopDim.emplace_back(SmallVector<int>{(int)MDimPos.back()});
673+
}
674+
if (NDimPos.size() == 1) {
675+
option.nestedTileSizes.emplace_back(
676+
SmallVector<int>{cfg.innerMostNBlock});
677+
option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
678+
option.loopDim.emplace_back(SmallVector<int>{(int)NDimPos.back()});
679+
}
639680
for (auto dim = 0UL; dim < linalgOp.getNumLoops(); dim++) {
640681
if (dim != MDimPos.back() && dim != NDimPos.back() &&
641682
iteratorTypes[dim] != mlir::utils::IteratorType::reduction) {
@@ -658,17 +699,24 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
658699
linalg::LinalgOp originOp,
659700
linalg::LinalgOp currentOp,
660701
innerBodyGenerationOption &option) const {
702+
661703
mlir::easybuild::EasyBuilder eb{rewriter, originOp.getLoc()};
662704
auto operandDimTypes = getOprandDimType(originOp);
663705
MatmulConfig cfg = getDefaultMatmulConfig(originOp);
664706
auto AShape = originOp.getShape(originOp.getDpsInputOperand(0));
665707
auto BShape = originOp.getShape(originOp.getDpsInputOperand(1));
666708
auto CShape = originOp.getShape(originOp.getDpsInitOperand(0));
667-
bool useBlockedLayout = BShape.size() > 2;
709+
710+
auto MDimNum = std::count_if((*operandDimTypes)[0].begin(),
711+
(*operandDimTypes)[0].end(),
712+
[](DimType d) { return d == DimType::M; });
713+
auto NDimNum = std::count_if((*operandDimTypes)[1].begin(),
714+
(*operandDimTypes)[1].end(),
715+
[](DimType d) { return d == DimType::N; });
668716
// TODO: support plain in/block out format
669717
SmallVector<int64_t> AInnermostDims, BInnermostDims, CInnermostDims;
670-
if (useBlockedLayout) {
671-
bool firstM = true, firstK = true, firstN = true;
718+
bool firstM = true, firstK = true, firstN = true;
719+
if (MDimNum > 1) {
672720
for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[0])) {
673721
if (iter == DimType::M && firstM) {
674722
AInnermostDims.push_back(1);
@@ -682,21 +730,6 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
682730
AInnermostDims.push_back(AShape[idx]);
683731
}
684732
}
685-
firstN = true;
686-
firstK = true;
687-
for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[1])) {
688-
if (iter == DimType::N && firstN) {
689-
BInnermostDims.push_back(1);
690-
firstN = false;
691-
} else if (iter == DimType::Batch) {
692-
BInnermostDims.push_back(1);
693-
} else if (iter == DimType::K && firstK) {
694-
BInnermostDims.push_back(cfg.KBlock / cfg.innerMostKBlock);
695-
firstK = false;
696-
} else {
697-
BInnermostDims.push_back(BShape[idx]);
698-
}
699-
}
700733
firstM = true;
701734
firstN = true;
702735
for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[2])) {
@@ -716,74 +749,94 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
716749
AInnermostDims = SmallVector<int64_t>{cfg.innerMostMBlock,
717750
cfg.KBlock / cfg.innerMostKBlock *
718751
cfg.innerMostKBlock};
752+
CInnermostDims =
753+
SmallVector<int64_t>{cfg.innerMostMBlock, cfg.innerMostNBlock};
754+
}
755+
if (NDimNum > 1) {
756+
firstN = true;
757+
firstK = true;
758+
for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[1])) {
759+
if (iter == DimType::N && firstN) {
760+
BInnermostDims.push_back(1);
761+
firstN = false;
762+
} else if (iter == DimType::Batch) {
763+
BInnermostDims.push_back(1);
764+
} else if (iter == DimType::K && firstK) {
765+
BInnermostDims.push_back(cfg.KBlock / cfg.innerMostKBlock);
766+
firstK = false;
767+
} else {
768+
BInnermostDims.push_back(BShape[idx]);
769+
}
770+
}
771+
} else {
719772
BInnermostDims = SmallVector<int64_t>{cfg.KBlock / cfg.innerMostKBlock *
720773
cfg.innerMostKBlock,
721774
cfg.innerMostNBlock};
722-
CInnermostDims =
723-
SmallVector<int64_t>{cfg.innerMostMBlock, cfg.innerMostNBlock};
724775
}
725776

726777
OpBuilder::InsertionGuard guard(rewriter);
727778
rewriter.setInsertionPoint(currentOp);
728779
auto dataType =
729-
dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInputs()[0].getType());
780+
dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInputs()[0].getType())
781+
.getElementType();
730782
auto weightType =
731-
dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInputs()[1].getType());
783+
dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInputs()[1].getType())
784+
.getElementType();
732785
auto resultType =
733-
dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInits()[0].getType());
734-
// use shrink layout when it is able to be converted to brgemm
735-
bool useShrinkedLayout = (BInnermostDims.size() == 4);
786+
dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInits()[0].getType())
787+
.getElementType();
736788

737789
// update the extractSlice to static size, replace it with
738790
// useBlockedLayout when
739791
if (failed(setStaticSizeForExtractSliceOp(
740792
rewriter, currentOp.getDpsInits()[0].getDefiningOp(), true,
741-
CInnermostDims, useShrinkedLayout ? 2 : 0)) ||
793+
CInnermostDims, MDimNum > 1 ? 2 : 0)) ||
742794
failed(setStaticSizeForExtractSliceOp(
743795
rewriter, currentOp.getDpsInputs()[1].getDefiningOp(), true,
744-
BInnermostDims, useShrinkedLayout)) ||
796+
BInnermostDims, NDimNum > 1)) ||
745797
failed(setStaticSizeForExtractSliceOp(
746798
rewriter, currentOp.getDpsInputs()[0].getDefiningOp(), true,
747-
AInnermostDims, useShrinkedLayout))) {
799+
AInnermostDims, MDimNum > 1))) {
748800
return failure();
749801
}
750-
751802
// View the tensor to brgemm required format
752803
Value dataOprand = tensorViewRankedTensor(
753804
rewriter,
754805
mlir::RankedTensorType::get(
755-
useBlockedLayout
756-
? SmallVector<int64_t>(AInnermostDims.begin() + 1,
757-
AInnermostDims.end())
758-
: SmallVector<int64_t>{1, AInnermostDims[0], AInnermostDims[1]},
759-
dataType.getElementType()),
760-
currentOp.getDpsInputs()[0]);
806+
MDimNum > 1 ? SmallVector<int64_t>(AInnermostDims.begin() + 1,
807+
AInnermostDims.end())
808+
: SmallVector<int64_t>{cfg.innerMostMBlock,
809+
cfg.KBlock / cfg.innerMostKBlock,
810+
cfg.innerMostKBlock},
811+
dataType),
812+
currentOp.getDpsInputs()[0],
813+
MDimNum == 1 ? SmallVector<int64_t>{1, 0, 2} : SmallVector<int64_t>{});
761814
Value weightOprand = tensorViewRankedTensor(
762815
rewriter,
763816
mlir::RankedTensorType::get(
764-
useBlockedLayout
765-
? SmallVector<int64_t>(BInnermostDims.begin() + 1,
766-
BInnermostDims.end())
767-
: SmallVector<int64_t>{1, BInnermostDims[0], BInnermostDims[1]},
768-
weightType.getElementType()),
817+
NDimNum > 1 ? SmallVector<int64_t>(BInnermostDims.begin() + 1,
818+
BInnermostDims.end())
819+
: SmallVector<int64_t>{cfg.KBlock / cfg.innerMostKBlock,
820+
cfg.innerMostKBlock,
821+
cfg.innerMostNBlock},
822+
weightType),
769823
currentOp.getDpsInputs()[1]);
770824
Value resultOprand = tensorViewRankedTensor(
771825
rewriter,
772826
mlir::RankedTensorType::get(
773-
SmallVector<int64_t>(CInnermostDims.begin() +
774-
(useBlockedLayout ? 2 : 0),
827+
SmallVector<int64_t>(CInnermostDims.begin() + (MDimNum > 1 ? 2 : 0),
775828
CInnermostDims.end()),
776-
resultType.getElementType()),
829+
resultType),
777830
currentOp.getDpsInits()[0]);
778-
779831
// Create the brgemm op and replace the origin linalg op
780832
linalg::LinalgOp matmul;
781-
if (BInnermostDims.size() == 4 || BInnermostDims.size() == 2) {
833+
if (dyn_cast<mlir::RankedTensorType>(weightOprand.getType())
834+
.getShape()
835+
.size() == 3) {
782836
matmul = rewriter.create<linalg::BatchReduceMatmulOp>(
783837
resultOprand.getLoc(), resultOprand.getType(),
784838
ValueRange{dataOprand, weightOprand}, resultOprand);
785839
} else {
786-
IRMapping mapping;
787840
matmul = rewriter.create<linalgx::BatchReduceMatmulVnniOp>(
788841
resultOprand.getLoc(), resultOprand.getType(),
789842
ValueRange{dataOprand, weightOprand}, resultOprand);

0 commit comments

Comments
 (0)