Skip to content

llama : remove all_pos_0, all_pos_1, all_seq_id from llama_batch #9745

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ struct common_init_result common_init_from_params(common_params & params) {
}

if (llama_model_has_encoder(model)) {
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size()));
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == -1) {
decoder_start_token_id = bos;
Expand All @@ -948,7 +948,7 @@ struct common_init_result common_init_from_params(common_params & params) {
tmp.push_back(decoder_start_token_id);
}
if (llama_model_has_decoder(model)) {
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
}
llama_kv_cache_clear(lctx);
llama_synchronize(lctx);
Expand Down
1 change: 0 additions & 1 deletion examples/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ int main(int argc, char ** argv) {
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};

const int ret = llama_decode(ctx, batch_view);
Expand Down
2 changes: 1 addition & 1 deletion examples/cvector-generator/cvector-generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {

static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
llama_kv_cache_clear(ctx);
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) {
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion examples/eval-callback/eval-callback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ static bool run(llama_context * ctx, const common_params & params) {

std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);

if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) {
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}
Expand Down
13 changes: 11 additions & 2 deletions examples/imatrix/imatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,12 +508,21 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
}

// TODO: use batch.logits to save computations instead of relying on logits_all == true
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
llama_batch batch = llama_batch_init(batch_size, 0, 1);
for (int i = 0; i < batch_size; i++) {
batch. token[i] = tokens[batch_start + i];
batch. pos[i] = j*n_batch + i;
batch.logits[i] = true;
batch.seq_id[i][0] = 0;
}

if (llama_decode(ctx, batch)) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}

llama_batch_free(batch);

// restore the original token in case it was set to BOS
tokens[batch_start] = token_org;

Expand Down
4 changes: 2 additions & 2 deletions examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ int main(int argc, char ** argv) {
n_past, n_left, n_ctx, params.n_keep, n_discard);

llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past + 1, -n_discard);
Copy link
Collaborator Author

@ngxson ngxson Oct 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small explanation for what's happening: We suppose to shift all tokens from n_keep + n_discard + 1, so the end of must be n_past + 1 (or we can simply set it to -1, which means [p0, inf))

Copy link
Member

@ggerganov ggerganov Oct 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I don't think n_past + 1 is needed here. There shouldn't be a token with pos == n_past in the KV cache.

But yes, using either n_past or -1 would achieve the same thing. Think using n_past is more illustrative.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok thanks, I figured out that I counted the token from 1, not from 0. I fixed that in 5d99ae4


n_past -= n_discard;

Expand All @@ -396,7 +396,7 @@ int main(int argc, char ** argv) {

LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());

if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) {
LOG_ERR("%s : failed to eval\n", __func__);
return 1;
}
Expand Down
16 changes: 8 additions & 8 deletions examples/llama-bench/llama-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1428,7 +1428,7 @@ struct sql_printer : public printer {
}
};

static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
llama_set_n_threads(ctx, n_threads, n_threads);

const llama_model * model = llama_get_model(ctx);
Expand All @@ -1444,14 +1444,14 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
for (int i = 1; i < n_tokens; i++) {
tokens[i] = std::rand() % n_vocab;
}
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0));
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
n_processed += n_tokens;
}

llama_synchronize(ctx);
}

static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
llama_set_n_threads(ctx, n_threads, n_threads);

const llama_model * model = llama_get_model(ctx);
Expand All @@ -1460,7 +1460,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads)
llama_token token = llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab;

for (int i = 0; i < n_gen; i++) {
llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0));
llama_decode(ctx, llama_batch_get_one(&token, 1));
llama_synchronize(ctx);
token = std::rand() % n_vocab;
}
Expand Down Expand Up @@ -1596,13 +1596,13 @@ int main(int argc, char ** argv) {
fprintf(stderr, "llama-bench: benchmark %d/%ld: warmup prompt run\n", params_idx, params_count);
}
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
}
if (t.n_gen > 0) {
if (params.progress) {
fprintf(stderr, "llama-bench: benchmark %d/%ld: warmup generation run\n", params_idx, params_count);
}
test_gen(ctx, 1, 0, t.n_threads);
test_gen(ctx, 1, t.n_threads);
}

for (int i = 0; i < params.reps; i++) {
Expand All @@ -1614,13 +1614,13 @@ int main(int argc, char ** argv) {
if (params.progress) {
fprintf(stderr, "llama-bench: benchmark %d/%ld: prompt run %d/%d\n", params_idx, params_count, i + 1, params.reps);
}
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
}
if (t.n_gen > 0) {
if (params.progress) {
fprintf(stderr, "llama-bench: benchmark %d/%ld: generation run %d/%d\n", params_idx, params_count, i + 1, params.reps);
}
test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads);
test_gen(ctx, t.n_gen, t.n_threads);
}

uint64_t t_ns = get_time_ns() - t_start;
Expand Down
3 changes: 0 additions & 3 deletions examples/llama.android/llama/src/main/cpp/llama-android.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,6 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
nullptr,
nullptr,
nullptr,
0,
0,
0,
};

if (embd) {
Expand Down
2 changes: 1 addition & 1 deletion examples/llava/llava-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
if (n_eval > n_batch) {
n_eval = n_batch;
}
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) {
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
return false;
}
Expand Down
38 changes: 36 additions & 2 deletions examples/llava/llava.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,39 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co
return true;
}

