Skip to content

Commit 5d0ffb6

Browse files
committed
llama : prefix input text for tokenization with whitespace
1 parent 5cad62b commit 5d0ffb6

File tree

5 files changed

+95
-63
lines changed

5 files changed

+95
-63
lines changed

examples/embedding/embedding.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,6 @@ int main(int argc, char ** argv) {
5656

5757
int n_past = 0;
5858

59-
// Add a space in front of the first character to match OG llama tokenizer behavior
60-
params.prompt.insert(0, 1, ' ');
61-
6259
// tokenize the prompt
6360
auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);
6461

examples/main/main.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,11 +195,6 @@ int main(int argc, char ** argv) {
195195
// tokenize the prompt
196196
std::vector<llama_token> embd_inp;
197197

198-
if (llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM) {
199-
// Add a space in front of the first character to match OG llama tokenizer behavior
200-
params.prompt.insert(0, 1, ' ');
201-
}
202-
203198
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
204199
embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
205200
} else {

llama.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1635,7 +1635,7 @@ static void llm_load_hparams(
16351635
}
16361636

16371637
// TODO: This should probably be in llama.h
1638-
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos);
1638+
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos);
16391639

16401640
static void llm_load_vocab(
16411641
llama_model_loader & ml,
@@ -3026,10 +3026,8 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
30263026
return vocab.token_to_id.at(buf);
30273027
}
30283028

