1
+ /* ***********************************************************************************
2
+ * Copyright (c) 2023, xeus-cpp contributors *
3
+ * Copyright (c) 2023, Johan Mabille, Loic Gouarin, Sylvain Corlay, Wolf Vollprecht *
4
+ * *
5
+ * Distributed under the terms of the BSD 3-Clause License. *
6
+ * *
7
+ * The full license is in the file LICENSE, distributed with this software. *
8
+ ************************************************************************************/
9
+ #include " xassist.hpp"
10
+
11
+ #include < curl/curl.h>
12
+ #include < fstream>
13
+ #include < iostream>
14
+ #include < nlohmann/json.hpp>
15
+ #include < string>
16
+ #include < unordered_set>
17
+
18
+ using json = nlohmann::json;
19
+
20
+ namespace xcpp
21
+ {
22
+ class APIKeyManager
23
+ {
24
+ public:
25
+ static void saveApiKey (const std::string& model, const std::string& apiKey)
26
+ {
27
+ std::string apiKeyFilePath = model + " _api_key.txt" ; // File to store the API key, named after the model
28
+ std::ofstream out (apiKeyFilePath);
29
+ if (out)
30
+ {
31
+ out << apiKey;
32
+ out.close ();
33
+ std::cout << " API key saved for model " << model << std::endl;
34
+ }
35
+ else
36
+ {
37
+ std::cerr << " Failed to open file for writing API key for model " << model << std::endl;
38
+ }
39
+ }
40
+
41
+ // Method to load the API key for a specific model
42
+ static std::string loadApiKey (const std::string& model)
43
+ {
44
+ std::string apiKeyFilePath = model + " _api_key.txt" ; // File to read the API key from, named after the model
45
+ std::ifstream in (apiKeyFilePath);
46
+ std::string apiKey;
47
+ if (in)
48
+ {
49
+ std::getline (in, apiKey);
50
+ in.close ();
51
+ return apiKey;
52
+ }
53
+
54
+ std::cerr << " Failed to open file for reading API key for model " << model << std::endl;
55
+ return " " ;
56
+ }
57
+
58
+ };
59
+
60
+ class CurlHelper
61
+ {
62
+ private:
63
+ CURL* m_curl;
64
+ curl_slist* m_headers;
65
+
66
+ public:
67
+ CurlHelper ()
68
+ : m_curl(curl_easy_init())
69
+ , m_headers(curl_slist_append(nullptr , " Content-Type: application/json" ))
70
+ {
71
+ }
72
+
73
+ ~CurlHelper ()
74
+ {
75
+ if (m_curl)
76
+ {
77
+ curl_easy_cleanup (m_curl);
78
+ }
79
+ if (m_headers)
80
+ {
81
+ curl_slist_free_all (m_headers);
82
+ }
83
+ }
84
+
85
+ // Delete copy constructor and copy assignment operator
86
+ CurlHelper (const CurlHelper&) = delete ;
87
+ CurlHelper& operator =(const CurlHelper&) = delete ;
88
+
89
+ // Delete move constructor and move assignment operator
90
+ CurlHelper (CurlHelper&&) = delete ;
91
+ CurlHelper& operator =(CurlHelper&&) = delete ;
92
+
93
+ std::string
94
+ performRequest (const std::string& url, const std::string& postData, const std::string& authHeader = " " )
95
+ {
96
+ if (!authHeader.empty ())
97
+ {
98
+ m_headers = curl_slist_append (m_headers, authHeader.c_str ());
99
+ }
100
+
101
+ curl_easy_setopt (m_curl, CURLOPT_URL, url.c_str ());
102
+ curl_easy_setopt (m_curl, CURLOPT_HTTPHEADER, m_headers);
103
+ curl_easy_setopt (m_curl, CURLOPT_POSTFIELDS, postData.c_str ());
104
+
105
+ std::string response;
106
+ curl_easy_setopt (
107
+ m_curl,
108
+ CURLOPT_WRITEFUNCTION,
109
+ +[](const char * in, size_t size, size_t num, std::string* out)
110
+ {
111
+ const size_t totalBytes (size * num);
112
+ out->append (in, totalBytes);
113
+ return totalBytes;
114
+ }
115
+ );
116
+ curl_easy_setopt (m_curl, CURLOPT_WRITEDATA, &response);
117
+
118
+ CURLcode res = curl_easy_perform (m_curl);
119
+ if (res != CURLE_OK)
120
+ {
121
+ std::cerr << " CURL request failed: " << curl_easy_strerror (res) << std::endl;
122
+ return " " ;
123
+ }
124
+
125
+ return response;
126
+ }
127
+ };
128
+
129
+ std::string gemini (const std::string& cell, const std::string& key)
130
+ {
131
+ CurlHelper curlHelper;
132
+ const std::string url = " https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key="
133
+ + key;
134
+ const std::string postData = R"( {"contents": [{"parts":[{"text": ")" + cell + R"( "}]}]})" ;
135
+
136
+ std::string response = curlHelper.performRequest (url, postData);
137
+
138
+ json j = json::parse (response);
139
+ if (j.find (" error" ) != j.end ())
140
+ {
141
+ std::cerr << " Error: " << j[" error" ][" message" ] << std::endl;
142
+ return " " ;
143
+ }
144
+
145
+ return j[" candidates" ][0 ][" content" ][" parts" ][0 ][" text" ];
146
+ }
147
+
148
+ std::string openai (const std::string& cell, const std::string& key)
149
+ {
150
+ CurlHelper curlHelper;
151
+ const std::string url = " https://api.openai.com/v1/chat/completions" ;
152
+ const std::string postData = R"( {
153
+ "model": "gpt-3.5-turbo-16k",
154
+ "messages": [{"role": "user", "content": ")"
155
+ + cell + R"( "}],
156
+ "temperature": 0.7
157
+ })" ;
158
+ std::string authHeader = " Authorization: Bearer " + key;
159
+
160
+ std::string response = curlHelper.performRequest (url, postData, authHeader);
161
+
162
+ json j = json::parse (response);
163
+
164
+ if (j.find (" error" ) != j.end ())
165
+ {
166
+ std::cerr << " Error: " << j[" error" ][" message" ] << std::endl;
167
+ return " " ;
168
+ }
169
+
170
+ return j[" choices" ][0 ][" message" ][" content" ];
171
+ }
172
+
173
+ void xassist::operator ()(const std::string& line, const std::string& cell)
174
+ {
175
+ try
176
+ {
177
+ std::istringstream iss (line);
178
+ std::vector<std::string> tokens (
179
+ std::istream_iterator<std::string>{iss},
180
+ std::istream_iterator<std::string>()
181
+ );
182
+
183
+ std::vector<std::string> models = {" gemini" , " openai" };
184
+ std::string model = tokens[1 ];
185
+
186
+ if (std::find (models.begin (), models.end (), model) == models.end ())
187
+ {
188
+ std::cerr << " Model not found." << std::endl;
189
+ return ;
190
+ }
191
+
192
+ APIKeyManager api;
193
+ if (tokens[2 ] == " --save-key" )
194
+ {
195
+ xcpp::APIKeyManager::saveApiKey (model, cell);
196
+ return ;
197
+ }
198
+
199
+ std::string key = xcpp::APIKeyManager::loadApiKey (model);
200
+ if (key.empty ())
201
+ {
202
+ std::cerr << " API key for model " << model << " is not available." << std::endl;
203
+ return ;
204
+ }
205
+
206
+ std::string response;
207
+ if (model == " gemini" )
208
+ {
209
+ response = gemini (cell, key);
210
+ }
211
+ else if (model == " openai" )
212
+ {
213
+ response = openai (cell, key);
214
+ }
215
+
216
+ std::cout << response;
217
+ }
218
+ catch (const std::runtime_error& e)
219
+ {
220
+ std::cerr << " Caught an exception: " << e.what () << std::endl;
221
+ }
222
+ catch (...)
223
+ {
224
+ std::cerr << " Caught an unknown exception" << std::endl;
225
+ }
226
+ }
227
+ } // namespace xcpp
0 commit comments