1
- from abc import ABC , abstractmethod
2
- from enum import IntEnum
3
- import importlib
4
1
import io
5
2
import os
6
3
import json
7
- import csv
8
4
import warnings
9
- import pickle
10
5
11
6
import numpy as np
12
- import pandas as pd
13
7
from cv2 import imencode
14
8
15
- from .elements import *
16
- from .io import load_dataframe
17
-
18
- __all__ = ["GCVFeatureType" , "GCVAgent" , "TesseractFeatureType" , "TesseractAgent" ]
9
+ from .base import BaseOCRAgent , BaseOCRElementType
10
+ from ..elements import Layout , TextBlock , Quadrilateral , TextBlock
19
11
20
12
21
13
def _cvt_GCV_vertices_to_points (vertices ):
22
14
return np .array ([[vertex .x , vertex .y ] for vertex in vertices ])
23
15
24
16
25
- class BaseOCRElementType (IntEnum ):
26
- @property
27
- @abstractmethod
28
- def attr_name (self ):
29
- pass
30
-
31
-
32
- class BaseOCRAgent (ABC ):
33
- @property
34
- @abstractmethod
35
- def DEPENDENCIES (self ):
36
- """DEPENDENCIES lists all necessary dependencies for the class."""
37
- pass
38
-
39
- @property
40
- @abstractmethod
41
- def MODULES (self ):
42
- """MODULES instructs how to import these necessary libraries.
43
-
44
- Note:
45
- Sometimes a python module have different installation name and module name (e.g.,
46
- `pip install tensorflow-gpu` when installing and `import tensorflow` when using
47
- ). And sometimes we only need to import a submodule but not whole module. MODULES
48
- is designed for this purpose.
49
-
50
- Returns:
51
- :obj: list(dict): A list of dict indicate how the model is imported.
52
-
53
- Example::
54
-
55
- [{
56
- "import_name": "_vision",
57
- "module_path": "google.cloud.vision"
58
- }]
59
-
60
- is equivalent to self._vision = importlib.import_module("google.cloud.vision")
61
- """
62
- pass
63
-
64
- @classmethod
65
- def _import_module (cls ):
66
- for m in cls .MODULES :
67
- if importlib .util .find_spec (m ["module_path" ]):
68
- setattr (
69
- cls , m ["import_name" ], importlib .import_module (m ["module_path" ])
70
- )
71
- else :
72
- raise ModuleNotFoundError (
73
- f"\n "
74
- f"\n Please install the following libraries to support the class { cls .__name__ } :"
75
- f"\n pip install { ' ' .join (cls .DEPENDENCIES )} "
76
- f"\n "
77
- )
78
-
79
- def __new__ (cls , * args , ** kwargs ):
80
-
81
- cls ._import_module ()
82
- return super ().__new__ (cls )
83
-
84
- @abstractmethod
85
- def detect (self , image ):
86
- pass
87
-
88
-
89
17
class GCVFeatureType (BaseOCRElementType ):
90
18
"""
91
19
The element types from Google Cloud Vision API
@@ -341,166 +269,4 @@ def save_response(self, res, file_name):
341
269
342
270
with open (file_name , "w" ) as f :
343
271
json_file = json .loads (res )
344
- json .dump (json_file , f )
345
-
346
-
347
- class TesseractFeatureType (BaseOCRElementType ):
348
- """
349
- The element types for Tesseract Detection API
350
- """
351
-
352
- PAGE = 0
353
- BLOCK = 1
354
- PARA = 2
355
- LINE = 3
356
- WORD = 4
357
-
358
- @property
359
- def attr_name (self ):
360
- name_cvt = {
361
- TesseractFeatureType .PAGE : "page_num" ,
362
- TesseractFeatureType .BLOCK : "block_num" ,
363
- TesseractFeatureType .PARA : "par_num" ,
364
- TesseractFeatureType .LINE : "line_num" ,
365
- TesseractFeatureType .WORD : "word_num" ,
366
- }
367
- return name_cvt [self ]
368
-
369
- @property
370
- def group_levels (self ):
371
- levels = ["page_num" , "block_num" , "par_num" , "line_num" , "word_num" ]
372
- return levels [: self + 1 ]
373
-
374
-
375
- class TesseractAgent (BaseOCRAgent ):
376
- """
377
- A wrapper for `Tesseract <https://github.com/tesseract-ocr/tesseract>`_ Text
378
- Detection APIs based on `PyTesseract <https://github.com/tesseract-ocr/tesseract>`_.
379
- """
380
-
381
- DEPENDENCIES = ["pytesseract" ]
382
- MODULES = [{"import_name" : "_pytesseract" , "module_path" : "pytesseract" }]
383
-
384
- def __init__ (self , languages = "eng" , ** kwargs ):
385
- """Create a Tesseract OCR Agent.
386
-
387
- Args:
388
- languages (:obj:`list` or :obj:`str`, optional):
389
- You can specify the language code(s) of the documents to detect to improve
390
- accuracy. The supported language and their code can be found on
391
- `its github repo <https://github.com/tesseract-ocr/langdata>`_.
392
- It supports two formats: 1) you can pass in the languages code as a string
393
- of format like `"eng+fra"`, or 2) you can pack them as a list of strings
394
- `["eng", "fra"]`.
395
- Defaults to 'eng'.
396
- """
397
- self .lang = languages if isinstance (languages , str ) else "+" .join (languages )
398
- self .configs = kwargs
399
-
400
- @classmethod
401
- def with_tesseract_executable (cls , tesseract_cmd_path , ** kwargs ):
402
-
403
- cls ._pytesseract .pytesseract .tesseract_cmd = tesseract_cmd_path
404
- return cls (** kwargs )
405
-
406
- def _detect (self , img_content ):
407
- res = {}
408
- res ["text" ] = self ._pytesseract .image_to_string (
409
- img_content , lang = self .lang , ** self .configs
410
- )
411
- _data = self ._pytesseract .image_to_data (
412
- img_content , lang = self .lang , ** self .configs
413
- )
414
- res ["data" ] = pd .read_csv (
415
- io .StringIO (_data ), quoting = csv .QUOTE_NONE , encoding = "utf-8" , sep = "\t "
416
- )
417
- return res
418
-
419
- def detect (
420
- self , image , return_response = False , return_only_text = True , agg_output_level = None
421
- ):
422
- """Send the input image for OCR.
423
-
424
- Args:
425
- image (:obj:`np.ndarray` or :obj:`str`):
426
- The input image array or the name of the image file
427
- return_response (:obj:`bool`, optional):
428
- Whether directly return all output (string and boxes
429
- info) from Tesseract.
430
- Defaults to `False`.
431
- return_only_text (:obj:`bool`, optional):
432
- Whether return only the texts in the OCR results.
433
- Defaults to `False`.
434
- agg_output_level (:obj:`~TesseractFeatureType`, optional):
435
- When set, aggregate the GCV output with respect to the
436
- specified aggregation level. Defaults to `None`.
437
- """
438
-
439
- res = self ._detect (image )
440
-
441
- if return_response :
442
- return res
443
-
444
- if return_only_text :
445
- return res ["text" ]
446
-
447
- if agg_output_level is not None :
448
- return self .gather_data (res , agg_output_level )
449
-
450
- return res ["text" ]
451
-
452
- @staticmethod
453
- def gather_data (response , agg_level ):
454
- """
455
- Gather the OCR'ed text, bounding boxes, and confidence
456
- in a given aggeragation level.
457
- """
458
- assert isinstance (
459
- agg_level , TesseractFeatureType
460
- ), f"Invalid agg_level { agg_level } "
461
- res = response ["data" ]
462
- df = (
463
- res [~ res .text .isna ()]
464
- .groupby (agg_level .group_levels )
465
- .apply (
466
- lambda gp : pd .Series (
467
- [
468
- gp ["left" ].min (),
469
- gp ["top" ].min (),
470
- gp ["width" ].max (),
471
- gp ["height" ].max (),
472
- gp ["conf" ].mean (),
473
- gp ["text" ].str .cat (sep = " " ),
474
- ]
475
- )
476
- )
477
- .reset_index (drop = True )
478
- .reset_index ()
479
- .rename (
480
- columns = {
481
- 0 : "x_1" ,
482
- 1 : "y_1" ,
483
- 2 : "w" ,
484
- 3 : "h" ,
485
- 4 : "score" ,
486
- 5 : "text" ,
487
- "index" : "id" ,
488
- }
489
- )
490
- .assign (x_2 = lambda x : x .x_1 + x .w , y_2 = lambda x : x .y_1 + x .h , block_type = "rectangle" )
491
- .drop (columns = ["w" , "h" ])
492
- )
493
-
494
- return load_dataframe (df )
495
-
496
- @staticmethod
497
- def load_response (filename ):
498
- with open (filename , "rb" ) as fp :
499
- res = pickle .load (fp )
500
- return res
501
-
502
- @staticmethod
503
- def save_response (res , file_name ):
504
-
505
- with open (file_name , "wb" ) as fp :
506
- pickle .dump (res , fp , protocol = pickle .HIGHEST_PROTOCOL )
272
+ json .dump (json_file , f )
0 commit comments