Skip to content

Commit c2d4de8

Browse files
committed
PR review
1 parent 934ffa3 commit c2d4de8

File tree

7 files changed

+145
-53
lines changed

7 files changed

+145
-53
lines changed

include/pytorch/tokenizers/re2_regex.h

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
19
#pragma once
210

311
#include <memory>
412
#include <string>
5-
#include "regex.h"
613

7-
// Third Party
814
#include <re2/re2.h>
915

16+
#include <pytorch/tokenizers/regex.h>
17+
18+
namespace tokenizers {
19+
1020
/**
1121
* @brief RE2-based implementation of IRegex.
1222
*/
@@ -24,11 +34,10 @@ class Re2Regex : public IRegex {
2434
*/
2535
virtual std::vector<Match> findAll(const std::string& text) const override;
2636

27-
protected:
2837
/**
2938
* @brief Check if RE2 compiled the pattern successfully.
3039
*/
31-
bool ok() const;
40+
bool ok() const override;
3241

3342
/**
3443
* @brief Expose internal RE2 pointer to the factory if needed.
@@ -38,5 +47,8 @@ class Re2Regex : public IRegex {
3847
private:
3948
std::unique_ptr<re2::RE2> regex_;
4049

41-
friend std::unique_ptr<IRegex> createRegex(const std::string& pattern);
50+
friend Result<std::unique_ptr<IRegex>> createRegex(
51+
const std::string& pattern);
4252
};
53+
54+
} // namespace tokenizers

include/pytorch/tokenizers/regex.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
19
#pragma once
210

311
#include <memory>
412
#include <string>
513
#include <vector>
614

15+
#include <pytorch/tokenizers/result.h>
16+
17+
namespace tokenizers {
18+
719
struct Match {
820
std::string text;
921
size_t position;
@@ -23,6 +35,13 @@ class IRegex {
2335
* @return A vector of strings containing all matched substrings.
2436
*/
2537
virtual std::vector<Match> findAll(const std::string& text) const = 0;
38+
39+
/**
40+
* @brief Check if the regex pattern was compiled successfully.
41+
*
42+
* @return true if the pattern is valid and ready to use, false otherwise.
43+
*/
44+
virtual bool ok() const = 0;
2645
};
2746

2847
/**
@@ -31,4 +50,6 @@ class IRegex {
3150
* @param pattern The regex pattern to compile.
3251
* @return A unique pointer to an IRegex-compatible object.
3352
*/
34-
std::unique_ptr<IRegex> createRegex(const std::string& pattern);
53+
Result<std::unique_ptr<IRegex>> createRegex(const std::string& pattern);
54+
55+
} // namespace tokenizers

include/pytorch/tokenizers/std_regex.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
19
#pragma once
210

311
#include <memory>
412
#include <regex>
513
#include <string>
614
#include "regex.h"
715

16+
namespace tokenizers {
17+
818
/**
919
* @brief std::regex-based implementation of IRegex.
1020
*/
@@ -23,6 +33,15 @@ class StdRegex : public IRegex {
2333
*/
2434
virtual std::vector<Match> findAll(const std::string& text) const override;
2535

36+
/**
37+
* @brief Check if std::regex compiled the pattern successfully.
38+
*
39+
* @return true if the pattern is valid, false otherwise.
40+
*/
41+
bool ok() const override;
42+
2643
private:
2744
std::regex regex_;
2845
};
46+
47+
} // namespace tokenizers

src/re2_regex.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <pytorch/tokenizers/re2_regex.h>
10+
11+
namespace tokenizers {
12+
13+
Re2Regex::Re2Regex(const std::string& pattern) {
14+
regex_ = std::make_unique<re2::RE2>(pattern);
15+
// Warmup re2 as it is slow on the first run, void the return value as it's
16+
// not needed Refer to
17+
// https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141
18+
(void)regex_->ReverseProgramSize();
19+
}
20+
21+
std::vector<Match> Re2Regex::findAll(const std::string& text) const {
22+
std::vector<Match> result;
23+
re2::StringPiece input(text);
24+
re2::StringPiece piece;
25+
26+
const char* base = input.data();
27+
28+
while (RE2::FindAndConsume(&input, *regex_, &piece)) {
29+
size_t pos = piece.data() - base;
30+
result.push_back({std::string(piece.data(), piece.size()), pos});
31+
}
32+
33+
return result;
34+
}
35+
36+
bool Re2Regex::ok() const {
37+
return regex_ && regex_->ok();
38+
}
39+
40+
const re2::RE2* Re2Regex::rawRegex() const {
41+
return regex_.get();
42+
}
43+
44+
} // namespace tokenizers

