Skip to content

Add regex interface with re2 and std::regex implementations #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions include/pytorch/tokenizers/re2_regex.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <memory>
#include <string>

#include <re2/re2.h>

#include <pytorch/tokenizers/regex.h>

namespace tokenizers {

/**
* @brief RE2-based implementation of IRegex.
*/
class Re2Regex : public IRegex {
public:
/**
* @brief Construct a RE2 regex with the given pattern.
*
* @param pattern The regex pattern to compile.
*/
explicit Re2Regex(const std::string& pattern);

/**
* @brief Return all non-overlapping matches found in the input string.
*/
virtual std::vector<Match> find_all(const std::string& text) const override;

/**
* @brief Check if RE2 compiled the pattern successfully.
*/
bool ok() const override;

/**
* @brief Expose internal RE2 pointer to the factory if needed.
*/
const re2::RE2* rawRegex() const;

private:
std::unique_ptr<re2::RE2> regex_;

friend Result<std::unique_ptr<IRegex>> create_regex(
const std::string& pattern);
};

} // namespace tokenizers
55 changes: 55 additions & 0 deletions include/pytorch/tokenizers/regex.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <memory>
#include <string>
#include <vector>

#include <pytorch/tokenizers/result.h>

namespace tokenizers {

struct Match {
std::string text;
size_t position;
};

/**
* @brief Abstract interface for regex wrappers.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May like something like this:

#pragma once

#include <string>
#include <vector>

class Regex {
public:
  virtual ~Regex() = default;

  // The only method subclasses have to implement.
  virtual std::pair<size_t, size_t> match(const std::string& text, size_t start) const = 0;

  // Convenience overload to match from the beginning.
  std::pair<size_t, size_t> match(const std::string& text) const {
    return match(text, 0);
  }

  // General implementation to match all.
  std::vector<std::pair<size_t, size_t>> match_all(const std::string& text, size_t start = 0) const {
    std::vector<std::pair<size_t, size_t>> matches;
    for (size_t length = 0;; start += length) {
      std::tie(start, length) = match(text, start);
      if (length == 0) {
        break;
      }
      matches.emplace_back(start, length);
    }
    return matches;
  }
};

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we should just leave this API as is. We can get into a more granular API design later if necessary but the main point of all of this was to simply just provide a pcre2 fallback if re2 didn't work. I don't really expect people to be adding different regex implementations to be honest so don't want to overengineer too much. Another reason is I'd like to not mess with the current re2 code which uses FindAndConsume, which is stateful and would not fit into the proposed match API

*/
class IRegex {
public:
virtual ~IRegex() = default;

/**
* @brief Find all non-overlapping matches in the input string.
*
* @param text The input string to search.
* @return A vector of strings containing all matched substrings.
*/
virtual std::vector<Match> find_all(const std::string& text) const = 0;

/**
* @brief Check if the regex pattern was compiled successfully.
*
* @return true if the pattern is valid and ready to use, false otherwise.
*/
virtual bool ok() const = 0;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally regex should either fail at constructor, or stay valid forever once created. Do we need a standalone check like this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To see if it failed during construction so it can then fallback on another regex implementation

Copy link

@shoumikhin shoumikhin Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we make this class construction-agnostic and let the subclasses deal with errors during construction? Eg. different regex impls may approach it differently: throw exceptions from constructor, return an error, or have a dedicated "compile" method, etc. And the users of this interface shouldn't care how exactly a regex has been constructed, they just want to get a match using an apriori valid interface. Normally, it's up to a "regex factory" or some higher level concept to deal with failures at construction (eg. pick a proper impl as a fallback according to some logic) and then provide a valid pointer to this interface. Like a platform-specific impl could leverage Apple's NSRegularExpression, but then fallback to std::regex or something if the former fails. But whoever gets a pointer to this interface would never need to reason if it's valid or not, but just use it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that makes sense, I can make this protected? I just added this to use in the factory function create_regex

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's consider how it's gonna be used. I guess IRegex will be injected as a dep to some higher level concept like some concrete impl of ITokenizer?

Some pseudocode:

// MyModelTokenizer is such that requires regex, so it'll use IRegex.
// Other tokenizers may not need regex at all, btw.
class MyModelTokenizer : public ITokenizer {
  MyModelTokenizer(const std::string& filepath, std::unique_ptr<IRegex> regex)
  : regex_(std::move(regex)) {
    // open file, initialize everything else
  }

  std::vector<size_t> encode(const std::string& text) override {
    // Use regex_ to parse the text, etc.
    // It's guaranteed the injected IRegex is ready to use and there's no need to validate it again
    // Tokenizer doesn't need use anything of IRegex beyond matching text
    auto tokens = regex_->match_all(text);
  }
};

// MyModelRunner is such that required text tokenization, so it'll use ITokenizer.
// Other runners may not need tokenization at all, btw, or expect some other components do tokenization and provide them with already ready tokens.
class MyModelRunner : public IRunner {
  MyModelRunner(std::unique_ptr<ITokenizer> tokenizer) 
  : tokenizer_(std::move(tokenizer)) {...}

