@@ -73,64 +73,33 @@ struct MHAToFlashAttention
73
73
FlashAttentionConfig cfg = getDefaultFlashAttentionConfig (sdpaOp);
74
74
Location loc = sdpaOp.getLoc ();
75
75
OpBuilder::InsertionGuard guard (rewriter);
76
- SmallVector<LoopLikeOpInterface> loops;
77
- SmallVector<Value> ivs;
78
76
auto shape =
79
77
dyn_cast<RankedTensorType>(sdpaOp.getOperand (0 ).getType ()).getShape ();
80
78
auto dtype = dyn_cast<RankedTensorType>(sdpaOp.getOperand (0 ).getType ())
81
79
.getElementType ();
82
80
int64_t seqLen = shape[2 ], headDim = shape[3 ];
83
81
auto Q = sdpaOp.getOperand (0 ), K = sdpaOp.getOperand (1 ),
84
82
V = sdpaOp.getOperand (2 ), mask = sdpaOp.getOperand (3 );
85
- // construct 2 parallel outermost loops for batchSize and numHeads
86
- SmallVector<Range> loopRanges;
87
- for (size_t i = 0 ; i < 2 ; ++i) {
88
- Range curRange;
89
- curRange.offset = getAsIndexOpFoldResult (rewriter.getContext (), 0UL );
90
- curRange.size = getAsIndexOpFoldResult (rewriter.getContext (), shape[i]);
91
- curRange.stride = getAsIndexOpFoldResult (rewriter.getContext (), 1UL );
92
- loopRanges.push_back (curRange);
93
- }
94
- SmallVector<OpFoldResult> tileSizes (
95
- 2 , getAsIndexOpFoldResult (rewriter.getContext (), 1UL ));
83
+ // construct 3 parallel outermost loops for
84
+ // batchSize/numHeads/(seqLen/rowBlockSize)
96
85
SmallVector<Value> destinationTensors;
97
86
tensor::getOrCreateDestinations (rewriter, sdpaOp.getLoc (), sdpaOp,
98
87
destinationTensors);
99
- for (auto [loopRange, tileSize] : llvm::zip_equal (loopRanges, tileSizes)) {
100
- Value lb =
101
- getValueOrCreateConstantIndexOp (rewriter, loc, loopRange.offset );
102
- Value ub = getValueOrCreateConstantIndexOp (rewriter, loc, loopRange.size );
103
- Value step = getValueOrCreateConstantIndexOp (rewriter, loc, tileSize);
104
- auto loop = rewriter.create <scf::ForOp>(
105
- loc, lb, ub, step, destinationTensors,
106
- [](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
107
- ValueRange /* iterArgs*/ ) {});
108
- loops.push_back (loop);
109
- rewriter.setInsertionPointToEnd (loop.getBody ());
110
- ivs.push_back (loop.getInductionVar ());
111
- destinationTensors.clear ();
112
- destinationTensors.insert (destinationTensors.begin (),
113
- loop.getRegionIterArgs ().begin (),
114
- loop.getRegionIterArgs ().end ());
88
+ SmallVector<OpFoldResult> lbs, ubs, tileSizes;
89
+ for (size_t i = 0 ; i < 3 ; ++i) {
90
+ lbs.push_back (getAsIndexOpFoldResult (rewriter.getContext (), 0 ));
91
+ ubs.push_back (getAsIndexOpFoldResult (rewriter.getContext (), shape[i]));
92
+ tileSizes.push_back (getAsIndexOpFoldResult (
93
+ rewriter.getContext (), i == 2 ? cfg.RowBlockSize : 1 ));
115
94
}
116
- // create rowBlockLoop
117
- auto rowBlockLoop = rewriter.create <scf::ForOp>(
118
- loc,
119
- getValueOrCreateConstantIndexOp (
120
- rewriter, loc, getAsIndexOpFoldResult (rewriter.getContext (), 0UL )),
121
- getValueOrCreateConstantIndexOp (
122
- rewriter, loc,
123
- getAsIndexOpFoldResult (rewriter.getContext (), seqLen)),
124
- getValueOrCreateConstantIndexOp (
125
- rewriter, loc,
126
- getAsIndexOpFoldResult (rewriter.getContext (), cfg.RowBlockSize )),
127
- destinationTensors,
128
- [](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
129
- ValueRange /* iterArgs*/ ) {});
130
- loops.push_back (rowBlockLoop);
131
- ivs.push_back (rowBlockLoop.getInductionVar ());
132
- rewriter.setInsertionPointToEnd (rowBlockLoop.getBody ());
133
- // inserting body for rowBlockLoop
95
+ // create forall loop
96
+ auto forallOp = rewriter.create <scf::ForallOp>(
97
+ loc, lbs, ubs, tileSizes, destinationTensors,
98
+ /* mapping=*/ std::nullopt,
99
+ /* bodyBuilderFn =*/ [](OpBuilder &, Location, ValueRange) {});
100
+ rewriter.setInsertionPointToEnd (forallOp.getBody ());
101
+ SmallVector<Value> ivs = forallOp.getInductionVars ();
102
+ // inserting body for forall loop
134
103
SmallVector<OpFoldResult> offsets;
135
104
offsets.push_back (getAsOpFoldResult (ivs[0 ]));
136
105
offsets.push_back (getAsOpFoldResult (ivs[1 ]));
@@ -182,7 +151,7 @@ struct MHAToFlashAttention
182
151
[](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
183
152
ValueRange /* iterArgs*/ ) {});
184
153
ivs.push_back (columnBlockLoop.getInductionVar ());
185
- rewriter.setInsertionPointToEnd (columnBlockLoop.getBody ());
154
+ rewriter.setInsertionPointToStart (columnBlockLoop.getBody ());
186
155
// innermost computations
187
156
Value prevOSlice = columnBlockLoop.getRegionIterArgs ()[0 ],
188
157
prevMaxSlice = columnBlockLoop.getRegionIterArgs ()[1 ],
@@ -386,8 +355,7 @@ struct MHAToFlashAttention
386
355
// yield all the results of the innermost loop.
387
356
rewriter.create <scf::YieldOp>(
388
357
loc, ValueRange{newOSlice, newMaxSlice, newSumSlice});
389
- // yield rowBlockLoop results
390
- rewriter.setInsertionPointToEnd (rowBlockLoop.getBody ());
358
+ // yield parallel loop results
391
359
auto innermostLoopResults = columnBlockLoop->getResults ();
392
360
Value OSliceFinal = innermostLoopResults[0 ];
393
361
SmallVector<OpFoldResult> outputOffsets;
@@ -398,20 +366,14 @@ struct MHAToFlashAttention
398
366
SmallVector<OpFoldResult> outputSizes (4 , rewriter.getIndexAttr (1 ));
399
367
outputSizes[2 ] = rewriter.getIndexAttr (cfg.RowBlockSize );
400
368
outputSizes[3 ] = rewriter.getIndexAttr (headDim);
401
- Value insertedRescaledOSlice = rewriter.create <tensor::InsertSliceOp>(
402
- loc, OSliceFinal, rowBlockLoop.getRegionIterArgs ()[0 ], outputOffsets,
369
+ // Add the scf.forall.in_parallel operations for the forall op
370
+ rewriter.setInsertionPointToEnd (forallOp.getBody ());
371
+ auto term = rewriter.create <scf::InParallelOp>(loc);
372
+ rewriter.setInsertionPointToStart (term.getBody ());
373
+ rewriter.create <tensor::ParallelInsertSliceOp>(
374
+ loc, OSliceFinal, forallOp.getRegionIterArgs ()[0 ], outputOffsets,
403
375
outputSizes, strides);
404
- rewriter.create <scf::YieldOp>(loc, ValueRange{insertedRescaledOSlice});
405
- // Add the scf.yield operations for all the outer loops.
406
- for (auto [outerLoop, innerLoop] :
407
- llvm::zip_equal (MutableArrayRef (loops).drop_back (),
408
- MutableArrayRef (loops).drop_front ())) {
409
- rewriter.setInsertionPointToEnd (
410
- cast<scf::ForOp>(outerLoop.getOperation ()).getBody ());
411
- rewriter.create <scf::YieldOp>(outerLoop.getLoc (),
412
- innerLoop->getResults ());
413
- }
414
- rewriter.replaceOp (sdpaOp, loops.front ()->getResults ());
376
+ rewriter.replaceOp (sdpaOp, forallOp->getResults ());
415
377
return success ();
416
378
}
417
379
};
0 commit comments