Skip to content

Commit 2d9fdbd

Browse files
committed
Cleanup: Encapsulate beam search functions into struct beam_search.
1 parent f57f7c2 commit 2d9fdbd

File tree

1 file changed

+141
-118
lines changed

1 file changed

+141
-118
lines changed

llama.cpp

Lines changed: 141 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include <ctime>
3737
#include <cinttypes>
3838
#include <fstream>
39+
#include <functional>
3940
#include <random>
4041
#include <map>
4142
#include <unordered_map>
@@ -2882,10 +2883,10 @@ struct beam {
28822883
float p; // Cumulative beam probability (renormalized with each token)
28832884
// end-of-sentence
28842885
bool eos() const { return !tokens.empty() && tokens.back() == llama_token_eos(); }
2885-
// Shift off first n tokens to the end of trunk.
2886-
void shift_tokens(std::vector<llama_token>& trunk, int const n) {
2887-
trunk.resize(trunk.size() + n);
2888-
std::copy(tokens.begin(), tokens.begin() + n, trunk.end() - n);
2886+
// Shift off first n tokens to the end of dest.
2887+
void shift_tokens(std::vector<llama_token>& dest, int const n) {
2888+
dest.resize(dest.size() + n);
2889+
std::copy(tokens.begin(), tokens.begin() + n, dest.end() - n);
28892890
shift_tokens(n);
28902891
}
28912892
// Shift off first n tokens and discard them.
@@ -2895,9 +2896,9 @@ struct beam {
28952896
}
28962897
};
28972898

2898-
void out_beam(std::ostream& os, llama_context* ctx, beam const& b) {
2899-
os << "p(" << b.p << ") eos(" << std::boolalpha << b.eos() << ") tokens(";
2900-
for (llama_token const token_id : b.tokens) {
2899+
void out_beam(std::ostream& os, llama_context* ctx, beam const& beam) {
2900+
os << "p(" << beam.p << ") eos(" << std::boolalpha << beam.eos() << ") tokens(";
2901+
for (llama_token const token_id : beam.tokens) {
29012902
os << llama_token_to_str(ctx, token_id);
29022903
}
29032904
os << ')';
@@ -2948,139 +2949,163 @@ struct logit_info {
29482949
}
29492950
};
29502951

2951-
// Track beams common prefix and when llama_eval has been applied with it.
2952-
struct beams_state {
2952+
struct beam_search {
2953+
llama_context * ctx;
29532954
int beam_width;
2955+
int n_past;
2956+
int n_predict;
2957+
int n_threads;
2958+
std::vector<beam> beams;
2959+
std::vector<beam> next_beams;
2960+
2961+
// Re-calculated on each loop iteration
29542962
int common_prefix_length;
2955-
int& n_past;
2956-
bool shifted;
2957-
std::vector<llama_token> trunk; // Save token prefix common to all beams here
2958-
beams_state(int beam_width, int& n_past, int n_predict)
2959-
: beam_width(beam_width)
2960-
, n_past(n_past) {
2961-
trunk.reserve(n_predict);
2962-
}
2963-
2964-
// Set common_prefix_length based on beams.
2965-
void find_common_prefix(std::vector<beam>& beams) {
2966-
shifted = false;
2967-
common_prefix_length = int(beams[0].tokens.size());
2963+
// true iff llama_eval() has been called with common prefix in current loop iteration.
2964+
bool common_prefix_evaluated;
2965+
// Save token prefix common to all beams here
2966+
std::vector<llama_token> response;
2967+
2968+
beam_search(llama_context * ctx, int beam_width, int n_past, int n_predict, int n_threads)
2969+
: ctx(ctx)
2970+
, beam_width(beam_width)
2971+
, n_past(n_past)
2972+
, n_predict(n_predict)
2973+
, n_threads(n_threads) {
2974+
beams.reserve(beam_width);
2975+
next_beams.reserve(beam_width);
2976+
}
2977+
2978+
// Find common_prefix_length based on beams.
2979+
// Requires beams is not empty.
2980+
int find_common_prefix_length() {
2981+
int common_prefix_length = int(beams[0].tokens.size());
29682982
for (int i=1 ; i<int(beams.size()) ; ++i) {
29692983
int const j_max = std::min(common_prefix_length, int(beams[i].tokens.size()));
29702984
for (int j=0 ; j<j_max ; ++j) {
29712985
if (beams[0].tokens[j] != beams[i].tokens[j]) {
2972-
common_prefix_length = j;
2973-
break;
2986+
return j;
29742987
}
29752988
}
29762989
}
2977-
}
2978-
};
2979-
2980-
void fill_next_beams_by_top_probabilities(llama_context* ctx, std::vector<beam>& next_beams,
2981-
beam& b, beams_state& beams_state, int const n_threads) {
2982-
auto const comp = [](beam const& a, beam const& b) { return a.p > b.p; };
2983-
if (beams_state.shifted) {
2984-
// llama_eval was already called during this iteration
2985-
// with the common token prefix, so shift it off this beam.
2986-
b.shift_tokens(beams_state.common_prefix_length);
2987-
}
2988-
if (b.eos()) {
2989-
// b is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
2990-
if (next_beams.size() < static_cast<size_t>(beams_state.beam_width)) {
2991-
next_beams.push_back(b);
2992-
if (next_beams.size() == static_cast<size_t>(beams_state.beam_width)) {
2993-
std::make_heap(next_beams.begin(), next_beams.end(), comp);
2994-
}
2995-
} else if (next_beams.front().p < b.p) {
2996-
std::pop_heap(next_beams.begin(), next_beams.end(), comp);
2997-
next_beams.back() = std::move(b);
2998-
std::push_heap(next_beams.begin(), next_beams.end(), comp);
2990+
return common_prefix_length;
2991+
}
2992+
2993+
// Min-heaps are used to efficiently gather the top-k elements (k=beam_width).
2994+
// The repetative patterns below reflect the 2 stages of heaps:
2995+
// * Gather elements until the vector is full, then call std::make_heap() on it.
2996+
// * If the heap is full and a new element is found that should be included,
2997+
// pop off the least element, replace it with the new, then push it into the heap.
2998+
void fill_next_beams_by_top_probabilities(beam& b) {
2999+
// Min-heaps use a greater-than comparator.
3000+
auto const comp = [](beam const& a, beam const& b) { return a.p > b.p; };
3001+
if (common_prefix_evaluated) {
3002+
// llama_eval was already called during this iteration
3003+
// with the common token prefix, so shift it off this beam.
3004+
b.shift_tokens(common_prefix_length);
29993005
}
3000-
} else {
3001-
// b is not at end-of-sentence, so branch with next top_k tokens.
3002-
if (!b.tokens.empty()) {
3003-
llama_eval(ctx, b.tokens.data(), b.tokens.size(), beams_state.n_past, n_threads);
3004-
if (!beams_state.shifted && beams_state.common_prefix_length) {
3005-
b.shift_tokens(beams_state.trunk, beams_state.common_prefix_length);
3006-
beams_state.n_past += beams_state.common_prefix_length;
3007-
beams_state.shifted = true;
3008-
}
3009-
}
3010-
logit_info li(ctx);
3011-
std::vector<llama_token_data> next_tokens = li.top_k(beams_state.beam_width);
3012-
int i=0;
3013-
if (next_beams.size() < static_cast<size_t>(beams_state.beam_width)) {
3014-
for (; next_beams.size() < static_cast<size_t>(beams_state.beam_width) ; ++i) {
3015-
beam next_beam = b;
3016-
next_beam.tokens.push_back(next_tokens[i].id);
3017-
next_beam.p *= li.probability_from_logit(next_tokens[i].logit);
3018-
next_beams.push_back(std::move(next_beam));
3019-
}
3020-
std::make_heap(next_beams.begin(), next_beams.end(), comp);
3021-
} else {
3022-
for (; next_beams.front().p == 0.0f ; ++i) {
3006+
if (b.eos()) {
3007+
// beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
3008+
if (next_beams.size() < static_cast<size_t>(beam_width)) {
3009+
next_beams.push_back(b);
3010+
if (next_beams.size() == static_cast<size_t>(beam_width)) {
3011+
std::make_heap(next_beams.begin(), next_beams.end(), comp);
3012+
}
3013+
} else if (next_beams.front().p < b.p) {
30233014
std::pop_heap(next_beams.begin(), next_beams.end(), comp);
3024-
next_beams.back() = b;
3025-
next_beams.back().tokens.push_back(next_tokens[i].id);
3026-
next_beams.back().p *= li.probability_from_logit(next_tokens[i].logit);
3015+
next_beams.back() = std::move(b);
30273016
std::push_heap(next_beams.begin(), next_beams.end(), comp);
30283017
}
3018+
} else {
3019+
// beam is not at end-of-sentence, so branch with next top_k tokens.
3020+
if (!b.tokens.empty()) {
3021+
llama_eval(ctx, b.tokens.data(), b.tokens.size(), n_past, n_threads);
3022+
if (!common_prefix_evaluated && common_prefix_length) {
3023+
b.shift_tokens(response, common_prefix_length);
3024+
n_past += common_prefix_length;
3025+
common_prefix_evaluated = true;
3026+
}
3027+
}
3028+
logit_info logit_info(ctx);
3029+
std::vector<llama_token_data> next_tokens = logit_info.top_k(beam_width);
3030+
int i=0;
3031+
if (next_beams.size() < static_cast<size_t>(beam_width)) {
3032+
for (; next_beams.size() < static_cast<size_t>(beam_width) ; ++i) {
3033+
beam next_beam = b;
3034+
next_beam.tokens.push_back(next_tokens[i].id);
3035+
next_beam.p *= logit_info.probability_from_logit(next_tokens[i].logit);
3036+
next_beams.push_back(std::move(next_beam));
3037+
}
3038+
std::make_heap(next_beams.begin(), next_beams.end(), comp);
3039+
} else {
3040+
for (; next_beams.front().p == 0.0f ; ++i) {
3041+
std::pop_heap(next_beams.begin(), next_beams.end(), comp);
3042+
next_beams.back() = b;
3043+
next_beams.back().tokens.push_back(next_tokens[i].id);
3044+
next_beams.back().p *= logit_info.probability_from_logit(next_tokens[i].logit);
3045+
std::push_heap(next_beams.begin(), next_beams.end(), comp);
3046+
}
3047+
}
3048+
for (; i < beam_width ; ++i) {
3049+
float const next_p = b.p * logit_info.probability_from_logit(next_tokens[i].logit);
3050+
if (next_beams.front().p < next_p) {
3051+
std::pop_heap(next_beams.begin(), next_beams.end(), comp);
3052+
next_beams.back() = b;
3053+
next_beams.back().tokens.push_back(next_tokens[i].id);
3054+
next_beams.back().p = next_p;
3055+
std::push_heap(next_beams.begin(), next_beams.end(), comp);
3056+
}
3057+
}
30293058
}
3030-
for (; i < beams_state.beam_width ; ++i) {
3031-
float const next_p = b.p * li.probability_from_logit(next_tokens[i].logit);
3032-
if (next_beams.front().p < next_p) {
3033-
std::pop_heap(next_beams.begin(), next_beams.end(), comp);
3034-
next_beams.back() = b;
3035-
next_beams.back().tokens.push_back(next_tokens[i].id);
3036-
next_beams.back().p = next_p;
3037-
std::push_heap(next_beams.begin(), next_beams.end(), comp);
3059+
}
3060+
3061+
// Loop:
3062+
// * while i < n_predict
3063+
// * until all of the beams have nreached end-of-sentence
3064+
// * until the highest probability beam is at end-of-sentence
3065+
// (since all other beam probabilities can only decrease)
3066+
void loop(std::function<void(std::vector<beam>&)> const callback) {
3067+
beams.push_back({{}, 1.0f});
3068+
auto const eos = [](beam const& beam) { return beam.eos(); };
3069+
for (int i=0 ; i<n_predict && !std::all_of(beams.begin(),beams.end(),eos) && !eos(top_beam()) ; ++i) {
3070+
common_prefix_evaluated = false;
3071+
common_prefix_length = find_common_prefix_length();
3072+
for (beam& beam : beams) {
3073+
fill_next_beams_by_top_probabilities(beam);
30383074
}
3075+
beams.swap(next_beams);
3076+
renormalize_beam_probabilities(beams);
3077+
std::for_each(next_beams.begin(), next_beams.end(), [](beam& beam) { beam.p = 0.0f; });
3078+
callback(beams);
30393079
}
3080+
beam& top_b = top_beam();
3081+
top_b.shift_tokens(response, top_b.tokens.size());
30403082
}
3041-
}
30423083

3043-
// As beams grow, the cumulative probabilities decrease.
3044-
// Renormalize them to avoid floating point underflow.
3045-
void renormalize_beam_probabilities(std::vector<beam>& beams) {
3046-
auto const sum_p = [](float sum, beam& b) { return sum + b.p; };
3047-
float const inv_sum = 1.0f / std::accumulate(beams.begin(), beams.end(), 0.0f, sum_p);
3048-
std::for_each(beams.begin(), beams.end(), [inv_sum](beam& b) { b.p *= inv_sum; });
3049-
}
3084+
// As beams grow, the cumulative probabilities decrease.
3085+
// Renormalize them to avoid floating point underflow.
3086+
static void renormalize_beam_probabilities(std::vector<beam>& beams) {
3087+
auto const sum_p = [](float sum, beam& beam) { return sum + beam.p; };
3088+
float const inv_sum = 1.0f / std::accumulate(beams.begin(), beams.end(), 0.0f, sum_p);
3089+
std::for_each(beams.begin(), beams.end(), [inv_sum](beam& beam) { beam.p *= inv_sum; });
3090+
}
30503091

3051-
// Return beam with highest probability.
3052-
beam& top_beam(std::vector<beam>& beams) {
3053-
auto const by_p = [](beam const& a, beam const& b) { return a.p < b.p; };
3054-
return *std::max_element(beams.begin(), beams.end(), by_p);
3055-
}
3092+
// Return beam with highest probability.
3093+
beam& top_beam() {
3094+
auto const by_p = [](beam const& a, beam const& b) { return a.p < b.p; };
3095+
return *std::max_element(beams.begin(), beams.end(), by_p);
3096+
}
3097+
};
30563098

3057-
// This is deterministic, but can be made probabilistic in
3058-
// fill_next_beams_by_top_probabilities() by randomly selecting from all next_beams.
30593099
// Not thread-safe.
3060-
const char* llama_beam_search(llama_context * ctx, int const beam_width,
3100+
const char* llama_beam_search(llama_context * ctx, int beam_width,
30613101
int n_past, int const n_predict, int const n_threads) {
30623102
static std::string beam_search_response;
30633103
assert(ctx);
30643104
const int64_t t_start_sample_us = ggml_time_us();
30653105

3066-
std::vector<beam> beams;
3067-
beams.reserve(beam_width);
3068-
beams.push_back({{}, 1.0});
3069-
std::vector<beam> next_beams;
3070-
next_beams.reserve(beam_width);
3071-
beams_state beams_state(beam_width, n_past, n_predict);
3072-
// Loop while there are any beams that have not yet reached end-of-sentence.
3073-
// If the highest probability beam is at end-of-sentence, then finish since all other
3074-
// beam probabilities can only decrease.
3075-
auto const eos = [](beam const& b) { return b.eos(); };
3076-
for (int i=0 ; i<n_predict && !eos(top_beam(beams)) && !std::all_of(beams.begin(),beams.end(),eos) ; ++i) {
3077-
beams_state.find_common_prefix(beams);
3078-
for (beam& b : beams) {
3079-
fill_next_beams_by_top_probabilities(ctx, next_beams, b, beams_state, n_threads);
3080-
}
3081-
beams.swap(next_beams);
3082-
std::for_each(next_beams.begin(), next_beams.end(), [](beam& b) { b.p = 0.0f; });
3083-
renormalize_beam_probabilities(beams);
3106+
beam_search beam_search(ctx, beam_width, n_past, n_predict, n_threads);
3107+
3108+
beam_search.loop([&](std::vector<beam>& beams) {
30843109
#if 1 // DEBUG: print current beams for this iteration
30853110
std::cout << "\n\nCurrent beams:\n";
30863111
for (size_t j=0 ; j < beams.size() ; ++j) {
@@ -3091,13 +3116,11 @@ const char* llama_beam_search(llama_context * ctx, int const beam_width,
30913116
#else
30923117
std::cout << '.' << std::flush; // Show progress
30933118
#endif
3094-
}
3119+
});
30953120

3096-
beam& top_b = top_beam(beams);
3097-
top_b.shift_tokens(beams_state.trunk, top_b.tokens.size());
30983121
// Save beam sentence to beam_search_response. Is there a better way?
30993122
std::ostringstream oss;
3100-
for (llama_token const token : beams_state.trunk) {
3123+
for (llama_token const token : beam_search.response) {
31013124
oss << llama_token_to_str(ctx, token);
31023125
}
31033126
beam_search_response = oss.str();

0 commit comments

Comments
 (0)