3029-
static std::string llama_escape_whitespace(const std::string& text) {
3030-
std::string result = text;
3031-
replace_all(result, " ", "\xe2\x96\x81");
3032-
return result;
3029+
static void llama_escape_whitespace(std::string & text) {
3030+
replace_all(text, " ", "\xe2\x96\x81");
30333031
}
30343032

30353033
static void llama_unescape_whitespace(std::string & word) {
@@ -3373,22 +3371,25 @@ struct llm_tokenizer_bpe {
33733371
llm_bigram_bpe::queue work_queue;
33743372
};
33753373

3376-
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos) {
3374+
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos) {
33773375
std::vector<llama_vocab::id> output;
33783376

3377+
if (bos && vocab.special_bos_id != -1) {
3378+
output.push_back(vocab.special_bos_id);
3379+
}
3380+
33793381
if (raw_text.empty()) {
33803382
return output;
33813383
}
33823384

3383-
if (bos && vocab.special_bos_id != -1) {
3384-
output.push_back(vocab.special_bos_id);
3385-
}
3385+
raw_text = " " + raw_text;
33863386

33873387
switch (vocab.type) {
33883388
case LLAMA_VOCAB_TYPE_SPM:
33893389
{
33903390
llm_tokenizer_spm tokenizer(vocab);
3391-
tokenizer.tokenize(llama_escape_whitespace(raw_text), output);
3391+
llama_escape_whitespace(raw_text);
3392+
tokenizer.tokenize(raw_text, output);
33923393
} break;
33933394
case LLAMA_VOCAB_TYPE_BPE:
33943395
{

tests/test-tokenizer-0.cpp

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include <map>
77
#include <vector>
88

9-
static std::string unescape_whitespace(llama_context* ctx, const std::vector<llama_token>& tokens) {
9+
static std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens) {
1010
std::string result;
1111
for (size_t i = 0; i < tokens.size(); ++i) {
1212
result += llama_token_to_str(ctx, tokens[i]);
@@ -16,38 +16,40 @@ static std::string unescape_whitespace(llama_context* ctx, const std::vector<lla
1616

1717
static const std::map<std::string, std::vector<llama_token>> & k_tests() {
1818
static std::map<std::string, std::vector<llama_token>> _k_tests = {
19-
{ " ", { 1, 259, }, },
20-
{ " ", { 1, 1678, }, },
21-
{ " ", { 1, 268, }, },
22-
{ "\t", { 1, 29871, 12, }, },
23-
{ "\n", { 1, 29871, 13, }, },
24-
{ "\t\n", { 1, 29871, 12, 13, }, },
25-
{ "Hello world", { 1, 15043, 3186, }, },
26-
{ " Hello world", { 1, 29871, 15043, 3186, }, },
27-
{ "Hello World", { 1, 15043, 2787, }, },
28-
{ " Hello World", { 1, 29871, 15043, 2787, }, },
29-
{ " Hello World!", { 1, 29871, 15043, 2787, 29991, }, },
30-
{ "Hello, world!", { 1, 15043, 29892, 3186, 29991, }, },
31-
{ " Hello, world!", { 1, 29871, 15043, 29892, 3186, 29991, }, },
32-
{ " this is 🦙.cpp", { 1, 29871, 445, 338, 29871, 243, 162, 169, 156, 29889, 8223, }, },
33-
{ "w048 7tuijk dsdfhu", { 1, 281, 29900, 29946, 29947, 29871, 29955, 9161, 13535, 18031, 2176, 6905, }, },
34-
{ "нещо на Български", { 1, 1538, 4851, 665, 1386, 29713, 1305, }, },
35-
{ "កាន់តែពិសេសអាចខលចេញ", { 1, 29871, 31849, 31324, 31934, 228, 162, 142, 228, 161,
36-
146, 228, 162, 133, 228, 161, 153, 228, 161, 186,
37-
31708, 228, 162, 132, 31708, 228, 161, 165, 31324, 228,
38-
161, 136, 228, 161, 132, 228, 161, 158, 228, 161,
39-
136, 228, 162, 132, 228, 161, 140, }, },
19+
{ "" , { }, },
20+
{ " ", { 259, }, },
21+
{ " ", { 1678, }, },
22+
{ " ", { 268, }, },
23+
{ "\t", { 29871, 12, }, },
24+
{ "\n", { 29871, 13, }, },
25+
{ "\t\n", { 29871, 12, 13, }, },
26+
{ "Hello world", { 15043, 3186, }, },
27+
{ " Hello world", { 29871, 15043, 3186, }, },
28+
{ "Hello World", { 15043, 2787, }, },
29+
{ " Hello World", { 29871, 15043, 2787, }, },
30+
{ " Hello World!", { 29871, 15043, 2787, 29991, }, },
31+
{ "Hello, world!", { 15043, 29892, 3186, 29991, }, },
32+
{ " Hello, world!", { 29871, 15043, 29892, 3186, 29991, }, },
33+
{ " this is 🦙.cpp", { 29871, 445, 338, 29871, 243, 162, 169, 156, 29889, 8223, }, },
34+
{ "w048 7tuijk dsdfhu", { 281, 29900, 29946, 29947, 29871, 29955, 9161, 13535, 18031, 2176, 6905, }, },
35+
{ "нещо на Български", { 1538, 4851, 665, 1386, 29713, 1305, }, },
36+
{ "កាន់តែពិសេសអាចខលចេញ",
37+
{ 29871, 31849, 31324, 31934, 228, 162, 142, 228, 161,
38+
146, 228, 162, 133, 228, 161, 153, 228, 161, 186,
39+
31708, 228, 162, 132, 31708, 228, 161, 165, 31324, 228,
40+
161, 136, 228, 161, 132, 228, 161, 158, 228, 161,
41+
136, 228, 162, 132, 228, 161, 140, }, },
4042
{ "🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)",
41-
{ 1, 29871, 243, 162, 157, 131, 313, 8945, 29897, 29871,
42-
243, 162, 155, 185, 30722, 243, 162, 143, 174, 30598,
43-
313, 20787, 953, 3848, 275, 16125, 630, 29897, 29871, 31681,
44-
313, 6194, 953, 29877, 2397, 393, 756, 967, 1914, 5993, 29897, }, },
45-
{ "Hello", { 1, 15043, }, },
46-
{ " Hello", { 1, 29871, 15043, }, },
47-
{ " Hello", { 1, 259, 15043, }, },
48-
{ " Hello", { 1, 1678, 15043, }, },
49-
{ " Hello", { 1, 268, 15043, }, },
50-
{ " Hello\n Hello", { 1, 268, 15043, 13, 1678, 15043, }, },
43+
{ 29871, 243, 162, 157, 131, 313, 8945, 29897, 29871,
44+
243, 162, 155, 185, 30722, 243, 162, 143, 174, 30598,
45+
313, 20787, 953, 3848, 275, 16125, 630, 29897, 29871, 31681,
46+
313, 6194, 953, 29877, 2397, 393, 756, 967, 1914, 5993, 29897, }, },
47+
{ "Hello", { 15043, }, },
48+
{ " Hello", { 29871, 15043, }, },
49+
{ " Hello", { 259, 15043, }, },
50+
{ " Hello", { 1678, 15043, }, },
51+
{ " Hello", { 268, 15043, }, },
52+
{ " Hello\n Hello", { 268, 15043, 13, 1678, 15043, }, },
5153
};
5254

5355
return _k_tests;
@@ -102,30 +104,34 @@ int main(int argc, char **argv) {
102104
bool success = true;
103105

104106
for (const auto & test_kv : k_tests()) {
105-
// Add a space in front of the first character to match OG llama tokenizer behavior
106-
std::vector<llama_token> res = llama_tokenize(ctx, " " + test_kv.first, true);
107-
fprintf(stderr, "%s : '%s' tokenized to '%s'\n",
108-
__func__, test_kv.first.c_str(), unescape_whitespace(ctx, res).c_str());
107+
const std::vector<llama_token> res_bos = llama_tokenize(ctx, test_kv.first, true);
108+
const std::vector<llama_token> res_nobos = llama_tokenize(ctx, test_kv.first, false);
109109

110-
bool correct = res.size() == test_kv.second.size();
110+
fprintf(stderr, "%s : '%s' tokenized to '%s'\n", __func__, test_kv.first.c_str(), llama_detokenize(ctx, res_bos).c_str());
111111

112-
for (int i = 0; i < (int) res.size() && correct; ++i) {
113-
if (res[i] != test_kv.second[i]) {
112+
bool correct = res_nobos.size() == test_kv.second.size() && res_bos.size() == res_nobos.size() + 1 && res_bos[0] == 1;
113+
114+
for (int i = 0; i < (int) res_nobos.size() && correct; ++i) {
115+
if (test_kv.second[i] != res_bos[i + 1]) {
116+
correct = false;
117+
}
118+
if (test_kv.second[i] != res_nobos[i]) {
114119
correct = false;
115120
}
116121
}
117122

118123
if (!correct) {
119124
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
120125
fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
121-
unescape_whitespace(ctx, res).c_str(), unescape_whitespace(ctx, test_kv.second).c_str());
126+
llama_detokenize(ctx, res_nobos).c_str(),
127+
llama_detokenize(ctx, test_kv.second).c_str());
122128
fprintf(stderr, "%s : expected tokens: ", __func__);
123129
for (const auto & t : test_kv.second) {
124130
fprintf(stderr, "%6d, ", t);
125131
}
126132
fprintf(stderr, "\n");
127133
fprintf(stderr, "%s : got tokens: ", __func__);
128-
for (const auto & t : res) {
134+
for (const auto & t : res_nobos) {
129135
fprintf(stderr, "%6d, ", t);
130136
}
131137
fprintf(stderr, "\n");

tests/test-tokenizer-0.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,40 @@
1212

1313
tokenizer = SentencePieceProcessor(dir_tokenizer + '/tokenizer.model')
1414

15-
text = 'Hello, world!'
16-
print(text)
17-
print(tokenizer.encode(text, add_bos=True))
18-
print(tokenizer.decode(tokenizer.encode(text, add_bos=True)))
15+
tests = [
16+
""
17+
" ",
18+
" ",
19+
" ",
20+
"\t",
21+
"\n",
22+
"\t\n",
23+
"Hello world",
24+
" Hello world",
25+
"Hello World",
26+
" Hello World",
27+
" Hello World!",
28+
"Hello, world!",
29+
" Hello, world!",
30+
" this is 🦙.cpp",
31+
"w048 7tuijk dsdfhu",
32+
"нещо на Български",
33+
"កាន់តែពិសេសអាចខលចេញ",
34+
"🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)",
35+
"Hello",
36+
" Hello",
37+
" Hello",
38+
" Hello",
39+
" Hello",
40+
" Hello\n Hello",
41+
]
42+
43+
44+
for text in tests:
45+
print('text: ', text)
46+
print('\nwith bos:')
47+
print(tokenizer.encode(text, add_bos=True))
48+
print(tokenizer.decode(tokenizer.encode(text, add_bos=True)))
49+
print('\nwithout bos:')
50+
print(tokenizer.encode(text, add_bos=False))
51+
print(tokenizer.decode(tokenizer.encode(text, add_bos=False)))

0 commit comments

Comments
 (0)