@@ -202,6 +202,49 @@ void dequantize_per_channel_optimized(
202
202
}
203
203
}
204
204
205
+ void dequant_and_gemm (
206
+ const int64_t m,
207
+ const int64_t n,
208
+ const int64_t k,
209
+ float * qk_data,
210
+ const int64_t qk_stride_m,
211
+ const MaybeQuantizedMatrixData& v_data,
212
+ const int64_t v_stride_n,
213
+ float * o_data,
214
+ const int64_t o_stride_m,
215
+ const float beta) {
216
+ std::vector<float > dequantized_v_data (v_data.m * v_data.n );
217
+ dequantize_per_channel_optimized (
218
+ static_cast <const int8_t *>(v_data.data ),
219
+ static_cast <const float *>(v_data.scales ),
220
+ static_cast <const int8_t *>(v_data.zero_points ),
221
+ dequantized_v_data.data (),
222
+ -128 ,
223
+ 127 ,
224
+ 1 ,
225
+ 0 ,
226
+ 0 ,
227
+ v_data.m ,
228
+ v_stride_n,
229
+ v_data.n ,
230
+ v_data.n ,
231
+ v_data.zero_points_stride );
232
+ ::executorch::cpublas::gemm (
233
+ ::executorch::cpublas::TransposeType::NoTranspose,
234
+ ::executorch::cpublas::TransposeType::NoTranspose,
235
+ n,
236
+ m,
237
+ k,
238
+ static_cast <float >(1 ),
239
+ dequantized_v_data.data(),
240
+ v_data.n,
241
+ qk_data,
242
+ qk_stride_m,
243
+ beta,
244
+ o_data,
245
+ o_stride_m);
246
+ }
247
+
205
248
template <typename accum_t >
206
249
void _qk_at_v_gemm (
207
250
const int64_t m,
@@ -216,36 +259,41 @@ void _qk_at_v_gemm(
216
259
const accum_t beta) {
217
260
if (v_data.dtype == ScalarType::Char) {
218
261
if constexpr (std::is_same<accum_t , float >::value) {
219
- std::vector<float > dequantized_v_data (v_data.m * v_data.n );
220
- dequantize_per_channel_optimized (
221
- static_cast <const int8_t *>(v_data.data ),
222
- static_cast <const float *>(v_data.scales ),
223
- static_cast <const int8_t *>(v_data.zero_points ),
224
- dequantized_v_data.data (),
225
- -128 ,
226
- 127 ,
227
- 1 ,
228
- 0 ,
229
- 0 ,
230
- v_data.m ,
231
- v_stride_n,
232
- v_data.n ,
233
- v_data.n ,
234
- v_data.zero_points_stride );
235
- ::executorch::cpublas::gemm (
236
- ::executorch::cpublas::TransposeType::NoTranspose,
237
- ::executorch::cpublas::TransposeType::NoTranspose,
238
- n,
239
- m,
240
- k,
241
- static_cast <accum_t >(1 ),
242
- dequantized_v_data.data(),
243
- v_data.n,
244
- qk_data,
245
- qk_stride_m,
246
- beta,
247
- o_data,
248
- o_stride_m);
262
+ if (m > 4 ) {
263
+ // For larger batch sizes, dequantize and use BLAS for better
264
+ // performance
265
+ dequant_and_gemm (
266
+ m,
267
+ n,
268
+ k,
269
+ const_cast <float *>(qk_data),
270
+ qk_stride_m,
271
+ v_data,
272
+ v_stride_n,
273
+ o_data,
274
+ o_stride_m,
275
+ beta);
276
+ } else {
277
+ // For smaller batch sizes, use quantized gemm
278
+ int a_stride_m_tmp, b_stride_n_tmp;
279
+ auto kernel = torchao::kernels::cpu::quantized_matmul::
280
+ get_fp32_a_input_channelwise_8bit_b_f32_c_matmul (
281
+ m, n, k, false , false , a_stride_m_tmp, b_stride_n_tmp);
282
+ kernel (
283
+ m,
284
+ n,
285
+ k,
286
+ qk_data,
287
+ qk_stride_m /* lhs_stride_m*/ ,
288
+ static_cast <const int8_t *>(v_data.data ),
289
+ v_stride_n /* rhs_stride_n*/ ,
290
+ o_data,
291
+ o_stride_m /* out_stride_n*/ ,
292
+ static_cast <const int8_t *>(v_data.zero_points ),
293
+ static_cast <const float *>(v_data.scales ),
294
+ beta,
295
+ v_data.zero_points_stride );
296
+ }
249
297
} else {
250
298
ET_CHECK_MSG (
251
299
false , " Accumulation in dtype other than float not supported yet" );
0 commit comments