-
Notifications
You must be signed in to change notification settings - Fork 12k
llama : remove all_pos_0, all_pos_1, all_seq_id from llama_batch #9745
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
b226c5b
1c48616
9970316
9276950
59fd6b6
7740c96
6a9769a
0639ff1
b4c9911
734f9e2
7264596
6395174
4be7ecf
9dd7e77
5d99ae4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -412,13 +412,22 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params | |
const int batch_start = start + j * n_batch; | ||
const int batch_size = std::min(end - batch_start, n_batch); | ||
|
||
llama_batch batch = llama_batch_init(batch_size, 0, 1); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
for (int i = 0; i < batch_size; i++) { | ||
batch. token[i] = tokens[batch_start + i]; | ||
batch. pos[i] = j*n_batch + i; | ||
batch.logits[i] = true; | ||
batch.seq_id[i][0] = 0; | ||
} | ||
|
||
//LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); | ||
// TODO: use llama_batch.logits instead of relying on logits_all == true | ||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { | ||
if (llama_decode(ctx, batch)) { | ||
//LOG_ERR("%s : failed to eval\n", __func__); | ||
return {tokens, -1, logit_history, prob_history}; | ||
} | ||
|
||
llama_batch_free(batch); | ||
|
||
// save original token and restore it after eval | ||
const auto token_org = tokens[batch_start]; | ||
|
||
|
@@ -704,7 +713,6 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector< | |
batch.n_seq_id + i, | ||
batch.seq_id + i, | ||
batch.logits + i, | ||
0, 0, 0, // unused | ||
}; | ||
|
||
const int ret = llama_decode(ctx, batch_view); | ||
|
@@ -1803,12 +1811,21 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { | |
tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); | ||
} | ||
|
||
// TODO: use llama_batch.logits instead of relying on logits_all == true | ||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { | ||
llama_batch batch = llama_batch_init(batch_size, 0, 1); | ||
for (int i = 0; i < batch_size; i++) { | ||
batch. token[i] = tokens[batch_start + i]; | ||
batch. pos[i] = j*n_batch + i; | ||
batch.logits[i] = true; | ||
batch.seq_id[i][0] = 0; | ||
} | ||
|
||
if (llama_decode(ctx, batch)) { | ||
LOG_ERR("%s : failed to eval\n", __func__); | ||
return; | ||
} | ||
|
||
llama_batch_free(batch); | ||
|
||
// restore the original token in case it was set to BOS | ||
tokens[batch_start] = token_org; | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,7 +49,7 @@ int main(int argc, char ** argv) { | |
auto tokens = common_tokenize(ctx, params.prompt, true); | ||
|
||
// evaluate prompt | ||
llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), n_past, 0)); | ||
llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size())); | ||
n_past += tokens.size(); | ||
|
||
// save state (rng, logits, embedding and kv_cache) to file | ||
|
@@ -77,7 +77,7 @@ int main(int argc, char ** argv) { | |
printf("%s", next_token_str.c_str()); | ||
result0 += next_token_str; | ||
|
||
if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0))) { | ||
if (llama_decode(ctx, llama_batch_get_one(&next_token, 1))) { | ||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__); | ||
llama_free(ctx); | ||
llama_free_model(model); | ||
|
@@ -133,7 +133,7 @@ int main(int argc, char ** argv) { | |
printf("%s", next_token_str.c_str()); | ||
result1 += next_token_str; | ||
|
||
if (llama_decode(ctx2, llama_batch_get_one(&next_token, 1, n_past, 0))) { | ||
if (llama_decode(ctx2, llama_batch_get_one(&next_token, 1))) { | ||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__); | ||
llama_free(ctx2); | ||
llama_free_model(model); | ||
|
@@ -221,7 +221,7 @@ int main(int argc, char ** argv) { | |
printf("%s", next_token_str.c_str()); | ||
result2 += next_token_str; | ||
|
||
if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1, n_past, 1))) { | ||
if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1))) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will generate a batch for make -j && ./llama-save-load-state -m ${some_model} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for spotting that! Fixed in 6395174 |
||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__); | ||
llama_free(ctx3); | ||
llama_free_model(model); | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small explanation for what's happening: We suppose to shift all tokens from
n_keep + n_discard + 1
, so the end of must ben_past + 1
(or we can simply set it to-1
, which means[p0, inf)
)Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, I don't think
n_past + 1
is needed here. There shouldn't be a token withpos == n_past
in the KV cache.But yes, using either
n_past
or-1
would achieve the same thing. Think usingn_past
is more illustrative.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok thanks, I figured out that I counted the token from 1, not from 0. I fixed that in 5d99ae4