Skip to content

Commit 4022ff1

Browse files
authored
{Executorch][llm] quantized sdpa. update attn_scores @ v gemm
Differential Revision: D71833065 Pull Request resolved: #10108
1 parent c352672 commit 4022ff1

File tree

1 file changed

+78
-30
lines changed

1 file changed

+78
-30
lines changed

extension/llm/custom_ops/op_sdpa_impl.h

Lines changed: 78 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,49 @@ void dequantize_per_channel_optimized(
202202
}
203203
}
204204

205+
void dequant_and_gemm(
206+
const int64_t m,
207+
const int64_t n,
208+
const int64_t k,
209+
float* qk_data,
210+
const int64_t qk_stride_m,
211+
const MaybeQuantizedMatrixData& v_data,
212+
const int64_t v_stride_n,
213+
float* o_data,
214+
const int64_t o_stride_m,
215+
const float beta) {
216+
std::vector<float> dequantized_v_data(v_data.m * v_data.n);
217+
dequantize_per_channel_optimized(
218+
static_cast<const int8_t*>(v_data.data),
219+
static_cast<const float*>(v_data.scales),
220+
static_cast<const int8_t*>(v_data.zero_points),
221+
dequantized_v_data.data(),
222+
-128,
223+
127,
224+
1,
225+
0,
226+
0,
227+
v_data.m,
228+
v_stride_n,
229+
v_data.n,
230+
v_data.n,
231+
v_data.zero_points_stride);
232+
::executorch::cpublas::gemm(
233+
::executorch::cpublas::TransposeType::NoTranspose,
234+
::executorch::cpublas::TransposeType::NoTranspose,
235+
n,
236+
m,
237+
k,
238+
static_cast<float>(1),
239+
dequantized_v_data.data(),
240+
v_data.n,
241+
qk_data,
242+
qk_stride_m,
243+
beta,
244+
o_data,
245+
o_stride_m);
246+
}
247+
205248
template <typename accum_t>
206249
void _qk_at_v_gemm(
207250
const int64_t m,
@@ -216,36 +259,41 @@ void _qk_at_v_gemm(
216259
const accum_t beta) {
217260
if (v_data.dtype == ScalarType::Char) {
218261
if constexpr (std::is_same<accum_t, float>::value) {
219-
std::vector<float> dequantized_v_data(v_data.m * v_data.n);
220-
dequantize_per_channel_optimized(
221-
static_cast<const int8_t*>(v_data.data),
222-
static_cast<const float*>(v_data.scales),
223-
static_cast<const int8_t*>(v_data.zero_points),
224-
dequantized_v_data.data(),
225-
-128,
226-
127,
227-
1,
228-
0,
229-
0,
230-
v_data.m,
231-
v_stride_n,
232-
v_data.n,
233-
v_data.n,
234-
v_data.zero_points_stride);
235-
::executorch::cpublas::gemm(
236-
::executorch::cpublas::TransposeType::NoTranspose,
237-
::executorch::cpublas::TransposeType::NoTranspose,
238-
n,
239-
m,
240-
k,
241-
static_cast<accum_t>(1),
242-
dequantized_v_data.data(),
243-
v_data.n,
244-
qk_data,
245-
qk_stride_m,
246-
beta,
247-
o_data,
248-
o_stride_m);
262+
if (m > 4) {
263+
// For larger batch sizes, dequantize and use BLAS for better
264+
// performance
265+
dequant_and_gemm(
266+
m,
267+
n,
268+
k,
269+
const_cast<float*>(qk_data),
270+
qk_stride_m,
271+
v_data,
272+
v_stride_n,
273+
o_data,
274+
o_stride_m,
275+
beta);
276+
} else {
277+
// For smaller batch sizes, use quantized gemm
278+
int a_stride_m_tmp, b_stride_n_tmp;
279+
auto kernel = torchao::kernels::cpu::quantized_matmul::
280+
get_fp32_a_input_channelwise_8bit_b_f32_c_matmul(
281+
m, n, k, false, false, a_stride_m_tmp, b_stride_n_tmp);
282+
kernel(
283+
m,
284+
n,
285+
k,
286+
qk_data,
287+
qk_stride_m /*lhs_stride_m*/,
288+
static_cast<const int8_t*>(v_data.data),
289+
v_stride_n /*rhs_stride_n*/,
290+
o_data,
291+
o_stride_m /*out_stride_n*/,
292+
static_cast<const int8_t*>(v_data.zero_points),
293+
static_cast<const float*>(v_data.scales),
294+
beta,
295+
v_data.zero_points_stride);
296+
}
249297
} else {
250298
ET_CHECK_MSG(
251299
false, "Accumulation in dtype other than float not supported yet");

0 commit comments

Comments
 (0)