-
Notifications
You must be signed in to change notification settings - Fork 7
Add regex interface with re2 and std::regex implementations #48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <memory> | ||
#include <string> | ||
|
||
#include <re2/re2.h> | ||
|
||
#include <pytorch/tokenizers/regex.h> | ||
|
||
namespace tokenizers { | ||
|
||
/** | ||
* @brief RE2-based implementation of IRegex. | ||
*/ | ||
class Re2Regex : public IRegex { | ||
public: | ||
/** | ||
* @brief Construct a RE2 regex with the given pattern. | ||
* | ||
* @param pattern The regex pattern to compile. | ||
*/ | ||
explicit Re2Regex(const std::string& pattern); | ||
|
||
/** | ||
* @brief Return all non-overlapping matches found in the input string. | ||
*/ | ||
virtual std::vector<Match> find_all(const std::string& text) const override; | ||
|
||
/** | ||
* @brief Check if RE2 compiled the pattern successfully. | ||
*/ | ||
bool ok() const override; | ||
|
||
/** | ||
* @brief Expose internal RE2 pointer to the factory if needed. | ||
*/ | ||
const re2::RE2* rawRegex() const; | ||
|
||
private: | ||
std::unique_ptr<re2::RE2> regex_; | ||
|
||
friend Result<std::unique_ptr<IRegex>> create_regex( | ||
const std::string& pattern); | ||
}; | ||
|
||
} // namespace tokenizers |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <memory> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include <pytorch/tokenizers/result.h> | ||
|
||
namespace tokenizers { | ||
|
||
struct Match { | ||
std::string text; | ||
size_t position; | ||
}; | ||
|
||
/** | ||
* @brief Abstract interface for regex wrappers. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May like something like this: #pragma once
#include <string>
#include <vector>
class Regex {
public:
virtual ~Regex() = default;
// The only method subclasses have to implement.
virtual std::pair<size_t, size_t> match(const std::string& text, size_t start) const = 0;
// Convenience overload to match from the beginning.
std::pair<size_t, size_t> match(const std::string& text) const {
return match(text, 0);
}
// General implementation to match all.
std::vector<std::pair<size_t, size_t>> match_all(const std::string& text, size_t start = 0) const {
std::vector<std::pair<size_t, size_t>> matches;
for (size_t length = 0;; start += length) {
std::tie(start, length) = match(text, start);
if (length == 0) {
break;
}
matches.emplace_back(start, length);
}
return matches;
}
}; There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like we should just leave this API as is. We can get into a more granular API design later if necessary but the main point of all of this was to simply just provide a pcre2 fallback if re2 didn't work. I don't really expect people to be adding different regex implementations to be honest so don't want to overengineer too much. Another reason is I'd like to not mess with the current re2 code which uses |
||
*/ | ||
class IRegex { | ||
jackzhxng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
public: | ||
virtual ~IRegex() = default; | ||
|
||
/** | ||
* @brief Find all non-overlapping matches in the input string. | ||
* | ||
* @param text The input string to search. | ||
* @return A vector of strings containing all matched substrings. | ||
*/ | ||
virtual std::vector<Match> find_all(const std::string& text) const = 0; | ||
|
||
/** | ||
* @brief Check if the regex pattern was compiled successfully. | ||
* | ||
* @return true if the pattern is valid and ready to use, false otherwise. | ||
*/ | ||
virtual bool ok() const = 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally regex should either fail at constructor, or stay valid forever once created. Do we need a standalone check like this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To see if it failed during construction so it can then fallback on another regex implementation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if we make this class construction-agnostic and let the subclasses deal with errors during construction? Eg. different regex impls may approach it differently: throw exceptions from constructor, return an error, or have a dedicated "compile" method, etc. And the users of this interface shouldn't care how exactly a regex has been constructed, they just want to get a match using an apriori valid interface. Normally, it's up to a "regex factory" or some higher level concept to deal with failures at construction (eg. pick a proper impl as a fallback according to some logic) and then provide a valid pointer to this interface. Like a platform-specific impl could leverage Apple's There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that makes sense, I can make this protected? I just added this to use in the factory function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's consider how it's gonna be used. I guess Some pseudocode: // MyModelTokenizer is such that requires regex, so it'll use IRegex.
// Other tokenizers may not need regex at all, btw.
class MyModelTokenizer : public ITokenizer {
MyModelTokenizer(const std::string& filepath, std::unique_ptr<IRegex> regex)
: regex_(std::move(regex)) {
// open file, initialize everything else
}
std::vector<size_t> encode(const std::string& text) override {
// Use regex_ to parse the text, etc.
// It's guaranteed the injected IRegex is ready to use and there's no need to validate it again
// Tokenizer doesn't need use anything of IRegex beyond matching text
auto tokens = regex_->match_all(text);
}
};
// MyModelRunner is such that required text tokenization, so it'll use ITokenizer.
// Other runners may not need tokenization at all, btw, or expect some other components do tokenization and provide them with already ready tokens.
class MyModelRunner : public IRunner {
MyModelRunner(std::unique_ptr<ITokenizer> tokenizer)
: tokenizer_(std::move(tokenizer)) {...}
std::vector<size_t> preprocess(const std::string& text) override {
return tokenizer_->encode(text);
}
size_t generate(const std::vector<size_t>& tokens) override { ... }
}; So we can inject various regex implementations into tokenizers that do need regexp, and the latter never have to deal with regex creation or check its validity. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I ended up removing it |
||
}; | ||
|
||
/** | ||
* @brief Creates a regex instance. Tries RE2 first, falls back to std::regex. | ||
* | ||
* @param pattern The regex pattern to compile. | ||
* @return A unique pointer to an IRegex-compatible object. | ||
*/ | ||
Result<std::unique_ptr<IRegex>> create_regex(const std::string& pattern); | ||
|
||
} // namespace tokenizers |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <memory> | ||
#include <regex> | ||
#include <string> | ||
#include "regex.h" | ||
jackzhxng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
namespace tokenizers { | ||
|
||
/** | ||
* @brief std::regex-based implementation of IRegex. | ||
*/ | ||
class StdRegex : public IRegex { | ||
public: | ||
/** | ||
* @brief Construct a std::regex wrapper with the given pattern. | ||
* | ||
* @param pattern The regex pattern to compile. | ||
* @throws std::regex_error if the pattern is invalid. | ||
*/ | ||
explicit StdRegex(const std::string& pattern); | ||
|
||
/** | ||
* @brief Find all non-overlapping matches in the input string. | ||
*/ | ||
virtual std::vector<Match> find_all(const std::string& text) const override; | ||
|
||
/** | ||
* @brief Check if std::regex compiled the pattern successfully. | ||
* | ||
* @return true if the pattern is valid, false otherwise. | ||
*/ | ||
bool ok() const override; | ||
|
||
private: | ||
std::regex regex_; | ||
}; | ||
|
||
} // namespace tokenizers |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <pytorch/tokenizers/re2_regex.h> | ||
|
||
namespace tokenizers { | ||
|
||
Re2Regex::Re2Regex(const std::string& pattern) { | ||
regex_ = std::make_unique<re2::RE2>(pattern); | ||
// Warmup re2 as it is slow on the first run, void the return value as it's | ||
// not needed Refer to | ||
// https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141 | ||
(void)regex_->ReverseProgramSize(); | ||
} | ||
|
||
std::vector<Match> Re2Regex::find_all(const std::string& text) const { | ||
std::vector<Match> result; | ||
re2::StringPiece input(text); | ||
re2::StringPiece piece; | ||
|
||
const char* base = input.data(); | ||
|
||
while (RE2::FindAndConsume(&input, *regex_, &piece)) { | ||
size_t pos = piece.data() - base; | ||
result.push_back({std::string(piece.data(), piece.size()), pos}); | ||
} | ||
|
||
return result; | ||
} | ||
|
||
bool Re2Regex::ok() const { | ||
return regex_ && regex_->ok(); | ||
} | ||
|
||
const re2::RE2* Re2Regex::rawRegex() const { | ||
return regex_.get(); | ||
} | ||
|
||
} // namespace tokenizers |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <pytorch/tokenizers/regex.h> | ||
#include <pytorch/tokenizers/re2_regex.h> | ||
#include <pytorch/tokenizers/std_regex.h> | ||
|
||
#include <re2/re2.h> | ||
#include <iostream> | ||
#include <memory> | ||
|
||
namespace tokenizers { | ||
|
||
/** | ||
* @brief Factory function that creates a regex object using RE2 if possible. | ||
* Falls back to std::regex if RE2 rejects the pattern with | ||
* ErrorBadPerlOp. | ||
*/ | ||
Result<std::unique_ptr<IRegex>> create_regex(const std::string& pattern) { | ||
// Try RE2 first | ||
auto re2 = std::make_unique<Re2Regex>("(" + pattern + ")"); | ||
|
||
if (re2->ok()) { | ||
return static_cast<std::unique_ptr<IRegex>>(std::move(re2)); | ||
} | ||
|
||
const re2::RE2* raw = re2->rawRegex(); | ||
if (raw && raw->error_code() == re2::RE2::ErrorBadPerlOp) { | ||
try { | ||
std::cout | ||
<< "RE2 is unable to support things such as negative lookaheads in " | ||
<< pattern << ", defaulting to std::regex."; | ||
auto std_regex = std::make_unique<StdRegex>("(" + pattern + ")"); | ||
return static_cast<std::unique_ptr<IRegex>>(std::move(std_regex)); | ||
} catch (const std::regex_error& e) { | ||
std::cerr << "std::regex failed: " << e.what() << std::endl; | ||
return tokenizers::Error::LoadFailure; | ||
} | ||
} else { | ||
std::cerr << "RE2 failed to compile pattern: " << pattern << "\n"; | ||
std::cerr << "Error: " << (raw ? raw->error() : "unknown") << std::endl; | ||
return tokenizers::Error::LoadFailure; | ||
} | ||
} | ||
|
||
} // namespace tokenizers |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <pytorch/tokenizers/std_regex.h> | ||
#include <regex> | ||
|
||
namespace tokenizers { | ||
|
||
StdRegex::StdRegex(const std::string& pattern) : regex_(pattern) {} | ||
|
||
std::vector<Match> StdRegex::find_all(const std::string& text) const { | ||
std::vector<Match> result; | ||
std::sregex_iterator iter(text.begin(), text.end(), regex_); | ||
std::sregex_iterator end; | ||
|
||
for (; iter != end; ++iter) { | ||
const auto& match = *iter; | ||
result.push_back({ | ||
match[1].str(), // capture group 1 | ||
static_cast<size_t>(match.position(1)) // position of group 1 | ||
}); | ||
} | ||
|
||
return result; | ||
} | ||
|
||
bool StdRegex::ok() const { | ||
// std::regex constructor throws if the pattern is invalid | ||
// If we got here, the pattern is valid | ||
return true; | ||
} | ||
|
||
} // namespace tokenizers |
Uh oh!
There was an error while loading. Please reload this page.