Skip to content

Commit 9668aa1

Browse files
committed
llama : distinguish pieces from decoded text + fix detokenization
1 parent 5d0ffb6 commit 9668aa1

File tree

15 files changed

+93
-68
lines changed

15 files changed

+93
-68
lines changed

common/common.cpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -733,16 +733,37 @@ std::vector<llama_token> llama_tokenize(
733733
return result;
734734
}
735735

736-
std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) {
736+
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
737737
std::vector<char> result(8, 0);
738-
const int n_tokens = llama_token_to_str(ctx, token, result.data(), result.size());
738+
const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size());
739739
if (n_tokens < 0) {
740740
result.resize(-n_tokens);
741-
int check = llama_token_to_str(ctx, token, result.data(), result.size());
741+
int check = llama_token_to_piece(ctx, token, result.data(), result.size());
742742
GGML_ASSERT(check == -n_tokens);
743743
} else {
744744
result.resize(n_tokens);
745745
}
746746

747747
return std::string(result.data(), result.size());
748748
}
749+
750+
std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens) {
751+
const llama_token bos_id = llama_token_bos(ctx);
752+
753+
std::string piece;
754+
std::string result;
755+
756+
for (size_t i = 0; i < tokens.size(); ++i) {
757+
piece = llama_token_to_piece(ctx, tokens[i]);
758+
759+
// remove the leading space of the first non-BOS token
760+
if (((tokens[0] == bos_id && i == 1) || (tokens[0] != bos_id && i == 0)) && piece[0] == ' ') {
761+
piece = piece.substr(1);
762+
}
763+
764+
result += piece;
765+
}
766+
767+
return result;
768+
}
769+

common/common.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,11 @@ std::vector<llama_token> llama_tokenize(
121121
const std::string & text,
122122
bool add_bos);
123123

124-
std::string llama_token_to_str(
124+
std::string llama_token_to_piece(
125125
const struct llama_context * ctx,
126126
llama_token token);
127+
128+
// removes the leading space from the first non-BOS token
129+
std::string llama_detokenize(
130+
llama_context * ctx,
131+
const std::vector<llama_token> & tokens);

examples/beam_search/beam_search.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ struct ostream_beam_view {
3535
std::ostream& operator<<(std::ostream& os, const ostream_beam_view & obv) {
3636
os << "p(" << obv.beam_view.p << ") eob(" << std::boolalpha << obv.beam_view.eob << ") tokens(";
3737
for (size_t i = 0 ; i < obv.beam_view.n_tokens ; ++i) {
38-
os << llama_token_to_str(obv.ctx, obv.beam_view.tokens[i]);
38+
os << llama_token_to_piece(obv.ctx, obv.beam_view.tokens[i]);
3939
}
4040
return os << ')';
4141
}
@@ -156,7 +156,7 @@ int main(int argc, char ** argv)
156156

157157
for( auto id : tokens_list )
158158
{
159-
std::cout << llama_token_to_str(ctx, id);
159+
std::cout << llama_token_to_piece(ctx, id);
160160
}
161161
std::cout << std::flush;
162162

@@ -175,7 +175,7 @@ int main(int argc, char ** argv)
175175

176176
std::cout << "\n\n";
177177
for (llama_token const token_id : callback_data.response) {
178-
std::cout << llama_token_to_str(ctx,token_id);
178+
std::cout << llama_token_to_piece(ctx,token_id);
179179
}
180180
std::cout << std::endl;
181181

examples/embd-input/embd-input-lib.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ const char * sampling(struct MyModel * mymodel) {
214214
if (id == llama_token_eos(ctx)) {
215215
ret = "</s>";
216216
} else {
217-
ret = llama_token_to_str(ctx, id);
217+
ret = llama_token_to_piece(ctx, id);
218218
}
219219
eval_id(mymodel, id);
220220
return ret.c_str();

examples/embedding/embedding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ int main(int argc, char ** argv) {
6464
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
6565
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
6666
for (int i = 0; i < (int) embd_inp.size(); i++) {
67-
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]).c_str());
67+
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
6868
}
6969
fprintf(stderr, "\n");
7070
}

examples/main/main.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -280,22 +280,22 @@ int main(int argc, char ** argv) {
280280
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
281281
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
282282
for (int i = 0; i < (int) embd_inp.size(); i++) {
283-
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]).c_str());
283+
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
284284
}
285285

