@@ -17,6 +17,77 @@ namespace torch {
17
17
namespace executor {
18
18
19
19
namespace native {
20
+ Tensor& sdpa_with_kv_cache_out_no_context (
21
+ const Tensor& q_projected,
22
+ const Tensor& k_projected,
23
+ const Tensor& v_projected,
24
+ Tensor& key_cache,
25
+ Tensor& value_cache,
26
+ const int64_t start_pos,
27
+ const int64_t seq_len,
28
+ // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
29
+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
30
+ const optional<Tensor> attn_mask,
31
+ const double dropout_p,
32
+ const bool is_causal,
33
+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
34
+ const optional<double > scale,
35
+ Tensor& output);
36
+
37
+ at::Tensor sdpa_with_kv_cache_aten (
38
+ const at::Tensor& q_projected,
39
+ const at::Tensor& k_projected,
40
+ const at::Tensor& v_projected,
41
+ at::Tensor& key_cache,
42
+ at::Tensor& value_cache,
43
+ const int64_t start_pos,
44
+ const int64_t seq_len,
45
+ // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
46
+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
47
+ const std::optional<at::Tensor> attn_mask,
48
+ const double dropout_p,
49
+ const bool is_causal,
50
+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
51
+ const std::optional<double > scale);
52
+
53
+ Tensor& custom_sdpa_out_no_context (
54
+ const Tensor& q,
55
+ const Tensor& k,
56
+ const Tensor& v,
57
+ const int64_t start_pos,
58
+ // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
59
+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
60
+ const optional<Tensor> attn_mask,
61
+ const double dropout_p,
62
+ const bool is_causal,
63
+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
64
+ const optional<double > scale,
65
+ Tensor& output);
66
+
67
+ at::Tensor custom_sdpa_aten (
68
+ const at::Tensor& q,
69
+ const at::Tensor& k,
70
+ const at::Tensor& v,
71
+ const int64_t start_pos,
72
+ // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
73
+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
74
+ const std::optional<at::Tensor> attn_mask,
75
+ const double dropout_p,
76
+ const bool is_causal,
77
+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
78
+ const std::optional<double > scale);
79
+
80
+ Tensor& update_cache_out_no_context (
81
+ const Tensor& value,
82
+ Tensor& cache,
83
+ const int64_t start_pos,
84
+ Tensor& output);
85
+
86
+ at::Tensor update_cache_aten (
87
+ const at::Tensor& value,
88
+ at::Tensor& cache,
89
+ const int64_t start_pos);
90
+
20
91
Tensor& sdpa_with_kv_cache_out_no_context (
21
92
const Tensor& q_projected,
22
93
const Tensor& k_projected,
0 commit comments