Skip to content

Add llama_beam_search(). #2267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct gpt_params {
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t n_beams = 0; // if non-zero then use beam search of given width.
float rope_freq_base = 10000.0f; // RoPE base frequency
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor

Expand Down
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ else()
add_subdirectory(simple)
add_subdirectory(embd-input)
add_subdirectory(llama-bench)
add_subdirectory(beam_search)
if (LLAMA_METAL)
add_subdirectory(metal)
endif()
Expand Down
8 changes: 8 additions & 0 deletions examples/beam_search/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
set(TARGET beam_search)
add_executable(${TARGET} beam_search.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
if(TARGET BUILD_INFO)
add_dependencies(${TARGET} BUILD_INFO)
endif()
188 changes: 188 additions & 0 deletions examples/beam_search/beam_search.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif

#include "common.h"
#include "llama.h"
#include "build-info.h"

#include <cassert>
#include <cinttypes>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>

#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#include <signal.h>
#include <unistd.h>
#elif defined (_WIN32)
#define WIN32_LEAN_AND_MEAN
#define NOMINMAX
#include <windows.h>
#include <signal.h>
#endif

// Used for debugging to print out beam tokens.
struct ostream_beam_view {
llama_context * ctx;
llama_beam_view beam_view;
};
std::ostream& operator<<(std::ostream& os, const ostream_beam_view & obv) {
os << "p(" << obv.beam_view.p << ") eob(" << std::boolalpha << obv.beam_view.eob << ") tokens(";
for (size_t i = 0 ; i < obv.beam_view.n_tokens ; ++i) {
os << llama_token_to_str(obv.ctx, obv.beam_view.tokens[i]);
}
return os << ')';
}

// Put here anything you want back in beam_search_callback().
struct beam_search_callback_data {
llama_context * ctx;
std::vector<llama_token> response;
};

// In this case, end-of-beam (eob) is equivalent to end-of-sentence (eos) but this need not always be the same.
// For example, eob can be flagged due to maximum token length, stop words, etc.
bool is_at_eob(const beam_search_callback_data & callback_data, const llama_token * tokens, const size_t n_tokens) {
return n_tokens && tokens[n_tokens-1] == llama_token_eos(callback_data.ctx);
}

// Function matching type llama_beam_search_callback_fn_t.
// Custom callback example is called each time the beams lengths increase:
// * Show progress by printing ',' following by number of convergent beam tokens if any.
// * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
// This is also called when the stop condition is met.
// Collect tokens into std::vector<llama_token> response which is pointed to by callback_data.
void beam_search_callback(void * callback_data_ptr, llama_beams_state beams_state) {
auto& callback_data = *static_cast<beam_search_callback_data*>(callback_data_ptr);
// Mark beams as EOS as needed.
for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
llama_beam_view& beam_view = beams_state.beam_views[i];
if (!beam_view.eob && is_at_eob(callback_data, beam_view.tokens, beam_view.n_tokens)) {
beam_view.eob = true;
}
}
printf(","); // Show progress
if (const size_t n = beams_state.common_prefix_length) {
callback_data.response.resize(callback_data.response.size() + n);
assert(0u < beams_state.n_beams);
const llama_token * tokens = beams_state.beam_views[0].tokens;
std::copy(tokens, tokens + n, callback_data.response.end() - n);
printf("%lu", n);
}
fflush(stdout);
#if 1 // DEBUG: print current beams for this iteration
std::cout << "\n\nCurrent beams (last_call=" << beams_state.last_call << "):\n";
for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
std::cout << "beams["<<i<<"]: " << ostream_beam_view{callback_data.ctx,beams_state.beam_views[i]} << std::endl;
}
#endif
}

int main(int argc, char ** argv)
{
gpt_params params;
//params.n_gpu_layers = 200;

//---------------------------------
// Print help :
//---------------------------------

if ( argc < 2 || argv[1][0] == '-' )
{
printf( "Usage: %s MODEL_PATH [BEAM_WIDTH=2] [PROMPT]\n" , argv[0] );
return 1 ;
}

//---------------------------------
// Load parameters :
//---------------------------------

params.model = argv[1];

params.n_beams = 2 < argc ? std::stoi(argv[2]) : 2;

if ( argc > 3 )
{
params.prompt = argv[3];
}

if ( params.prompt.empty() )
{
params.prompt = "### Request:\nHow many countries are there?\n\n### Response:\n";
}

//---------------------------------
// Init LLM :
//---------------------------------

llama_backend_init(params.numa);

llama_model * model;
llama_context * ctx;

std::tie(model, ctx) = llama_init_from_gpt_params( params );

if ( model == NULL )
{
fprintf( stderr , "%s: error: unable to load model\n" , __func__ );
return 1;
}

//---------------------------------
// Tokenize the prompt :
//---------------------------------

std::vector<llama_token> tokens_list = llama_tokenize(ctx, params.prompt, true);

const size_t max_context_size = llama_n_ctx( ctx );
const size_t max_tokens_list_size = max_context_size - 4 ;

if (tokens_list.size() > max_tokens_list_size)
{
fprintf( stderr , "%s: error: prompt too long (%lu tokens, max %lu)\n" ,
__func__ , tokens_list.size() , max_tokens_list_size );
return 1;
}

fprintf( stderr, "\n\n" );

// Print the tokens from the prompt :

for( auto id : tokens_list )
{
std::cout << llama_token_to_str(ctx, id);
}
std::cout << std::flush;

int n_past = llama_get_kv_cache_token_count(ctx);
if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads))
{
fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ );
return 1;
}
n_past += tokens_list.size();

