Skip to content

Add a path to use quantized gemm from torchao in sdpa #9933

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions extension/llm/custom_ops/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,17 @@ runtime.python_test(
"//caffe2:torch",
],
)

runtime.python_test(
name = "test_quantized_sdpa",
srcs = [
"test_quantized_sdpa.py",
],
preload_deps = [
":custom_ops_aot_lib_mkl_noomp",
":custom_ops_aot_py",
],
deps = [
"//caffe2:torch",
],
)
94 changes: 72 additions & 22 deletions extension/llm/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ bool validate_flash_attention_args(
"scaled_dot_product_attention_flash_attention: Q/K/V should have the same head size");

ET_CHECK_OR_RETURN_FALSE(
(query.scalar_type() == ScalarType::Float), "Query must be Float type");
(query.scalar_type() == ScalarType::Float) ||
(query.scalar_type() == ScalarType::Char),
"Query must be Float type");

ET_CHECK_OR_RETURN_FALSE(
(query.scalar_type() == key.scalar_type()) &&
Expand Down Expand Up @@ -354,9 +356,14 @@ Tensor& custom_sdpa_out_impl(
output,
"Invalid arguments");

int64_t seq_len = q.size(1);
auto q_seq_len = q.size(1);

bool is_seq_at_dim_1{true};
if (q.scalar_type() == ScalarType::Char) {
is_seq_at_dim_1 = false;
seq_len = q.size(2);
q_seq_len = q.size(2);
ET_KERNEL_CHECK_MSG(
ctx,
q_scales.has_value() && q_zero_points.has_value() &&
Expand Down Expand Up @@ -390,9 +397,6 @@ Tensor& custom_sdpa_out_impl(

ET_CHECK_MSG(q.dim() == 4, "query must be a 4D tensor");

const int64_t seq_len = q.size(1);
auto q_seq_len = q.size(1);

const int64_t num_keys_for_causal_attention = start_pos + seq_len;

ET_KERNEL_CHECK(
Expand All @@ -418,12 +422,12 @@ Tensor& custom_sdpa_out_impl(
is_causal,
attn_mask,
scale,
nullopt, // q_zero_points
nullopt, // q_scales
nullopt, // k_zero_points
nullopt, // k_scales
nullopt, // v_zero_points
nullopt, // v_scales
q_zero_points, // q_zero_points
q_scales, // q_scales
k_zero_points, // k_zero_points
k_scales, // k_scales
v_zero_points, // v_zero_points
v_scales, // v_scales
is_seq_at_dim_1, /* is_seq_at_dim_1 */
start_pos,
num_keys_for_causal_attention);
Expand All @@ -437,12 +441,12 @@ Tensor& custom_sdpa_out_impl(
is_causal,
attn_mask,
scale,
nullopt, // q_zero_points
nullopt, // q_scales
nullopt, // k_zero_points
nullopt, // k_scales
nullopt, // v_zero_points
nullopt, // v_scales
q_zero_points, // q_zero_points
q_scales, // q_scales
k_zero_points, // k_zero_points
k_scales, // k_scales
v_zero_points, // v_zero_points
v_scales, // v_scales
is_seq_at_dim_1, /* is_seq_at_dim_1 */
start_pos,
num_keys_for_causal_attention);
Expand All @@ -456,12 +460,12 @@ Tensor& custom_sdpa_out_impl(
is_causal,
attn_mask,
scale,
nullopt, // q_zero_points
nullopt, // q_scales
nullopt, // k_zero_points
nullopt, // k_scales
nullopt, // v_zero_points
nullopt, // v_scales
q_zero_points, // q_zero_points
q_scales, // q_scales
k_zero_points, // k_zero_points
k_scales, // k_scales
v_zero_points, // v_zero_points
v_scales, // v_scales
is_seq_at_dim_1, /* is_seq_at_dim_1 */
start_pos,
num_keys_for_causal_attention);
Expand All @@ -470,6 +474,45 @@ Tensor& custom_sdpa_out_impl(
return output;
}

#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA
Tensor& custom_quantized_sdpa_out(
RuntimeContext& ctx,
const Tensor& q,
const Tensor& k,
const Tensor& v,
const int64_t start_pos,
const optional<Tensor>& attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
const optional<Tensor>& q_zero_points,
const optional<Tensor>& q_scales,
const optional<Tensor>& k_zero_points,
const optional<Tensor>& k_scales,
const optional<Tensor>& v_zero_points,
const optional<Tensor>& v_scales,
Tensor& output) {
return custom_sdpa_out_impl(
ctx,
q,
k,
v,
start_pos,
attn_mask,
dropout_p,
is_causal,
scale,
output,
q_zero_points,
q_scales,
k_zero_points,
k_scales,
v_zero_points,
v_scales);
}
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA

/*
Input params
@param[in] q_projected Projected query with query weights.
Expand Down Expand Up @@ -570,3 +613,10 @@ EXECUTORCH_LIBRARY(
llama,
"custom_sdpa.out",
torch::executor::native::custom_sdpa_out);

#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA
EXECUTORCH_LIBRARY(
llama,
"custom_quantized_sdpa.out",
torch::executor::native::custom_quantized_sdpa_out);
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
20 changes: 20 additions & 0 deletions extension/llm/custom_ops/op_sdpa.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,26 @@ Tensor& flash_attention_kernel_out(
const optional<double> scale,
Tensor& output);

#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA
Tensor& custom_quantized_sdpa_out(
RuntimeContext& ctx,
const Tensor& q,
const Tensor& k,
const Tensor& v,
const int64_t start_pos,
const optional<Tensor>& attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
const optional<Tensor>& q_zero_points,
const optional<Tensor>& q_scales,
const optional<Tensor>& k_zero_points,
const optional<Tensor>& k_scales,
const optional<Tensor>& v_zero_points,
const optional<Tensor>& v_scales,
Tensor& output);
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
} // namespace native
} // namespace executor
} // namespace torch
143 changes: 143 additions & 0 deletions extension/llm/custom_ops/op_sdpa_aot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,47 @@ at::Tensor custom_sdpa_aten(
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const std::optional<double> scale);

#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA
Tensor& custom_quantized_sdpa_out_no_context(
const Tensor& q,
const Tensor& k,
const Tensor& v,
const int64_t start_pos,
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<Tensor> attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
const optional<Tensor> q_zero_points,
const optional<Tensor> q_scales,
const optional<Tensor> k_zero_points,
const optional<Tensor> k_scales,
const optional<Tensor> v_zero_points,
const optional<Tensor> v_scales,
Tensor& output);

at::Tensor custom_quantized_sdpa_aten(
const at::Tensor& q,
const at::Tensor& k,
const at::Tensor& v,
const int64_t start_pos,
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const std::optional<at::Tensor> attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const std::optional<double> scale,
const std::optional<at::Tensor>& q_zero_points,
const std::optional<at::Tensor>& q_scales,
const std::optional<at::Tensor>& k_zero_points,
const std::optional<at::Tensor>& k_scales,
const std::optional<at::Tensor>& v_zero_points,
const std::optional<at::Tensor>& v_scales);
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA

Tensor& update_cache_out_no_context(
const Tensor& value,
Tensor& cache,
Expand Down Expand Up @@ -198,6 +239,85 @@ at::Tensor custom_sdpa_aten(
return output;
}

#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA
Tensor& custom_quantized_sdpa_out_no_context(
const Tensor& q,
const Tensor& k,
const Tensor& v,
const int64_t start_pos,
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<Tensor> attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
const optional<Tensor> q_zero_points,
const optional<Tensor> q_scales,
const optional<Tensor> k_zero_points,
const optional<Tensor> k_scales,
const optional<Tensor> v_zero_points,
const optional<Tensor> v_scales,
Tensor& output) {
executorch::aten::RuntimeContext context{};
return torch::executor::native::custom_quantized_sdpa_out(
context,
q,
k,
v,
start_pos,
attn_mask,
dropout_p,
is_causal,
scale,
q_zero_points,
q_scales,
k_zero_points,
k_scales,
v_zero_points,
v_scales,
output);
}

at::Tensor custom_quantized_sdpa_aten(
const at::Tensor& q,
const at::Tensor& k,
const at::Tensor& v,
const int64_t start_pos,
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const std::optional<at::Tensor> attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const std::optional<double> scale,
const std::optional<at::Tensor>& q_zero_points,
const std::optional<at::Tensor>& q_scales,
const std::optional<at::Tensor>& k_zero_points,
const std::optional<at::Tensor>& k_scales,
const std::optional<at::Tensor>& v_zero_points,
const std::optional<at::Tensor>& v_scales) {
auto output = at::empty(q.sizes());
WRAP_TO_ATEN(custom_quantized_sdpa_out_no_context, 14)
(q,
k,
v,
start_pos,
attn_mask,
dropout_p,
is_causal,
scale,
q_zero_points,
q_scales,
k_zero_points,
k_scales,
v_zero_points,
v_scales,
output);
return output;
}
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA

Tensor& update_cache_out_no_context(
const Tensor& value,
Tensor& cache,
Expand Down Expand Up @@ -245,6 +365,20 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
m.def(
"update_cache.out(Tensor value, Tensor(a!) cache, "
"SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)");
#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA
m.def(
"custom_quantized_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
"float? scale=None, Tensor? q_zero_points=None, Tensor? q_scales=None, "
"Tensor? k_zero_points=None, Tensor? k_scales=None, Tensor? v_zero_points=None, "
"Tensor? v_scales=None) -> Tensor");
m.def(
"custom_quantized_sdpa.out(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
"float? scale=None, Tensor? q_zero_points=None, Tensor? q_scales=None, "
"Tensor? k_zero_points=None, Tensor? k_scales=None, Tensor? v_zero_points=None, "
"Tensor? v_scales=None, *, Tensor(a!) out) -> Tensor(a!)");
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
}

// TODO: Rename this file to op_custom_ops_aot.cpp
Expand All @@ -263,4 +397,13 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
m.impl(
"update_cache.out",
WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 3));
#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA
m.impl(
"custom_quantized_sdpa",
torch::executor::native::custom_quantized_sdpa_aten);
m.impl(
"custom_quantized_sdpa.out",
WRAP_TO_ATEN(
torch::executor::native::custom_quantized_sdpa_out_no_context, 14));
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
}
Loading
Loading