@@ -2901,8 +2901,10 @@ struct llama_beam {
2901
2901
}
2902
2902
// Shift off first n tokens and discard them.
2903
2903
void shift_tokens (size_t const n) {
2904
- std::copy (tokens.begin () + n, tokens.end (), tokens.begin ());
2905
- tokens.resize (tokens.size () - n);
2904
+ if (n) {
2905
+ std::copy (tokens.begin () + n, tokens.end (), tokens.begin ());
2906
+ tokens.resize (tokens.size () - n);
2907
+ }
2906
2908
}
2907
2909
llama_beam_view view () const { return {tokens.data (), tokens.size (), p, eos}; }
2908
2910
};
@@ -2963,8 +2965,6 @@ struct beam_search {
2963
2965
2964
2966
// Re-calculated on each loop iteration
2965
2967
size_t common_prefix_length;
2966
- // true iff llama_eval() has been called with non-empty common prefix in current loop iteration.
2967
- bool common_prefix_evaluated;
2968
2968
2969
2969
// Used to communicate to/from callback on beams state.
2970
2970
std::vector<llama_beam_view> beam_views;
@@ -2996,11 +2996,6 @@ struct beam_search {
2996
2996
void fill_next_beams_by_top_probabilities (llama_beam& beam) {
2997
2997
// Min-heaps use a greater-than comparator.
2998
2998
auto const comp = [](llama_beam const & a, llama_beam const & b) { return a.p > b.p ; };
2999
- if (common_prefix_evaluated) {
3000
- // llama_eval was already called during this iteration
3001
- // with the common token prefix, so shift it off this beam.
3002
- beam.shift_tokens (common_prefix_length);
3003
- }
3004
2999
if (beam.eos ) {
3005
3000
// beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
3006
3001
if (next_beams.size () < n_beams) {
@@ -3017,11 +3012,6 @@ struct beam_search {
3017
3012
// beam is not at end-of-sentence, so branch with next top_k tokens.
3018
3013
if (!beam.tokens .empty ()) {
3019
3014
llama_eval (ctx, beam.tokens .data (), beam.tokens .size (), n_past, n_threads);
3020
- if (!common_prefix_evaluated && common_prefix_length) {
3021
- beam.shift_tokens (common_prefix_length);
3022
- n_past += common_prefix_length;
3023
- common_prefix_evaluated = true ;
3024
- }
3025
3015
}
3026
3016
logit_info logit_info (ctx);
3027
3017
std::vector<llama_token_data> next_tokens = logit_info.top_k (n_beams);
@@ -3076,9 +3066,7 @@ struct beam_search {
3076
3066
// Side effect: set common_prefix_length = find_common_prefix_length();
3077
3067
llama_beams_state get_beams_state (bool const last_call) {
3078
3068
for (size_t i=0 ; i<beams.size () ; ++i) {
3079
- // beam_views[i] = beams[i].view();
3080
- auto view = beams.at (i).view ();
3081
- beam_views.at (i) = view; // capacity 0
3069
+ beam_views[i] = beams[i].view ();
3082
3070
}
3083
3071
common_prefix_length = find_common_prefix_length ();
3084
3072
return {beam_views.data (), beams.size (), common_prefix_length, last_call};
@@ -3094,12 +3082,16 @@ struct beam_search {
3094
3082
auto const not_eos = [](llama_beam const & beam) { return !beam.eos ; };
3095
3083
for (int i=0 ; i<n_predict && std::any_of (beams.begin (),beams.end (),not_eos) &&
3096
3084
!beams[top_beam_index ()].eos ; ++i) {
3097
- callback (callback_state, get_beams_state (false ));
3085
+ callback (callback_state, get_beams_state (false )); // Sets common_prefix_length
3098
3086
update_beams_from_beam_views (); // Update values (p,eos) that callback may have changed.
3099
- common_prefix_evaluated = false ; // Any common prefix has not yet been llama_eval()ed.
3087
+ if (common_prefix_length) {
3088
+ llama_eval (ctx, beams[0 ].tokens .data (), common_prefix_length, n_past, n_threads);
3089
+ n_past += common_prefix_length;
3090
+ }
3100
3091
// Zero-out next_beam probabilities to place them last in following min-heap.
3101
3092
std::for_each (next_beams.begin (), next_beams.end (), [](llama_beam& beam) { beam.p = 0 .0f ; });
3102
3093
for (llama_beam& beam : beams) {
3094
+ beam.shift_tokens (common_prefix_length);
3103
3095
fill_next_beams_by_top_probabilities (beam);
3104
3096
}
3105
3097
// next_beams become the beams of next/final iteration. Swap them to re-use memory.
0 commit comments