  std::vector<size_t> preprocess(const std::string& text) override {
    return tokenizer_->encode(text);
  }

  size_t generate(const std::vector<size_t>& tokens) override { ... }
};

So we can inject various regex implementations into tokenizers that do need regexp, and the latter never have to deal with regex creation or check its validity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up removing it

};

/**
* @brief Creates a regex instance. Tries RE2 first, falls back to std::regex.
*
* @param pattern The regex pattern to compile.
* @return A unique pointer to an IRegex-compatible object.
*/
Result<std::unique_ptr<IRegex>> create_regex(const std::string& pattern);

} // namespace tokenizers
47 changes: 47 additions & 0 deletions include/pytorch/tokenizers/std_regex.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <memory>
#include <regex>
#include <string>
#include "regex.h"

namespace tokenizers {

/**
* @brief std::regex-based implementation of IRegex.
*/
class StdRegex : public IRegex {
public:
/**
* @brief Construct a std::regex wrapper with the given pattern.
*
* @param pattern The regex pattern to compile.
* @throws std::regex_error if the pattern is invalid.
*/
explicit StdRegex(const std::string& pattern);

/**
* @brief Find all non-overlapping matches in the input string.
*/
virtual std::vector<Match> find_all(const std::string& text) const override;

/**
* @brief Check if std::regex compiled the pattern successfully.
*
* @return true if the pattern is valid, false otherwise.
*/
bool ok() const override;

private:
std::regex regex_;
};

} // namespace tokenizers
44 changes: 44 additions & 0 deletions src/re2_regex.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <pytorch/tokenizers/re2_regex.h>

namespace tokenizers {

Re2Regex::Re2Regex(const std::string& pattern) {
regex_ = std::make_unique<re2::RE2>(pattern);
// Warmup re2 as it is slow on the first run, void the return value as it's
// not needed Refer to
// https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141
(void)regex_->ReverseProgramSize();
}

std::vector<Match> Re2Regex::find_all(const std::string& text) const {
std::vector<Match> result;
re2::StringPiece input(text);
re2::StringPiece piece;

const char* base = input.data();

while (RE2::FindAndConsume(&input, *regex_, &piece)) {
size_t pos = piece.data() - base;
result.push_back({std::string(piece.data(), piece.size()), pos});
}

return result;
}

bool Re2Regex::ok() const {
return regex_ && regex_->ok();
}

const re2::RE2* Re2Regex::rawRegex() const {
return regex_.get();
}

} // namespace tokenizers
51 changes: 51 additions & 0 deletions src/regex.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <pytorch/tokenizers/regex.h>
#include <pytorch/tokenizers/re2_regex.h>
#include <pytorch/tokenizers/std_regex.h>

#include <re2/re2.h>
#include <iostream>
#include <memory>

namespace tokenizers {

/**
* @brief Factory function that creates a regex object using RE2 if possible.
* Falls back to std::regex if RE2 rejects the pattern with
* ErrorBadPerlOp.
*/
Result<std::unique_ptr<IRegex>> create_regex(const std::string& pattern) {
// Try RE2 first
auto re2 = std::make_unique<Re2Regex>("(" + pattern + ")");

if (re2->ok()) {
return static_cast<std::unique_ptr<IRegex>>(std::move(re2));
}

const re2::RE2* raw = re2->rawRegex();
if (raw && raw->error_code() == re2::RE2::ErrorBadPerlOp) {
try {
std::cout
<< "RE2 is unable to support things such as negative lookaheads in "
<< pattern << ", defaulting to std::regex.";
auto std_regex = std::make_unique<StdRegex>("(" + pattern + ")");
return static_cast<std::unique_ptr<IRegex>>(std::move(std_regex));
} catch (const std::regex_error& e) {
std::cerr << "std::regex failed: " << e.what() << std::endl;
return tokenizers::Error::LoadFailure;
}
} else {
std::cerr << "RE2 failed to compile pattern: " << pattern << "\n";
std::cerr << "Error: " << (raw ? raw->error() : "unknown") << std::endl;
return tokenizers::Error::LoadFailure;
}
}

} // namespace tokenizers
38 changes: 38 additions & 0 deletions src/std_regex.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <pytorch/tokenizers/std_regex.h>
#include <regex>

namespace tokenizers {

StdRegex::StdRegex(const std::string& pattern) : regex_(pattern) {}

std::vector<Match> StdRegex::find_all(const std::string& text) const {
std::vector<Match> result;
std::sregex_iterator iter(text.begin(), text.end(), regex_);
std::sregex_iterator end;

for (; iter != end; ++iter) {
const auto& match = *iter;
result.push_back({
match[1].str(), // capture group 1
static_cast<size_t>(match.position(1)) // position of group 1
});
}

return result;
}

bool StdRegex::ok() const {
// std::regex constructor throws if the pattern is invalid
// If we got here, the pattern is valid
return true;
}

} // namespace tokenizers