@@ -53,8 +53,8 @@ static FlashAttentionConfig
53
53
getDefaultFlashAttentionConfig (linalgx::ScaledDotProductAttentionOp &sdpaOp) {
54
54
// TODO: allow tuning
55
55
FlashAttentionConfig cfg;
56
- cfg.RowBlockSize = 32 ;
57
- cfg.ColumnBlockSize = 32 ;
56
+ cfg.RowBlockSize = 64 ;
57
+ cfg.ColumnBlockSize = 64 ;
58
58
return cfg;
59
59
}
60
60
@@ -109,15 +109,19 @@ struct MHAToFlashAttention
109
109
sizes[2 ] = rewriter.getIndexAttr (cfg.RowBlockSize );
110
110
sizes[3 ] = rewriter.getIndexAttr (headDim);
111
111
SmallVector<OpFoldResult> strides (4 , rewriter.getIndexAttr (1 ));
112
- Value QSlice = rewriter.create <tensor::ExtractSliceOp>(loc, Q, offsets,
113
- sizes, strides);
114
- SmallVector<ReassociationIndices> reassocIndices{{0 , 1 , 2 }, {3 }};
115
- Value collapsedQSlice =
116
- rewriter.create <tensor::CollapseShapeOp>(loc, QSlice, reassocIndices);
112
+ SmallVector<int64_t > QSliceShape{1 , cfg.RowBlockSize , headDim};
113
+ SmallVector<int64_t > KVSliceShape{1 , cfg.ColumnBlockSize , headDim};
114
+ Value QSliceShapeOut =
115
+ rewriter.create <tensor::EmptyOp>(loc, QSliceShape, dtype);
116
+ Value KVSliceShapeOut =
117
+ rewriter.create <tensor::EmptyOp>(loc, KVSliceShape, dtype);
118
+ Value QSlice = rewriter.create <tensor::ExtractSliceOp>(
119
+ loc, cast<RankedTensorType>(QSliceShapeOut.getType ()), Q, offsets,
120
+ sizes, strides);
117
121
Value OSlice = rewriter.create <tensor::ExtractSliceOp>(
118
122
loc, destinationTensors[0 ], offsets, sizes, strides);
119
- Value collapsedOSlice =
120
- rewriter. create <tensor::CollapseShapeOp>( loc, OSlice, reassocIndices );
123
+ Value collapsedOSlice = rewriter. create <tensor::CollapseShapeOp>(
124
+ loc, OSlice, SmallVector<ReassociationIndices>{{ 0 , 1 , 2 }, { 3 }} );
121
125
SmallVector<int64_t > blockShape (1 , cfg.RowBlockSize );
122
126
Value maxSlice = rewriter.create <tensor::EmptyOp>(loc, blockShape, dtype);
123
127
Value sumSlice = rewriter.create <tensor::EmptyOp>(loc, blockShape, dtype);
@@ -159,44 +163,40 @@ struct MHAToFlashAttention
159
163
// adjust offsets and sizes
160
164
offsets[2 ] = getAsOpFoldResult (ivs[3 ]);
161
165
sizes[2 ] = rewriter.getIndexAttr (cfg.ColumnBlockSize );
162
- Value KSlice = rewriter.create <tensor::ExtractSliceOp>(loc, K, offsets,
163
- sizes, strides);
164
- Value VSlice = rewriter.create <tensor::ExtractSliceOp>(loc, V, offsets,
165
- sizes, strides);
166
+ Value KSlice = rewriter.create <tensor::ExtractSliceOp>(
167
+ loc, cast<RankedTensorType>(KVSliceShapeOut.getType ()), K, offsets,
168
+ sizes, strides);
169
+ Value VSlice = rewriter.create <tensor::ExtractSliceOp>(
170
+ loc, cast<RankedTensorType>(KVSliceShapeOut.getType ()), V, offsets,
171
+ sizes, strides);
166
172
offsets[2 ] = getAsOpFoldResult (ivs[2 ]);
167
173
offsets[3 ] = getAsOpFoldResult (ivs[3 ]);
168
174
sizes[2 ] = rewriter.getIndexAttr (cfg.RowBlockSize );
169
175
sizes[3 ] = rewriter.getIndexAttr (cfg.ColumnBlockSize );
176
+ SmallVector<int64_t > maskSliceShape{cfg.RowBlockSize , cfg.ColumnBlockSize };
177
+ Value QKShapeOut =
178
+ rewriter.create <tensor::EmptyOp>(loc, maskSliceShape, dtype);
170
179
Value maskSlice = rewriter.create <tensor::ExtractSliceOp>(
171
- loc, mask, offsets, sizes, strides);
172
- // collapse
173
- Value collapsedKSlice =
174
- rewriter.create <tensor::CollapseShapeOp>(loc, KSlice, reassocIndices);
175
- Value collapsedVSlice =
176
- rewriter.create <tensor::CollapseShapeOp>(loc, VSlice, reassocIndices);
177
- Value collapsedMaskSlice = rewriter.create <tensor::CollapseShapeOp>(
178
- loc, maskSlice, reassocIndices);
180
+ loc, cast<RankedTensorType>(QKShapeOut.getType ()), mask, offsets, sizes,
181
+ strides);
179
182
// transpose K
180
- SmallVector<int64_t > transposedShape{headDim, cfg.RowBlockSize };
183
+ SmallVector<int64_t > transposedShape{1 , headDim, cfg.RowBlockSize };
181
184
Value transposedShapeOut =
182
185
rewriter.create <tensor::EmptyOp>(loc, transposedShape, dtype);
183
- SmallVector<int64_t > transPerm{1 , 0 };
184
- Value transposedK =
185
- rewriter
186
- .create <linalg::TransposeOp>(loc, collapsedKSlice,
187
- transposedShapeOut, transPerm)
188
- ->getResult (0 );
186
+ SmallVector<int64_t > transPerm{0 , 2 , 1 };
187
+ Value transposedKSlice = rewriter
188
+ .create <linalg::TransposeOp>(
189
+ loc, KSlice, transposedShapeOut, transPerm)
190
+ ->getResult (0 );
189
191
// matmul QK
190
- SmallVector<int64_t > QKShape{cfg.RowBlockSize , cfg.ColumnBlockSize };
191
- Value QKShapeOut = rewriter.create <tensor::EmptyOp>(loc, QKShape, dtype);
192
192
Value matmulQKOutFilled =
193
193
rewriter.create <linalg::FillOp>(loc, zero, QKShapeOut).getResult (0 );
194
- Value matmulQK =
195
- rewriter
196
- . create <linalg::MatmulOp>( loc, matmulQKOutFilled.getType (),
197
- ValueRange{collapsedQSlice, transposedK },
198
- ValueRange{matmulQKOutFilled})
199
- .getResult (0 );
194
+ Value matmulQK = rewriter
195
+ . create <linalg::BatchReduceMatmulOp>(
196
+ loc, matmulQKOutFilled.getType (),
197
+ ValueRange{QSlice, transposedKSlice },
198
+ ValueRange{matmulQKOutFilled})
199
+ .getResult (0 );
200
200
// scale & add mask
201
201
float rsqrtHead = 1 / sqrt (headDim);
202
202
SmallVector<AffineMap, 2 > indexingMaps;
@@ -220,7 +220,7 @@ struct MHAToFlashAttention
220
220
.getResult (0 );
221
221
Value add = rewriter
222
222
.create <linalg::AddOp>(loc, QKShapeOut.getType (),
223
- ValueRange{mul, collapsedMaskSlice },
223
+ ValueRange{mul, maskSlice },
224
224
ValueRange{QKShapeOut})
225
225
.getResult (0 );
226
226
// tiling softmax
@@ -310,12 +310,20 @@ struct MHAToFlashAttention
310
310
Value VShapeOut = rewriter.create <tensor::EmptyOp>(loc, VShape, dtype);
311
311
Value matmulVOutFilled =
312
312
rewriter.create <linalg::FillOp>(loc, zero, VShapeOut).getResult (0 );
313
- Value matmulV =
314
- rewriter
315
- .create <linalg::MatmulOp>(loc, matmulVOutFilled.getType (),
316
- ValueRange{PSlice, collapsedVSlice},
317
- ValueRange{matmulVOutFilled})
318
- .getResult (0 );
313
+ SmallVector<OpFoldResult> expandedPSliceShape{
314
+ rewriter.getIndexAttr (1 ), rewriter.getIndexAttr (cfg.RowBlockSize ),
315
+ rewriter.getIndexAttr (cfg.ColumnBlockSize )};
316
+ Value expandedPSliceShapeOut =
317
+ rewriter.create <tensor::EmptyOp>(loc, expandedPSliceShape, dtype);
318
+ Value expandedPSlice = rewriter.create <tensor::ExpandShapeOp>(
319
+ loc, expandedPSliceShapeOut.getType (), PSlice,
320
+ SmallVector<ReassociationIndices>{{0 , 1 }, {2 }}, expandedPSliceShape);
321
+ Value matmulV = rewriter
322
+ .create <linalg::BatchReduceMatmulOp>(
323
+ loc, matmulVOutFilled.getType (),
324
+ ValueRange{expandedPSlice, VSlice},
325
+ ValueRange{matmulVOutFilled})
326
+ .getResult (0 );
319
327
Value newSumSliceRecipBroadcasted =
320
328
rewriter
321
329
.create <linalg::BroadcastOp>(loc, newSumSliceRecip, VShapeOut,
0 commit comments