36
36
#include < ctime>
37
37
#include < cinttypes>
38
38
#include < fstream>
39
+ #include < functional>
39
40
#include < random>
40
41
#include < map>
41
42
#include < unordered_map>
@@ -2882,10 +2883,10 @@ struct beam {
2882
2883
float p; // Cumulative beam probability (renormalized with each token)
2883
2884
// end-of-sentence
2884
2885
bool eos () const { return !tokens.empty () && tokens.back () == llama_token_eos (); }
2885
- // Shift off first n tokens to the end of trunk .
2886
- void shift_tokens (std::vector<llama_token>& trunk , int const n) {
2887
- trunk .resize (trunk .size () + n);
2888
- std::copy (tokens.begin (), tokens.begin () + n, trunk .end () - n);
2886
+ // Shift off first n tokens to the end of dest .
2887
+ void shift_tokens (std::vector<llama_token>& dest , int const n) {
2888
+ dest .resize (dest .size () + n);
2889
+ std::copy (tokens.begin (), tokens.begin () + n, dest .end () - n);
2889
2890
shift_tokens (n);
2890
2891
}
2891
2892
// Shift off first n tokens and discard them.
@@ -2895,9 +2896,9 @@ struct beam {
2895
2896
}
2896
2897
};
2897
2898
2898
- void out_beam (std::ostream& os, llama_context* ctx, beam const & b ) {
2899
- os << " p(" << b .p << " ) eos(" << std::boolalpha << b .eos () << " ) tokens(" ;
2900
- for (llama_token const token_id : b .tokens ) {
2899
+ void out_beam (std::ostream& os, llama_context* ctx, beam const & beam ) {
2900
+ os << " p(" << beam .p << " ) eos(" << std::boolalpha << beam .eos () << " ) tokens(" ;
2901
+ for (llama_token const token_id : beam .tokens ) {
2901
2902
os << llama_token_to_str (ctx, token_id);
2902
2903
}
2903
2904
os << ' )' ;
@@ -2948,139 +2949,163 @@ struct logit_info {
2948
2949
}
2949
2950
};
2950
2951
2951
- // Track beams common prefix and when llama_eval has been applied with it.
2952
- struct beams_state {
2952
+ struct beam_search {
2953
+ llama_context * ctx;
2953
2954
int beam_width;
2955
+ int n_past;
2956
+ int n_predict;
2957
+ int n_threads;
2958
+ std::vector<beam> beams;
2959
+ std::vector<beam> next_beams;
2960
+
2961
+ // Re-calculated on each loop iteration
2954
2962
int common_prefix_length;
2955
- int & n_past;
2956
- bool shifted;
2957
- std::vector<llama_token> trunk; // Save token prefix common to all beams here
2958
- beams_state (int beam_width, int & n_past, int n_predict)
2959
- : beam_width(beam_width)
2960
- , n_past(n_past) {
2961
- trunk.reserve (n_predict);
2962
- }
2963
-
2964
- // Set common_prefix_length based on beams.
2965
- void find_common_prefix (std::vector<beam>& beams) {
2966
- shifted = false ;
2967
- common_prefix_length = int (beams[0 ].tokens .size ());
2963
+ // true iff llama_eval() has been called with common prefix in current loop iteration.
2964
+ bool common_prefix_evaluated;
2965
+ // Save token prefix common to all beams here
2966
+ std::vector<llama_token> response;
2967
+
2968
+ beam_search (llama_context * ctx, int beam_width, int n_past, int n_predict, int n_threads)
2969
+ : ctx(ctx)
2970
+ , beam_width(beam_width)
2971
+ , n_past(n_past)
2972
+ , n_predict(n_predict)
2973
+ , n_threads(n_threads) {
2974
+ beams.reserve (beam_width);
2975
+ next_beams.reserve (beam_width);
2976
+ }
2977
+
2978
+ // Find common_prefix_length based on beams.
2979
+ // Requires beams is not empty.
2980
+ int find_common_prefix_length () {
2981
+ int common_prefix_length = int (beams[0 ].tokens .size ());
2968
2982
for (int i=1 ; i<int (beams.size ()) ; ++i) {
2969
2983
int const j_max = std::min (common_prefix_length, int (beams[i].tokens .size ()));
2970
2984
for (int j=0 ; j<j_max ; ++j) {
2971
2985
if (beams[0 ].tokens [j] != beams[i].tokens [j]) {
2972
- common_prefix_length = j;
2973
- break ;
2986
+ return j;
2974
2987
}
2975
2988
}
2976
2989
}
2977
- }
2978
- };
2979
-
2980
- void fill_next_beams_by_top_probabilities (llama_context* ctx, std::vector<beam>& next_beams,
2981
- beam& b, beams_state& beams_state, int const n_threads) {
2982
- auto const comp = [](beam const & a, beam const & b) { return a.p > b.p ; };
2983
- if (beams_state.shifted ) {
2984
- // llama_eval was already called during this iteration
2985
- // with the common token prefix, so shift it off this beam.
2986
- b.shift_tokens (beams_state.common_prefix_length );
2987
- }
2988
- if (b.eos ()) {
2989
- // b is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
2990
- if (next_beams.size () < static_cast <size_t >(beams_state.beam_width )) {
2991
- next_beams.push_back (b);
2992
- if (next_beams.size () == static_cast <size_t >(beams_state.beam_width )) {
2993
- std::make_heap (next_beams.begin (), next_beams.end (), comp);
2994
- }
2995
- } else if (next_beams.front ().p < b.p ) {
2996
- std::pop_heap (next_beams.begin (), next_beams.end (), comp);
2997
- next_beams.back () = std::move (b);
2998
- std::push_heap (next_beams.begin (), next_beams.end (), comp);
2990
+ return common_prefix_length;
2991
+ }
2992
+
2993
+ // Min-heaps are used to efficiently gather the top-k elements (k=beam_width).
2994
+ // The repetative patterns below reflect the 2 stages of heaps:
2995
+ // * Gather elements until the vector is full, then call std::make_heap() on it.
2996
+ // * If the heap is full and a new element is found that should be included,
2997
+ // pop off the least element, replace it with the new, then push it into the heap.
2998
+ void fill_next_beams_by_top_probabilities (beam& b) {
2999
+ // Min-heaps use a greater-than comparator.
3000
+ auto const comp = [](beam const & a, beam const & b) { return a.p > b.p ; };
3001
+ if (common_prefix_evaluated) {
3002
+ // llama_eval was already called during this iteration
3003
+ // with the common token prefix, so shift it off this beam.
3004
+ b.shift_tokens (common_prefix_length);
2999
3005
}
3000
- } else {
3001
- // b is not at end-of-sentence, so branch with next top_k tokens.
3002
- if (!b.tokens .empty ()) {
3003
- llama_eval (ctx, b.tokens .data (), b.tokens .size (), beams_state.n_past , n_threads);
3004
- if (!beams_state.shifted && beams_state.common_prefix_length ) {
3005
- b.shift_tokens (beams_state.trunk , beams_state.common_prefix_length );
3006
- beams_state.n_past += beams_state.common_prefix_length ;
3007
- beams_state.shifted = true ;
3008
- }
3009
- }
3010
- logit_info li (ctx);
3011
- std::vector<llama_token_data> next_tokens = li.top_k (beams_state.beam_width );
3012
- int i=0 ;
3013
- if (next_beams.size () < static_cast <size_t >(beams_state.beam_width )) {
3014
- for (; next_beams.size () < static_cast <size_t >(beams_state.beam_width ) ; ++i) {
3015
- beam next_beam = b;
3016
- next_beam.tokens .push_back (next_tokens[i].id );
3017
- next_beam.p *= li.probability_from_logit (next_tokens[i].logit );
3018
- next_beams.push_back (std::move (next_beam));
3019
- }
3020
- std::make_heap (next_beams.begin (), next_beams.end (), comp);
3021
- } else {
3022
- for (; next_beams.front ().p == 0 .0f ; ++i) {
3006
+ if (b.eos ()) {
3007
+ // beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
3008
+ if (next_beams.size () < static_cast <size_t >(beam_width)) {
3009
+ next_beams.push_back (b);
3010
+ if (next_beams.size () == static_cast <size_t >(beam_width)) {
3011
+ std::make_heap (next_beams.begin (), next_beams.end (), comp);
3012
+ }
3013
+ } else if (next_beams.front ().p < b.p ) {
3023
3014
std::pop_heap (next_beams.begin (), next_beams.end (), comp);
3024
- next_beams.back () = b;
3025
- next_beams.back ().tokens .push_back (next_tokens[i].id );
3026
- next_beams.back ().p *= li.probability_from_logit (next_tokens[i].logit );
3015
+ next_beams.back () = std::move (b);
3027
3016
std::push_heap (next_beams.begin (), next_beams.end (), comp);
3028
3017
}
3018
+ } else {
3019
+ // beam is not at end-of-sentence, so branch with next top_k tokens.
3020
+ if (!b.tokens .empty ()) {
3021
+ llama_eval (ctx, b.tokens .data (), b.tokens .size (), n_past, n_threads);
3022
+ if (!common_prefix_evaluated && common_prefix_length) {
3023
+ b.shift_tokens (response, common_prefix_length);
3024
+ n_past += common_prefix_length;
3025
+ common_prefix_evaluated = true ;
3026
+ }
3027
+ }
3028
+ logit_info logit_info (ctx);
3029
+ std::vector<llama_token_data> next_tokens = logit_info.top_k (beam_width);
3030
+ int i=0 ;
3031
+ if (next_beams.size () < static_cast <size_t >(beam_width)) {
3032
+ for (; next_beams.size () < static_cast <size_t >(beam_width) ; ++i) {
3033
+ beam next_beam = b;
3034
+ next_beam.tokens .push_back (next_tokens[i].id );
3035
+ next_beam.p *= logit_info.probability_from_logit (next_tokens[i].logit );
3036
+ next_beams.push_back (std::move (next_beam));
3037
+ }
3038
+ std::make_heap (next_beams.begin (), next_beams.end (), comp);
3039
+ } else {
3040
+ for (; next_beams.front ().p == 0 .0f ; ++i) {
3041
+ std::pop_heap (next_beams.begin (), next_beams.end (), comp);
3042
+ next_beams.back () = b;
3043
+ next_beams.back ().tokens .push_back (next_tokens[i].id );
3044
+ next_beams.back ().p *= logit_info.probability_from_logit (next_tokens[i].logit );
3045
+ std::push_heap (next_beams.begin (), next_beams.end (), comp);
3046
+ }
3047
+ }
3048
+ for (; i < beam_width ; ++i) {
3049
+ float const next_p = b.p * logit_info.probability_from_logit (next_tokens[i].logit );
3050
+ if (next_beams.front ().p < next_p) {
3051
+ std::pop_heap (next_beams.begin (), next_beams.end (), comp);
3052
+ next_beams.back () = b;
3053
+ next_beams.back ().tokens .push_back (next_tokens[i].id );
3054
+ next_beams.back ().p = next_p;
3055
+ std::push_heap (next_beams.begin (), next_beams.end (), comp);
3056
+ }
3057
+ }
3029
3058
}
3030
- for (; i < beams_state.beam_width ; ++i) {
3031
- float const next_p = b.p * li.probability_from_logit (next_tokens[i].logit );
3032
- if (next_beams.front ().p < next_p) {
3033
- std::pop_heap (next_beams.begin (), next_beams.end (), comp);
3034
- next_beams.back () = b;
3035
- next_beams.back ().tokens .push_back (next_tokens[i].id );
3036
- next_beams.back ().p = next_p;
3037
- std::push_heap (next_beams.begin (), next_beams.end (), comp);
3059
+ }
3060
+
3061
+ // Loop:
3062
+ // * while i < n_predict
3063
+ // * until all of the beams have nreached end-of-sentence
3064
+ // * until the highest probability beam is at end-of-sentence
3065
+ // (since all other beam probabilities can only decrease)
3066
+ void loop (std::function<void (std::vector<beam>&)> const callback) {
3067
+ beams.push_back ({{}, 1 .0f });
3068
+ auto const eos = [](beam const & beam) { return beam.eos (); };
3069
+ for (int i=0 ; i<n_predict && !std::all_of (beams.begin (),beams.end (),eos) && !eos (top_beam ()) ; ++i) {
3070
+ common_prefix_evaluated = false ;
3071
+ common_prefix_length = find_common_prefix_length ();
3072
+ for (beam& beam : beams) {
3073
+ fill_next_beams_by_top_probabilities (beam);
3038
3074
}
3075
+ beams.swap (next_beams);
3076
+ renormalize_beam_probabilities (beams);
3077
+ std::for_each (next_beams.begin (), next_beams.end (), [](beam& beam) { beam.p = 0 .0f ; });
3078
+ callback (beams);
3039
3079
}
3080
+ beam& top_b = top_beam ();
3081
+ top_b.shift_tokens (response, top_b.tokens .size ());
3040
3082
}
3041
- }
3042
3083
3043
- // As beams grow, the cumulative probabilities decrease.
3044
- // Renormalize them to avoid floating point underflow.
3045
- void renormalize_beam_probabilities (std::vector<beam>& beams) {
3046
- auto const sum_p = [](float sum, beam& b ) { return sum + b .p ; };
3047
- float const inv_sum = 1 .0f / std::accumulate (beams.begin (), beams.end (), 0 .0f , sum_p);
3048
- std::for_each (beams.begin (), beams.end (), [inv_sum](beam& b ) { b .p *= inv_sum; });
3049
- }
3084
+ // As beams grow, the cumulative probabilities decrease.
3085
+ // Renormalize them to avoid floating point underflow.
3086
+ static void renormalize_beam_probabilities (std::vector<beam>& beams) {
3087
+ auto const sum_p = [](float sum, beam& beam ) { return sum + beam .p ; };
3088
+ float const inv_sum = 1 .0f / std::accumulate (beams.begin (), beams.end (), 0 .0f , sum_p);
3089
+ std::for_each (beams.begin (), beams.end (), [inv_sum](beam& beam ) { beam .p *= inv_sum; });
3090
+ }
3050
3091
3051
- // Return beam with highest probability.
3052
- beam& top_beam (std::vector<beam>& beams) {
3053
- auto const by_p = [](beam const & a, beam const & b) { return a.p < b.p ; };
3054
- return *std::max_element (beams.begin (), beams.end (), by_p);
3055
- }
3092
+ // Return beam with highest probability.
3093
+ beam& top_beam () {
3094
+ auto const by_p = [](beam const & a, beam const & b) { return a.p < b.p ; };
3095
+ return *std::max_element (beams.begin (), beams.end (), by_p);
3096
+ }
3097
+ };
3056
3098
3057
- // This is deterministic, but can be made probabilistic in
3058
- // fill_next_beams_by_top_probabilities() by randomly selecting from all next_beams.
3059
3099
// Not thread-safe.
3060
- const char * llama_beam_search (llama_context * ctx, int const beam_width,
3100
+ const char * llama_beam_search (llama_context * ctx, int beam_width,
3061
3101
int n_past, int const n_predict, int const n_threads) {
3062
3102
static std::string beam_search_response;
3063
3103
assert (ctx);
3064
3104
const int64_t t_start_sample_us = ggml_time_us ();
3065
3105
3066
- std::vector<beam> beams;
3067
- beams.reserve (beam_width);
3068
- beams.push_back ({{}, 1.0 });
3069
- std::vector<beam> next_beams;
3070
- next_beams.reserve (beam_width);
3071
- beams_state beams_state (beam_width, n_past, n_predict);
3072
- // Loop while there are any beams that have not yet reached end-of-sentence.
3073
- // If the highest probability beam is at end-of-sentence, then finish since all other
3074
- // beam probabilities can only decrease.
3075
- auto const eos = [](beam const & b) { return b.eos (); };
3076
- for (int i=0 ; i<n_predict && !eos (top_beam (beams)) && !std::all_of (beams.begin (),beams.end (),eos) ; ++i) {
3077
- beams_state.find_common_prefix (beams);
3078
- for (beam& b : beams) {
3079
- fill_next_beams_by_top_probabilities (ctx, next_beams, b, beams_state, n_threads);
3080
- }
3081
- beams.swap (next_beams);
3082
- std::for_each (next_beams.begin (), next_beams.end (), [](beam& b) { b.p = 0 .0f ; });
3083
- renormalize_beam_probabilities (beams);
3106
+ beam_search beam_search (ctx, beam_width, n_past, n_predict, n_threads);
3107
+
3108
+ beam_search.loop ([&](std::vector<beam>& beams) {
3084
3109
#if 1 // DEBUG: print current beams for this iteration
3085
3110
std::cout << " \n\n Current beams:\n " ;
3086
3111
for (size_t j=0 ; j < beams.size () ; ++j) {
@@ -3091,13 +3116,11 @@ const char* llama_beam_search(llama_context * ctx, int const beam_width,
3091
3116
#else
3092
3117
std::cout << '.' << std::flush; // Show progress
3093
3118
#endif
3094
- }
3119
+ });
3095
3120
3096
- beam& top_b = top_beam (beams);
3097
- top_b.shift_tokens (beams_state.trunk , top_b.tokens .size ());
3098
3121
// Save beam sentence to beam_search_response. Is there a better way?
3099
3122
std::ostringstream oss;
3100
- for (llama_token const token : beams_state. trunk ) {
3123
+ for (llama_token const token : beam_search. response ) {
3101
3124
oss << llama_token_to_str (ctx, token);
3102
3125
}
3103
3126
beam_search_response = oss.str ();
0 commit comments