Skip to content

Commit 46ee25a

Browse files
committed
fix for all
1 parent e13ec10 commit 46ee25a

File tree

2 files changed

+38
-76
lines changed

2 files changed

+38
-76
lines changed

lib/gc/Transforms/FlashAttentionConversion.cpp

Lines changed: 25 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -73,64 +73,33 @@ struct MHAToFlashAttention
7373
FlashAttentionConfig cfg = getDefaultFlashAttentionConfig(sdpaOp);
7474
Location loc = sdpaOp.getLoc();
7575
OpBuilder::InsertionGuard guard(rewriter);
76-
SmallVector<LoopLikeOpInterface> loops;
77-
SmallVector<Value> ivs;
7876
auto shape =
7977
dyn_cast<RankedTensorType>(sdpaOp.getOperand(0).getType()).getShape();
8078
auto dtype = dyn_cast<RankedTensorType>(sdpaOp.getOperand(0).getType())
8179
.getElementType();
8280
int64_t seqLen = shape[2], headDim = shape[3];
8381
auto Q = sdpaOp.getOperand(0), K = sdpaOp.getOperand(1),
8482
V = sdpaOp.getOperand(2), mask = sdpaOp.getOperand(3);
85-
// construct 2 parallel outermost loops for batchSize and numHeads
86-
SmallVector<Range> loopRanges;
87-
for (size_t i = 0; i < 2; ++i) {
88-
Range curRange;
89-
curRange.offset = getAsIndexOpFoldResult(rewriter.getContext(), 0UL);
90-
curRange.size = getAsIndexOpFoldResult(rewriter.getContext(), shape[i]);
91-
curRange.stride = getAsIndexOpFoldResult(rewriter.getContext(), 1UL);
92-
loopRanges.push_back(curRange);
93-
}
94-
SmallVector<OpFoldResult> tileSizes(
95-
2, getAsIndexOpFoldResult(rewriter.getContext(), 1UL));
83+
// construct 3 parallel outermost loops for
84+
// batchSize/numHeads/(seqLen/rowBlockSize)
9685
SmallVector<Value> destinationTensors;
9786
tensor::getOrCreateDestinations(rewriter, sdpaOp.getLoc(), sdpaOp,
9887
destinationTensors);
99-
for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
100-
Value lb =
101-
getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
102-
Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
103-
Value step = getValueOrCreateConstantIndexOp(rewriter, loc, tileSize);
104-
auto loop = rewriter.create<scf::ForOp>(
105-
loc, lb, ub, step, destinationTensors,
106-
[](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
107-
ValueRange /*iterArgs*/) {});
108-
loops.push_back(loop);
109-
rewriter.setInsertionPointToEnd(loop.getBody());
110-
ivs.push_back(loop.getInductionVar());
111-
destinationTensors.clear();
112-
destinationTensors.insert(destinationTensors.begin(),
113-
loop.getRegionIterArgs().begin(),
114-
loop.getRegionIterArgs().end());
88+
SmallVector<OpFoldResult> lbs, ubs, tileSizes;
89+
for (size_t i = 0; i < 3; ++i) {
90+
lbs.push_back(getAsIndexOpFoldResult(rewriter.getContext(), 0));
91+
ubs.push_back(getAsIndexOpFoldResult(rewriter.getContext(), shape[i]));
92+
tileSizes.push_back(getAsIndexOpFoldResult(
93+
rewriter.getContext(), i == 2 ? cfg.RowBlockSize : 1));
11594
}
116-
// create rowBlockLoop
117-
auto rowBlockLoop = rewriter.create<scf::ForOp>(
118-
loc,
119-
getValueOrCreateConstantIndexOp(
120-
rewriter, loc, getAsIndexOpFoldResult(rewriter.getContext(), 0UL)),
121-
getValueOrCreateConstantIndexOp(
122-
rewriter, loc,
123-
getAsIndexOpFoldResult(rewriter.getContext(), seqLen)),
124-
getValueOrCreateConstantIndexOp(
125-
rewriter, loc,
126-
getAsIndexOpFoldResult(rewriter.getContext(), cfg.RowBlockSize)),
127-
destinationTensors,
128-
[](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
129-
ValueRange /*iterArgs*/) {});
130-
loops.push_back(rowBlockLoop);
131-
ivs.push_back(rowBlockLoop.getInductionVar());
132-
rewriter.setInsertionPointToEnd(rowBlockLoop.getBody());
133-
// inserting body for rowBlockLoop
95+
// create forall loop
96+
auto forallOp = rewriter.create<scf::ForallOp>(
97+
loc, lbs, ubs, tileSizes, destinationTensors,
98+
/*mapping=*/std::nullopt,
99+
/*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
100+
rewriter.setInsertionPointToEnd(forallOp.getBody());
101+
SmallVector<Value> ivs = forallOp.getInductionVars();
102+
// inserting body for forall loop
134103
SmallVector<OpFoldResult> offsets;
135104
offsets.push_back(getAsOpFoldResult(ivs[0]));
136105
offsets.push_back(getAsOpFoldResult(ivs[1]));
@@ -182,7 +151,7 @@ struct MHAToFlashAttention
182151
[](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
183152
ValueRange /*iterArgs*/) {});
184153
ivs.push_back(columnBlockLoop.getInductionVar());
185-
rewriter.setInsertionPointToEnd(columnBlockLoop.getBody());
154+
rewriter.setInsertionPointToStart(columnBlockLoop.getBody());
186155
// innermost computations
187156
Value prevOSlice = columnBlockLoop.getRegionIterArgs()[0],
188157
prevMaxSlice = columnBlockLoop.getRegionIterArgs()[1],
@@ -386,8 +355,7 @@ struct MHAToFlashAttention
386355
// yield all the results of the innermost loop.
387356
rewriter.create<scf::YieldOp>(
388357
loc, ValueRange{newOSlice, newMaxSlice, newSumSlice});
389-
// yield rowBlockLoop results
390-
rewriter.setInsertionPointToEnd(rowBlockLoop.getBody());
358+
// yield parallel loop results
391359
auto innermostLoopResults = columnBlockLoop->getResults();
392360
Value OSliceFinal = innermostLoopResults[0];
393361
SmallVector<OpFoldResult> outputOffsets;
@@ -398,20 +366,14 @@ struct MHAToFlashAttention
398366
SmallVector<OpFoldResult> outputSizes(4, rewriter.getIndexAttr(1));
399367
outputSizes[2] = rewriter.getIndexAttr(cfg.RowBlockSize);
400368
outputSizes[3] = rewriter.getIndexAttr(headDim);
401-
Value insertedRescaledOSlice = rewriter.create<tensor::InsertSliceOp>(
402-
loc, OSliceFinal, rowBlockLoop.getRegionIterArgs()[0], outputOffsets,
369+
// Add the scf.forall.in_parallel operations for the forall op
370+
rewriter.setInsertionPointToEnd(forallOp.getBody());
371+
auto term = rewriter.create<scf::InParallelOp>(loc);
372+
rewriter.setInsertionPointToStart(term.getBody());
373+
rewriter.create<tensor::ParallelInsertSliceOp>(
374+
loc, OSliceFinal, forallOp.getRegionIterArgs()[0], outputOffsets,
403375
outputSizes, strides);
404-
rewriter.create<scf::YieldOp>(loc, ValueRange{insertedRescaledOSlice});
405-
// Add the scf.yield operations for all the outer loops.
406-
for (auto [outerLoop, innerLoop] :
407-
llvm::zip_equal(MutableArrayRef(loops).drop_back(),
408-
MutableArrayRef(loops).drop_front())) {
409-
rewriter.setInsertionPointToEnd(
410-
cast<scf::ForOp>(outerLoop.getOperation()).getBody());
411-
rewriter.create<scf::YieldOp>(outerLoop.getLoc(),
412-
innerLoop->getResults());
413-
}
414-
rewriter.replaceOp(sdpaOp, loops.front()->getResults());
376+
rewriter.replaceOp(sdpaOp, forallOp->getResults());
415377
return success();
416378
}
417379
};

