Skip to content

Commit 738986c

Browse files
committed
temp cache
1 parent 8648723 commit 738986c

File tree

2 files changed

+169
-20
lines changed

2 files changed

+169
-20
lines changed

lib/gc/Dialect/Linalgx/LinalgxOps.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -635,9 +635,12 @@ ScaledDotProductAttentionOp::decomposeOperation(OpBuilder &b) {
635635
mask = getInputs()[3];
636636
auto dtype = cast<RankedTensorType>(query.getType()).getElementType();
637637
auto shape = cast<RankedTensorType>(query.getType()).getShape();
638+
float rsqrt_head = 1 / sqrt(shape[3]);
638639

639640
SmallVector<int64_t> permutation{0, 1, 3, 2};
640641
SmallVector<int64_t> transposeShape{shape[0], shape[1], shape[3], shape[2]};
642+
auto constant =
643+
b.create<arith::ConstantOp>(loc, b.getFloatAttr(dtype, rsqrt_head));
641644
auto transposeOut = b.create<tensor::EmptyOp>(loc, transposeShape, dtype);
642645
auto transpose = b.create<linalg::TransposeOp>(
643646
/*location=*/loc,
@@ -652,16 +655,28 @@ ScaledDotProductAttentionOp::decomposeOperation(OpBuilder &b) {
652655
/*inputs=*/ValueRange{query, transpose->getResult(0)},
653656
/*outputs=*/ValueRange{matmulQKOut.getResult()});
654657

658+
auto mulOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
659+
auto mul = b.create<linalg::GenericOp>(
660+
/*location=*/loc, matmulQKOut.getResult().getType(),
661+
/*inputs=*/ValueRange{query, transpose->getResult(0)},
662+
/*outputs=*/ValueRange{matmulQKOut.getResult()});
663+
655664
auto addOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
656665
auto add = b.create<linalg::AddOp>(
657666
/*location=*/loc, addOut.getResult().getType(),
658667
/*inputs=*/ValueRange{matmulQK->getResult(0), mask},
659668
/*outputs=*/ValueRange{addOut.getResult()});
660669

670+
auto softmaxOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
671+
auto softmax = b.create<linalg::SoftmaxOp>(
672+
/*location=*/loc, softmaxOut.getResult().getType(),
673+
/*inputs=*/add->getResult(0),
674+
/*outputs=*/softmaxOut.getResult(), 3);
675+
661676
auto matmulVOut = b.create<tensor::EmptyOp>(loc, shape, dtype);
662677
auto matmulV = b.create<linalgx::MultiBatchMatmulOp>(
663678
/*location=*/loc, matmulVOut.getResult().getType(),
664-
/*inputs=*/ValueRange{add->getResult(0), value},
679+
/*inputs=*/ValueRange{softmax->getResult(0), value},
665680
/*outputs=*/ValueRange{matmulVOut.getResult()});
666681
return SmallVector<Value>{matmulV.getResults()[0]};
667682
}

lib/gc/Transforms/FlashAttentionConversion.cpp

Lines changed: 153 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
#include <llvm/Support/Debug.h>
3838

39+
#include <iostream>
3940
#include <memory>
4041

4142
namespace mlir {
@@ -45,16 +46,161 @@ namespace gc {
4546

4647
namespace {
4748

49+
struct FlashAttentionConfig {
50+
int RowBlock, ColumnBlock;
51+
};
52+
53+
static FlashAttentionConfig
54+
getDefaultFlashAttentionConfig(linalgx::ScaledDotProductAttentionOp &sdpaOp) {
55+
// TODO: allow tuning
56+
auto seqLen = sdpaOp.getShape(sdpaOp.getDpsInputOperand(0))[2];
57+
FlashAttentionConfig cfg;
58+
59+
// cfg.RowBlock = seqLen / 64;
60+
// cfg.ColBlock = seqLen / 64;
61+
return cfg;
62+
}
63+
4864
struct MHAToFlashAttention
49-
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
50-
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
65+
: public OpRewritePattern<linalgx::ScaledDotProductAttentionOp> {
66+
using OpRewritePattern<
67+
linalgx::ScaledDotProductAttentionOp>::OpRewritePattern;
68+
69+
struct OuterLoopGenerationResult {
70+
/// Tiled operations that are generated during tiling. The order does not
71+
/// matter except the last op. The replacements are expected to be the
72+
/// results of the last op.
73+
SmallVector<Operation *> tiledOps;
74+
/// The `scf.for` operations that iterate over the tiles.
75+
SmallVector<LoopLikeOpInterface> loops;
76+
SmallVector<LoopLikeOpInterface> reductionLoops;
77+
};
5178

52-
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
79+
// FailureOr<OuterLoopGenerationResult>
80+
// outerLoopGeneration(RewriterBase &rewriter, linalg::LinalgOp linalgOp)
81+
// const {
82+
// SmallVector<unsigned> RowDimPos, ColDimPos;
83+
// linalgOp.getReductionDims(KDimPos);
84+
// getMatmulParallelDims(linalgOp, 0, MDimPos);
85+
// getMatmulParallelDims(linalgOp, 1, NDimPos);
86+
87+
// OuterLoopGenerationOption option;
88+
// auto iteratorTypes = linalgOp.getIteratorTypesArray();
89+
// auto KFirstDim = (int)getOprandDim(linalgOp, KDimPos[0], 1);
90+
// auto MFirstDim = (int)getOprandDim(linalgOp, MDimPos[0], 0);
91+
// auto NFirstDim = (int)getOprandDim(linalgOp, NDimPos[0], 1);
92+
// auto KParallelBlockSize =
93+
// KDimPos.size() > 1
94+
// ? divAndCeil(KFirstDim, cfg.KThreads)
95+
// : divAndCeil(divAndCeil(KFirstDim, cfg.KBlock), cfg.KThreads) *
96+
// cfg.KBlock;
97+
// auto MParallelBlockSize =
98+
// MDimPos.size() > 1
99+
// ? divAndCeil(MFirstDim, cfg.MThreads)
100+
// : divAndCeil(divAndCeil(MFirstDim, cfg.MBlock), cfg.MThreads) *
101+
// cfg.MBlock;
102+
// auto NParallelBlockSize =
103+
// NDimPos.size() > 1
104+
// ? divAndCeil(NFirstDim, cfg.NThreads)
105+
// : divAndCeil(divAndCeil(NFirstDim, cfg.NBlock), cfg.NThreads) *
106+
// cfg.NBlock;
107+
// auto KOuterBlockSize = KDimPos.size() > 1
108+
// ? (cfg.KBlock - 1) / cfg.innerMostKBlock + 1
109+
// : cfg.KBlock;
110+
// auto MOuterBlockSize = MDimPos.size() > 1
111+
// ? (cfg.MBlock - 1) / cfg.innerMostMBlock + 1
112+
// : cfg.MBlock;
113+
// auto NOuterBlockSize = NDimPos.size() > 1
114+
// ? (cfg.NBlock - 1) / cfg.innerMostNBlock + 1
115+
// : cfg.NBlock;
116+
// // Outer
117+
// option.nestedTileSizes.emplace_back(SmallVector<int>{
118+
// MParallelBlockSize, NParallelBlockSize, KParallelBlockSize});
119+
// option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForallOp);
120+
// option.loopDim.emplace_back(
121+
// SmallVector<int>{(int)MDimPos[0], (int)NDimPos[0], (int)KDimPos[0]});
122+
// // Middle
123+
// for (auto [tile, dim] :
124+
// llvm::zip(SmallVector<int>{MOuterBlockSize, NOuterBlockSize,
125+
// KOuterBlockSize},
126+
// SmallVector<int>{(int)MDimPos[0], (int)NDimPos[0],
127+
// (int)KDimPos[0]})) {
128+
// option.nestedTileSizes.emplace_back(SmallVector<int>{tile});
129+
// option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
130+
// option.loopDim.emplace_back(SmallVector<int>{dim});
131+
// }
132+
// // Inner
133+
// if (KDimPos.size() == 1) {
134+
// option.nestedTileSizes.emplace_back(SmallVector<int>{cfg.KBlock});
135+
// option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
136+
// option.loopDim.emplace_back(SmallVector<int>{(int)KDimPos.back()});
137+
// }
138+
// if (MDimPos.size() == 1) {
139+
// option.nestedTileSizes.emplace_back(
140+
// SmallVector<int>{cfg.innerMostMBlock});
141+
// option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
142+
// option.loopDim.emplace_back(SmallVector<int>{(int)MDimPos.back()});
143+
// }
144+
// if (NDimPos.size() == 1) {
145+
// option.nestedTileSizes.emplace_back(
146+
// SmallVector<int>{cfg.innerMostNBlock});
147+
// option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
148+
// option.loopDim.emplace_back(SmallVector<int>{(int)NDimPos.back()});
149+
// }
150+
// for (auto dim = 0UL; dim < linalgOp.getNumLoops(); dim++) {
151+
// if (dim != MDimPos.back() && dim != NDimPos.back() &&
152+
// iteratorTypes[dim] != mlir::utils::IteratorType::reduction) {
153+
// option.nestedTileSizes.emplace_back(SmallVector<int>{1});
154+
// option.loopType.emplace_back(
155+
// OuterLoopGenerationOption::LoopType::ForOp);
156+
// option.loopDim.emplace_back(SmallVector<int>{(int)dim});
157+
// }
158+
// }
159+
160+
// auto lowPrecisionCast =
161+
// [&](RewriterBase &rewriter, Location loc,
162+
// linalg::LinalgOp linalgop) -> FailureOr<linalg::LinalgOp> {
163+
// auto legalizedResult = matmulDtypeLegalize(
164+
// rewriter, linalgop.getOperation(), !hasFillOp, true);
165+
// if (legalizedResult->castOp && legalizedResult->linalgOp) {
166+
// auto linalgOp = legalizedResult->linalgOp;
167+
// rewriter.replaceOp(linalgop,
168+
// linalgOp->getResult(linalgOp->getNumResults() -
169+
// 1));
170+
// return dyn_cast<linalg::LinalgOp>(linalgOp);
171+
// }
172+
// return failure();
173+
// };
174+
// option.innermostFullResultCallBacks.push_back(lowPrecisionCast);
175+
176+
// if (hasFillOp) {
177+
// auto removeReduncantFill =
178+
// [&](RewriterBase &rewriter, Location loc,
179+
// const linalg::ForallReductionTilingResult &result)
180+
// -> FailureOr<linalg::LinalgOp> {
181+
// auto initValue = result.initialValues;
182+
// if (initValue.size() == 1 &&
183+
// isa<linalg::FillOp>(initValue[0].getDefiningOp())) {
184+
// rewriter.replaceOp(initValue[0].getDefiningOp(),
185+
// dyn_cast<DestinationStyleOpInterface>(
186+
// initValue[0].getDefiningOp())
187+
// .getDpsInits()[0]);
188+
// }
189+
// return dyn_cast<linalg::LinalgOp>(result.parallelTiledOp);
190+
// };
191+
// option.finalReduceCallBacks.push_back(removeReduncantFill);
192+
// }
193+
// return generateOuterLoop(rewriter, linalgOp, option);
194+
// }
195+
196+
LogicalResult matchAndRewrite(linalgx::ScaledDotProductAttentionOp sdpaOp,
53197
PatternRewriter &rewriter) const override {
54-
if (!llvm::isa<linalgx::ScaledDotProductAttentionOp>(linalgOp))
55-
return failure();
56-
if (linalgOp.hasPureBufferSemantics())
57-
return failure();
198+
auto decomposableOp =
199+
dyn_cast<mlir::linalg::AggregatedOpInterface>(sdpaOp.getOperation());
200+
FailureOr<SmallVector<Value>> maybeNewResults =
201+
decomposableOp.decomposeOperation(rewriter);
202+
rewriter.replaceOp(decomposableOp, *maybeNewResults);
203+
return success();
58204
}
59205
};
60206

@@ -65,19 +211,7 @@ struct FlashAttentionConversion
65211
auto &ctx = getContext();
66212
IRRewriter rewriter(&ctx);
67213
RewritePatternSet patterns(&ctx);
68-
69214
patterns.add<MHAToFlashAttention>(patterns.getContext());
70-
// linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
71-
// linalg::ControlDropUnitDims options;
72-
// options.rankReductionStrategy =
73-
// linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice;
74-
// linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
75-
// tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
76-
77-
// for (auto *dialect : ctx.getLoadedDialects())
78-
// dialect->getCanonicalizationPatterns(patterns);
79-
// for (RegisteredOperationName op : ctx.getRegisteredOperations())
80-
// op.getCanonicalizationPatterns(patterns, &ctx);
81215
if (failed(applyPatternsAndFoldGreedily(getOperation(),
82216
std::move(patterns)))) {
83217
return signalPassFailure();

0 commit comments

Comments
 (0)