Skip to content

Commit fbbf0eb

Browse files
committed
Use llama_ prefix for struct names.
1 parent 35d7514 commit fbbf0eb

File tree

3 files changed

+41
-43
lines changed

3 files changed

+41
-43
lines changed

examples/beam_search/beam_search.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@
3030
// Used for debugging to print out beam tokens.
3131
struct ostream_beam_view {
3232
llama_context* ctx;
33-
beam_view bv;
33+
llama_beam_view beam_view;
3434
};
3535
std::ostream& operator<<(std::ostream& os, ostream_beam_view const& obv) {
36-
os << "p(" << obv.bv.p << ") eos(" << std::boolalpha << obv.bv.eos() << ") tokens(";
37-
for (size_t i=0 ; i<obv.bv.n_tokens ; ++i) {
38-
os << llama_token_to_str(obv.ctx, obv.bv.tokens[i]);
36+
os << "p(" << obv.beam_view.p << ") eos(" << std::boolalpha << obv.beam_view.eos() << ") tokens(";
37+
for (size_t i=0 ; i<obv.beam_view.n_tokens ; ++i) {
38+
os << llama_token_to_str(obv.ctx, obv.beam_view.tokens[i]);
3939
}
4040
return os << ')';
4141
}
@@ -52,7 +52,7 @@ struct beam_search_callback_state {
5252
// * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
5353
// This is also called when the stop condition is met.
5454
// Collect tokens into std::vector<llama_token> response which is pointed to by callback_state.
55-
beam_search_control beam_search_callback(void* callback_state, beams_state const beams_state) {
55+
llama_beam_search_control beam_search_callback(void* callback_state, llama_beams_state const beams_state) {
5656
auto const state = *static_cast<beam_search_callback_state*>(callback_state);
5757
printf(","); // Show progress
5858
if (size_t const n = beams_state.common_prefix_length) {
@@ -69,11 +69,10 @@ beam_search_control beam_search_callback(void* callback_state, beams_state const
6969
std::cout << "beams["<<i<<"]: " << ostream_beam_view{state.ctx,beams_state.beam_views[i]} << std::endl;
7070
}
7171
#endif
72-
beam_search_control control {
72+
return llama_beam_search_control{
7373
beams_state.n_beams, // = collapse_to. Any index out of range means do not collapse beams.
7474
false // = stop. Don't stop beam search.
7575
};
76-
return control;
7776
}
7877

7978
int main(int argc, char ** argv)

llama.cpp

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
#include <ctime>
3737
#include <cinttypes>
3838
#include <fstream>
39-
#include <functional>
4039
#include <random>
4140
#include <map>
4241
#include <unordered_map>
@@ -2876,7 +2875,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
28762875
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
28772876
}
28782877

2879-
struct beam {
2878+
struct llama_beam {
28802879
std::vector<llama_token> tokens;
28812880
float p; // Cumulative beam probability (renormalized relative to all beams)
28822881
// end-of-sentence
@@ -2939,16 +2938,16 @@ struct beam_search {
29392938
int n_past;
29402939
int n_predict;
29412940
int n_threads;
2942-
std::vector<beam> beams;
2943-
std::vector<beam> next_beams;
2941+
std::vector<llama_beam> beams;
2942+
std::vector<llama_beam> next_beams;
29442943

29452944
// Re-calculated on each loop iteration
29462945
size_t common_prefix_length;
29472946
// true iff llama_eval() has been called with non-empty common prefix in current loop iteration.
29482947
bool common_prefix_evaluated;
29492948

2950-
// Temporary memory used by beams_state to pass back via callback.
2951-
std::vector<beam_view> beam_views;
2949+
// Temporary memory used by llama_beams_state to pass back via callback.
2950+
std::vector<llama_beam_view> beam_views;
29522951

29532952
beam_search(llama_context * ctx, size_t beam_width, int n_past, int n_predict, int n_threads)
29542953
: ctx(ctx)
@@ -2974,32 +2973,32 @@ struct beam_search {
29742973
// * Gather elements until the vector is full, then call std::make_heap() on it.
29752974
// * If the heap is full and a new element is found that should be included, pop the
29762975
// least element to the back(), replace it with the new, then push it into the heap.
2977-
void fill_next_beams_by_top_probabilities(beam& b) {
2976+
void fill_next_beams_by_top_probabilities(llama_beam& beam) {
29782977
// Min-heaps use a greater-than comparator.
2979-
auto const comp = [](beam const& a, beam const& b) { return a.p > b.p; };
2978+
auto const comp = [](llama_beam const& a, llama_beam const& b) { return a.p > b.p; };
29802979
if (common_prefix_evaluated) {
29812980
// llama_eval was already called during this iteration
29822981
// with the common token prefix, so shift it off this beam.
2983-
b.shift_tokens(common_prefix_length);
2982+
beam.shift_tokens(common_prefix_length);
29842983
}
2985-
if (b.eos()) {
2984+
if (beam.eos()) {
29862985
// beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
29872986
if (next_beams.size() < beam_width) {
2988-
next_beams.push_back(std::move(b));
2987+
next_beams.push_back(std::move(beam));
29892988
if (next_beams.size() == beam_width) {
29902989
std::make_heap(next_beams.begin(), next_beams.end(), comp);
29912990
}
2992-
} else if (next_beams.front().p < b.p) {
2991+
} else if (next_beams.front().p < beam.p) {
29932992
std::pop_heap(next_beams.begin(), next_beams.end(), comp);
2994-
next_beams.back() = std::move(b);
2993+
next_beams.back() = std::move(beam);
29952994
std::push_heap(next_beams.begin(), next_beams.end(), comp);
29962995
}
29972996
} else {
29982997
// beam is not at end-of-sentence, so branch with next top_k tokens.
2999-
if (!b.tokens.empty()) {
3000-
llama_eval(ctx, b.tokens.data(), b.tokens.size(), n_past, n_threads);
2998+
if (!beam.tokens.empty()) {
2999+
llama_eval(ctx, beam.tokens.data(), beam.tokens.size(), n_past, n_threads);
30013000
if (!common_prefix_evaluated && common_prefix_length) {
3002-
b.shift_tokens(common_prefix_length);
3001+
beam.shift_tokens(common_prefix_length);
30033002
n_past += common_prefix_length;
30043003
common_prefix_evaluated = true;
30053004
}
@@ -3009,7 +3008,7 @@ struct beam_search {
30093008
size_t i=0;
30103009
if (next_beams.size() < beam_width) {
30113010
for (; next_beams.size() < beam_width ; ++i) {
3012-
beam next_beam = b;
3011+
llama_beam next_beam = beam;
30133012
next_beam.tokens.push_back(next_tokens[i].id);
30143013
next_beam.p *= logit_info.probability_from_logit(next_tokens[i].logit);
30153014
next_beams.push_back(std::move(next_beam));
@@ -3018,17 +3017,17 @@ struct beam_search {
30183017
} else {
30193018
for (; next_beams.front().p == 0.0f ; ++i) {
30203019
std::pop_heap(next_beams.begin(), next_beams.end(), comp);
3021-
next_beams.back() = b;
3020+
next_beams.back() = beam;
30223021
next_beams.back().tokens.push_back(next_tokens[i].id);
30233022
next_beams.back().p *= logit_info.probability_from_logit(next_tokens[i].logit);
30243023
std::push_heap(next_beams.begin(), next_beams.end(), comp);
30253024
}
30263025
}
30273026
for (; i < beam_width ; ++i) {
3028-
float const next_p = b.p * logit_info.probability_from_logit(next_tokens[i].logit);
3027+
float const next_p = beam.p * logit_info.probability_from_logit(next_tokens[i].logit);
30293028
if (next_beams.front().p < next_p) {
30303029
std::pop_heap(next_beams.begin(), next_beams.end(), comp);
3031-
next_beams.back() = b;
3030+
next_beams.back() = beam;
30323031
next_beams.back().tokens.push_back(next_tokens[i].id);
30333032
next_beams.back().p = next_p;
30343033
std::push_heap(next_beams.begin(), next_beams.end(), comp);
@@ -3055,9 +3054,9 @@ struct beam_search {
30553054

30563055
// Construct beams_state to send back to caller via the callback function.
30573056
// Side effect: set common_prefix_length = find_common_prefix_length();
3058-
beams_state get_beams_state(bool const last_call) {
3057+
llama_beams_state get_beams_state(bool const last_call) {
30593058
for (size_t i=0 ; i<beams.size() ; ++i) {
3060-
beam_views[i] = beam_view{beams[i].tokens.data(), beams[i].tokens.size(), beams[i].p};
3059+
beam_views[i] = llama_beam_view{beams[i].tokens.data(), beams[i].tokens.size(), beams[i].p};
30613060
}
30623061
common_prefix_length = find_common_prefix_length();
30633062
return {beam_views.data(), beams.size(), common_prefix_length, last_call};
@@ -3070,10 +3069,10 @@ struct beam_search {
30703069
// (since all other beam probabilities can only decrease)
30713070
void loop(llama_beam_search_callback_fn_t const callback, void* const callback_state) {
30723071
beams.push_back({{}, 1.0f}); // Start with one empty beam w/ probability = 1.0.
3073-
auto const not_eos = [](beam const& beam) { return !beam.eos(); };
3072+
auto const not_eos = [](llama_beam const& beam) { return !beam.eos(); };
30743073
for (int i=0 ; i<n_predict && std::any_of(beams.begin(),beams.end(),not_eos) &&
30753074
!beams[top_beam_index()].eos() ; ++i) {
3076-
beam_search_control const control = callback(callback_state, get_beams_state(false));
3075+
llama_beam_search_control const control = callback(callback_state, get_beams_state(false));
30773076
if (control.collapse_to < beams.size()) {
30783077
// Caller has manually selected a specific beam. Collapse beams into it.
30793078
collapse_beams(control.collapse_to);
@@ -3082,30 +3081,30 @@ struct beam_search {
30823081
break;
30833082
}
30843083
common_prefix_evaluated = false;
3085-
for (beam& beam : beams) {
3084+
for (llama_beam& beam : beams) {
30863085
fill_next_beams_by_top_probabilities(beam);
30873086
}
30883087
beams.swap(next_beams);
30893088
renormalize_beam_probabilities(beams);
3090-
std::for_each(next_beams.begin(), next_beams.end(), [](beam& beam) { beam.p = 0.0f; });
3089+
std::for_each(next_beams.begin(), next_beams.end(), [](llama_beam& beam) { beam.p = 0.0f; });
30913090
}
30923091
collapse_beams(top_beam_index());
30933092
callback(callback_state, get_beams_state(true));
30943093
}
30953094

30963095
// As beams grow, the cumulative probabilities decrease.
30973096
// Renormalize them to avoid floating point underflow.
3098-
static void renormalize_beam_probabilities(std::vector<beam>& beams) {
3099-
auto const sum_p = [](float sum, beam& beam) { return sum + beam.p; };
3097+
static void renormalize_beam_probabilities(std::vector<llama_beam>& beams) {
3098+
auto const sum_p = [](float sum, llama_beam& beam) { return sum + beam.p; };
31003099
float const inv_sum = 1.0f / std::accumulate(beams.begin(), beams.end(), 0.0f, sum_p);
3101-
std::for_each(beams.begin(), beams.end(), [=](beam& beam) { beam.p *= inv_sum; });
3100+
std::for_each(beams.begin(), beams.end(), [=](llama_beam& beam) { beam.p *= inv_sum; });
31023101
}
31033102

31043103
// Return index of highest ranking beam by (probability,eos()).
31053104
// In other words choose most probable beam. In case of ties, choose beam at end-of-sentence.
31063105
// Assumes beams is non-empty.
31073106
size_t top_beam_index() {
3108-
auto const by_p_and_eos = [](beam const& a, beam const& b) {
3107+
auto const by_p_and_eos = [](llama_beam const& a, llama_beam const& b) {
31093108
return a.p < b.p || (a.p == b.p && a.eos() < b.eos()); };
31103109
return std::max_element(beams.begin(), beams.end(), by_p_and_eos) - beams.begin();
31113110
}

llama.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ extern "C" {
444444
LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
445445

446446
// Lightweight view of a beam
447-
struct beam_view {
447+
struct llama_beam_view {
448448
llama_token const* tokens;
449449
size_t n_tokens;
450450
float p; // Cumulative beam probability (renormalized relative to all beams)
@@ -456,27 +456,27 @@ extern "C" {
456456
// Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
457457
// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
458458
// These pointers are valid only during the synchronous callback, so should not be saved.
459-
struct beams_state {
460-
beam_view* beam_views; // View of each beam.
459+
struct llama_beams_state {
460+
llama_beam_view* beam_views; // View of each beam.
461461
size_t n_beams; // Number of elements in beam_views[].
462462
size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
463463
bool last_call; // True iff this is the last callback invocation.
464464
};
465465
// Must be returned by beam_search_callback function.
466-
struct beam_search_control {
466+
struct llama_beam_search_control {
467467
size_t collapse_to; // Collapse to a beam index. Ignored if n_beams <= collapse_to.
468468
bool stop; // Stop beam search. Set to false to continue.
469469
};
470470
// Type of pointer to the beam_search_callback function.
471471
// void* callback_state is any custom data passed to llama_beam_search, that is subsequently
472472
// passed back to beam_search_callback. This avoids having to use global variables in the callback.
473-
typedef beam_search_control (*llama_beam_search_callback_fn_t)(void* callback_state, beams_state);
473+
typedef llama_beam_search_control (*llama_beam_search_callback_fn_t)(void* callback_state, llama_beams_state);
474474

475475
/// @details Deterministically returns entire sentence constructed by a beam search.
476476
/// @param ctx Pointer to the llama_context.
477477
/// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
478478
/// The return beam_search_control can be used to control the beam_search execution.
479-
/// @param callback_state A pointer that is passed back to callback and nothing more.
479+
/// @param callback_state A pointer that is simply passed back to callback.
480480
/// @param beam_width The number of parallel beams to use.
481481
/// @param n_past The number of tokens already evaluated.
482482
/// @param n_predict The maximum number of tokens to predict.

0 commit comments

Comments
 (0)