Skip to content

llama : rename batch.logits to batch.output #10004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
<< ", pos " << std::to_string(batch.pos[i])
<< ", n_seq_id " << std::to_string(batch.n_seq_id[i])
<< ", seq_id " << std::to_string(batch.seq_id[i][0])
<< ", logits " << std::to_string(batch.logits[i]);
<< ", output " << std::to_string(batch.output[i]);
}

buf << " ]";
Expand Down Expand Up @@ -1617,7 +1617,7 @@ void common_batch_add(
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits) {
bool output) {
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");

batch.token [batch.n_tokens] = id;
Expand All @@ -1626,7 +1626,7 @@ void common_batch_add(
for (size_t i = 0; i < seq_ids.size(); ++i) {
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
}
batch.logits [batch.n_tokens] = logits;
batch.output [batch.n_tokens] = output;

batch.n_tokens++;
}
Expand Down
4 changes: 2 additions & 2 deletions examples/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ int main(int argc, char ** argv) {
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
batch.output + i,
};

const int ret = llama_decode(ctx, batch_view);
Expand Down Expand Up @@ -128,7 +128,7 @@ int main(int argc, char ** argv) {
common_batch_add(batch, 0, i, { j }, false);
}
}
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;

const auto t_pp_start = ggml_time_us();

Expand Down
6 changes: 3 additions & 3 deletions examples/batched.swift/Sources/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ for (i, token) in tokens.enumerated() {
if let seq_id = batch.seq_id[i] {
seq_id[0] = 0
}
batch.logits[i] = 0
batch.output[i] = 0
}

// llama_decode will output logits only for the last token of the prompt
batch.logits[Int(batch.n_tokens) - 1] = 1
batch.output[Int(batch.n_tokens) - 1] = 1

if llama_decode(context, batch) != 0 {
print("llama_decode() failed")
Expand Down Expand Up @@ -171,7 +171,7 @@ while n_cur <= n_len {
if let seq_id = batch.seq_id[Int(batch.n_tokens)] {
seq_id[0] = Int32(i)
}
batch.logits[Int(batch.n_tokens)] = 1
batch.output[Int(batch.n_tokens)] = 1

i_batch[i] = batch.n_tokens

Expand Down
2 changes: 1 addition & 1 deletion examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ int main(int argc, char ** argv) {
}

// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;

if (llama_decode(ctx, batch) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
Expand Down
2 changes: 1 addition & 1 deletion examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}

for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
if (!batch.output[i]) {
continue;
}

Expand Down
6 changes: 3 additions & 3 deletions examples/llama.android/llama/src/main/cpp/llama-android.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
common_batch_add(*batch, 0, i, { 0 }, false);
}

batch->logits[batch->n_tokens - 1] = true;
batch->output[batch->n_tokens - 1] = true;
llama_kv_cache_clear(context);

const auto t_pp_start = ggml_time_us();
Expand Down Expand Up @@ -297,7 +297,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
for (int i = 0; i < n_tokens; ++i) {
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
}
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
batch->output = (int8_t *) malloc(sizeof(int8_t) * n_tokens);

return reinterpret_cast<jlong>(batch);
}
Expand Down Expand Up @@ -381,7 +381,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
}

// llama_decode will output logits only for the last token of the prompt
batch->logits[batch->n_tokens - 1] = true;
batch->output[batch->n_tokens - 1] = true;

if (llama_decode(context, *batch) != 0) {
LOGe("llama_decode() failed");
Expand Down
8 changes: 4 additions & 4 deletions examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ func llama_batch_clear(_ batch: inout llama_batch) {
batch.n_tokens = 0
}

func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ logits: Bool) {
func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ outputs: Bool) {
batch.token [Int(batch.n_tokens)] = id
batch.pos [Int(batch.n_tokens)] = pos
batch.n_seq_id[Int(batch.n_tokens)] = Int32(seq_ids.count)
for i in 0..<seq_ids.count {
batch.seq_id[Int(batch.n_tokens)]![Int(i)] = seq_ids[i]
}
batch.logits [Int(batch.n_tokens)] = logits ? 1 : 0
batch.output [Int(batch.n_tokens)] = outputs ? 1 : 0

batch.n_tokens += 1
}
Expand Down Expand Up @@ -139,7 +139,7 @@ actor LlamaContext {
let i = Int(i1)
llama_batch_add(&batch, tokens_list[i], Int32(i), [0], false)
}
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
batch.output[Int(batch.n_tokens) - 1] = 1 // true

