@@ -1157,6 +1157,63 @@ static void log_server_request(const Request &req, const Response &res)
1157
1157
});
1158
1158
}
1159
1159
1160
+ bool is_at_eos (llama_server_context&, llama_token const * tokens, size_t const n_tokens) {
1161
+ return n_tokens && tokens[n_tokens-1 ] == llama_token_eos ();
1162
+ }
1163
+
1164
+ // Function matching type llama_beam_search_callback_fn_t.
1165
+ // Custom callback example is called each time the beams lengths increase:
1166
+ // * Show progress by printing ',' following by number of convergent beam tokens if any.
1167
+ // * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
1168
+ // This is also called when the stop condition is met.
1169
+ // Collect tokens into std::vector<llama_token> response which is pointed to by callback_state.
1170
+ void beam_search_callback (void * callback_state, llama_beams_state beams_state) {
1171
+ auto & llama = *static_cast <llama_server_context*>(callback_state);
1172
+ // Mark beams as EOS as needed.
1173
+ for (size_t i=0 ; i<beams_state.n_beams ; ++i) {
1174
+ llama_beam_view& beam_view = beams_state.beam_views [i];
1175
+ if (!beam_view.eos && is_at_eos (llama, beam_view.tokens , beam_view.n_tokens )) {
1176
+ beam_view.eos = true ;
1177
+ }
1178
+ }
1179
+ printf (" ," ); // Show progress
1180
+ if (size_t const n = beams_state.common_prefix_length ) {
1181
+ llama.generated_token_probs .resize (llama.generated_token_probs .size () + n);
1182
+ assert (0u < beams_state.n_beams );
1183
+ llama_token const * tokens = beams_state.beam_views [0 ].tokens ;
1184
+ // std::copy(tokens, tokens + n, llama->generated_token_probs.end() - n);
1185
+ auto const map = [](llama_token tok) { return completion_token_output{{},tok}; };
1186
+ std::transform (tokens, tokens + n, llama.generated_token_probs .end () - n, map);
1187
+ printf (" %lu" , n);
1188
+ }
1189
+ fflush (stdout);
1190
+ #if 0 // DEBUG: print current beams for this iteration
1191
+ std::cout << "\n\nCurrent beams:\n";
1192
+ for (size_t i=0 ; i < beams_state.n_beams ; ++i) {
1193
+ std::cout << "beams["<<i<<"]: " << ostream_beam_view{state.ctx,beams_state.beam_views[i]} << std::endl;
1194
+ }
1195
+ #endif
1196
+ }
1197
+
1198
+ struct token_translator {
1199
+ llama_context* ctx;
1200
+ char const * operator ()(llama_token tok) const { return llama_token_to_str (ctx, tok); }
1201
+ char const * operator ()(completion_token_output cto) const { return (*this )(cto.tok ); }
1202
+ };
1203
+
1204
+ void append_to_generated_text_from_generated_token_probs (llama_server_context& llama) {
1205
+ auto & gtps = llama.generated_token_probs ;
1206
+ auto translator = token_translator{llama.ctx };
1207
+ auto add_strlen = [=](size_t sum, completion_token_output const & cto) { return sum + strlen (translator (cto)); };
1208
+ size_t const len = std::accumulate (gtps.begin (), gtps.end (), size_t (0 ), add_strlen);
1209
+ if (llama.generated_text .capacity () < llama.generated_text .size () + len) {
1210
+ llama.generated_text .reserve (llama.generated_text .size () + len);
1211
+ }
1212
+ for (completion_token_output const & cto : gtps) {
1213
+ llama.generated_text += translator (cto);
1214
+ }
1215
+ }
1216
+
1160
1217
int main (int argc, char **argv)
1161
1218
{
1162
1219
// own arguments required by this example
@@ -1233,22 +1290,30 @@ int main(int argc, char **argv)
1233
1290
llama.beginCompletion ();
1234
1291
1235
1292
if (!llama.stream ) {
1236
- size_t stop_pos = std::string::npos;
1293
+ if (llama.params .n_beams ) {
1294
+ // Fill llama.generated_token_probs vector with final beam.
1295
+ llama_beam_search (llama.ctx , beam_search_callback, &llama, llama.params .n_beams ,
1296
+ llama.n_past , llama.n_remain , llama.params .n_threads );
1297
+ // Translate llama.generated_token_probs to llama.generated_text.
1298
+ append_to_generated_text_from_generated_token_probs (llama);
1299
+ } else {
1300
+ size_t stop_pos = std::string::npos;
1237
1301
1238
- while (llama.has_next_token ) {
1239
- const completion_token_output token_with_probs = llama.doCompletion ();
1240
- const std::string token_text = token_with_probs.tok == -1 ? " " : llama_token_to_str (llama.ctx , token_with_probs.tok );
1302
+ while (llama.has_next_token ) {
1303
+ const completion_token_output token_with_probs = llama.doCompletion ();
1304
+ const std::string token_text = token_with_probs.tok == -1 ? " " : llama_token_to_str (llama.ctx , token_with_probs.tok );
1241
1305
1242
- stop_pos = llama.findStoppingStrings (llama.generated_text ,
1243
- token_text.size (), STOP_FULL);
1244
- }
1306
+ stop_pos = llama.findStoppingStrings (llama.generated_text ,
1307
+ token_text.size (), STOP_FULL);
1308
+ }
1245
1309
1246
- if (stop_pos == std::string::npos) {
1247
- stop_pos = llama.findStoppingStrings (llama.generated_text , 0 , STOP_PARTIAL);
1248
- }
1249
- if (stop_pos != std::string::npos) {
1250
- llama.generated_text .erase (llama.generated_text .begin () + stop_pos,
1251
- llama.generated_text .end ());
1310
+ if (stop_pos == std::string::npos) {
1311
+ stop_pos = llama.findStoppingStrings (llama.generated_text , 0 , STOP_PARTIAL);
1312
+ }
1313
+ if (stop_pos != std::string::npos) {
1314
+ llama.generated_text .erase (llama.generated_text .begin () + stop_pos,
1315
+ llama.generated_text .end ());
1316
+ }
1252
1317
}
1253
1318
1254
1319
const json data = format_final_response (llama, llama.generated_text , llama.generated_token_probs );
0 commit comments