Skip to content

Commit 40c9403

Browse files
committed
Simplify common_prefix_length logic and drop use of common_prefix_evaluated bool.
1 parent 3359308 commit 40c9403

File tree

1 file changed

+11
-19
lines changed

1 file changed

+11
-19
lines changed

llama.cpp

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2901,8 +2901,10 @@ struct llama_beam {
29012901
}
29022902
// Shift off first n tokens and discard them.
29032903
void shift_tokens(size_t const n) {
2904-
std::copy(tokens.begin() + n, tokens.end(), tokens.begin());
2905-
tokens.resize(tokens.size() - n);
2904+
if (n) {
2905+
std::copy(tokens.begin() + n, tokens.end(), tokens.begin());
2906+
tokens.resize(tokens.size() - n);
2907+
}
29062908
}
29072909
llama_beam_view view() const { return {tokens.data(), tokens.size(), p, eos}; }
29082910
};
@@ -2963,8 +2965,6 @@ struct beam_search {
29632965

29642966
// Re-calculated on each loop iteration
29652967
size_t common_prefix_length;
2966-
// true iff llama_eval() has been called with non-empty common prefix in current loop iteration.
2967-
bool common_prefix_evaluated;
29682968

29692969
// Used to communicate to/from callback on beams state.
29702970
std::vector<llama_beam_view> beam_views;
@@ -2996,11 +2996,6 @@ struct beam_search {
29962996
void fill_next_beams_by_top_probabilities(llama_beam& beam) {
29972997
// Min-heaps use a greater-than comparator.
29982998
auto const comp = [](llama_beam const& a, llama_beam const& b) { return a.p > b.p; };
2999-
if (common_prefix_evaluated) {
3000-
// llama_eval was already called during this iteration
3001-
// with the common token prefix, so shift it off this beam.
3002-
beam.shift_tokens(common_prefix_length);
3003-
}
30042999
if (beam.eos) {
30053000
// beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
30063001
if (next_beams.size() < n_beams) {
@@ -3017,11 +3012,6 @@ struct beam_search {
30173012
// beam is not at end-of-sentence, so branch with next top_k tokens.
30183013
if (!beam.tokens.empty()) {
30193014
llama_eval(ctx, beam.tokens.data(), beam.tokens.size(), n_past, n_threads);
3020-
if (!common_prefix_evaluated && common_prefix_length) {
3021-
beam.shift_tokens(common_prefix_length);
3022-
n_past += common_prefix_length;
3023-
common_prefix_evaluated = true;
3024-
}
30253015
}
30263016
logit_info logit_info(ctx);
30273017
std::vector<llama_token_data> next_tokens = logit_info.top_k(n_beams);
@@ -3076,9 +3066,7 @@ struct beam_search {
30763066
// Side effect: set common_prefix_length = find_common_prefix_length();
30773067
llama_beams_state get_beams_state(bool const last_call) {
30783068
for (size_t i=0 ; i<beams.size() ; ++i) {
3079-
//beam_views[i] = beams[i].view();
3080-
auto view = beams.at(i).view();
3081-
beam_views.at(i) = view; // capacity 0
3069+
beam_views[i] = beams[i].view();
30823070
}
30833071
common_prefix_length = find_common_prefix_length();
30843072
return {beam_views.data(), beams.size(), common_prefix_length, last_call};
@@ -3094,12 +3082,16 @@ struct beam_search {
30943082
auto const not_eos = [](llama_beam const& beam) { return !beam.eos; };
30953083
for (int i=0 ; i<n_predict && std::any_of(beams.begin(),beams.end(),not_eos) &&
30963084
!beams[top_beam_index()].eos ; ++i) {
3097-
callback(callback_state, get_beams_state(false));
3085+
callback(callback_state, get_beams_state(false)); // Sets common_prefix_length
30983086
update_beams_from_beam_views(); // Update values (p,eos) that callback may have changed.
3099-
common_prefix_evaluated = false; // Any common prefix has not yet been llama_eval()ed.
3087+
if (common_prefix_length) {
3088+
llama_eval(ctx, beams[0].tokens.data(), common_prefix_length, n_past, n_threads);
3089+
n_past += common_prefix_length;
3090+
}
31003091
// Zero-out next_beam probabilities to place them last in following min-heap.
31013092
std::for_each(next_beams.begin(), next_beams.end(), [](llama_beam& beam) { beam.p = 0.0f; });
31023093
for (llama_beam& beam : beams) {
3094+
beam.shift_tokens(common_prefix_length);
31033095
fill_next_beams_by_top_probabilities(beam);
31043096
}
31053097
// next_beams become the beams of next/final iteration. Swap them to re-use memory.

0 commit comments

Comments
 (0)