beam_search_callback_data callback_data{ctx, {}};
size_t const beam_width = static_cast<size_t>(params.n_beams);
int const n_predict = 256;
llama_beam_search(ctx, beam_search_callback, &callback_data, beam_width, n_past, n_predict, params.n_threads);

std::cout << "\n\n";
for (llama_token const token_id : callback_data.response) {
std::cout << llama_token_to_str(ctx,token_id);
}
std::cout << std::endl;

llama_free( ctx );
llama_free_model( model );

llama_backend_free();

return 0;
}
90 changes: 77 additions & 13 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,62 @@ static void log_server_request(const Request &req, const Response &res)
});
}

bool is_at_eob(llama_server_context & server_context, const llama_token * tokens, const size_t n_tokens) {
return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context.ctx);
}

// Function matching type llama_beam_search_callback_fn_t.
// Custom callback example is called each time the beams lengths increase:
// * Show progress by printing ',' following by number of convergent beam tokens if any.
// * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
// This is also called when the stop condition is met.
// Collect tokens into std::vector<llama_token> response which is pointed to by callback_data.
void beam_search_callback(void * callback_data, llama_beams_state beams_state) {
auto & llama = *static_cast<llama_server_context*>(callback_data);
// Mark beams as EOS as needed.
for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
llama_beam_view& beam_view = beams_state.beam_views[i];
if (!beam_view.eob && is_at_eob(llama, beam_view.tokens, beam_view.n_tokens)) {
beam_view.eob = true;
}
}
printf(","); // Show progress
if (const size_t n = beams_state.common_prefix_length) {
llama.generated_token_probs.resize(llama.generated_token_probs.size() + n);
assert(0u < beams_state.n_beams);
const llama_token * tokens = beams_state.beam_views[0].tokens;
const auto map = [](llama_token tok) { return completion_token_output{{},tok}; };
std::transform(tokens, tokens + n, llama.generated_token_probs.end() - n, map);
printf("%lu", n);
}
fflush(stdout);
#if 0 // DEBUG: print current beams for this iteration
std::cout << "\n\nCurrent beams:\n";
for (size_t i=0 ; i < beams_state.n_beams ; ++i) {
std::cout << "beams["<<i<<"]: " << ostream_beam_view{state.ctx,beams_state.beam_views[i]} << std::endl;
}
#endif
}

struct token_translator {
llama_context * ctx;
std::string operator()(llama_token tok) const { return llama_token_to_str(ctx, tok); }
std::string operator()(completion_token_output cto) const { return (*this)(cto.tok); }
};

void append_to_generated_text_from_generated_token_probs(llama_server_context & llama) {
auto & gtps = llama.generated_token_probs;
auto translator = token_translator{llama.ctx};
auto add_strlen = [=](size_t sum, const completion_token_output & cto) { return sum + translator(cto).size(); };
const size_t len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen);
if (llama.generated_text.capacity() < llama.generated_text.size() + len) {
llama.generated_text.reserve(llama.generated_text.size() + len);
}
for (const completion_token_output & cto : gtps) {
llama.generated_text += translator(cto);
}
}

int main(int argc, char **argv)
{
// own arguments required by this example
Expand Down Expand Up @@ -1291,22 +1347,30 @@ int main(int argc, char **argv)
llama.beginCompletion();

if (!llama.stream) {
size_t stop_pos = std::string::npos;
if (llama.params.n_beams) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattpulver I noticed this check for llama.params.n_beams, but n_beams param doesn't seem to be set anywhere. Am I misinterpreting? If I set it myself, will it work along with the grammar for this server example?

Copy link
Contributor Author

@mattpulver mattpulver Sep 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is an example of where it is set and used:
https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/examples/beam-search/beam-search.cpp#L110

In examples/server/server.cpp I believe it may be set via the command line it should be set by server_params_parse() but it seems that was not yet done. Feel free to submit that as a PR.

I don't think beam search and grammar will currently work together. That is currently an open item: #2923

// Fill llama.generated_token_probs vector with final beam.
llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams,
llama.n_past, llama.n_remain, llama.params.n_threads);
// Translate llama.generated_token_probs to llama.generated_text.
append_to_generated_text_from_generated_token_probs(llama);
} else {
size_t stop_pos = std::string::npos;

while (llama.has_next_token) {
const completion_token_output token_with_probs = llama.doCompletion();
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama.ctx, token_with_probs.tok);
while (llama.has_next_token) {
const completion_token_output token_with_probs = llama.doCompletion();
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama.ctx, token_with_probs.tok);

stop_pos = llama.findStoppingStrings(llama.generated_text,
token_text.size(), STOP_FULL);
}
stop_pos = llama.findStoppingStrings(llama.generated_text,
token_text.size(), STOP_FULL);
}

if (stop_pos == std::string::npos) {
stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL);
}
if (stop_pos != std::string::npos) {
llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
llama.generated_text.end());
if (stop_pos == std::string::npos) {
stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL);
}
if (stop_pos != std::string::npos) {
llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
llama.generated_text.end());
}
}

const json data = format_final_response(llama, llama.generated_text, llama.generated_token_probs);
Expand Down
Loading