Skip to content

Commit 226d0cd

Browse files
committed
fix performance
1 parent 46ee25a commit 226d0cd

File tree

1 file changed

+51
-43
lines changed

1 file changed

+51
-43
lines changed

lib/gc/Transforms/FlashAttentionConversion.cpp

Lines changed: 51 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ static FlashAttentionConfig
5353
getDefaultFlashAttentionConfig(linalgx::ScaledDotProductAttentionOp &sdpaOp) {
5454
// TODO: allow tuning
5555
FlashAttentionConfig cfg;
56-
cfg.RowBlockSize = 32;
57-
cfg.ColumnBlockSize = 32;
56+
cfg.RowBlockSize = 64;
57+
cfg.ColumnBlockSize = 64;
5858
return cfg;
5959
}
6060

@@ -109,15 +109,19 @@ struct MHAToFlashAttention
109109
sizes[2] = rewriter.getIndexAttr(cfg.RowBlockSize);
110110
sizes[3] = rewriter.getIndexAttr(headDim);
111111
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
112-
Value QSlice = rewriter.create<tensor::ExtractSliceOp>(loc, Q, offsets,
113-
sizes, strides);
114-
SmallVector<ReassociationIndices> reassocIndices{{0, 1, 2}, {3}};
115-
Value collapsedQSlice =
116-
rewriter.create<tensor::CollapseShapeOp>(loc, QSlice, reassocIndices);
112+
SmallVector<int64_t> QSliceShape{1, cfg.RowBlockSize, headDim};
113+
SmallVector<int64_t> KVSliceShape{1, cfg.ColumnBlockSize, headDim};
114+
Value QSliceShapeOut =
115+
rewriter.create<tensor::EmptyOp>(loc, QSliceShape, dtype);
116+
Value KVSliceShapeOut =
117+
rewriter.create<tensor::EmptyOp>(loc, KVSliceShape, dtype);
118+
Value QSlice = rewriter.create<tensor::ExtractSliceOp>(
119+
loc, cast<RankedTensorType>(QSliceShapeOut.getType()), Q, offsets,
120+
sizes, strides);
117121
Value OSlice = rewriter.create<tensor::ExtractSliceOp>(
118122
loc, destinationTensors[0], offsets, sizes, strides);
119-
Value collapsedOSlice =
120-
rewriter.create<tensor::CollapseShapeOp>(loc, OSlice, reassocIndices);
123+
Value collapsedOSlice = rewriter.create<tensor::CollapseShapeOp>(
124+
loc, OSlice, SmallVector<ReassociationIndices>{{0, 1, 2}, {3}});
121125
SmallVector<int64_t> blockShape(1, cfg.RowBlockSize);
122126
Value maxSlice = rewriter.create<tensor::EmptyOp>(loc, blockShape, dtype);
123127
Value sumSlice = rewriter.create<tensor::EmptyOp>(loc, blockShape, dtype);
@@ -159,44 +163,40 @@ struct MHAToFlashAttention
159163
// adjust offsets and sizes
160164
offsets[2] = getAsOpFoldResult(ivs[3]);
161165
sizes[2] = rewriter.getIndexAttr(cfg.ColumnBlockSize);
162-
Value KSlice = rewriter.create<tensor::ExtractSliceOp>(loc, K, offsets,
163-
sizes, strides);
164-
Value VSlice = rewriter.create<tensor::ExtractSliceOp>(loc, V, offsets,
165-
sizes, strides);
166+
Value KSlice = rewriter.create<tensor::ExtractSliceOp>(
167+
loc, cast<RankedTensorType>(KVSliceShapeOut.getType()), K, offsets,
168+
sizes, strides);
169+
Value VSlice = rewriter.create<tensor::ExtractSliceOp>(
170+
loc, cast<RankedTensorType>(KVSliceShapeOut.getType()), V, offsets,
171+
sizes, strides);
166172
offsets[2] = getAsOpFoldResult(ivs[2]);
167173
offsets[3] = getAsOpFoldResult(ivs[3]);
168174
sizes[2] = rewriter.getIndexAttr(cfg.RowBlockSize);
169175
sizes[3] = rewriter.getIndexAttr(cfg.ColumnBlockSize);
176+
SmallVector<int64_t> maskSliceShape{cfg.RowBlockSize, cfg.ColumnBlockSize};
177+
Value QKShapeOut =
178+
rewriter.create<tensor::EmptyOp>(loc, maskSliceShape, dtype);
170179
Value maskSlice = rewriter.create<tensor::ExtractSliceOp>(
171-
loc, mask, offsets, sizes, strides);
172-
// collapse
173-
Value collapsedKSlice =
174-
rewriter.create<tensor::CollapseShapeOp>(loc, KSlice, reassocIndices);
175-
Value collapsedVSlice =
176-
rewriter.create<tensor::CollapseShapeOp>(loc, VSlice, reassocIndices);
177-
Value collapsedMaskSlice = rewriter.create<tensor::CollapseShapeOp>(
178-
loc, maskSlice, reassocIndices);
180+
loc, cast<RankedTensorType>(QKShapeOut.getType()), mask, offsets, sizes,
181+
strides);
179182
// transpose K
180-
SmallVector<int64_t> transposedShape{headDim, cfg.RowBlockSize};
183+
SmallVector<int64_t> transposedShape{1, headDim, cfg.RowBlockSize};
181184
Value transposedShapeOut =
182185
rewriter.create<tensor::EmptyOp>(loc, transposedShape, dtype);
183-
SmallVector<int64_t> transPerm{1, 0};
184-
Value transposedK =
185-
rewriter
186-
.create<linalg::TransposeOp>(loc, collapsedKSlice,
187-
transposedShapeOut, transPerm)
188-
->getResult(0);
186+
SmallVector<int64_t> transPerm{0, 2, 1};
187+
Value transposedKSlice = rewriter
188+
.create<linalg::TransposeOp>(
189+
loc, KSlice, transposedShapeOut, transPerm)
190+
->getResult(0);
189191
// matmul QK
190-
SmallVector<int64_t> QKShape{cfg.RowBlockSize, cfg.ColumnBlockSize};
191-
Value QKShapeOut = rewriter.create<tensor::EmptyOp>(loc, QKShape, dtype);
192192
Value matmulQKOutFilled =
193193
rewriter.create<linalg::FillOp>(loc, zero, QKShapeOut).getResult(0);
194-
Value matmulQK =
195-
rewriter
196-
.create<linalg::MatmulOp>(loc, matmulQKOutFilled.getType(),
197-
ValueRange{collapsedQSlice, transposedK},
198-
ValueRange{matmulQKOutFilled})
199-
.getResult(0);
194+
Value matmulQK = rewriter
195+
.create<linalg::BatchReduceMatmulOp>(
196+
loc, matmulQKOutFilled.getType(),
197+
ValueRange{QSlice, transposedKSlice},
198+
ValueRange{matmulQKOutFilled})
199+
.getResult(0);
200200
// scale & add mask
201201
float rsqrtHead = 1 / sqrt(headDim);
202202
SmallVector<AffineMap, 2> indexingMaps;
@@ -220,7 +220,7 @@ struct MHAToFlashAttention
220220
.getResult(0);
221221
Value add = rewriter
222222
.create<linalg::AddOp>(loc, QKShapeOut.getType(),
223-
ValueRange{mul, collapsedMaskSlice},
223+
ValueRange{mul, maskSlice},
224224
ValueRange{QKShapeOut})
225225
.getResult(0);
226226
// tiling softmax
@@ -310,12 +310,20 @@ struct MHAToFlashAttention
310310
Value VShapeOut = rewriter.create<tensor::EmptyOp>(loc, VShape, dtype);
311311
Value matmulVOutFilled =
312312
rewriter.create<linalg::FillOp>(loc, zero, VShapeOut).getResult(0);
313-
Value matmulV =
314-
rewriter
315-
.create<linalg::MatmulOp>(loc, matmulVOutFilled.getType(),
316-
ValueRange{PSlice, collapsedVSlice},
317-
ValueRange{matmulVOutFilled})
318-
.getResult(0);
313+
SmallVector<OpFoldResult> expandedPSliceShape{
314+
rewriter.getIndexAttr(1), rewriter.getIndexAttr(cfg.RowBlockSize),
315+
rewriter.getIndexAttr(cfg.ColumnBlockSize)};
316+
Value expandedPSliceShapeOut =
317+
rewriter.create<tensor::EmptyOp>(loc, expandedPSliceShape, dtype);
318+
Value expandedPSlice = rewriter.create<tensor::ExpandShapeOp>(
319+
loc, expandedPSliceShapeOut.getType(), PSlice,
320+
SmallVector<ReassociationIndices>{{0, 1}, {2}}, expandedPSliceShape);
321+
Value matmulV = rewriter
322+
.create<linalg::BatchReduceMatmulOp>(
323+
loc, matmulVOutFilled.getType(),
324+
ValueRange{expandedPSlice, VSlice},
325+
ValueRange{matmulVOutFilled})
326+
.getResult(0);
319327
Value newSumSliceRecipBroadcasted =
320328
rewriter
321329
.create<linalg::BroadcastOp>(loc, newSumSliceRecip, VShapeOut,

0 commit comments

Comments
 (0)