Skip to content

Add regex interface with re2 and std::regex implementations #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions include/pytorch/tokenizers/re2_regex.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#pragma once

#include <memory>
#include <string>
#include "regex.h"

// Third Party
#include <re2/re2.h>

/**
* @brief RE2-based implementation of IRegex.
*/
class Re2Regex : public IRegex {
public:
/**
* @brief Construct a RE2 regex with the given pattern.
*
* @param pattern The regex pattern to compile.
*/
explicit Re2Regex(const std::string& pattern);

/**
* @brief Return all non-overlapping matches found in the input string.
*/
virtual std::vector<Match> findAll(const std::string& text) const override;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a Result? Is this changed in the top PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't expect this to error


protected:
/**
* @brief Check if RE2 compiled the pattern successfully.
*/
bool ok() const;

/**
* @brief Expose internal RE2 pointer to the factory if needed.
*/
const re2::RE2* rawRegex() const;

private:
std::unique_ptr<re2::RE2> regex_;

friend std::unique_ptr<IRegex> createRegex(const std::string& pattern);
};
34 changes: 34 additions & 0 deletions include/pytorch/tokenizers/regex.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#pragma once

#include <memory>
#include <string>
#include <vector>

struct Match {
std::string text;
size_t position;
};

/**
* @brief Abstract interface for regex wrappers.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May like something like this:

#pragma once

#include <string>
#include <vector>

class Regex {
public:
  virtual ~Regex() = default;

  // The only method subclasses have to implement.
  virtual std::pair<size_t, size_t> match(const std::string& text, size_t start) const = 0;

  // Convenience overload to match from the beginning.
  std::pair<size_t, size_t> match(const std::string& text) const {
    return match(text, 0);
  }

  // General implementation to match all.
  std::vector<std::pair<size_t, size_t>> match_all(const std::string& text, size_t start = 0) const {
    std::vector<std::pair<size_t, size_t>> matches;
    for (size_t length = 0;; start += length) {
      std::tie(start, length) = match(text, start);
      if (length == 0) {
        break;
      }
      matches.emplace_back(start, length);
    }
    return matches;
  }
};

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we should just leave this API as is. We can get into a more granular API design later if necessary but the main point of all of this was to simply just provide a pcre2 fallback if re2 didn't work. I don't really expect people to be adding different regex implementations to be honest so don't want to overengineer too much. Another reason is I'd like to not mess with the current re2 code which uses FindAndConsume, which is stateful and would not fit into the proposed match API

*/
class IRegex {
public:
virtual ~IRegex() = default;

/**
* @brief Find all non-overlapping matches in the input string.
*
* @param text The input string to search.
* @return A vector of strings containing all matched substrings.
*/
virtual std::vector<Match> findAll(const std::string& text) const = 0;
};

/**
* @brief Creates a regex instance. Tries RE2 first, falls back to std::regex.
*
* @param pattern The regex pattern to compile.
* @return A unique pointer to an IRegex-compatible object.
*/
std::unique_ptr<IRegex> createRegex(const std::string& pattern);
28 changes: 28 additions & 0 deletions include/pytorch/tokenizers/std_regex.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#pragma once

#include <memory>
#include <regex>
#include <string>
#include "regex.h"

/**
* @brief std::regex-based implementation of IRegex.
*/
class StdRegex : public IRegex {
public:
/**
* @brief Construct a std::regex wrapper with the given pattern.
*
* @param pattern The regex pattern to compile.
* @throws std::regex_error if the pattern is invalid.
*/
explicit StdRegex(const std::string& pattern);

/**
* @brief Find all non-overlapping matches in the input string.
*/
virtual std::vector<Match> findAll(const std::string& text) const override;

private:
std::regex regex_;
};
33 changes: 33 additions & 0 deletions src/re2_regex.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include "pytorch/tokenizers/re2_regex.h"
#include <re2/re2.h>

Re2Regex::Re2Regex(const std::string& pattern) {
regex_ = std::make_unique<re2::RE2>("(" + pattern + ")");
// Warmup re2 as it is slow on the first run, void the return value as it's
// not needed Refer to
// https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141
(void)regex_->ReverseProgramSize();
}

bool Re2Regex::ok() const {
return regex_ && regex_->ok();
}

const re2::RE2* Re2Regex::rawRegex() const {
return regex_.get();
}

std::vector<Match> Re2Regex::findAll(const std::string& text) const {
std::vector<Match> result;
re2::StringPiece input(text);
re2::StringPiece piece;

const char* base = input.data();

while (RE2::FindAndConsume(&input, *regex_, &piece)) {
size_t pos = piece.data() - base;
result.push_back({ std::string(piece.data(), piece.size()), pos });
}

return result;
}
37 changes: 37 additions & 0 deletions src/regex.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include "pytorch/tokenizers/regex.h"
#include "pytorch/tokenizers/re2_regex.h"
#include "pytorch/tokenizers/std_regex.h"

#include <re2/re2.h>
#include <iostream>
#include <memory>

/**
* @brief Factory function that creates a regex object using RE2 if possible.
* Falls back to std::regex if RE2 rejects the pattern with
* ErrorBadPerlOp.
*/
std::unique_ptr<IRegex> createRegex(const std::string& pattern) {
auto re2 = std::make_unique<Re2Regex>(pattern);

if (re2->ok()) {
return re2;
}

const re2::RE2* raw = re2->rawRegex();
if (raw && raw->error_code() == re2::RE2::ErrorBadPerlOp) {
try {
std::cout
<< "RE2 is unable to support things such as negative lookaheads in "
<< pattern << ", defaulting to std::regex.";
return std::make_unique<StdRegex>(pattern);
} catch (const std::regex_error& e) {
std::cerr << "std::regex failed: " << e.what() << std::endl;
return nullptr;
}
} else {
std::cerr << "RE2 failed to compile pattern: " << pattern << "\n";
std::cerr << "Error: " << (raw ? raw->error() : "unknown") << std::endl;
return nullptr;
}
}
22 changes: 22 additions & 0 deletions src/std_regex.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "pytorch/tokenizers/std_regex.h"
#include <regex>

StdRegex::StdRegex(const std::string& pattern)
: regex_("(" + pattern + ")") // Add parentheses like RE2 version
{}

std::vector<Match> StdRegex::findAll(const std::string& text) const {
std::vector<Match> result;
std::sregex_iterator iter(text.begin(), text.end(), regex_);
std::sregex_iterator end;

for (; iter != end; ++iter) {
const auto& match = *iter;
result.push_back({
match[1].str(), // capture group 1
static_cast<size_t>(match.position(1)) // position of group 1
});
}

return result;
}