if llama_decode(context, batch) != 0 {
print("llama_decode() failed")
Expand Down Expand Up @@ -208,7 +208,7 @@ actor LlamaContext {
for i in 0..<n_tokens {
llama_batch_add(&batch, 0, Int32(i), [0], false)
}
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
batch.output[Int(batch.n_tokens) - 1] = 1 // true

llama_kv_cache_clear(context)

Expand Down
8 changes: 4 additions & 4 deletions examples/llava/llava.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,13 +441,13 @@ struct llava_embd_batch {
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id> seq_id_0;
std::vector<llama_seq_id *> seq_ids;
std::vector<int8_t> logits;
std::vector<int8_t> outputs;
llama_batch batch;
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
pos .resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids .resize(n_tokens + 1);
logits .resize(n_tokens);
outputs .resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
Expand All @@ -458,13 +458,13 @@ struct llava_embd_batch {
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
/*output =*/ outputs.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = pos_0 + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
batch.output [i] = false;
}
}
};
Expand Down
4 changes: 2 additions & 2 deletions examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ int main(int argc, char ** argv) {

// extract the logits only for the last token
if (batch.n_tokens > 0) {
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;
}

client.n_prompt = tokens_prompt.size();
Expand Down Expand Up @@ -309,7 +309,7 @@ int main(int argc, char ** argv) {
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
batch.output + i,
};

const int ret = llama_decode(ctx, batch_view);
Expand Down
4 changes: 2 additions & 2 deletions examples/passkey/passkey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ int main(int argc, char ** argv) {
}

if (i + n_batch >= n_tokens_all) {
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;
}

if (llama_decode(ctx, batch) != 0) {
Expand Down Expand Up @@ -180,7 +180,7 @@ int main(int argc, char ** argv) {
}

if (i + n_batch >= n_tokens_all) {
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;
}

if (llama_decode(ctx, batch) != 0) {
Expand Down
14 changes: 7 additions & 7 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -572,9 +572,9 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
batch.pos [idx] = j*n_batch + k;
batch.n_seq_id[idx] = 1;
batch.seq_id [idx][0] = seq;
batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
batch.output [idx] = batch.pos[idx] >= first ? 1 : 0;

n_outputs += batch.logits[idx] != 0;
n_outputs += batch.output[idx] != 0;
}
batch.n_tokens += batch_size;

Expand Down Expand Up @@ -669,7 +669,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
batch.output + i,
};

const int ret = llama_decode(ctx, batch_view);
Expand All @@ -680,7 +680,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<

int n_outputs = 0;
for (int i = 0; i < n_tokens; ++i) {
n_outputs += batch_view.logits[i] != 0;
n_outputs += batch_view.output[i] != 0;
}

memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
Expand Down Expand Up @@ -896,7 +896,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
}
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
n_logits += 1;

for (int s = 0; s < 4; ++s) {
Expand Down Expand Up @@ -1177,7 +1177,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
}
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;
n_logits += 1;

for (int s = 0; s < 2; ++s) {
Expand Down Expand Up @@ -1545,7 +1545,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
}
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
n_logits += 1;

for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
Expand Down
2 changes: 1 addition & 1 deletion examples/retrieval/retrieval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}

for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
if (!batch.output[i]) {
continue;
}

Expand Down
2 changes: 1 addition & 1 deletion examples/save-load-state/save-load-state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ int main(int argc, char ** argv) {
for (size_t i = 0; i < tokens.size(); i++) {
common_batch_add(batch, tokens[i], i, {0}, false);
}
batch.logits[batch.n_tokens - 1] = true; // generate next token
batch.output[batch.n_tokens - 1] = true; // generate next token

// evaluate prompt
llama_decode(ctx, batch);
Expand Down
8 changes: 4 additions & 4 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2413,7 +2413,7 @@ struct server_context {
std::vector<float> embd_res(n_embd, 0.0f);

for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
if (!batch.output[i] || batch.seq_id[i][0] != slot.id) {
continue;
}

Expand Down Expand Up @@ -2451,7 +2451,7 @@ struct server_context {
res->n_tokens = slot.n_prompt_tokens;

for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
if (!batch.output[i] || batch.seq_id[i][0] != slot.id) {
continue;
}

Expand Down Expand Up @@ -3109,7 +3109,7 @@ struct server_context {
}

// extract the logits only for the last token
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;

slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
Expand Down Expand Up @@ -3149,7 +3149,7 @@ struct server_context {
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
batch.output + i,
};

const int ret = llama_decode(ctx, batch_view);
Expand Down
2 changes: 1 addition & 1 deletion examples/tts/tts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size());

// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;

if (llama_decode(ctx_ttc, batch) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
Expand Down
2 changes: 1 addition & 1 deletion include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ extern "C" {
llama_pos * pos;
int32_t * n_seq_id;
llama_seq_id ** seq_id;
int8_t * logits; // TODO: rename this to "output"
int8_t * output;
} llama_batch;

enum llama_model_kv_override_type {
Expand Down
18 changes: 9 additions & 9 deletions src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,17 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
ubatch.output[ubatch.n_tokens + i] = 1;
out_ids.push_back(ids[seq.offset + i]);
}
} else if (batch->logits) {
} else if (batch->output) {
if (ubatch.equal_seqs) {
for (size_t i = 0; i < length; ++i) {
size_t id = ids[seq.offset + i];
int8_t is_output = batch->logits[id];
int8_t is_output = batch->output[id];
ubatch.output[ubatch.n_tokens + i] = is_output;
if (is_output) { out_ids.push_back(id); }
}
} else {
// simple split
ubatch.output = batch->logits + seq.offset;
ubatch.output = batch->output + seq.offset;
for (size_t i = 0; i < length; ++i) {
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
}
Expand Down Expand Up @@ -298,10 +298,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
}
batch.seq_id = seq_id.data();
}
if (!batch.logits) {
logits.resize(batch.n_tokens);
logits[logits.size() - 1] = true;
batch.logits = logits.data();
if (!batch.output) {
outputs.resize(batch.n_tokens);
outputs[outputs.size() - 1] = true;
batch.output = outputs.data();
}
}

Expand Down Expand Up @@ -348,7 +348,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
}
batch.seq_id[n_tokens_alloc] = nullptr;

batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
batch.output = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);

return batch;
}
Expand All @@ -364,5 +364,5 @@ void llama_batch_free(struct llama_batch batch) {
}
free(batch.seq_id);
}
if (batch.logits) free(batch.logits);
if (batch.output) free(batch.output);
}
Loading
Loading