Skip to content

Commit 1528660

Browse files
committed
Add llama_beam_search().
1 parent cc34dbd commit 1528660

File tree

7 files changed

+554
-13
lines changed

7 files changed

+554
-13
lines changed

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ struct gpt_params {
2828
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
2929
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
3030
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
31+
int32_t n_beams = 0; // if non-zero then use beam search of given width.
3132
float rope_freq_base = 10000.0f; // RoPE base frequency
3233
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor
3334

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ else()
2525
add_subdirectory(simple)
2626
add_subdirectory(embd-input)
2727
add_subdirectory(llama-bench)
28+
add_subdirectory(beam_search)
2829
if (LLAMA_METAL)
2930
add_subdirectory(metal)
3031
endif()

examples/beam_search/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
set(TARGET beam_search)
2+
add_executable(${TARGET} beam_search.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_11)
6+
if(TARGET BUILD_INFO)
7+
add_dependencies(${TARGET} BUILD_INFO)
8+
endif()

examples/beam_search/beam_search.cpp

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
#ifndef _GNU_SOURCE
2+
#define _GNU_SOURCE
3+
#endif
4+
5+
#include "common.h"
6+
#include "llama.h"
7+
#include "build-info.h"
8+
9+
#include <cassert>
10+
#include <cinttypes>
11+
#include <cmath>
12+
#include <cstdio>
13+
#include <cstring>
14+
#include <ctime>
15+
#include <fstream>
16+
#include <iostream>
17+
#include <string>
18+
#include <vector>
19+
20+
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
21+
#include <signal.h>
22+
#include <unistd.h>
23+
#elif defined (_WIN32)
24+
#define WIN32_LEAN_AND_MEAN
25+
#define NOMINMAX
26+
#include <windows.h>
27+
#include <signal.h>
28+
#endif
29+
30+
// Used for debugging to print out beam tokens.
31+
struct ostream_beam_view {
32+
llama_context* ctx;
33+
llama_beam_view beam_view;
34+
};
35+
std::ostream& operator<<(std::ostream& os, ostream_beam_view const& obv) {
36+
os << "p(" << obv.beam_view.p << ") eos(" << std::boolalpha << obv.beam_view.eos << ") tokens(";
37+
for (size_t i=0 ; i<obv.beam_view.n_tokens ; ++i) {
38+
os << llama_token_to_str(obv.ctx, obv.beam_view.tokens[i]);
39+
}
40+
return os << ')';
41+
}
42+
43+
// Put here anything you want back in beam_search_callback().
44+
struct beam_search_callback_data {
45+
llama_context* ctx;
46+
std::vector<llama_token> response;
47+
};
48+
49+
bool is_at_eos(beam_search_callback_data const& callback_data, llama_token const* tokens, size_t const n_tokens) {
50+
return n_tokens && tokens[n_tokens-1] == llama_token_eos(callback_data.ctx);
51+
}
52+
53+
// Function matching type llama_beam_search_callback_fn_t.
54+
// Custom callback example is called each time the beams lengths increase:
55+
// * Show progress by printing ',' following by number of convergent beam tokens if any.
56+
// * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
57+
// This is also called when the stop condition is met.
58+
// Collect tokens into std::vector<llama_token> response which is pointed to by callback_data.
59+
void beam_search_callback(void* callback_data_ptr, llama_beams_state beams_state) {
60+
auto& callback_data = *static_cast<beam_search_callback_data*>(callback_data_ptr);
61+
// Mark beams as EOS as needed.
62+
for (size_t i=0 ; i<beams_state.n_beams ; ++i) {
63+
llama_beam_view& beam_view = beams_state.beam_views[i];
64+
if (!beam_view.eos && is_at_eos(callback_data, beam_view.tokens, beam_view.n_tokens)) {
65+
beam_view.eos = true;
66+
}
67+
}
68+
printf(","); // Show progress
69+
if (size_t const n = beams_state.common_prefix_length) {
70+
callback_data.response.resize(callback_data.response.size() + n);
71+
assert(0u < beams_state.n_beams);
72+
llama_token const* tokens = beams_state.beam_views[0].tokens;
73+
std::copy(tokens, tokens + n, callback_data.response.end() - n);
74+
printf("%lu", n);
75+
}
76+
fflush(stdout);
77+
#if 1 // DEBUG: print current beams for this iteration
78+
std::cout << "\n\nCurrent beams (last_call=" << beams_state.last_call << "):\n";
79+
for (size_t i=0 ; i < beams_state.n_beams ; ++i) {
80+
std::cout << "beams["<<i<<"]: " << ostream_beam_view{callback_data.ctx,beams_state.beam_views[i]} << std::endl;
81+
}
82+
#endif
83+
}
84+
85+
int main(int argc, char ** argv)
86+
{
87+
gpt_params params;
88+
//params.n_gpu_layers = 200;
89+
90+
//---------------------------------
91+
// Print help :
92+
//---------------------------------
93+
94+
if ( argc < 2 || argv[1][0] == '-' )
95+
{
96+
printf( "Usage: %s MODEL_PATH [BEAM_WIDTH=2] [PROMPT]\n" , argv[0] );
97+
return 1 ;
98+
}
99+
100+
//---------------------------------
101+
// Load parameters :
102+
//---------------------------------
103+
104+
params.model = argv[1];
105+
106+
params.n_beams = 2 < argc ? std::stoi(argv[2]) : 2;
107+
108+
if ( argc > 3 )
109+
{
110+
params.prompt = argv[3];
111+
}
112+
113+
if ( params.prompt.empty() )
114+
{
115+
params.prompt = "### Request:\nHow many countries are there?\n\n### Response:\n";
116+
}
117+
118+
//---------------------------------
119+
// Init LLM :
120+
//---------------------------------
121+
122+
llama_backend_init(params.numa);
123+
124+
llama_model * model;
125+
llama_context * ctx;
126+
127+
std::tie(model, ctx) = llama_init_from_gpt_params( params );
128+
129+
if ( model == NULL )
130+
{
131+
fprintf( stderr , "%s: error: unable to load model\n" , __func__ );
132+
return 1;
133+
}
134+
135+
//---------------------------------
136+
// Tokenize the prompt :
137+
//---------------------------------
138+
139+
std::vector<llama_token> tokens_list = llama_tokenize(ctx, params.prompt, true);
140+
141+
const size_t max_context_size = llama_n_ctx( ctx );
142+
const size_t max_tokens_list_size = max_context_size - 4 ;
143+
144+
if (tokens_list.size() > max_tokens_list_size)
145+
{
146+
fprintf( stderr , "%s: error: prompt too long (%lu tokens, max %lu)\n" ,
147+
__func__ , tokens_list.size() , max_tokens_list_size );
148+
return 1;
149+
}
150+
151+
fprintf( stderr, "\n\n" );
152+
153+
// Print the tokens from the prompt :
154+
155+
for( auto id : tokens_list )
156+
{
157+
std::cout << llama_token_to_str(ctx, id);
158+
}
159+
std::cout << std::flush;
160+
161+
int n_past = llama_get_kv_cache_token_count(ctx);
162+
if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads))
163+
{
164+
fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ );
165+
return 1;
166+
}
167+
n_past += tokens_list.size();
168+
169+
beam_search_callback_data callback_data{ctx, {}};
170+
size_t const beam_width = static_cast<size_t>(params.n_beams);
171+
int const n_predict = 256;
172+
llama_beam_search(ctx, beam_search_callback, &callback_data, beam_width, n_past, n_predict, params.n_threads);
173+
174+
std::cout << "\n\n";
175+
for (llama_token const token_id : callback_data.response) {
176+
std::cout << llama_token_to_str(ctx,token_id);
177+
}
178+
std::cout << std::endl;
179+
180+
llama_free( ctx );
181+
llama_free_model( model );
182+
183+
llama_backend_free();
184+
185+
return 0;
186+
}

