@@ -4349,7 +4349,7 @@ struct llama_beam {
4349
4349
};
4350
4350
4351
4351
// A struct for calculating logit-related info.
4352
- struct logit_info {
4352
+ struct llama_logit_info {
4353
4353
const float * const logits;
4354
4354
const int n_vocab;
4355
4355
const float max_l;
@@ -4358,7 +4358,7 @@ struct logit_info {
4358
4358
float max_l;
4359
4359
float operator ()(float sum, float l) const { return sum + std::exp (l - max_l); }
4360
4360
};
4361
- logit_info (llama_context * ctx)
4361
+ llama_logit_info (llama_context * ctx)
4362
4362
: logits(llama_get_logits(ctx))
4363
4363
, n_vocab(llama_n_vocab(ctx))
4364
4364
, max_l(*std::max_element (logits, logits + n_vocab))
@@ -4393,7 +4393,7 @@ struct logit_info {
4393
4393
}
4394
4394
};
4395
4395
4396
- struct beam_search {
4396
+ struct llama_beam_search_data {
4397
4397
llama_context * ctx;
4398
4398
size_t n_beams;
4399
4399
int n_past;
@@ -4408,7 +4408,7 @@ struct beam_search {
4408
4408
// Used to communicate to/from callback on beams state.
4409
4409
std::vector<llama_beam_view> beam_views;
4410
4410
4411
- beam_search (llama_context * ctx, size_t n_beams, int n_past, int n_predict, int n_threads)
4411
+ llama_beam_search_data (llama_context * ctx, size_t n_beams, int n_past, int n_predict, int n_threads)
4412
4412
: ctx(ctx)
4413
4413
, n_beams(n_beams)
4414
4414
, n_past(n_past)
@@ -4452,7 +4452,7 @@ struct beam_search {
4452
4452
if (!beam.tokens .empty ()) {
4453
4453
llama_eval (ctx, beam.tokens .data (), beam.tokens .size (), n_past, n_threads);
4454
4454
}
4455
- logit_info logit_info (ctx);
4455
+ llama_logit_info logit_info (ctx);
4456
4456
std::vector<llama_token_data> next_tokens = logit_info.top_k (n_beams);
4457
4457
size_t i=0 ;
4458
4458
if (next_beams.size () < n_beams) {
@@ -4569,9 +4569,9 @@ void llama_beam_search(llama_context * ctx,
4569
4569
assert (ctx);
4570
4570
const int64_t t_start_sample_us = ggml_time_us ();
4571
4571
4572
- beam_search beam_search (ctx, n_beams, n_past, n_predict, n_threads);
4572
+ llama_beam_search_data beam_search_data (ctx, n_beams, n_past, n_predict, n_threads);
4573
4573
4574
- beam_search .loop (callback, callback_data);
4574
+ beam_search_data .loop (callback, callback_data);
4575
4575
4576
4576
ctx->t_sample_us += ggml_time_us () - t_start_sample_us;
4577
4577
ctx->n_sample ++;
0 commit comments