Skip to content

Commit db9657c

Browse files
committed
Add beam search to server non-streaming completion posts.
1 parent 7470edd commit db9657c

File tree

1 file changed

+78
-13
lines changed

1 file changed

+78
-13
lines changed

examples/server/server.cpp

Lines changed: 78 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,6 +1157,63 @@ static void log_server_request(const Request &req, const Response &res)
11571157
});
11581158
}
11591159

1160+
bool is_at_eos(llama_server_context&, llama_token const* tokens, size_t const n_tokens) {
1161+
return n_tokens && tokens[n_tokens-1] == llama_token_eos();
1162+
}
1163+
1164+
// Function matching type llama_beam_search_callback_fn_t.
1165+
// Custom callback example is called each time the beams lengths increase:
1166+
// * Show progress by printing ',' following by number of convergent beam tokens if any.
1167+
// * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
1168+
// This is also called when the stop condition is met.
1169+
// Collect tokens into std::vector<llama_token> response which is pointed to by callback_state.
1170+
void beam_search_callback(void* callback_state, llama_beams_state beams_state) {
1171+
auto& llama = *static_cast<llama_server_context*>(callback_state);
1172+
// Mark beams as EOS as needed.
1173+
for (size_t i=0 ; i<beams_state.n_beams ; ++i) {
1174+
llama_beam_view& beam_view = beams_state.beam_views[i];
1175+
if (!beam_view.eos && is_at_eos(llama, beam_view.tokens, beam_view.n_tokens)) {
1176+
beam_view.eos = true;
1177+
}
1178+
}
1179+
printf(","); // Show progress
1180+
if (size_t const n = beams_state.common_prefix_length) {
1181+
llama.generated_token_probs.resize(llama.generated_token_probs.size() + n);
1182+
assert(0u < beams_state.n_beams);
1183+
llama_token const* tokens = beams_state.beam_views[0].tokens;
1184+
//std::copy(tokens, tokens + n, llama->generated_token_probs.end() - n);
1185+
auto const map = [](llama_token tok) { return completion_token_output{{},tok}; };
1186+
std::transform(tokens, tokens + n, llama.generated_token_probs.end() - n, map);
1187+
printf("%lu", n);
1188+
}
1189+
fflush(stdout);
1190+
#if 0 // DEBUG: print current beams for this iteration
1191+
std::cout << "\n\nCurrent beams:\n";
1192+
for (size_t i=0 ; i < beams_state.n_beams ; ++i) {
1193+
std::cout << "beams["<<i<<"]: " << ostream_beam_view{state.ctx,beams_state.beam_views[i]} << std::endl;
1194+
}
1195+
#endif
1196+
}
1197+
1198+
struct token_translator {
1199+
llama_context* ctx;
1200+
char const* operator()(llama_token tok) const { return llama_token_to_str(ctx, tok); }
1201+
char const* operator()(completion_token_output cto) const { return (*this)(cto.tok); }
1202+
};
1203+
1204+
void append_to_generated_text_from_generated_token_probs(llama_server_context& llama) {
1205+
auto& gtps = llama.generated_token_probs;
1206+
auto translator = token_translator{llama.ctx};
1207+
auto add_strlen = [=](size_t sum, completion_token_output const& cto) { return sum + strlen(translator(cto)); };
1208+
size_t const len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen);
1209+
if (llama.generated_text.capacity() < llama.generated_text.size() + len) {
1210+
llama.generated_text.reserve(llama.generated_text.size() + len);
1211+
}
1212+
for (completion_token_output const& cto : gtps) {
1213+
llama.generated_text += translator(cto);
1214+
}
1215+
}
1216+
11601217
int main(int argc, char **argv)
11611218
{
11621219
// own arguments required by this example
@@ -1233,22 +1290,30 @@ int main(int argc, char **argv)
12331290
llama.beginCompletion();
12341291

12351292
if (!llama.stream) {
1236-
size_t stop_pos = std::string::npos;
1293+
if (llama.params.n_beams) {
1294+
// Fill llama.generated_token_probs vector with final beam.
1295+
llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams,
1296+
llama.n_past, llama.n_remain, llama.params.n_threads);
1297+
// Translate llama.generated_token_probs to llama.generated_text.
1298+
append_to_generated_text_from_generated_token_probs(llama);
1299+
} else {
1300+
size_t stop_pos = std::string::npos;
12371301

1238-
while (llama.has_next_token) {
1239-
const completion_token_output token_with_probs = llama.doCompletion();
1240-
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama.ctx, token_with_probs.tok);
1302+
while (llama.has_next_token) {
1303+
const completion_token_output token_with_probs = llama.doCompletion();
1304+
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama.ctx, token_with_probs.tok);
12411305

1242-
stop_pos = llama.findStoppingStrings(llama.generated_text,
1243-
token_text.size(), STOP_FULL);
1244-
}
1306+
stop_pos = llama.findStoppingStrings(llama.generated_text,
1307+
token_text.size(), STOP_FULL);
1308+
}
12451309

1246-
if (stop_pos == std::string::npos) {
1247-
stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL);
1248-
}
1249-
if (stop_pos != std::string::npos) {
1250-
llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
1251-
llama.generated_text.end());
1310+
if (stop_pos == std::string::npos) {
1311+
stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL);
1312+
}
1313+
if (stop_pos != std::string::npos) {
1314+
llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
1315+
llama.generated_text.end());
1316+
}
12521317
}
12531318

12541319
const json data = format_final_response(llama, llama.generated_text, llama.generated_token_probs);

0 commit comments

Comments
 (0)