Skip to content

Commit 1ffbc52

Browse files
committed
Add examples/beam_search/beam_search.cpp for testing.
1 parent 4b20567 commit 1ffbc52

File tree

6 files changed

+265
-31
lines changed

6 files changed

+265
-31
lines changed

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ else()
4242
add_subdirectory(train-text-from-scratch)
4343
add_subdirectory(simple)
4444
add_subdirectory(embd-input)
45+
add_subdirectory(beam_search)
4546
if (LLAMA_METAL)
4647
add_subdirectory(metal)
4748
endif()

examples/beam_search/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
set(TARGET beam_search)
2+
add_executable(${TARGET} beam_search.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_11)
6+
if(TARGET BUILD_INFO)
7+
add_dependencies(${TARGET} BUILD_INFO)
8+
endif()

examples/beam_search/beam_search.cpp

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
#ifndef _GNU_SOURCE
2+
#define _GNU_SOURCE
3+
#endif
4+
5+
#include "common.h"
6+
#include "llama.h"
7+
#include "build-info.h"
8+
9+
#include <cassert>
10+
#include <cinttypes>
11+
#include <cmath>
12+
#include <cstdio>
13+
#include <cstring>
14+
#include <ctime>
15+
#include <fstream>
16+
#include <iostream>
17+
#include <string>
18+
#include <vector>
19+
20+
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
21+
#include <signal.h>
22+
#include <unistd.h>
23+
#elif defined (_WIN32)
24+
#define WIN32_LEAN_AND_MEAN
25+
#define NOMINMAX
26+
#include <windows.h>
27+
#include <signal.h>
28+
#endif
29+
30+
31+
32+
int main(int argc, char ** argv)
33+
{
34+
gpt_params params;
35+
36+
//---------------------------------
37+
// Print help :
38+
//---------------------------------
39+
40+
if ( argc < 2 || argv[1][0] == '-' )
41+
{
42+
printf( "Usage: %s MODEL_PATH [PROMPT]\n" , argv[0] );
43+
return 1 ;
44+
}
45+
46+
//---------------------------------
47+
// Load parameters :
48+
//---------------------------------
49+
50+
params.model = argv[1];
51+
52+
params.n_beams = 2; // Hard-code 2 until we can calculate how much memory is required
53+
54+
if ( argc > 2 )
55+
{
56+
params.prompt = argv[2];
57+
}
58+
59+
if ( params.prompt.empty() )
60+
{
61+
params.prompt = "### Request:\nHow many countries are there?\n\n### Response:\n";
62+
}
63+
64+
//---------------------------------
65+
// Init LLM :
66+
//---------------------------------
67+
68+
llama_backend_init(params.numa);
69+
70+
llama_model * model;
71+
llama_context * ctx;
72+
73+
std::tie(model, ctx) = llama_init_from_gpt_params( params );
74+
75+
if ( model == NULL )
76+
{
77+
fprintf( stderr , "%s: error: unable to load model\n" , __func__ );
78+
return 1;
79+
}
80+
81+
//---------------------------------
82+
// Tokenize the prompt :
83+
//---------------------------------
84+
85+
std::vector<llama_token> tokens_list = llama_tokenize(ctx, params.prompt, true);
86+
87+
const size_t max_context_size = llama_n_ctx( ctx );
88+
const size_t max_tokens_list_size = max_context_size - 4 ;
89+
90+
if (tokens_list.size() > max_tokens_list_size)
91+
{
92+
fprintf( stderr , "%s: error: prompt too long (%lu tokens, max %lu)\n" ,
93+
__func__ , tokens_list.size() , max_tokens_list_size );
94+
return 1;
95+
}
96+
97+
fprintf( stderr, "\n\n" );
98+
99+
// Print the tokens from the prompt :
100+
101+
for( auto id : tokens_list )
102+
{
103+
printf( "%s" , llama_token_to_str( ctx , id ) );
104+
}
105+
106+
fflush(stdout);
107+
108+
#if 1
109+
int n_past = llama_get_kv_cache_token_count(ctx);
110+
if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads))
111+
{
112+
fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ );
113+
return 1;
114+
}
115+
n_past += tokens_list.size();
116+
117+
int const n_predict = 1024;
118+
char const* response = llama_beam_search(ctx, params.n_beams, n_past, n_predict, params.n_threads);
119+
printf("\nDone:\n\n%s%s\n", params.prompt.c_str(), response);
120+
#else
121+
//---------------------------------
122+
// Main prediction loop :
123+
//---------------------------------
124+
125+
// The LLM keeps a contextual cache memory of previous token evaluation.
126+
// Usually, once this cache is full, it is required to recompute a compressed context based on previous
127+
// tokens (see "infinite text generation via context swapping" in the main example), but in this minimalist
128+
// example, we will just stop the loop once this cache is full or once an end of stream is detected.
129+
130+
while ( llama_get_kv_cache_token_count( ctx ) < max_context_size )
131+
{
132+
//---------------------------------
133+
// Evaluate the tokens :
134+
//---------------------------------
135+
136+
if ( llama_eval( ctx , tokens_list.data() , tokens_list.size() , llama_get_kv_cache_token_count( ctx ) , params.n_threads ) )
137+
{
138+
fprintf( stderr, "%s : failed to eval\n" , __func__ );
139+
return 1;
140+
}
141+
142+
tokens_list.clear();
143+
144+
//---------------------------------
145+
// Select the best prediction :
146+
//---------------------------------
147+
148+
llama_token new_token_id = 0;
149+
150+
auto logits = llama_get_logits( ctx );
151+
auto n_vocab = llama_n_vocab( ctx ); // the size of the LLM vocabulary (in tokens)
152+
153+
std::vector<llama_token_data> candidates;
154+
candidates.reserve( n_vocab );
155+
156+
for( llama_token token_id = 0 ; token_id < n_vocab ; token_id++ )
157+
{
158+
candidates.emplace_back( llama_token_data{ token_id , logits[ token_id ] , 0.0f } );
159+
}
160+
161+
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
162+
163+
// Select it using the "Greedy sampling" method :
164+
new_token_id = llama_sample_token_greedy( ctx , &candidates_p );
165+
166+
167+
// is it an end of stream ?
168+
if ( new_token_id == llama_token_eos() )
169+
{
170+
fprintf(stderr, " [end of text]\n");
171+
break;
172+
}
173+
174+
// Print the new token :
175+
printf( "%s" , llama_token_to_str( ctx , new_token_id ) );
176+
fflush( stdout );
177+
178+
// Push this new token for next evaluation :
179+
tokens_list.push_back( new_token_id );
180+
181+
} // wend of main loop
182+
#endif
183+
184+
llama_free( ctx );
185+
llama_free_model( model );
186+
187+
llama_backend_free();
188+
189+
return 0;
190+
}
191+
192+
// EOF