286286
if (ctx_guidance) {
287287
fprintf(stderr, "\n");
288288
fprintf(stderr, "%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str());
289289
fprintf(stderr, "%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
290290
for (int i = 0; i < (int) guidance_inp.size(); i++) {
291-
fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_str(ctx, guidance_inp[i]).c_str());
291+
fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str());
292292
}
293293
}
294294

295295
if (params.n_keep > 0) {
296296
fprintf(stderr, "%s: static prompt based on n_keep: '", __func__);
297297
for (int i = 0; i < params.n_keep; i++) {
298-
fprintf(stderr, "%s", llama_token_to_str(ctx, embd_inp[i]).c_str());
298+
fprintf(stderr, "%s", llama_token_to_piece(ctx, embd_inp[i]).c_str());
299299
}
300300
fprintf(stderr, "'\n");
301301
}
@@ -451,7 +451,7 @@ int main(int argc, char ** argv) {
451451
//printf("\n---\n");
452452
//printf("resetting: '");
453453
//for (int i = 0; i < (int) embd.size(); i++) {
454-
// printf("%s", llama_token_to_str(ctx, embd[i]));
454+
// printf("%s", llama_token_to_piece(ctx, embd[i]));
455455
//}
456456
//printf("'\n");
457457
//printf("\n---\n");
@@ -504,7 +504,7 @@ int main(int argc, char ** argv) {
504504
input_size = embd_guidance.size();
505505
//fprintf(stderr, "\n---------------------\n");
506506
//for (int i = 0; i < (int) embd_guidance.size(); i++) {
507-
//fprintf(stderr, "%s", llama_token_to_str(ctx, embd_guidance[i]));
507+
//fprintf(stderr, "%s", llama_token_to_piece(ctx, embd_guidance[i]));
508508
//}
509509
//fprintf(stderr, "\n---------------------\n");
510510
} else {
@@ -663,7 +663,7 @@ int main(int argc, char ** argv) {
663663
// display text
664664
if (input_echo) {
665665
for (auto id : embd) {
666-
printf("%s", llama_token_to_str(ctx, id).c_str());
666+
printf("%s", llama_token_to_piece(ctx, id).c_str());
667667
}
668668
fflush(stdout);
669669
}
@@ -679,7 +679,7 @@ int main(int argc, char ** argv) {
679679
if (params.antiprompt.size()) {
680680
std::string last_output;
681681
for (auto id : last_n_tokens) {
682-
last_output += llama_token_to_str(ctx, id);
682+
last_output += llama_token_to_piece(ctx, id);
683683
}
684684

685685
is_antiprompt = false;

examples/save-load-state/save-load-state.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ int main(int argc, char ** argv) {
8787
}
8888
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
8989
auto next_token = llama_sample_token(ctx, &candidates_p);
90-
auto next_token_str = llama_token_to_str(ctx, next_token);
90+
auto next_token_str = llama_token_to_piece(ctx, next_token);
9191
last_n_tokens_data.push_back(next_token);
9292

9393
printf("%s", next_token_str.c_str());
@@ -147,7 +147,7 @@ int main(int argc, char ** argv) {
147147
}
148148
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
149149
auto next_token = llama_sample_token(ctx2, &candidates_p);
150-
auto next_token_str = llama_token_to_str(ctx2, next_token);
150+
auto next_token_str = llama_token_to_piece(ctx2, next_token);
151151
last_n_tokens_data.push_back(next_token);
152152

153153
printf("%s", next_token_str.c_str());

examples/server/server.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end)
9494
std::string ret;
9595
for (; begin != end; ++begin)
9696
{
97-
ret += llama_token_to_str(ctx, *begin);
97+
ret += llama_token_to_piece(ctx, *begin);
9898
}
9999
return ret;
100100
}
@@ -123,7 +123,7 @@ static void server_log(const char *level, const char *function, int line,
123123
// format incomplete utf-8 multibyte character for output
124124
static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token)
125125
{
126-
std::string out = token == -1 ? "" : llama_token_to_str(ctx, token);
126+
std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token);
127127
// if the size is 1 and first bit is 1, meaning it's a partial character
128128
// (size > 1 meaning it's already a known token)
129129
if (out.size() == 1 && (out[0] & 0x80) == 0x80)
@@ -566,7 +566,7 @@ struct llama_server_context
566566

567567
if (!embd.empty() && embd.back() == llama_token_eos(ctx))
568568
{
569-
// stopping_word = llama_token_to_str(ctx, embd.back());
569+
// stopping_word = llama_token_to_piece(ctx, embd.back());
570570
has_next_token = false;
571571
stopped_eos = true;
572572
LOG_VERBOSE("eos token found", {});
@@ -613,7 +613,7 @@ struct llama_server_context
613613
{
614614
const completion_token_output token_with_probs = nextToken();
615615

616-
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(ctx, token_with_probs.tok);
616+
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok);
617617
generated_text += token_text;
618618

619619
if (params.n_probs > 0)
@@ -1248,7 +1248,7 @@ void beam_search_callback(void * callback_data, llama_beams_state beams_state) {
12481248

12491249
struct token_translator {
12501250
llama_context * ctx;
1251-
std::string operator()(llama_token tok) const { return llama_token_to_str(ctx, tok); }
1251+
std::string operator()(llama_token tok) const { return llama_token_to_piece(ctx, tok); }
12521252
std::string operator()(completion_token_output cto) const { return (*this)(cto.tok); }
12531253
};
12541254

@@ -1358,7 +1358,7 @@ int main(int argc, char **argv)
13581358

13591359
while (llama.has_next_token) {
13601360
const completion_token_output token_with_probs = llama.doCompletion();
1361-
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama.ctx, token_with_probs.tok);
1361+
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(llama.ctx, token_with_probs.tok);
13621362

13631363
stop_pos = llama.findStoppingStrings(llama.generated_text,
13641364
token_text.size(), STOP_FULL);
@@ -1389,7 +1389,7 @@ int main(int argc, char **argv)
13891389
if (token_with_probs.tok == -1 || llama.multibyte_pending > 0) {
13901390
continue;
13911391
}
1392-
const std::string token_text = llama_token_to_str(llama.ctx, token_with_probs.tok);
1392+
const std::string token_text = llama_token_to_piece(llama.ctx, token_with_probs.tok);
13931393

13941394
size_t pos = std::min(sent_count, llama.generated_text.size());
13951395

examples/simple/simple.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ int main(int argc, char ** argv) {
6363
fprintf(stderr, "\n\n");
6464

6565
for (auto id : tokens_list) {
66-
fprintf(stderr, "%s", llama_token_to_str(ctx, id).c_str());
66+
fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str());
6767
}
6868

6969
fflush(stderr);
@@ -112,7 +112,7 @@ int main(int argc, char ** argv) {
112112
}
113113

114114
// print the new token :
115-
printf("%s", llama_token_to_str(ctx, new_token_id).c_str());
115+
printf("%s", llama_token_to_piece(ctx, new_token_id).c_str());
116116
fflush(stdout);
117117

118118
// push this new token for next evaluation

examples/train-text-from-scratch/train-text-from-scratch.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1964,7 +1964,7 @@ void print_matrix(struct ggml_tensor * probs) {
19641964

19651965

19661966
void print_token(struct llama_context * ctx, llama_token token) {
1967-
printf("%s", llama_token_to_str(ctx, token).c_str());
1967+
printf("%s", llama_token_to_piece(ctx, token).c_str());
19681968
}
19691969

19701970
void print_tokens(struct llama_context* ctx, struct ggml_tensor * tokens) {
@@ -2202,7 +2202,7 @@ int tokenize_file(struct llama_context * lctx, const char * filename, std::vecto
22022202
const char * in = buf.data();
22032203
const char * end = buf.data() + buf.size();
22042204
for (int i = 0; i < (int) out.size(); ++i) {
2205-
std::string s = llama_token_to_str(lctx, out[i]);
2205+
std::string s = llama_token_to_piece(lctx, out[i]);
22062206
int len = s.length();
22072207
if (in >= end) {
22082208
printf("%s: unexpected end of original text.\n", __func__);

llama.cpp

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -796,12 +796,12 @@ static void llama_nop(struct ggml_tensor * tensor) { // don't offload by default
796796
(void) tensor;
797797
}
798798

799-
static std::string llama_token_to_text(const struct llama_context * ctx, llama_token token) {
799+
static std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) {
800800
std::vector<char> result(8, 0);
801-
const int n_tokens = llama_token_to_str(ctx, token, result.data(), result.size());
801+
const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size());
802802
if (n_tokens < 0) {
803803
result.resize(-n_tokens);
804-
int check = llama_token_to_str(ctx, token, result.data(), result.size());
804+
int check = llama_token_to_piece(ctx, token, result.data(), result.size());
805805
GGML_ASSERT(check == -n_tokens);
806806
} else {
807807
result.resize(n_tokens);
@@ -3374,6 +3374,11 @@ struct llm_tokenizer_bpe {
33743374
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos) {
33753375
std::vector<llama_vocab::id> output;
33763376

3377+
// OG tokenizer behavior:
3378+
//
3379+
// tokenizer.encode('', add_bos=True) returns [1]
3380+
// tokenizer.encode('', add_bos=False) returns []
3381+
33773382
if (bos && vocab.special_bos_id != -1) {
33783383
output.push_back(vocab.special_bos_id);
33793384
}
@@ -3382,11 +3387,12 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
33823387
return output;
33833388
}
33843389

3385-
raw_text = " " + raw_text;
3386-
33873390
switch (vocab.type) {
33883391
case LLAMA_VOCAB_TYPE_SPM:
33893392
{
3393+
// without adding this leading whitespace, we do not get the same results as the original tokenizer
3394+
raw_text = " " + raw_text;
3395+
33903396
llm_tokenizer_spm tokenizer(vocab);
33913397
llama_escape_whitespace(raw_text);
33923398
tokenizer.tokenize(raw_text, output);
@@ -4079,16 +4085,16 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
40794085
std::vector<llama_grammar_candidate> candidates_grammar;
40804086

40814087
for (size_t i = 0; i < candidates->size; ++i) {
4082-
const llama_token id = candidates->data[i].id;
4083-
const std::string text = llama_token_to_text(ctx, id);
4088+
const llama_token id = candidates->data[i].id;
4089+
const std::string piece = llama_token_to_str(ctx, id);
40844090
if (id == eos) {
40854091
if (!allow_eos) {
40864092
candidates->data[i].logit = -INFINITY;
40874093
}
4088-
} else if (text.empty() || text[0] == 0) {
4094+
} else if (piece.empty() || piece[0] == 0) {
40894095
candidates->data[i].logit = -INFINITY;
40904096
} else {
4091-
candidates_decoded.push_back(decode_utf8(text.c_str(), grammar->partial_utf8));
4097+
candidates_decoded.push_back(decode_utf8(piece.c_str(), grammar->partial_utf8));
40924098
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
40934099
}
40944100
}
@@ -4292,10 +4298,10 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
42924298
GGML_ASSERT(false);
42934299
}
42944300

4295-
const std::string text = llama_token_to_text(ctx, token);
4301+
const std::string piece = llama_token_to_str(ctx, token);
42964302

42974303
// Note terminating 0 in decoded string
4298-
const auto decoded = decode_utf8(text.c_str(), grammar->partial_utf8);
4304+
const auto decoded = decode_utf8(piece.c_str(), grammar->partial_utf8);
42994305
const auto & code_points = decoded.first;
43004306
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
43014307
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
@@ -6089,12 +6095,12 @@ int llama_tokenize_with_model(
60896095
return res.size();
60906096
}
60916097

6092-
int llama_token_to_str(const struct llama_context * ctx, llama_token token, char * buf, int length) {
6093-
return llama_token_to_str_with_model(&ctx->model, token, buf, length);
6098+
int llama_token_to_piece(const struct llama_context * ctx, llama_token token, char * buf, int length) {
6099+
return llama_token_to_piece_with_model(&ctx->model, token, buf, length);
60946100
}
60956101

6096-
// does not write null-terminator to str
6097-
int llama_token_to_str_with_model(const struct llama_model * model, llama_token token, char * buf, int length) {
6102+
// does not write null-terminator to buf
6103+
int llama_token_to_piece_with_model(const struct llama_model * model, llama_token token, char * buf, int length) {
60986104
if (0 <= token && token < llama_model_n_vocab(model)) {
60996105
if (llama_is_normal_token(model->vocab, token)) {
61006106
std::string result = model->vocab.id_to_token[token].text;

llama.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -381,15 +381,17 @@ extern "C" {
381381
int n_max_tokens,
382382
bool add_bos);
383383

384-
// Token Id -> String. Uses the vocabulary in the provided context
385-
// Does not write null terminator to the buffer
386-
LLAMA_API int llama_token_to_str(
384+
// Token Id -> Piece.
385+
// Uses the vocabulary in the provided context.
386+
// Does not write null terminator to the buffer.
387+
// Use code is responsible to remove the leading whitespace of the first non-BOS token.
388+
LLAMA_API int llama_token_to_piece(
387389
const struct llama_context * ctx,
388390
llama_token token,
389391
char * buf,
390392
int length);
391393

392-
LLAMA_API int llama_token_to_str_with_model(
394+
LLAMA_API int llama_token_to_piece_with_model(
393395
const struct llama_model * model,
394396
llama_token token,
395397
char * buf,

0 commit comments

Comments
 (0)