examples/server/server.cpp

Lines changed: 78 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,6 +1208,63 @@ static void log_server_request(const Request &req, const Response &res)
12081208
});
12091209
}
12101210

1211+
bool is_at_eos(llama_server_context& server_context, llama_token const* tokens, size_t const n_tokens) {
1212+
return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context.ctx);
1213+
}
1214+
1215+
// Function matching type llama_beam_search_callback_fn_t.
1216+
// Custom callback example is called each time the beams lengths increase:
1217+
// * Show progress by printing ',' following by number of convergent beam tokens if any.
1218+
// * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
1219+
// This is also called when the stop condition is met.
1220+
// Collect tokens into std::vector<llama_token> response which is pointed to by callback_data.
1221+
void beam_search_callback(void* callback_data, llama_beams_state beams_state) {
1222+
auto& llama = *static_cast<llama_server_context*>(callback_data);
1223+
// Mark beams as EOS as needed.
1224+
for (size_t i=0 ; i<beams_state.n_beams ; ++i) {
1225+
llama_beam_view& beam_view = beams_state.beam_views[i];
1226+
if (!beam_view.eos && is_at_eos(llama, beam_view.tokens, beam_view.n_tokens)) {
1227+
beam_view.eos = true;
1228+
}
1229+
}
1230+
printf(","); // Show progress
1231+
if (size_t const n = beams_state.common_prefix_length) {
1232+
llama.generated_token_probs.resize(llama.generated_token_probs.size() + n);
1233+
assert(0u < beams_state.n_beams);
1234+
llama_token const* tokens = beams_state.beam_views[0].tokens;
1235+
//std::copy(tokens, tokens + n, llama->generated_token_probs.end() - n);
1236+
auto const map = [](llama_token tok) { return completion_token_output{{},tok}; };
1237+
std::transform(tokens, tokens + n, llama.generated_token_probs.end() - n, map);
1238+
printf("%lu", n);
1239+
}
1240+
fflush(stdout);
1241+
#if 0 // DEBUG: print current beams for this iteration
1242+
std::cout << "\n\nCurrent beams:\n";
1243+
for (size_t i=0 ; i < beams_state.n_beams ; ++i) {
1244+
std::cout << "beams["<<i<<"]: " << ostream_beam_view{state.ctx,beams_state.beam_views[i]} << std::endl;
1245+
}
1246+
#endif
1247+
}
1248+
1249+
struct token_translator {
1250+
llama_context* ctx;
1251+
std::string operator()(llama_token tok) const { return llama_token_to_str(ctx, tok); }
1252+
std::string operator()(completion_token_output cto) const { return (*this)(cto.tok); }
1253+
};
1254+
1255+
void append_to_generated_text_from_generated_token_probs(llama_server_context& llama) {
1256+
auto& gtps = llama.generated_token_probs;
1257+
auto translator = token_translator{llama.ctx};
1258+
auto add_strlen = [=](size_t sum, completion_token_output const& cto) { return sum + translator(cto).size(); };
1259+
size_t const len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen);
1260+
if (llama.generated_text.capacity() < llama.generated_text.size() + len) {
1261+
llama.generated_text.reserve(llama.generated_text.size() + len);
1262+
}
1263+
for (completion_token_output const& cto : gtps) {
1264+
llama.generated_text += translator(cto);
1265+
}
1266+
}
1267+
12111268
int main(int argc, char **argv)
12121269
{
12131270
// own arguments required by this example
@@ -1290,22 +1347,30 @@ int main(int argc, char **argv)
12901347
llama.beginCompletion();
12911348

12921349
if (!llama.stream) {
1293-
size_t stop_pos = std::string::npos;
1350+
if (llama.params.n_beams) {
1351+
// Fill llama.generated_token_probs vector with final beam.
1352+
llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams,
1353+
llama.n_past, llama.n_remain, llama.params.n_threads);
1354+
// Translate llama.generated_token_probs to llama.generated_text.
1355+
append_to_generated_text_from_generated_token_probs(llama);
1356+
} else {
1357+
size_t stop_pos = std::string::npos;
12941358

1295-
while (llama.has_next_token) {
1296-
const completion_token_output token_with_probs = llama.doCompletion();
1297-
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama.ctx, token_with_probs.tok);
1359+
while (llama.has_next_token) {
1360+
const completion_token_output token_with_probs = llama.doCompletion();
1361+
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama.ctx, token_with_probs.tok);
12981362

1299-
stop_pos = llama.findStoppingStrings(llama.generated_text,
1300-
token_text.size(), STOP_FULL);
1301-
}
1363+
stop_pos = llama.findStoppingStrings(llama.generated_text,
1364+
token_text.size(), STOP_FULL);
1365+
}
13021366

1303-
if (stop_pos == std::string::npos) {
1304-
stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL);
1305-
}
1306-
if (stop_pos != std::string::npos) {
1307-
llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
1308-
llama.generated_text.end());
1367+
if (stop_pos == std::string::npos) {
1368+
stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL);
1369+
}
1370+
if (stop_pos != std::string::npos) {
1371+
llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
1372+
llama.generated_text.end());
1373+
}
13091374
}
13101375

13111376
const json data = format_final_response(llama, llama.generated_text, llama.generated_token_probs);

0 commit comments

Comments
 (0)