src/re2_regex.cpp

Lines changed: 0 additions & 33 deletions
This file was deleted.

src/regex.cpp

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,31 @@
1-
#include "pytorch/tokenizers/regex.h"
2-
#include "pytorch/tokenizers/re2_regex.h"
3-
#include "pytorch/tokenizers/std_regex.h"
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <pytorch/tokenizers/regex.h>
10+
#include <pytorch/tokenizers/re2_regex.h>
11+
#include <pytorch/tokenizers/std_regex.h>
412

513
#include <re2/re2.h>
614
#include <iostream>
715
#include <memory>
816

17+
namespace tokenizers {
18+
919
/**
1020
* @brief Factory function that creates a regex object using RE2 if possible.
1121
* Falls back to std::regex if RE2 rejects the pattern with
12-
* ErrorBadPerlOp.
22+
* ErrorBadPerlOp.
1323
*/
14-
std::unique_ptr<IRegex> createRegex(const std::string& pattern) {
15-
auto re2 = std::make_unique<Re2Regex>(pattern);
24+
Result<std::unique_ptr<IRegex>> createRegex(const std::string& pattern) {
25+
auto re2 = std::make_unique<Re2Regex>("(" + pattern + ")");
1626

1727
if (re2->ok()) {
18-
return re2;
28+
return static_cast<std::unique_ptr<IRegex>>(std::move(re2));
1929
}
2030

2131
const re2::RE2* raw = re2->rawRegex();
@@ -24,14 +34,17 @@ std::unique_ptr<IRegex> createRegex(const std::string& pattern) {
2434
std::cout
2535
<< "RE2 is unable to support things such as negative lookaheads in "
2636
<< pattern << ", defaulting to std::regex.";
27-
return std::make_unique<StdRegex>(pattern);
37+
auto std_regex = std::make_unique<StdRegex>("(" + pattern + ")");
38+
return static_cast<std::unique_ptr<IRegex>>(std::move(std_regex));
2839
} catch (const std::regex_error& e) {
2940
std::cerr << "std::regex failed: " << e.what() << std::endl;
30-
return nullptr;
41+
return tokenizers::Error::LoadFailure;
3142
}
3243
} else {
3344
std::cerr << "RE2 failed to compile pattern: " << pattern << "\n";
3445
std::cerr << "Error: " << (raw ? raw->error() : "unknown") << std::endl;
35-
return nullptr;
46+
return tokenizers::Error::LoadFailure;
3647
}
3748
}
49+
50+
} // namespace tokenizers

src/std_regex.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
1-
#include "pytorch/tokenizers/std_regex.h"
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <pytorch/tokenizers/std_regex.h>
210
#include <regex>
311

4-
StdRegex::StdRegex(const std::string& pattern)
5-
: regex_("(" + pattern + ")") // Add parentheses like RE2 version
6-
{}
12+
namespace tokenizers {
13+
14+
StdRegex::StdRegex(const std::string& pattern) : regex_(pattern) {}
715

816
std::vector<Match> StdRegex::findAll(const std::string& text) const {
917
std::vector<Match> result;
@@ -20,3 +28,11 @@ std::vector<Match> StdRegex::findAll(const std::string& text) const {
2028

2129
return result;
2230
}
31+
32+
bool StdRegex::ok() const {
33+
// std::regex constructor throws if the pattern is invalid
34+
// If we got here, the pattern is valid
35+
return true;
36+
}
37+
38+
} // namespace tokenizers

0 commit comments

Comments
 (0)