Skip to content

Commit e8d5488

Browse files
authored
Re-org the OCR Model Files (#64)
1 parent 2afd2a6 commit e8d5488

File tree

4 files changed

+241
-237
lines changed

4 files changed

+241
-237
lines changed

src/layoutparser/ocr/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .gcv_agent import GCVAgent, GCVFeatureType
2+
from .tesseract_agent import TesseractAgent, TesseractFeatureType

src/layoutparser/ocr/base.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from abc import ABC, abstractmethod
2+
from enum import IntEnum
3+
import importlib
4+
5+
6+
class BaseOCRElementType(IntEnum):
7+
@property
8+
@abstractmethod
9+
def attr_name(self):
10+
pass
11+
12+
13+
class BaseOCRAgent(ABC):
14+
@property
15+
@abstractmethod
16+
def DEPENDENCIES(self):
17+
"""DEPENDENCIES lists all necessary dependencies for the class."""
18+
pass
19+
20+
@property
21+
@abstractmethod
22+
def MODULES(self):
23+
"""MODULES instructs how to import these necessary libraries.
24+
25+
Note:
26+
Sometimes a python module have different installation name and module name (e.g.,
27+
`pip install tensorflow-gpu` when installing and `import tensorflow` when using
28+
). And sometimes we only need to import a submodule but not whole module. MODULES
29+
is designed for this purpose.
30+
31+
Returns:
32+
:obj: list(dict): A list of dict indicate how the model is imported.
33+
34+
Example::
35+
36+
[{
37+
"import_name": "_vision",
38+
"module_path": "google.cloud.vision"
39+
}]
40+
41+
is equivalent to self._vision = importlib.import_module("google.cloud.vision")
42+
"""
43+
pass
44+
45+
@classmethod
46+
def _import_module(cls):
47+
for m in cls.MODULES:
48+
if importlib.util.find_spec(m["module_path"]):
49+
setattr(
50+
cls, m["import_name"], importlib.import_module(m["module_path"])
51+
)
52+
else:
53+
raise ModuleNotFoundError(
54+
f"\n "
55+
f"\nPlease install the following libraries to support the class {cls.__name__}:"
56+
f"\n pip install {' '.join(cls.DEPENDENCIES)}"
57+
f"\n "
58+
)
59+
60+
def __new__(cls, *args, **kwargs):
61+
62+
cls._import_module()
63+
return super().__new__(cls)
64+
65+
@abstractmethod
66+
def detect(self, image):
67+
pass

src/layoutparser/ocr.py renamed to src/layoutparser/ocr/gcv_agent.py

Lines changed: 3 additions & 237 deletions
Original file line numberDiff line numberDiff line change
@@ -1,91 +1,19 @@
1-
from abc import ABC, abstractmethod
2-
from enum import IntEnum
3-
import importlib
41
import io
52
import os
63
import json
7-
import csv
84
import warnings
9-
import pickle
105

116
import numpy as np
12-
import pandas as pd
137
from cv2 import imencode
148

15-
from .elements import *
16-
from .io import load_dataframe
17-
18-
__all__ = ["GCVFeatureType", "GCVAgent", "TesseractFeatureType", "TesseractAgent"]
9+
from .base import BaseOCRAgent, BaseOCRElementType
10+
from ..elements import Layout, TextBlock, Quadrilateral, TextBlock
1911

2012

2113
def _cvt_GCV_vertices_to_points(vertices):
2214
return np.array([[vertex.x, vertex.y] for vertex in vertices])
2315

2416

25-
class BaseOCRElementType(IntEnum):
26-
@property
27-
@abstractmethod
28-
def attr_name(self):
29-
pass
30-
31-
32-
class BaseOCRAgent(ABC):
33-
@property
34-
@abstractmethod
35-
def DEPENDENCIES(self):
36-
"""DEPENDENCIES lists all necessary dependencies for the class."""
37-
pass
38-
39-
@property
40-
@abstractmethod
41-
def MODULES(self):
42-
"""MODULES instructs how to import these necessary libraries.
43-
44-
Note:
45-
Sometimes a python module have different installation name and module name (e.g.,
46-
`pip install tensorflow-gpu` when installing and `import tensorflow` when using
47-
). And sometimes we only need to import a submodule but not whole module. MODULES
48-
is designed for this purpose.
49-
50-
Returns:
51-
:obj: list(dict): A list of dict indicate how the model is imported.
52-
53-
Example::
54-
55-
[{
56-
"import_name": "_vision",
57-
"module_path": "google.cloud.vision"
58-
}]
59-
60-
is equivalent to self._vision = importlib.import_module("google.cloud.vision")
61-
"""
62-
pass
63-
64-
@classmethod
65-
def _import_module(cls):
66-
for m in cls.MODULES:
67-
if importlib.util.find_spec(m["module_path"]):
68-
setattr(
69-
cls, m["import_name"], importlib.import_module(m["module_path"])
70-
)
71-
else:
72-
raise ModuleNotFoundError(
73-
f"\n "
74-
f"\nPlease install the following libraries to support the class {cls.__name__}:"
75-
f"\n pip install {' '.join(cls.DEPENDENCIES)}"
76-
f"\n "
77-
)
78-
79-
def __new__(cls, *args, **kwargs):
80-
81-
cls._import_module()
82-
return super().__new__(cls)
83-
84-
@abstractmethod
85-
def detect(self, image):
86-
pass
87-
88-
8917
class GCVFeatureType(BaseOCRElementType):
9018
"""
9119
The element types from Google Cloud Vision API
@@ -341,166 +269,4 @@ def save_response(self, res, file_name):
341269

342270
with open(file_name, "w") as f:
343271
json_file = json.loads(res)
344-
json.dump(json_file, f)
345-
346-
347-
class TesseractFeatureType(BaseOCRElementType):
348-
"""
349-
The element types for Tesseract Detection API
350-
"""
351-
352-
PAGE = 0
353-
BLOCK = 1
354-
PARA = 2
355-
LINE = 3
356-
WORD = 4
357-
358-
@property
359-
def attr_name(self):
360-
name_cvt = {
361-
TesseractFeatureType.PAGE: "page_num",
362-
TesseractFeatureType.BLOCK: "block_num",
363-
TesseractFeatureType.PARA: "par_num",
364-
TesseractFeatureType.LINE: "line_num",
365-
TesseractFeatureType.WORD: "word_num",
366-
}
367-
return name_cvt[self]
368-
369-
@property
370-
def group_levels(self):
371-
levels = ["page_num", "block_num", "par_num", "line_num", "word_num"]
372-
return levels[: self + 1]
373-
374-
375-
class TesseractAgent(BaseOCRAgent):
376-
"""
377-
A wrapper for `Tesseract <https://github.com/tesseract-ocr/tesseract>`_ Text
378-
Detection APIs based on `PyTesseract <https://github.com/tesseract-ocr/tesseract>`_.
379-
"""
380-
381-
DEPENDENCIES = ["pytesseract"]
382-
MODULES = [{"import_name": "_pytesseract", "module_path": "pytesseract"}]
383-
384-
def __init__(self, languages="eng", **kwargs):
385-
"""Create a Tesseract OCR Agent.
386-
387-
Args:
388-
languages (:obj:`list` or :obj:`str`, optional):
389-
You can specify the language code(s) of the documents to detect to improve
390-
accuracy. The supported language and their code can be found on
391-
`its github repo <https://github.com/tesseract-ocr/langdata>`_.
392-
It supports two formats: 1) you can pass in the languages code as a string
393-
of format like `"eng+fra"`, or 2) you can pack them as a list of strings
394-
`["eng", "fra"]`.
395-
Defaults to 'eng'.
396-
"""
397-
self.lang = languages if isinstance(languages, str) else "+".join(languages)
398-
self.configs = kwargs
399-
400-
@classmethod
401-
def with_tesseract_executable(cls, tesseract_cmd_path, **kwargs):
402-
403-
cls._pytesseract.pytesseract.tesseract_cmd = tesseract_cmd_path
404-
return cls(**kwargs)
405-
406-
def _detect(self, img_content):
407-
res = {}
408-
res["text"] = self._pytesseract.image_to_string(
409-
img_content, lang=self.lang, **self.configs
410-
)
411-
_data = self._pytesseract.image_to_data(
412-
img_content, lang=self.lang, **self.configs
413-
)
414-
res["data"] = pd.read_csv(
415-
io.StringIO(_data), quoting=csv.QUOTE_NONE, encoding="utf-8", sep="\t"
416-
)
417-
return res
418-
419-
def detect(
420-
self, image, return_response=False, return_only_text=True, agg_output_level=None
421-
):
422-
"""Send the input image for OCR.
423-
424-
Args:
425-
image (:obj:`np.ndarray` or :obj:`str`):
426-
The input image array or the name of the image file
427-
return_response (:obj:`bool`, optional):
428-
Whether directly return all output (string and boxes
429-
info) from Tesseract.
430-
Defaults to `False`.
431-
return_only_text (:obj:`bool`, optional):
432-
Whether return only the texts in the OCR results.
433-
Defaults to `False`.
434-
agg_output_level (:obj:`~TesseractFeatureType`, optional):
435-
When set, aggregate the GCV output with respect to the
436-
specified aggregation level. Defaults to `None`.
437-
"""
438-
439-
res = self._detect(image)
440-
441-
if return_response:
442-
return res
443-
444-
if return_only_text:
445-
return res["text"]
446-
447-
if agg_output_level is not None:
448-
return self.gather_data(res, agg_output_level)
449-
450-
return res["text"]
451-
452-
@staticmethod
453-
def gather_data(response, agg_level):
454-
"""
455-
Gather the OCR'ed text, bounding boxes, and confidence
456-
in a given aggeragation level.
457-
"""
458-
assert isinstance(
459-
agg_level, TesseractFeatureType
460-
), f"Invalid agg_level {agg_level}"
461-
res = response["data"]
462-
df = (
463-
res[~res.text.isna()]
464-
.groupby(agg_level.group_levels)
465-
.apply(
466-
lambda gp: pd.Series(
467-
[
468-
gp["left"].min(),
469-
gp["top"].min(),
470-
gp["width"].max(),
471-
gp["height"].max(),
472-
gp["conf"].mean(),
473-
gp["text"].str.cat(sep=" "),
474-
]
475-
)
476-
)
477-
.reset_index(drop=True)
478-
.reset_index()
479-
.rename(
480-
columns={
481-
0: "x_1",
482-
1: "y_1",
483-
2: "w",
484-
3: "h",
485-
4: "score",
486-
5: "text",
487-
"index": "id",
488-
}
489-
)
490-
.assign(x_2=lambda x: x.x_1 + x.w, y_2=lambda x: x.y_1 + x.h, block_type="rectangle")
491-
.drop(columns=["w", "h"])
492-
)
493-
494-
return load_dataframe(df)
495-
496-
@staticmethod
497-
def load_response(filename):
498-
with open(filename, "rb") as fp:
499-
res = pickle.load(fp)
500-
return res
501-
502-
@staticmethod
503-
def save_response(res, file_name):
504-
505-
with open(file_name, "wb") as fp:
506-
pickle.dump(res, fp, protocol=pickle.HIGHEST_PROTOCOL)
272+
json.dump(json_file, f)

0 commit comments

Comments
 (0)