Skip to content

Decouple tokenizers from Re2 and use IRegex interface #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions include/pytorch/tokenizers/bpe_tokenizer_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,16 @@
#include <unordered_map>
#include <vector>

// Third Party
#include <re2/re2.h>

// Local
#include <pytorch/tokenizers/error.h>
#include <pytorch/tokenizers/regex.h>
#include <pytorch/tokenizers/result.h>
#include <pytorch/tokenizers/string_integer_map.h>
#include <pytorch/tokenizers/tokenizer.h>

namespace tokenizers {
namespace detail {

using Re2UPtr = std::unique_ptr<re2::RE2>;
using TokenMap = StringIntegerMap<>;

template <typename TToken, typename TRank>
Expand Down Expand Up @@ -119,9 +116,15 @@ class BPETokenizerBase : public Tokenizer {
explicit BPETokenizerBase() {}
virtual ~BPETokenizerBase() override {}

std::pair<std::optional<std::string>, re2::StringPiece>
std::pair<std::optional<std::string>, std::string>
split_with_allowed_special_token_(
const std::string& input,
const TokenMap& allowed_special) const;

std::pair<std::optional<std::string>, std::string>
split_with_allowed_special_token_(
re2::StringPiece& input,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use std::string_view?

const std::string& input,
size_t offset,
const TokenMap& allowed_special) const;

Result<std::pair<std::vector<uint64_t>, uint64_t>> encode_with_special_token_(
Expand All @@ -133,17 +136,17 @@ class BPETokenizerBase : public Tokenizer {
const TokenMap& encoder) const;

// Protected members that can be overloaded by other BPE tokenizers
Re2UPtr special_token_regex_;
std::unique_ptr<IRegex> special_token_regex_;
std::optional<TokenMap> token_map_;
std::optional<TokenMap> special_token_map_;

private:
virtual Error _encode(
re2::StringPiece& input,
const std::string& input,
std::vector<uint64_t>& ret,
uint64_t& last_piece_token_len) const = 0;

virtual void _decode(re2::StringPiece input, std::string& ret) const = 0;
virtual void _decode(const std::string& input, std::string& ret) const = 0;
};

} // namespace detail
Expand Down
7 changes: 2 additions & 5 deletions include/pytorch/tokenizers/hf_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
// Standard
#include <string>

// Third Party
#include <re2/re2.h>

// Local
#include <pytorch/tokenizers/bpe_tokenizer_base.h>
#include <pytorch/tokenizers/error.h>
Expand All @@ -43,11 +40,11 @@ class HFTokenizer : public detail::BPETokenizerBase {

private:
Error _encode(
re2::StringPiece& input,
const std::string& input,
std::vector<uint64_t>& ret,
uint64_t& last_piece_token_len) const override;

void _decode(re2::StringPiece input, std::string& ret) const override;
void _decode(const std::string& input, std::string& ret) const override;

PreTokenizer::Ptr _pretokenizer;
TokenDecoder::Ptr _decoder;
Expand Down
19 changes: 11 additions & 8 deletions include/pytorch/tokenizers/pre_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
#include <nlohmann/json.hpp>
#include <re2/re2.h>

// Local
#include <pytorch/tokenizers/regex.h>

namespace tokenizers {

// -- Base ---------------------------------------------------------------------
Expand All @@ -42,7 +45,7 @@ class PreTokenizer {
* https://abseil.io/docs/cpp/guides/strings#string_view
*/
virtual std::vector<std::string> pre_tokenize(
re2::StringPiece input) const = 0;
const std::string& input) const = 0;

virtual ~PreTokenizer() = default;
}; // end class PreTokenizer
Expand Down Expand Up @@ -138,18 +141,16 @@ class PreTokenizerConfig {

class RegexPreTokenizer : public PreTokenizer {
public:
typedef std::unique_ptr<re2::RE2> Re2UPtr;

explicit RegexPreTokenizer(const std::string& pattern)
: regex_(RegexPreTokenizer::create_regex_(pattern)) {}

/** Pre-tokenize with the stored regex */
std::vector<std::string> pre_tokenize(re2::StringPiece input) const;
std::vector<std::string> pre_tokenize(const std::string& input) const;

protected:
static Re2UPtr create_regex_(const std::string& pattern);
static std::unique_ptr<IRegex> create_regex_(const std::string& pattern);

Re2UPtr regex_;
std::unique_ptr<IRegex> regex_;

}; // end class RegexPreTokenizer

Expand Down Expand Up @@ -185,7 +186,8 @@ class ByteLevelPreTokenizer : public PreTokenizer {
: ByteLevelPreTokenizer(true, pattern) {}

/** Perform pre-tokenization */
std::vector<std::string> pre_tokenize(re2::StringPiece input) const override;
std::vector<std::string> pre_tokenize(
const std::string& input) const override;

private:
const std::string pattern_;
Expand All @@ -206,7 +208,8 @@ class SequencePreTokenizer : public PreTokenizer {
explicit SequencePreTokenizer(std::vector<PreTokenizer::Ptr> pre_tokenizers);

/** Perform pre-tokenization */
std::vector<std::string> pre_tokenize(re2::StringPiece input) const override;
std::vector<std::string> pre_tokenize(
const std::string& input) const override;

private:
const std::vector<PreTokenizer::Ptr> pre_tokenizers_;
Expand Down
17 changes: 17 additions & 0 deletions include/pytorch/tokenizers/result.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,23 @@ T* Result<T>::operator->() {

} // namespace tokenizers

/**
* Unwraps a Result<T> value, throwing a runtime_error if the result contains an
* error.
*
* @param[in] result__ The Result<T> to unwrap
*/
#define TK_UNWRAP_THROW(result__) \
({ \
auto unwrap_result__ = (result__); \
if (!unwrap_result__.ok()) { \
throw std::runtime_error( \
"Error: " + \
std::to_string(static_cast<int>(unwrap_result__.error()))); \
} \
std::move(unwrap_result__.get()); \
})

/**
* Unwrap a Result to obtain its value. If the Result contains an error,
* propogate the error via trivial function return.
Expand Down
9 changes: 5 additions & 4 deletions include/pytorch/tokenizers/tiktoken.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
#include <cstdint>

// Third Party
#include "re2/re2.h"
#include <re2/re2.h>

// Local
#include <pytorch/tokenizers/bpe_tokenizer_base.h>
#include <pytorch/tokenizers/regex.h>
#include <pytorch/tokenizers/result.h>
#include <pytorch/tokenizers/tokenizer.h>

Expand Down Expand Up @@ -77,11 +78,11 @@ class Tiktoken : public detail::BPETokenizerBase {
}

Error _encode(
re2::StringPiece& input,
const std::string& input,
std::vector<uint64_t>& ret,
uint64_t& last_piece_token_len) const override;

void _decode(re2::StringPiece input, std::string& ret) const override;
void _decode(const std::string& input, std::string& ret) const override;

detail::TokenMap _build_special_token_map(ssize_t num_base_tokens) const;

Expand All @@ -93,7 +94,7 @@ class Tiktoken : public detail::BPETokenizerBase {
const std::string _pattern =
R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)";

detail::Re2UPtr _regex;
std::unique_ptr<IRegex> _regex;
};

} // namespace tokenizers
4 changes: 2 additions & 2 deletions include/pytorch/tokenizers/token_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class TokenDecoder {
*
* @returns decoded: The decoded token string
*/
virtual std::string decode(re2::StringPiece token) const = 0;
virtual std::string decode(const std::string& token) const = 0;

// virtual destructor
virtual ~TokenDecoder() = default;
Expand Down Expand Up @@ -92,7 +92,7 @@ class TokenDecoderConfig {

class ByteLevelTokenDecoder : public TokenDecoder {
public:
std::string decode(re2::StringPiece token) const override;
std::string decode(const std::string& token) const override;

}; // end class ByteLevelTokenDecoder

Expand Down
53 changes: 17 additions & 36 deletions src/bpe_tokenizer_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,42 +130,25 @@ static std::vector<uint64_t> _byte_pair_merge(
// ---- Helper utils end -------------------------------------------------------
// ---- protected start --------------------------------------------------------

std::pair<std::optional<std::string>, re2::StringPiece>
std::pair<std::optional<std::string>, std::string>
BPETokenizerBase::split_with_allowed_special_token_(
re2::StringPiece& input,
const std::string& input,
size_t offset,
const TokenMap& allowed_special) const {
if (!special_token_regex_) {
return std::make_pair(std::nullopt, input);
return std::make_pair(std::nullopt, input.substr(offset));
}

#if __cplusplus >= 202002L
auto start = input.begin();
#else
const char* start = input.data();
#endif
auto matches = special_token_regex_->find_all(input.substr(offset));

std::string special;
while (true) {
if (!re2::RE2::FindAndConsume(&input, *special_token_regex_, &special)) {
// No special token.
break;
for (const auto& m : matches) {
std::string matched_text = input.substr(offset + m.start, m.end - m.start);
if (allowed_special.tryGetInteger(matched_text).has_value()) {
return {matched_text, input.substr(offset, m.start)};
}

if (allowed_special.tryGetInteger(special).has_value()) {
// Found an allowed special token, split the text with it.
#if __cplusplus >= 202002L
return std::make_pair(
special,
re2::StringPiece(start, input.begin() - start - special.size()));
#else
return std::make_pair(
special,
re2::StringPiece(start, (input.data() - start) - special.size()));
#endif
} // else try to find the next special token
}

return std::make_pair(std::nullopt, input);
return {std::nullopt, input.substr(offset)};
}

Result<std::pair<std::vector<uint64_t>, uint64_t>>
Expand All @@ -174,33 +157,31 @@ BPETokenizerBase::encode_with_special_token_(
const TokenMap& allowed_special) const {
std::vector<uint64_t> tokens;
uint64_t last_piece_token_len = 0;
re2::StringPiece input(text);
while (true) {
size_t offset = 0;

while (offset < text.size()) {
auto [special, sub_input] =
split_with_allowed_special_token_(input, allowed_special);
split_with_allowed_special_token_(text, offset, allowed_special);

TK_CHECK_OK_OR_RETURN_ERROR(
_encode(sub_input, tokens, last_piece_token_len));
offset += sub_input.size();

if (special) {
const auto result = special_token_map_->tryGetInteger(*special);
if (!result) {
// Should never go here, since special pattern includes all special
// chars.
TK_LOG(Error, "unknown special token: %s\n", special->c_str());
return Error::EncodeFailure;
}

tokens.push_back(*result);
last_piece_token_len = 0;
offset += special->size(); // advance past the matched token
} else {
break;
}
}

// last_piece_token_len is how many tokens came from the last regex split.
// This is used for determining unstable tokens, since you can't merge
// across (stable) regex splits
return std::make_pair(tokens, last_piece_token_len);
}

Expand Down Expand Up @@ -273,7 +254,7 @@ Result<std::string> BPETokenizerBase::decode(uint64_t prev, uint64_t cur)
} else {
token_bytes = *result;
}
_decode(token_bytes, ret);
_decode(std::string(token_bytes), ret);

return ret;
}
Expand Down
10 changes: 4 additions & 6 deletions src/hf_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,11 @@ Error HFTokenizer::load(const std::string& path) {

// Set up the pre-tokenizer
try {
std::cout << "Setting up pretokenizer..." << std::endl;
_pretokenizer = PreTokenizerConfig()
.parse_json(parsed_json.at("pre_tokenizer"))
.create();
std::cout << "Pretokenizer set up" << std::endl;
} catch (const json::out_of_range& e) {
fprintf(stderr, "Could not parse pre_tokenizer: %s\n", e.what());
return Error::LoadFailure;
Expand Down Expand Up @@ -231,7 +233,7 @@ Error HFTokenizer::load(const std::string& path) {
// -------------------------private method start--------------------------------

Error HFTokenizer::_encode(
re2::StringPiece& input,
const std::string& input,
std::vector<uint64_t>& ret,
uint64_t& last_piece_token_len) const {
for (const auto& piece : _pretokenizer->pre_tokenize(input)) {
Expand All @@ -249,15 +251,11 @@ Error HFTokenizer::_encode(
return Error::Ok;
}

void HFTokenizer::_decode(re2::StringPiece input, std::string& ret) const {
void HFTokenizer::_decode(const std::string& input, std::string& ret) const {
if (_decoder) {
ret += _decoder->decode(input);
} else {
#ifdef _USE_INTERNAL_STRING_VIEW
ret += input.as_string();
#else
ret += input;
#endif
}
}

Expand Down
Loading