Skip to content

Commit f52b18b

Browse files
authored
Decouple tokenizers from Re2 and use IRegex interface
Differential Revision: D73238728 Pull Request resolved: #49
1 parent bca09a2 commit f52b18b

13 files changed

+138
-128
lines changed

include/pytorch/tokenizers/bpe_tokenizer_base.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,16 @@
1818
#include <unordered_map>
1919
#include <vector>
2020

21-
// Third Party
22-
#include <re2/re2.h>
23-
2421
// Local
2522
#include <pytorch/tokenizers/error.h>
23+
#include <pytorch/tokenizers/regex.h>
2624
#include <pytorch/tokenizers/result.h>
2725
#include <pytorch/tokenizers/string_integer_map.h>
2826
#include <pytorch/tokenizers/tokenizer.h>
2927

3028
namespace tokenizers {
3129
namespace detail {
3230

33-
using Re2UPtr = std::unique_ptr<re2::RE2>;
3431
using TokenMap = StringIntegerMap<>;
3532

3633
template <typename TToken, typename TRank>
@@ -119,9 +116,15 @@ class BPETokenizerBase : public Tokenizer {
119116
explicit BPETokenizerBase() {}
120117
virtual ~BPETokenizerBase() override {}
121118

122-
std::pair<std::optional<std::string>, re2::StringPiece>
119+
std::pair<std::optional<std::string>, std::string>
120+
split_with_allowed_special_token_(
121+
const std::string& input,
122+
const TokenMap& allowed_special) const;
123+
124+
std::pair<std::optional<std::string>, std::string>
123125
split_with_allowed_special_token_(
124-
re2::StringPiece& input,
126+
const std::string& input,
127+
size_t offset,
125128
const TokenMap& allowed_special) const;
126129

127130
Result<std::pair<std::vector<uint64_t>, uint64_t>> encode_with_special_token_(
@@ -133,17 +136,17 @@ class BPETokenizerBase : public Tokenizer {
133136
const TokenMap& encoder) const;
134137

135138
// Protected members that can be overloaded by other BPE tokenizers
136-
Re2UPtr special_token_regex_;
139+
std::unique_ptr<IRegex> special_token_regex_;
137140
std::optional<TokenMap> token_map_;
138141
std::optional<TokenMap> special_token_map_;
139142

140143
private:
141144
virtual Error _encode(
142-
re2::StringPiece& input,
145+
const std::string& input,
143146
std::vector<uint64_t>& ret,
144147
uint64_t& last_piece_token_len) const = 0;
145148

146-
virtual void _decode(re2::StringPiece input, std::string& ret) const = 0;
149+
virtual void _decode(const std::string& input, std::string& ret) const = 0;
147150
};
148151

149152
} // namespace detail

include/pytorch/tokenizers/hf_tokenizer.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@
1515
// Standard
1616
#include <string>
1717

18-
// Third Party
19-
#include <re2/re2.h>
20-
2118
// Local
2219
#include <pytorch/tokenizers/bpe_tokenizer_base.h>
2320
#include <pytorch/tokenizers/error.h>
@@ -43,11 +40,11 @@ class HFTokenizer : public detail::BPETokenizerBase {
4340

4441
private:
4542
Error _encode(
46-
re2::StringPiece& input,
43+
const std::string& input,
4744
std::vector<uint64_t>& ret,
4845
uint64_t& last_piece_token_len) const override;
4946

50-
void _decode(re2::StringPiece input, std::string& ret) const override;
47+
void _decode(const std::string& input, std::string& ret) const override;
5148

5249
PreTokenizer::Ptr _pretokenizer;
5350
TokenDecoder::Ptr _decoder;

include/pytorch/tokenizers/pre_tokenizer.h

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
#include <nlohmann/json.hpp>
2020
#include <re2/re2.h>
2121

22+
// Local
23+
#include <pytorch/tokenizers/regex.h>
24+
2225
namespace tokenizers {
2326

2427
// -- Base ---------------------------------------------------------------------
@@ -42,7 +45,7 @@ class PreTokenizer {
4245
* https://abseil.io/docs/cpp/guides/strings#string_view
4346
*/
4447
virtual std::vector<std::string> pre_tokenize(
45-
re2::StringPiece input) const = 0;
48+
const std::string& input) const = 0;
4649

4750
virtual ~PreTokenizer() = default;
4851
}; // end class PreTokenizer
@@ -138,18 +141,16 @@ class PreTokenizerConfig {
138141

139142
class RegexPreTokenizer : public PreTokenizer {
140143
public:
141-
typedef std::unique_ptr<re2::RE2> Re2UPtr;
142-
143144
explicit RegexPreTokenizer(const std::string& pattern)
144145
: regex_(RegexPreTokenizer::create_regex_(pattern)) {}
145146

146147
/** Pre-tokenize with the stored regex */
147-
std::vector<std::string> pre_tokenize(re2::StringPiece input) const;
148+
std::vector<std::string> pre_tokenize(const std::string& input) const;
148149

149150
protected:
150-
static Re2UPtr create_regex_(const std::string& pattern);
151+
static std::unique_ptr<IRegex> create_regex_(const std::string& pattern);
151152

152-
Re2UPtr regex_;
153+
std::unique_ptr<IRegex> regex_;
153154

154155
}; // end class RegexPreTokenizer
155156

@@ -185,7 +186,8 @@ class ByteLevelPreTokenizer : public PreTokenizer {
185186
: ByteLevelPreTokenizer(true, pattern) {}
186187

187188
/** Perform pre-tokenization */
188-
std::vector<std::string> pre_tokenize(re2::StringPiece input) const override;
189+
std::vector<std::string> pre_tokenize(
190+
const std::string& input) const override;
189191

190192
private:
191193
const std::string pattern_;
@@ -206,7 +208,8 @@ class SequencePreTokenizer : public PreTokenizer {
206208
explicit SequencePreTokenizer(std::vector<PreTokenizer::Ptr> pre_tokenizers);
207209

208210
/** Perform pre-tokenization */
209-
std::vector<std::string> pre_tokenize(re2::StringPiece input) const override;
211+
std::vector<std::string> pre_tokenize(
212+
const std::string& input) const override;
210213

211214
private:
212215
const std::vector<PreTokenizer::Ptr> pre_tokenizers_;

include/pytorch/tokenizers/result.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,23 @@ T* Result<T>::operator->() {
185185

186186
} // namespace tokenizers
187187

188+
/**
189+
* Unwraps a Result<T> value, throwing a runtime_error if the result contains an
190+
* error.
191+
*
192+
* @param[in] result__ The Result<T> to unwrap
193+
*/
194+
#define TK_UNWRAP_THROW(result__) \
195+
({ \
196+
auto unwrap_result__ = (result__); \
197+
if (!unwrap_result__.ok()) { \
198+
throw std::runtime_error( \
199+
"Error: " + \
200+
std::to_string(static_cast<int>(unwrap_result__.error()))); \
201+
} \
202+
std::move(unwrap_result__.get()); \
203+
})
204+
188205
/**
189206
* Unwrap a Result to obtain its value. If the Result contains an error,
190207
* propogate the error via trivial function return.

include/pytorch/tokenizers/tiktoken.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
#include <cstdint>
1616

1717
// Third Party
18-
#include "re2/re2.h"
18+
#include <re2/re2.h>
1919

2020
// Local
2121
#include <pytorch/tokenizers/bpe_tokenizer_base.h>
22+
#include <pytorch/tokenizers/regex.h>
2223
#include <pytorch/tokenizers/result.h>
2324
#include <pytorch/tokenizers/tokenizer.h>
2425

@@ -77,11 +78,11 @@ class Tiktoken : public detail::BPETokenizerBase {
7778
}
7879

7980
Error _encode(
80-
re2::StringPiece& input,
81+
const std::string& input,
8182
std::vector<uint64_t>& ret,
8283
uint64_t& last_piece_token_len) const override;
8384

84-
void _decode(re2::StringPiece input, std::string& ret) const override;
85+
void _decode(const std::string& input, std::string& ret) const override;
8586

8687
detail::TokenMap _build_special_token_map(ssize_t num_base_tokens) const;
8788

@@ -93,7 +94,7 @@ class Tiktoken : public detail::BPETokenizerBase {
9394
const std::string _pattern =
9495
R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)";
9596

96-
detail::Re2UPtr _regex;
97+
std::unique_ptr<IRegex> _regex;
9798
};
9899

99100
} // namespace tokenizers

include/pytorch/tokenizers/token_decoder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class TokenDecoder {
4545
*
4646
* @returns decoded: The decoded token string
4747
*/
48-
virtual std::string decode(re2::StringPiece token) const = 0;
48+
virtual std::string decode(const std::string& token) const = 0;
4949

5050
// virtual destructor
5151
virtual ~TokenDecoder() = default;
@@ -92,7 +92,7 @@ class TokenDecoderConfig {
9292

9393
class ByteLevelTokenDecoder : public TokenDecoder {
9494
public:
95-
std::string decode(re2::StringPiece token) const override;
95+
std::string decode(const std::string& token) const override;
9696

9797
}; // end class ByteLevelTokenDecoder
9898

src/bpe_tokenizer_base.cpp

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -130,42 +130,25 @@ static std::vector<uint64_t> _byte_pair_merge(
130130
// ---- Helper utils end -------------------------------------------------------
131131
// ---- protected start --------------------------------------------------------
132132

133-
std::pair<std::optional<std::string>, re2::StringPiece>
133+
std::pair<std::optional<std::string>, std::string>
134134
BPETokenizerBase::split_with_allowed_special_token_(
135-
re2::StringPiece& input,
135+
const std::string& input,
136+
size_t offset,
136137
const TokenMap& allowed_special) const {
137138
if (!special_token_regex_) {
138-
return std::make_pair(std::nullopt, input);
139+
return std::make_pair(std::nullopt, input.substr(offset));
139140
}
140141

141-
#if __cplusplus >= 202002L
142-
auto start = input.begin();
143-
#else
144-
const char* start = input.data();
145-
#endif
142+
auto matches = special_token_regex_->find_all(input.substr(offset));
146143

147-
std::string special;
148-
while (true) {
149-
if (!re2::RE2::FindAndConsume(&input, *special_token_regex_, &special)) {
150-
// No special token.
151-
break;
144+
for (const auto& m : matches) {
145+
std::string matched_text = input.substr(offset + m.start, m.end - m.start);
146+
if (allowed_special.tryGetInteger(matched_text).has_value()) {
147+
return {matched_text, input.substr(offset, m.start)};
152148
}
153-
154-
if (allowed_special.tryGetInteger(special).has_value()) {
155-
// Found an allowed special token, split the text with it.
156-
#if __cplusplus >= 202002L
157-
return std::make_pair(
158-
special,
159-
re2::StringPiece(start, input.begin() - start - special.size()));
160-
#else
161-
return std::make_pair(
162-
special,
163-
re2::StringPiece(start, (input.data() - start) - special.size()));
164-
#endif
165-
} // else try to find the next special token
166149
}
167150

168-
return std::make_pair(std::nullopt, input);
151+
return {std::nullopt, input.substr(offset)};
169152
}
170153

171154
Result<std::pair<std::vector<uint64_t>, uint64_t>>
@@ -174,33 +157,31 @@ BPETokenizerBase::encode_with_special_token_(
174157
const TokenMap& allowed_special) const {
175158
std::vector<uint64_t> tokens;
176159
uint64_t last_piece_token_len = 0;
177-
re2::StringPiece input(text);
178-
while (true) {
160+
size_t offset = 0;
161+
162+
while (offset < text.size()) {
179163
auto [special, sub_input] =
180-
split_with_allowed_special_token_(input, allowed_special);
164+
split_with_allowed_special_token_(text, offset, allowed_special);
181165

182166
TK_CHECK_OK_OR_RETURN_ERROR(
183167
_encode(sub_input, tokens, last_piece_token_len));
168+
offset += sub_input.size();
184169

185170
if (special) {
186171
const auto result = special_token_map_->tryGetInteger(*special);
187172
if (!result) {
188-
// Should never go here, since special pattern includes all special
189-
// chars.
190173
TK_LOG(Error, "unknown special token: %s\n", special->c_str());
191174
return Error::EncodeFailure;
192175
}
193176

194177
tokens.push_back(*result);
195178
last_piece_token_len = 0;
179+
offset += special->size(); // advance past the matched token
196180
} else {
197181
break;
198182
}
199183
}
200184

201-
// last_piece_token_len is how many tokens came from the last regex split.
202-
// This is used for determining unstable tokens, since you can't merge
203-
// across (stable) regex splits
204185
return std::make_pair(tokens, last_piece_token_len);
205186
}
206187

@@ -273,7 +254,7 @@ Result<std::string> BPETokenizerBase::decode(uint64_t prev, uint64_t cur)
273254
} else {
274255
token_bytes = *result;
275256
}
276-
_decode(token_bytes, ret);
257+
_decode(std::string(token_bytes), ret);
277258

278259
return ret;
279260
}

src/hf_tokenizer.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,11 @@ Error HFTokenizer::load(const std::string& path) {
100100

101101
// Set up the pre-tokenizer
102102
try {
103+
std::cout << "Setting up pretokenizer..." << std::endl;
103104
_pretokenizer = PreTokenizerConfig()
104105
.parse_json(parsed_json.at("pre_tokenizer"))
105106
.create();
107+
std::cout << "Pretokenizer set up" << std::endl;
106108
} catch (const json::out_of_range& e) {
107109
fprintf(stderr, "Could not parse pre_tokenizer: %s\n", e.what());
108110
return Error::LoadFailure;
@@ -231,7 +233,7 @@ Error HFTokenizer::load(const std::string& path) {
231233
// -------------------------private method start--------------------------------
232234

233235
Error HFTokenizer::_encode(
234-
re2::StringPiece& input,
236+
const std::string& input,
235237
std::vector<uint64_t>& ret,
236238
uint64_t& last_piece_token_len) const {
237239
for (const auto& piece : _pretokenizer->pre_tokenize(input)) {
@@ -249,15 +251,11 @@ Error HFTokenizer::_encode(
249251
return Error::Ok;
250252
}
251253

252-
void HFTokenizer::_decode(re2::StringPiece input, std::string& ret) const {
254+
void HFTokenizer::_decode(const std::string& input, std::string& ret) const {
253255
if (_decoder) {
254256
ret += _decoder->decode(input);
255257
} else {
256-
#ifdef _USE_INTERNAL_STRING_VIEW
257-
ret += input.as_string();
258-
#else
259258
ret += input;
260-
#endif
261259
}
262260
}
263261

0 commit comments

Comments
 (0)