struct llava_embd_batch {
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id> seq_id_0;
std::vector<llama_seq_id *> seq_ids;
std::vector<int8_t> logits;
llama_batch batch;
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
pos .resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids .resize(n_tokens + 1);
logits .resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ nullptr,
/*embd =*/ embd,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = pos_0 + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
}
}
};

bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
int n_embd = llama_n_embd(llama_get_model(ctx_llama));

Expand All @@ -409,8 +442,9 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
if (n_eval > n_batch) {
n_eval = n_batch;
}
llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
if (llama_decode(ctx_llama, batch)) {
float * embd = image_embed->embed+i*n_embd;
llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
if (llama_decode(ctx_llama, llava_batch.batch)) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion examples/llava/minicpmv-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
if (n_eval > n_batch) {
n_eval = n_batch;
}
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) {
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
return false;
}
Expand Down
4 changes: 2 additions & 2 deletions examples/lookahead/lookahead.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ int main(int argc, char ** argv) {
const auto t_enc_start = ggml_time_us();

// eval the prompt
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1));
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));

for (int s = 1; s < W + G + 1; ++s) {
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
Expand Down
4 changes: 2 additions & 2 deletions examples/lookup/lookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ int main(int argc, char ** argv){

const auto t_enc_start = ggml_time_us();

llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1));
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));

const auto t_enc_end = ggml_time_us();

Expand Down
6 changes: 3 additions & 3 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ int main(int argc, char ** argv) {
int enc_input_size = embd_inp.size();
llama_token * enc_input_buf = embd_inp.data();

if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) {
if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size))) {
LOG_ERR("%s : failed to eval\n", __func__);
return 1;
}
Expand Down Expand Up @@ -582,7 +582,7 @@ int main(int argc, char ** argv) {
n_past, n_left, n_ctx, params.n_keep, n_discard);

llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past + 1 , -n_discard);

n_past -= n_discard;

Expand Down Expand Up @@ -648,7 +648,7 @@ int main(int argc, char ** argv) {

LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());

if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) {
LOG_ERR("%s : failed to eval\n", __func__);
return 1;
}
Expand Down
1 change: 0 additions & 1 deletion examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,6 @@ int main(int argc, char ** argv) {
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};

const int ret = llama_decode(ctx, batch_view);
Expand Down
27 changes: 22 additions & 5 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,13 +412,22 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);

llama_batch batch = llama_batch_init(batch_size, 0, 1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move the llama_batch outside the loop and reuse it. Maybe utilize the common_batch_ API to make it little less cumbersome.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 734f9e2 and 4be7ecf

for (int i = 0; i < batch_size; i++) {
batch. token[i] = tokens[batch_start + i];
batch. pos[i] = j*n_batch + i;
batch.logits[i] = true;
batch.seq_id[i][0] = 0;
}

//LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
// TODO: use llama_batch.logits instead of relying on logits_all == true
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
if (llama_decode(ctx, batch)) {
//LOG_ERR("%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
}

llama_batch_free(batch);

// save original token and restore it after eval
const auto token_org = tokens[batch_start];

Expand Down Expand Up @@ -704,7 +713,6 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};

const int ret = llama_decode(ctx, batch_view);
Expand Down Expand Up @@ -1803,12 +1811,21 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
}

// TODO: use llama_batch.logits instead of relying on logits_all == true
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
llama_batch batch = llama_batch_init(batch_size, 0, 1);
for (int i = 0; i < batch_size; i++) {
batch. token[i] = tokens[batch_start + i];
batch. pos[i] = j*n_batch + i;
batch.logits[i] = true;
batch.seq_id[i][0] = 0;
}

if (llama_decode(ctx, batch)) {
LOG_ERR("%s : failed to eval\n", __func__);
return;
}

llama_batch_free(batch);

// restore the original token in case it was set to BOS
tokens[batch_start] = token_org;

Expand Down
8 changes: 4 additions & 4 deletions examples/save-load-state/save-load-state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ int main(int argc, char ** argv) {
auto tokens = common_tokenize(ctx, params.prompt, true);

// evaluate prompt
llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), n_past, 0));
llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()));
n_past += tokens.size();

// save state (rng, logits, embedding and kv_cache) to file
Expand Down Expand Up @@ -77,7 +77,7 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str.c_str());
result0 += next_token_str;

if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0))) {
if (llama_decode(ctx, llama_batch_get_one(&next_token, 1))) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx);
llama_free_model(model);
Expand Down Expand Up @@ -133,7 +133,7 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str.c_str());
result1 += next_token_str;

if (llama_decode(ctx2, llama_batch_get_one(&next_token, 1, n_past, 0))) {
if (llama_decode(ctx2, llama_batch_get_one(&next_token, 1))) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx2);
llama_free_model(model);
Expand Down Expand Up @@ -221,7 +221,7 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str.c_str());
result2 += next_token_str;

if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1, n_past, 1))) {
if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1))) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will generate a batch for seq_id == 0 and it needs to be seq_id == 1

make -j && ./llama-save-load-state -m ${some_model}

Copy link
Collaborator Author

@ngxson ngxson Oct 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for spotting that! Fixed in 6395174

fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx3);
llama_free_model(model);
Expand Down
1 change: 0 additions & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2268,7 +2268,6 @@ struct server_context {
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};

const int ret = llama_decode(ctx, batch_view);
Expand Down
Loading
Loading