36
36
37
37
#include < llvm/Support/Debug.h>
38
38
39
+ #include < iostream>
39
40
#include < memory>
40
41
41
42
namespace mlir {
@@ -45,16 +46,161 @@ namespace gc {
45
46
46
47
namespace {
47
48
49
+ struct FlashAttentionConfig {
50
+ int RowBlock, ColumnBlock;
51
+ };
52
+
53
+ static FlashAttentionConfig
54
+ getDefaultFlashAttentionConfig (linalgx::ScaledDotProductAttentionOp &sdpaOp) {
55
+ // TODO: allow tuning
56
+ auto seqLen = sdpaOp.getShape (sdpaOp.getDpsInputOperand (0 ))[2 ];
57
+ FlashAttentionConfig cfg;
58
+
59
+ // cfg.RowBlock = seqLen / 64;
60
+ // cfg.ColBlock = seqLen / 64;
61
+ return cfg;
62
+ }
63
+
48
64
struct MHAToFlashAttention
49
- : public OpInterfaceRewritePattern<linalg::LinalgOp> {
50
- using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
65
+ : public OpRewritePattern<linalgx::ScaledDotProductAttentionOp> {
66
+ using OpRewritePattern<
67
+ linalgx::ScaledDotProductAttentionOp>::OpRewritePattern;
68
+
69
+ struct OuterLoopGenerationResult {
70
+ // / Tiled operations that are generated during tiling. The order does not
71
+ // / matter except the last op. The replacements are expected to be the
72
+ // / results of the last op.
73
+ SmallVector<Operation *> tiledOps;
74
+ // / The `scf.for` operations that iterate over the tiles.
75
+ SmallVector<LoopLikeOpInterface> loops;
76
+ SmallVector<LoopLikeOpInterface> reductionLoops;
77
+ };
51
78
52
- LogicalResult matchAndRewrite (linalg::LinalgOp linalgOp,
79
+ // FailureOr<OuterLoopGenerationResult>
80
+ // outerLoopGeneration(RewriterBase &rewriter, linalg::LinalgOp linalgOp)
81
+ // const {
82
+ // SmallVector<unsigned> RowDimPos, ColDimPos;
83
+ // linalgOp.getReductionDims(KDimPos);
84
+ // getMatmulParallelDims(linalgOp, 0, MDimPos);
85
+ // getMatmulParallelDims(linalgOp, 1, NDimPos);
86
+
87
+ // OuterLoopGenerationOption option;
88
+ // auto iteratorTypes = linalgOp.getIteratorTypesArray();
89
+ // auto KFirstDim = (int)getOprandDim(linalgOp, KDimPos[0], 1);
90
+ // auto MFirstDim = (int)getOprandDim(linalgOp, MDimPos[0], 0);
91
+ // auto NFirstDim = (int)getOprandDim(linalgOp, NDimPos[0], 1);
92
+ // auto KParallelBlockSize =
93
+ // KDimPos.size() > 1
94
+ // ? divAndCeil(KFirstDim, cfg.KThreads)
95
+ // : divAndCeil(divAndCeil(KFirstDim, cfg.KBlock), cfg.KThreads) *
96
+ // cfg.KBlock;
97
+ // auto MParallelBlockSize =
98
+ // MDimPos.size() > 1
99
+ // ? divAndCeil(MFirstDim, cfg.MThreads)
100
+ // : divAndCeil(divAndCeil(MFirstDim, cfg.MBlock), cfg.MThreads) *
101
+ // cfg.MBlock;
102
+ // auto NParallelBlockSize =
103
+ // NDimPos.size() > 1
104
+ // ? divAndCeil(NFirstDim, cfg.NThreads)
105
+ // : divAndCeil(divAndCeil(NFirstDim, cfg.NBlock), cfg.NThreads) *
106
+ // cfg.NBlock;
107
+ // auto KOuterBlockSize = KDimPos.size() > 1
108
+ // ? (cfg.KBlock - 1) / cfg.innerMostKBlock + 1
109
+ // : cfg.KBlock;
110
+ // auto MOuterBlockSize = MDimPos.size() > 1
111
+ // ? (cfg.MBlock - 1) / cfg.innerMostMBlock + 1
112
+ // : cfg.MBlock;
113
+ // auto NOuterBlockSize = NDimPos.size() > 1
114
+ // ? (cfg.NBlock - 1) / cfg.innerMostNBlock + 1
115
+ // : cfg.NBlock;
116
+ // // Outer
117
+ // option.nestedTileSizes.emplace_back(SmallVector<int>{
118
+ // MParallelBlockSize, NParallelBlockSize, KParallelBlockSize});
119
+ // option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForallOp);
120
+ // option.loopDim.emplace_back(
121
+ // SmallVector<int>{(int)MDimPos[0], (int)NDimPos[0], (int)KDimPos[0]});
122
+ // // Middle
123
+ // for (auto [tile, dim] :
124
+ // llvm::zip(SmallVector<int>{MOuterBlockSize, NOuterBlockSize,
125
+ // KOuterBlockSize},
126
+ // SmallVector<int>{(int)MDimPos[0], (int)NDimPos[0],
127
+ // (int)KDimPos[0]})) {
128
+ // option.nestedTileSizes.emplace_back(SmallVector<int>{tile});
129
+ // option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
130
+ // option.loopDim.emplace_back(SmallVector<int>{dim});
131
+ // }
132
+ // // Inner
133
+ // if (KDimPos.size() == 1) {
134
+ // option.nestedTileSizes.emplace_back(SmallVector<int>{cfg.KBlock});
135
+ // option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
136
+ // option.loopDim.emplace_back(SmallVector<int>{(int)KDimPos.back()});
137
+ // }
138
+ // if (MDimPos.size() == 1) {
139
+ // option.nestedTileSizes.emplace_back(
140
+ // SmallVector<int>{cfg.innerMostMBlock});
141
+ // option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
142
+ // option.loopDim.emplace_back(SmallVector<int>{(int)MDimPos.back()});
143
+ // }
144
+ // if (NDimPos.size() == 1) {
145
+ // option.nestedTileSizes.emplace_back(
146
+ // SmallVector<int>{cfg.innerMostNBlock});
147
+ // option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
148
+ // option.loopDim.emplace_back(SmallVector<int>{(int)NDimPos.back()});
149
+ // }
150
+ // for (auto dim = 0UL; dim < linalgOp.getNumLoops(); dim++) {
151
+ // if (dim != MDimPos.back() && dim != NDimPos.back() &&
152
+ // iteratorTypes[dim] != mlir::utils::IteratorType::reduction) {
153
+ // option.nestedTileSizes.emplace_back(SmallVector<int>{1});
154
+ // option.loopType.emplace_back(
155
+ // OuterLoopGenerationOption::LoopType::ForOp);
156
+ // option.loopDim.emplace_back(SmallVector<int>{(int)dim});
157
+ // }
158
+ // }
159
+
160
+ // auto lowPrecisionCast =
161
+ // [&](RewriterBase &rewriter, Location loc,
162
+ // linalg::LinalgOp linalgop) -> FailureOr<linalg::LinalgOp> {
163
+ // auto legalizedResult = matmulDtypeLegalize(
164
+ // rewriter, linalgop.getOperation(), !hasFillOp, true);
165
+ // if (legalizedResult->castOp && legalizedResult->linalgOp) {
166
+ // auto linalgOp = legalizedResult->linalgOp;
167
+ // rewriter.replaceOp(linalgop,
168
+ // linalgOp->getResult(linalgOp->getNumResults() -
169
+ // 1));
170
+ // return dyn_cast<linalg::LinalgOp>(linalgOp);
171
+ // }
172
+ // return failure();
173
+ // };
174
+ // option.innermostFullResultCallBacks.push_back(lowPrecisionCast);
175
+
176
+ // if (hasFillOp) {
177
+ // auto removeReduncantFill =
178
+ // [&](RewriterBase &rewriter, Location loc,
179
+ // const linalg::ForallReductionTilingResult &result)
180
+ // -> FailureOr<linalg::LinalgOp> {
181
+ // auto initValue = result.initialValues;
182
+ // if (initValue.size() == 1 &&
183
+ // isa<linalg::FillOp>(initValue[0].getDefiningOp())) {
184
+ // rewriter.replaceOp(initValue[0].getDefiningOp(),
185
+ // dyn_cast<DestinationStyleOpInterface>(
186
+ // initValue[0].getDefiningOp())
187
+ // .getDpsInits()[0]);
188
+ // }
189
+ // return dyn_cast<linalg::LinalgOp>(result.parallelTiledOp);
190
+ // };
191
+ // option.finalReduceCallBacks.push_back(removeReduncantFill);
192
+ // }
193
+ // return generateOuterLoop(rewriter, linalgOp, option);
194
+ // }
195
+
196
+ LogicalResult matchAndRewrite (linalgx::ScaledDotProductAttentionOp sdpaOp,
53
197
PatternRewriter &rewriter) const override {
54
- if (!llvm::isa<linalgx::ScaledDotProductAttentionOp>(linalgOp))
55
- return failure ();
56
- if (linalgOp.hasPureBufferSemantics ())
57
- return failure ();
198
+ auto decomposableOp =
199
+ dyn_cast<mlir::linalg::AggregatedOpInterface>(sdpaOp.getOperation ());
200
+ FailureOr<SmallVector<Value>> maybeNewResults =
201
+ decomposableOp.decomposeOperation (rewriter);
202
+ rewriter.replaceOp (decomposableOp, *maybeNewResults);
203
+ return success ();
58
204
}
59
205
};
60
206
@@ -65,19 +211,7 @@ struct FlashAttentionConversion
65
211
auto &ctx = getContext ();
66
212
IRRewriter rewriter (&ctx);
67
213
RewritePatternSet patterns (&ctx);
68
-
69
214
patterns.add <MHAToFlashAttention>(patterns.getContext ());
70
- // linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
71
- // linalg::ControlDropUnitDims options;
72
- // options.rankReductionStrategy =
73
- // linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice;
74
- // linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
75
- // tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
76
-
77
- // for (auto *dialect : ctx.getLoadedDialects())
78
- // dialect->getCanonicalizationPatterns(patterns);
79
- // for (RegisteredOperationName op : ctx.getRegisteredOperations())
80
- // op.getCanonicalizationPatterns(patterns, &ctx);
81
215
if (failed (applyPatternsAndFoldGreedily (getOperation (),
82
216
std::move (patterns)))) {
83
217
return signalPassFailure ();
0 commit comments