Skip to content

Commit 6a2e239

Browse files
committed
Run LSTM recognition in multiple threads
Init time option lstm_num_threads should be used to set the number of LSTM threads
1 parent 2991d36 commit 6a2e239

File tree

6 files changed

+130
-59
lines changed

6 files changed

+130
-59
lines changed

src/ccmain/control.cpp

+88-40
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <cstdint> // for int16_t, int32_t
2727
#include <cstdio> // for fclose, fopen, FILE
2828
#include <ctime> // for clock
29+
#include <future>
2930
#include "control.h"
3031
#ifndef DISABLED_LEGACY_ENGINE
3132
# include "docqual.h"
@@ -194,36 +195,42 @@ void Tesseract::SetupWordPassN(int pass_n, WordData *word) {
194195
}
195196
}
196197

197-
// Runs word recognition on all the words.
198-
bool Tesseract::RecogAllWordsPassN(int pass_n, ETEXT_DESC *monitor, PAGE_RES_IT *pr_it,
199-
std::vector<WordData> *words) {
200-
// TODO(rays) Before this loop can be parallelized (it would yield a massive
201-
// speed-up) all remaining member globals need to be converted to local/heap
202-
// (eg set_pass1 and set_pass2) and an intermediate adaption pass needs to be
203-
// added. The results will be significantly different with adaption on, and
204-
// deterioration will need investigation.
205-
pr_it->restart_page();
206-
for (unsigned w = 0; w < words->size(); ++w) {
207-
WordData *word = &(*words)[w];
208-
if (w > 0) {
209-
word->prev_word = &(*words)[w - 1];
198+
bool Tesseract::RecogWordsSegment(std::vector<WordData>::iterator start,
199+
std::vector<WordData>::iterator end,
200+
int pass_n,
201+
ETEXT_DESC *monitor,
202+
PAGE_RES *page_res,
203+
LSTMRecognizer *lstm_recognizer,
204+
std::atomic<int>& words_done,
205+
int total_words,
206+
std::mutex& monitor_mutex) {
207+
PAGE_RES_IT pr_it(page_res);
208+
// Process a segment of the words vector
209+
pr_it.restart_page();
210+
211+
for (auto it = start; it != end; ++it, ++words_done) {
212+
WordData *word = &(*it);
213+
if (it != start) {
214+
word->prev_word = &(*(it - 1));
210215
}
211216
if (monitor != nullptr) {
217+
std::lock_guard<std::mutex> lock(monitor_mutex);
212218
monitor->ocr_alive = true;
213219
if (pass_n == 1) {
214-
monitor->progress = 70 * w / words->size();
220+
monitor->progress = 70 * words_done / total_words;
215221
} else {
216-
monitor->progress = 70 + 30 * w / words->size();
222+
monitor->progress = 70 + 30 * words_done / total_words;
217223
}
224+
// Only call the progress callback for the first thread.
218225
if (monitor->progress_callback2 != nullptr) {
219-
TBOX box = pr_it->word()->word->bounding_box();
226+
TBOX box = pr_it.word()->word->bounding_box();
220227
(*monitor->progress_callback2)(monitor, box.left(), box.right(), box.top(), box.bottom());
221228
}
222229
if (monitor->deadline_exceeded() ||
223-
(monitor->cancel != nullptr && (*monitor->cancel)(monitor->cancel_this, words->size()))) {
230+
(monitor->cancel != nullptr && (*monitor->cancel)(monitor->cancel_this, total_words))) {
224231
// Timeout. Fake out the rest of the words.
225-
for (; w < words->size(); ++w) {
226-
(*words)[w].word->SetupFake(unicharset);
232+
for (; it != end; ++it) {
233+
it->word->SetupFake(unicharset);
227234
}
228235
return false;
229236
}
@@ -238,31 +245,69 @@ bool Tesseract::RecogAllWordsPassN(int pass_n, ETEXT_DESC *monitor, PAGE_RES_IT
238245
}
239246
}
240247
// Sync pr_it with the WordData.
241-
while (pr_it->word() != nullptr && pr_it->word() != word->word) {
242-
pr_it->forward();
248+
while (pr_it.word() != nullptr && pr_it.word() != word->word) {
249+
pr_it.forward();
243250
}
244-
ASSERT_HOST(pr_it->word() != nullptr);
251+
ASSERT_HOST(pr_it.word() != nullptr);
245252
bool make_next_word_fuzzy = false;
246253
#ifndef DISABLED_LEGACY_ENGINE
247-
if (!AnyLSTMLang() && ReassignDiacritics(pass_n, pr_it, &make_next_word_fuzzy)) {
254+
if (!AnyLSTMLang() && ReassignDiacritics(pass_n, &pr_it, &make_next_word_fuzzy)) {
248255
// Needs to be setup again to see the new outlines in the chopped_word.
249256
SetupWordPassN(pass_n, word);
250257
}
251258
#endif // ndef DISABLED_LEGACY_ENGINE
252259

253-
classify_word_and_language(pass_n, pr_it, word);
260+
classify_word_and_language(pass_n, &pr_it, word, lstm_recognizer);
254261
if (tessedit_dump_choices || debug_noise_removal) {
255262
tprintf("Pass%d: %s [%s]\n", pass_n, word->word->best_choice->unichar_string().c_str(),
256263
word->word->best_choice->debug_string().c_str());
257264
}
258-
pr_it->forward();
259-
if (make_next_word_fuzzy && pr_it->word() != nullptr) {
260-
pr_it->MakeCurrentWordFuzzy();
265+
pr_it.forward();
266+
if (make_next_word_fuzzy && pr_it.word() != nullptr) {
267+
pr_it.MakeCurrentWordFuzzy();
261268
}
262269
}
263270
return true;
264271
}
265272

273+
// Runs word recognition on all the words.
274+
bool Tesseract::RecogAllWordsPassN(int pass_n, ETEXT_DESC *monitor, PAGE_RES *page_res,
275+
std::vector<WordData> *words) {
276+
int total_words = words->size();
277+
int segment_size = total_words / lstm_num_threads;
278+
std::atomic<int> words_done(0);
279+
std::mutex monitor_mutex;
280+
std::vector<std::future<bool>> futures;
281+
282+
// Launch multiple threads to recognize the words in parallel
283+
auto segment_start = words->begin() + segment_size;
284+
for (int i = 1; i < lstm_num_threads; ++i) {
285+
auto segment_end = (i == lstm_num_threads - 1) ? words->end() : segment_start + segment_size;
286+
futures.push_back(std::async(std::launch::async, &Tesseract::RecogWordsSegment,
287+
this, segment_start, segment_end, pass_n, monitor, page_res,
288+
lstm_recognizers_[i], std::ref(words_done), total_words, std::ref(monitor_mutex)));
289+
segment_start = segment_end;
290+
}
291+
292+
// Process the first segment in this thread
293+
bool overall_result = RecogWordsSegment(words->begin(),
294+
words->begin() + segment_size,
295+
pass_n,
296+
monitor,
297+
page_res,
298+
lstm_recognizers_[0],
299+
std::ref(words_done),
300+
total_words,
301+
std::ref(monitor_mutex));
302+
303+
// Wait for all threads to complete and aggregate results
304+
for (auto &f : futures) {
305+
overall_result &= f.get();
306+
}
307+
308+
return overall_result;
309+
}
310+
266311
/**
267312
* recog_all_words()
268313
*
@@ -340,7 +385,7 @@ bool Tesseract::recog_all_words(PAGE_RES *page_res, ETEXT_DESC *monitor,
340385

341386
most_recently_used_ = this;
342387
// Run pass 1 word recognition.
343-
if (!RecogAllWordsPassN(1, monitor, &page_res_it, &words)) {
388+
if (!RecogAllWordsPassN(1, monitor, page_res, &words)) {
344389
return false;
345390
}
346391
// Pass 1 post-processing.
@@ -380,11 +425,10 @@ bool Tesseract::recog_all_words(PAGE_RES *page_res, ETEXT_DESC *monitor,
380425
}
381426
most_recently_used_ = this;
382427
// Run pass 2 word recognition.
383-
if (!RecogAllWordsPassN(2, monitor, &page_res_it, &words)) {
428+
if (!RecogAllWordsPassN(2, monitor, page_res, &words)) {
384429
return false;
385430
}
386431
}
387-
388432
// The next passes are only required for Tess-only.
389433
if (AnyTessLang() && !AnyLSTMLang()) {
390434
// ****************** Pass 3 *******************
@@ -871,14 +915,15 @@ static int SelectBestWords(double rating_ratio, double certainty_margin, bool de
871915
// Returns positive if this recognizer found more new best words than the
872916
// number kept from best_words.
873917
int Tesseract::RetryWithLanguage(const WordData &word_data, WordRecognizer recognizer, bool debug,
874-
WERD_RES **in_word, PointerVector<WERD_RES> *best_words) {
918+
WERD_RES **in_word, PointerVector<WERD_RES> *best_words,
919+
LSTMRecognizer *lstm_recognizer) {
875920
if (debug) {
876921
tprintf("Trying word using lang %s, oem %d\n", lang.c_str(),
877922
static_cast<int>(tessedit_ocr_engine_mode));
878923
}
879924
// Run the recognizer on the word.
880925
PointerVector<WERD_RES> new_words;
881-
(this->*recognizer)(word_data, in_word, &new_words);
926+
(this->*recognizer)(word_data, in_word, &new_words, lstm_recognizer);
882927
if (new_words.empty()) {
883928
// Transfer input word to new_words, as the classifier must have put
884929
// the result back in the input.
@@ -1300,7 +1345,10 @@ float Tesseract::ClassifyBlobAsWord(int pass_n, PAGE_RES_IT *pr_it, C_BLOB *blob
13001345
// Recognizes in the current language, and if successful that is all.
13011346
// If recognition was not successful, tries all available languages until
13021347
// it gets a successful result or runs out of languages. Keeps the best result.
1303-
void Tesseract::classify_word_and_language(int pass_n, PAGE_RES_IT *pr_it, WordData *word_data) {
1348+
void Tesseract::classify_word_and_language(int pass_n, PAGE_RES_IT *pr_it, WordData *word_data,
1349+
LSTMRecognizer *lstm_recognizer_thread_local) {
1350+
LSTMRecognizer *lstm_recognizer = lstm_recognizer_thread_local ? lstm_recognizer_thread_local
1351+
: lstm_recognizer_;
13041352
#ifdef DISABLED_LEGACY_ENGINE
13051353
WordRecognizer recognizer = &Tesseract::classify_word_pass1;
13061354
#else
@@ -1333,19 +1381,19 @@ void Tesseract::classify_word_and_language(int pass_n, PAGE_RES_IT *pr_it, WordD
13331381
}
13341382
}
13351383
most_recently_used_->RetryWithLanguage(*word_data, recognizer, debug, &word_data->lang_words[sub],
1336-
&best_words);
1384+
&best_words, lstm_recognizer);
13371385
Tesseract *best_lang_tess = most_recently_used_;
13381386
if (!WordsAcceptable(best_words)) {
13391387
// Try all the other languages to see if they are any better.
13401388
if (most_recently_used_ != this &&
13411389
this->RetryWithLanguage(*word_data, recognizer, debug,
1342-
&word_data->lang_words[sub_langs_.size()], &best_words) > 0) {
1390+
&word_data->lang_words[sub_langs_.size()], &best_words, lstm_recognizer) > 0) {
13431391
best_lang_tess = this;
13441392
}
13451393
for (unsigned i = 0; !WordsAcceptable(best_words) && i < sub_langs_.size(); ++i) {
13461394
if (most_recently_used_ != sub_langs_[i] &&
13471395
sub_langs_[i]->RetryWithLanguage(*word_data, recognizer, debug, &word_data->lang_words[i],
1348-
&best_words) > 0) {
1396+
&best_words, lstm_recognizer) > 0) {
13491397
best_lang_tess = sub_langs_[i];
13501398
}
13511399
}
@@ -1378,7 +1426,7 @@ void Tesseract::classify_word_and_language(int pass_n, PAGE_RES_IT *pr_it, WordD
13781426
*/
13791427

13801428
void Tesseract::classify_word_pass1(const WordData &word_data, WERD_RES **in_word,
1381-
PointerVector<WERD_RES> *out_words) {
1429+
PointerVector<WERD_RES> *out_words, LSTMRecognizer *lstm_recognizer) {
13821430
ROW *row = word_data.row;
13831431
BLOCK *block = word_data.block;
13841432
prev_word_best_choice_ =
@@ -1390,14 +1438,14 @@ void Tesseract::classify_word_pass1(const WordData &word_data, WERD_RES **in_wor
13901438
tessedit_ocr_engine_mode == OEM_TESSERACT_LSTM_COMBINED) {
13911439
#endif // def DISABLED_LEGACY_ENGINE
13921440
if (!(*in_word)->odd_size || tessedit_ocr_engine_mode == OEM_LSTM_ONLY) {
1393-
LSTMRecognizeWord(*block, row, *in_word, out_words);
1441+
LSTMRecognizeWord(*block, row, *in_word, out_words, lstm_recognizer);
13941442
if (!out_words->empty()) {
13951443
return; // Successful lstm recognition.
13961444
}
13971445
}
13981446
if (tessedit_ocr_engine_mode == OEM_LSTM_ONLY) {
13991447
// No fallback allowed, so use a fake.
1400-
(*in_word)->SetupFake(lstm_recognizer_->GetUnicharset());
1448+
(*in_word)->SetupFake(lstm_recognizer->GetUnicharset());
14011449
return;
14021450
}
14031451

@@ -1534,7 +1582,7 @@ bool Tesseract::TestNewNormalization(int original_misfits, float baseline_shift,
15341582
*/
15351583

15361584
void Tesseract::classify_word_pass2(const WordData &word_data, WERD_RES **in_word,
1537-
PointerVector<WERD_RES> *out_words) {
1585+
PointerVector<WERD_RES> *out_words, LSTMRecognizer *lstm_recognizer) {
15381586
// Return if we do not want to run Tesseract.
15391587
if (tessedit_ocr_engine_mode == OEM_LSTM_ONLY) {
15401588
return;

src/ccmain/linerec.cpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ ImageData *Tesseract::GetRectImage(const TBOX &box, const BLOCK &block, int padd
228228
// Recognizes a word or group of words, converting to WERD_RES in *words.
229229
// Analogous to classify_word_pass1, but can handle a group of words as well.
230230
void Tesseract::LSTMRecognizeWord(const BLOCK &block, ROW *row, WERD_RES *word,
231-
PointerVector<WERD_RES> *words) {
231+
PointerVector<WERD_RES> *words, LSTMRecognizer *lstm_recognizer) {
232232
TBOX word_box = word->word->bounding_box();
233233
// Get the word image - no frills.
234234
if (tessedit_pageseg_mode == PSM_SINGLE_WORD || tessedit_pageseg_mode == PSM_RAW_LINE) {
@@ -251,30 +251,30 @@ void Tesseract::LSTMRecognizeWord(const BLOCK &block, ROW *row, WERD_RES *word,
251251

252252
bool do_invert = tessedit_do_invert;
253253
float threshold = do_invert ? double(invert_threshold) : 0.0f;
254-
lstm_recognizer_->RecognizeLine(*im_data, threshold, classify_debug_level > 0,
255-
kWorstDictCertainty / kCertaintyScale, word_box, words,
256-
lstm_choice_mode, lstm_choice_iterations);
254+
lstm_recognizer->RecognizeLine(*im_data, threshold, classify_debug_level > 0,
255+
kWorstDictCertainty / kCertaintyScale, word_box, words,
256+
lstm_choice_mode, lstm_choice_iterations);
257257
delete im_data;
258-
SearchWords(words);
258+
SearchWords(words, lstm_recognizer);
259259
}
260260

261261
// Apply segmentation search to the given set of words, within the constraints
262262
// of the existing ratings matrix. If there is already a best_choice on a word
263263
// leaves it untouched and just sets the done/accepted etc flags.
264-
void Tesseract::SearchWords(PointerVector<WERD_RES> *words) {
264+
void Tesseract::SearchWords(PointerVector<WERD_RES> *words, LSTMRecognizer *lstm_recognizer) {
265265
// Run the segmentation search on the network outputs and make a BoxWord
266266
// for each of the output words.
267267
// If we drop a word as junk, then there is always a space in front of the
268268
// next.
269-
const Dict *stopper_dict = lstm_recognizer_->GetDict();
269+
const Dict *stopper_dict = lstm_recognizer->GetDict();
270270
if (stopper_dict == nullptr) {
271271
stopper_dict = &getDict();
272272
}
273273
for (unsigned w = 0; w < words->size(); ++w) {
274274
WERD_RES *word = (*words)[w];
275275
if (word->best_choice == nullptr) {
276276
// It is a dud.
277-
word->SetupFake(lstm_recognizer_->GetUnicharset());
277+
word->SetupFake(lstm_recognizer->GetUnicharset());
278278
} else {
279279
// Set the best state.
280280
for (unsigned i = 0; i < word->best_choice->length(); ++i) {

src/ccmain/tessedit.cpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,11 @@ bool Tesseract::init_tesseract_lang_data(const std::string &arg0,
169169
tessedit_ocr_engine_mode == OEM_TESSERACT_LSTM_COMBINED) {
170170
#endif // ndef DISABLED_LEGACY_ENGINE
171171
if (mgr->IsComponentAvailable(TESSDATA_LSTM)) {
172-
lstm_recognizer_ = new LSTMRecognizer(language_data_path_prefix.c_str());
173-
ASSERT_HOST(lstm_recognizer_->Load(this->params(), lstm_use_matrix ? language : "", mgr));
172+
for (int i = 0; i < lstm_num_threads; ++i) {
173+
lstm_recognizers_.push_back(new LSTMRecognizer(language_data_path_prefix.c_str()));
174+
lstm_recognizers_.back()->Load(this->params(), lstm_use_matrix ? language : "", mgr);
175+
}
176+
lstm_recognizer_ = lstm_recognizers_[0];
174177
} else {
175178
tprintf("Error: LSTM requested, but not present!! Loading tesseract.\n");
176179
tessedit_ocr_engine_mode.set_value(OEM_TESSERACT_ONLY);

src/ccmain/tesseractclass.cpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,10 @@ Tesseract::Tesseract()
439439
"lstm_choice_mode. Note that lstm_choice_mode must be set to a "
440440
"value greater than 0 to produce results.",
441441
this->params())
442+
, INT_INIT_MEMBER(lstm_num_threads, 1,
443+
"Sets the number of threads used by the LSTM recognizer. The "
444+
"default value is 1.",
445+
this->params())
442446
, double_MEMBER(lstm_rating_coefficient, 5,
443447
"Sets the rating coefficient for the lstm choices. The smaller the "
444448
"coefficient, the better are the ratings for each choice and less "
@@ -477,7 +481,10 @@ Tesseract::~Tesseract() {
477481
for (auto *lang : sub_langs_) {
478482
delete lang;
479483
}
480-
delete lstm_recognizer_;
484+
for (int i = 0; i < lstm_recognizers_.size(); ++i) {
485+
delete lstm_recognizers_[i];
486+
}
487+
lstm_recognizers_.clear();
481488
lstm_recognizer_ = nullptr;
482489
}
483490

0 commit comments

Comments
 (0)