test/gc/Transform/flashAttention.mlir

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,29 @@
11
// RUN: gc-opt --split-input-file --flash-attention-conversion --gc-cpu-pipeline %s | gc-cpu-runner -e main -entry-point-result=void
22
// | FileCheck --allow-empty
33

4-
func.func @flash_attention(%arg0: tensor<4x4x384x64xf32>, %arg1: tensor<4x4x384x64xf32>, %arg2: tensor<4x4x384x64xf32>, %arg3: tensor<4x4x384x384xf32>) -> tensor<4x4x384x64xf32> {
5-
%0 = tensor.empty() : tensor<4x4x384x64xf32>
6-
%1 = linalgx.scaled_dot_product_attention ins(%arg0, %arg1, %arg2, %arg3: tensor<4x4x384x64xf32>, tensor<4x4x384x64xf32>, tensor<4x4x384x64xf32>, tensor<4x4x384x384xf32>) outs(%0 : tensor<4x4x384x64xf32>) -> tensor<4x4x384x64xf32>
7-
return %1 : tensor<4x4x384x64xf32>
4+
func.func @flash_attention(%arg0: tensor<56x16x384x64xf32>, %arg1: tensor<56x16x384x64xf32>, %arg2: tensor<56x16x384x64xf32>, %arg3: tensor<56x16x384x384xf32>) -> tensor<56x16x384x64xf32> {
5+
%0 = tensor.empty() : tensor<56x16x384x64xf32>
6+
%1 = linalgx.scaled_dot_product_attention ins(%arg0, %arg1, %arg2, %arg3: tensor<56x16x384x64xf32>, tensor<56x16x384x64xf32>, tensor<56x16x384x64xf32>, tensor<56x16x384x384xf32>) outs(%0 : tensor<56x16x384x64xf32>) -> tensor<56x16x384x64xf32>
7+
return %1 : tensor<56x16x384x64xf32>
88
}
99

1010
func.func @main() {
1111
%cst = arith.constant 4.000000e+00 : f32
1212

13-
%QKVShape = tensor.empty() : tensor<4x4x384x64xf32>
14-
%maskShape = tensor.empty() : tensor<4x4x384x384xf32>
13+
%QKVShape = tensor.empty() : tensor<56x16x384x64xf32>
14+
%maskShape = tensor.empty() : tensor<56x16x384x384xf32>
1515

16-
%Q = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<4x4x384x64xf32>) -> tensor<4x4x384x64xf32>
17-
%K = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<4x4x384x64xf32>) -> tensor<4x4x384x64xf32>
18-
%V = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<4x4x384x64xf32>) -> tensor<4x4x384x64xf32>
19-
%mask = linalg.fill ins(%cst : f32) outs(%maskShape : tensor<4x4x384x384xf32>) -> tensor<4x4x384x384xf32>
16+
%Q = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<56x16x384x64xf32>) -> tensor<56x16x384x64xf32>
17+
%K = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<56x16x384x64xf32>) -> tensor<56x16x384x64xf32>
18+
%V = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<56x16x384x64xf32>) -> tensor<56x16x384x64xf32>
19+
%mask = linalg.fill ins(%cst : f32) outs(%maskShape : tensor<56x16x384x384xf32>) -> tensor<56x16x384x384xf32>
2020

2121
%out = func.call @flash_attention(%Q, %K, %V, %mask) :
22-
(tensor<4x4x384x64xf32>, tensor<4x4x384x64xf32>, tensor<4x4x384x64xf32>, tensor<4x4x384x384xf32>)
23-
-> (tensor<4x4x384x64xf32>)
22+
(tensor<56x16x384x64xf32>, tensor<56x16x384x64xf32>, tensor<56x16x384x64xf32>, tensor<56x16x384x384xf32>)
23+
-> (tensor<56x16x384x64xf32>)
2424

2525
%idx = arith.constant 0 : index
26-
%val = tensor.extract %out[%idx, %idx, %idx, %idx] : tensor<4x4x384x64xf32>
26+
%val = tensor.extract %out[%idx, %idx, %idx, %idx] : tensor<56x16x384x64xf32>
2727
cpuruntime.printf "output[0, 0, 0, 0]: %f\n" %val : f32
2828

2929
return

0 commit comments

Comments
 (0)