Skip to content

Commit d52357f

Browse files
[Executorch][custo ops] Add prototype defs for custom op (#9815)
Pull Request resolved: #9786 To fix mac buck builds ghstack-source-id: 275283538 Differential Revision: [D71370605](https://our.internmc.facebook.com/intern/diff/D71370605/) Co-authored-by: Kimish Patel <[email protected]>
1 parent 6af0428 commit d52357f

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

extension/llm/custom_ops/op_sdpa_aot.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,77 @@ namespace torch {
1717
namespace executor {
1818

1919
namespace native {
20+
Tensor& sdpa_with_kv_cache_out_no_context(
21+
const Tensor& q_projected,
22+
const Tensor& k_projected,
23+
const Tensor& v_projected,
24+
Tensor& key_cache,
25+
Tensor& value_cache,
26+
const int64_t start_pos,
27+
const int64_t seq_len,
28+
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
29+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
30+
const optional<Tensor> attn_mask,
31+
const double dropout_p,
32+
const bool is_causal,
33+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
34+
const optional<double> scale,
35+
Tensor& output);
36+
37+
at::Tensor sdpa_with_kv_cache_aten(
38+
const at::Tensor& q_projected,
39+
const at::Tensor& k_projected,
40+
const at::Tensor& v_projected,
41+
at::Tensor& key_cache,
42+
at::Tensor& value_cache,
43+
const int64_t start_pos,
44+
const int64_t seq_len,
45+
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
46+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
47+
const std::optional<at::Tensor> attn_mask,
48+
const double dropout_p,
49+
const bool is_causal,
50+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
51+
const std::optional<double> scale);
52+
53+
Tensor& custom_sdpa_out_no_context(
54+
const Tensor& q,
55+
const Tensor& k,
56+
const Tensor& v,
57+
const int64_t start_pos,
58+
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
59+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
60+
const optional<Tensor> attn_mask,
61+
const double dropout_p,
62+
const bool is_causal,
63+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
64+
const optional<double> scale,
65+
Tensor& output);
66+
67+
at::Tensor custom_sdpa_aten(
68+
const at::Tensor& q,
69+
const at::Tensor& k,
70+
const at::Tensor& v,
71+
const int64_t start_pos,
72+
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
73+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
74+
const std::optional<at::Tensor> attn_mask,
75+
const double dropout_p,
76+
const bool is_causal,
77+
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
78+
const std::optional<double> scale);
79+
80+
Tensor& update_cache_out_no_context(
81+
const Tensor& value,
82+
Tensor& cache,
83+
const int64_t start_pos,
84+
Tensor& output);
85+
86+
at::Tensor update_cache_aten(
87+
const at::Tensor& value,
88+
at::Tensor& cache,
89+
const int64_t start_pos);
90+
2091
Tensor& sdpa_with_kv_cache_out_no_context(
2192
const Tensor& q_projected,
2293
const Tensor& k_projected,

extension/llm/custom_ops/op_tile_crop_aot.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,17 @@ namespace executor {
1717

1818
namespace native {
1919

20+
Tensor&
21+
tile_crop_out_no_context(const Tensor& input, int64_t tile_size, Tensor& out);
22+
2023
Tensor&
2124
tile_crop_out_no_context(const Tensor& input, int64_t tile_size, Tensor& out) {
2225
executorch::aten::RuntimeContext context{};
2326
return tile_crop_out_impl(context, input, tile_size, out);
2427
}
2528

29+
at::Tensor tile_crop_aten(const at::Tensor& input, int64_t tile_size);
30+
2631
at::Tensor tile_crop_aten(const at::Tensor& input, int64_t tile_size) {
2732
// max_num_tiles = 4, num_channels = 3.
2833
auto output = at::empty({4, 3, tile_size, tile_size});

0 commit comments

Comments
 (0)