examples/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ struct gpt_params {
3434
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
3535
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
3636
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
37+
int32_t n_beams = 0; // Used in mem allocation if > 0 and by llama_beam_search().
3738
float rms_norm_eps = LLAMA_DEFAULT_RMS_EPS; // rms norm epsilon
3839
float rope_freq_base = 10000.0f; // RoPE base frequency
3940
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor

llama-util.h

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -436,25 +436,27 @@ struct llama_buffer {
436436
}
437437

438438
void resize(size_t len) {
439-
size = 0;
439+
if (size != len) {
440+
size = 0;
440441
#ifdef GGML_USE_METAL
441-
free(addr);
442-
if (len) {
443-
int result = posix_memalign((void **) &addr, getpagesize(), len);
444-
if (result == 0) {
445-
memset(addr, 0, len);
446-
size = len;
447-
} else {
448-
addr = NULL;
449-
}
450-
}
442+
free(addr);
443+
if (len) {
444+
int result = posix_memalign((void **) &addr, getpagesize(), len);
445+
if (result == 0) {
446+
memset(addr, 0, len);
447+
size = len;
448+
} else {
449+
addr = NULL;
450+
}
451+
}
451452
#else
452-
delete[] addr;
453-
if (len) {
454-
addr = new uint8_t[len];
455-
size = len;
456-
}
453+
delete[] addr;
454+
if (len) {
455+
addr = new uint8_t[len];
456+
size = len;
457+
}
457458
#endif
459+
}
458460
}
459461
};
460462

0 commit comments

Comments
 (0)