-
Notifications
You must be signed in to change notification settings - Fork 12k
Add llama_beam_search(). #2267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add llama_beam_search(). #2267
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
c4269e0
Add llama_beam_search().
mattpulver abe0829
Add '// Beam search' heading to llama.{h,cpp} after llama_grammar_acc…
mattpulver 9bedaf4
Add space around * pointers and & references.
mattpulver e46a8b5
Add spaces around comparison and assignment operators.
mattpulver 93daad7
Prefer west const.
mattpulver fa33614
Use llama_ prefix for structs in global namespace.
mattpulver b619cfc
Delete obsolete comment from an earlier revision.
mattpulver 5fa1ea2
Change eos to eob in llama_beam and llama_beam_view structs.
mattpulver File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
set(TARGET beam_search) | ||
add_executable(${TARGET} beam_search.cpp) | ||
install(TARGETS ${TARGET} RUNTIME) | ||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) | ||
target_compile_features(${TARGET} PRIVATE cxx_std_11) | ||
if(TARGET BUILD_INFO) | ||
add_dependencies(${TARGET} BUILD_INFO) | ||
endif() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
#ifndef _GNU_SOURCE | ||
#define _GNU_SOURCE | ||
#endif | ||
|
||
#include "common.h" | ||
#include "llama.h" | ||
#include "build-info.h" | ||
|
||
#include <cassert> | ||
#include <cinttypes> | ||
#include <cmath> | ||
#include <cstdio> | ||
#include <cstring> | ||
#include <ctime> | ||
#include <fstream> | ||
#include <iostream> | ||
#include <string> | ||
#include <vector> | ||
|
||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) | ||
#include <signal.h> | ||
#include <unistd.h> | ||
#elif defined (_WIN32) | ||
#define WIN32_LEAN_AND_MEAN | ||
#define NOMINMAX | ||
#include <windows.h> | ||
#include <signal.h> | ||
#endif | ||
|
||
// Used for debugging to print out beam tokens. | ||
struct ostream_beam_view { | ||
llama_context * ctx; | ||
llama_beam_view beam_view; | ||
}; | ||
std::ostream& operator<<(std::ostream& os, const ostream_beam_view & obv) { | ||
os << "p(" << obv.beam_view.p << ") eob(" << std::boolalpha << obv.beam_view.eob << ") tokens("; | ||
for (size_t i = 0 ; i < obv.beam_view.n_tokens ; ++i) { | ||
os << llama_token_to_str(obv.ctx, obv.beam_view.tokens[i]); | ||
} | ||
return os << ')'; | ||
} | ||
|
||
// Put here anything you want back in beam_search_callback(). | ||
struct beam_search_callback_data { | ||
llama_context * ctx; | ||
std::vector<llama_token> response; | ||
}; | ||
|
||
// In this case, end-of-beam (eob) is equivalent to end-of-sentence (eos) but this need not always be the same. | ||
// For example, eob can be flagged due to maximum token length, stop words, etc. | ||
bool is_at_eob(const beam_search_callback_data & callback_data, const llama_token * tokens, const size_t n_tokens) { | ||
return n_tokens && tokens[n_tokens-1] == llama_token_eos(callback_data.ctx); | ||
} | ||
|
||
// Function matching type llama_beam_search_callback_fn_t. | ||
// Custom callback example is called each time the beams lengths increase: | ||
// * Show progress by printing ',' following by number of convergent beam tokens if any. | ||
// * When all beams converge to a common prefix, they are made available in beams_state.beams[0]. | ||
// This is also called when the stop condition is met. | ||
// Collect tokens into std::vector<llama_token> response which is pointed to by callback_data. | ||
void beam_search_callback(void * callback_data_ptr, llama_beams_state beams_state) { | ||
auto& callback_data = *static_cast<beam_search_callback_data*>(callback_data_ptr); | ||
// Mark beams as EOS as needed. | ||
for (size_t i = 0 ; i < beams_state.n_beams ; ++i) { | ||
llama_beam_view& beam_view = beams_state.beam_views[i]; | ||
if (!beam_view.eob && is_at_eob(callback_data, beam_view.tokens, beam_view.n_tokens)) { | ||
beam_view.eob = true; | ||
} | ||
} | ||
printf(","); // Show progress | ||
if (const size_t n = beams_state.common_prefix_length) { | ||
callback_data.response.resize(callback_data.response.size() + n); | ||
assert(0u < beams_state.n_beams); | ||
const llama_token * tokens = beams_state.beam_views[0].tokens; | ||
std::copy(tokens, tokens + n, callback_data.response.end() - n); | ||
printf("%lu", n); | ||
} | ||
fflush(stdout); | ||
#if 1 // DEBUG: print current beams for this iteration | ||
std::cout << "\n\nCurrent beams (last_call=" << beams_state.last_call << "):\n"; | ||
for (size_t i = 0 ; i < beams_state.n_beams ; ++i) { | ||
std::cout << "beams["<<i<<"]: " << ostream_beam_view{callback_data.ctx,beams_state.beam_views[i]} << std::endl; | ||
} | ||
#endif | ||
} | ||
|
||
int main(int argc, char ** argv) | ||
{ | ||
gpt_params params; | ||
//params.n_gpu_layers = 200; | ||
|
||
//--------------------------------- | ||
// Print help : | ||
//--------------------------------- | ||
|
||
if ( argc < 2 || argv[1][0] == '-' ) | ||
{ | ||
printf( "Usage: %s MODEL_PATH [BEAM_WIDTH=2] [PROMPT]\n" , argv[0] ); | ||
return 1 ; | ||
} | ||
|
||
//--------------------------------- | ||
// Load parameters : | ||
//--------------------------------- | ||
|
||
params.model = argv[1]; | ||
|
||
params.n_beams = 2 < argc ? std::stoi(argv[2]) : 2; | ||
|
||
if ( argc > 3 ) | ||
{ | ||
params.prompt = argv[3]; | ||
} | ||
|
||
if ( params.prompt.empty() ) | ||
{ | ||
params.prompt = "### Request:\nHow many countries are there?\n\n### Response:\n"; | ||
} | ||
|
||
//--------------------------------- | ||
// Init LLM : | ||
//--------------------------------- | ||
|
||
llama_backend_init(params.numa); | ||
|
||
llama_model * model; | ||
llama_context * ctx; | ||
|
||
std::tie(model, ctx) = llama_init_from_gpt_params( params ); | ||
|
||
if ( model == NULL ) | ||
{ | ||
fprintf( stderr , "%s: error: unable to load model\n" , __func__ ); | ||
return 1; | ||
} | ||
|
||
//--------------------------------- | ||
// Tokenize the prompt : | ||
//--------------------------------- | ||
|
||
std::vector<llama_token> tokens_list = llama_tokenize(ctx, params.prompt, true); | ||
|
||
const size_t max_context_size = llama_n_ctx( ctx ); | ||
const size_t max_tokens_list_size = max_context_size - 4 ; | ||
|
||
if (tokens_list.size() > max_tokens_list_size) | ||
{ | ||
fprintf( stderr , "%s: error: prompt too long (%lu tokens, max %lu)\n" , | ||
__func__ , tokens_list.size() , max_tokens_list_size ); | ||
return 1; | ||
} | ||
|
||
fprintf( stderr, "\n\n" ); | ||
|
||
// Print the tokens from the prompt : | ||
|
||
for( auto id : tokens_list ) | ||
{ | ||
std::cout << llama_token_to_str(ctx, id); | ||
} | ||
std::cout << std::flush; | ||
|
||
int n_past = llama_get_kv_cache_token_count(ctx); | ||
if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads)) | ||
{ | ||
fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ ); | ||
return 1; | ||
} | ||
n_past += tokens_list.size(); | ||
|
||
beam_search_callback_data callback_data{ctx, {}}; | ||
size_t const beam_width = static_cast<size_t>(params.n_beams); | ||
int const n_predict = 256; | ||
llama_beam_search(ctx, beam_search_callback, &callback_data, beam_width, n_past, n_predict, params.n_threads); | ||
|
||
std::cout << "\n\n"; | ||
for (llama_token const token_id : callback_data.response) { | ||
std::cout << llama_token_to_str(ctx,token_id); | ||
} | ||
std::cout << std::endl; | ||
|
||
llama_free( ctx ); | ||
llama_free_model( model ); | ||
|
||
llama_backend_free(); | ||
|
||
return 0; | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mattpulver I noticed this check for
llama.params.n_beams
, but n_beams param doesn't seem to be set anywhere. Am I misinterpreting? If I set it myself, will it work along with the grammar for this server example?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is an example of where it is set and used:
https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/examples/beam-search/beam-search.cpp#L110
In
examples/server/server.cpp
I believe it may be set via the command lineit should be set byserver_params_parse()
but it seems that was not yet done. Feel free to submit that as a PR.I don't think beam search and grammar will currently work together. That is currently an open item: #2923