Skip to content

Commit

Permalink
Merge pull request #7 from maxent-ai/pipeline
Browse files Browse the repository at this point in the history
Pipeline
  • Loading branch information
bharathgs authored Jul 3, 2022
2 parents bdd094e + 0a5646d commit 27a7b96
Show file tree
Hide file tree
Showing 12 changed files with 216 additions and 99 deletions.
82 changes: 46 additions & 36 deletions ocrpy/io/reader.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,59 @@
import io
import os
import pdf2image
from io import BytesIO
from ..utils import LOGGER
from dotenv import load_dotenv
from attr import define, field
from pdf2image import convert_from_bytes
from ..utils import FileTypeNotSupported
from typing import Union, Generator, ByteString
from ..utils import guess_extension, guess_storage
from cloudpathlib import S3Client, GSClient, AnyPath

__all__ = ['DocumentReader']

# TO-DO: Add logging and improve the error handling


@define
class DocumentReader:
"""
Read an image file from a given location and returns a byte array.
Reads an image or a pdf file from a local or remote location.
Note: Currently supports Google Storage and Amazon S3 Remote Files.
Attributes
----------
file : str
The path to the file to be read.
credentials : str
The path to the credentials file.
Note:
If the Remote storage is AWS S3, the credentials file must be in the .env format.
If the Remote storage is Google Storage, the credentials file must be in the .json format.
"""
file = field()
credentials = field(default=None)
file: str = field()
credentials: str = field(default=None)
storage_type = field(default=None, init=False)

def read(self):
file_type = self.get_file_type()
if file_type == 'image':
return self._read_image(self.file)
elif file_type == 'pdf':
return self._read_pdf(self.file)
else:
raise ValueError("File type not supported")
def __attrs_post_init__(self):
self.storage_type = guess_storage(self.file)

def get_file_type(self):
if self.file.endswith(".png") or self.file.endswith(".jpg"):
file_type = "image"
elif self.file.endswith(".pdf"):
file_type = "pdf"
else:
file_type = "unknown"
return file_type
def read(self) -> Union[Generator, ByteString]:
"""
Reads the file from a local or remote location and
returns the data in byte-string for an image or as a
generator of byte-strings for a pdf.
def get_storage_type(self):
storage_type = None
if self.file.startswith("gs://"):
storage_type = 'gs'
elif self.file.startswith("s3://"):
storage_type = 's3'
else:
storage_type = 'local'
return storage_type
Returns
-------
data : Union[bytes, List[bytes]]
The data in byte-string for an image or as a
generator of byte-strings for a pdf.
"""

file_type = guess_extension(self.file)
reader_methods = {'IMAGE': self._read_image, 'PDF': self._read_pdf}
return reader_methods[file_type](self.file) if file_type in reader_methods else FileTypeNotSupported(
f"""We failed to understand the file type of {self.file}. The supported file-types are .png, .jpg or .pdf files. Please check the file type and try again.""")

def _read_image(self, file):
return self._read(file)
Expand All @@ -58,22 +69,21 @@ def _read(self, file):
return file_data.read_bytes()

def _get_client(self, file):
storage_type = self.get_storage_type()
if storage_type == "gs" and self.credentials:
storage_type = self.storage_type
if storage_type == "GS" and self.credentials:
client = GSClient(application_credentials=self.credentials)

elif storage_type == 's3' and self.credentials:
elif storage_type == 'S3' and self.credentials:
load_dotenv(self.credentials)
client = S3Client(aws_access_key_id=os.getenv(
'aws_access_key_id'), aws_secret_access_key=os.getenv('aws_secret_access_key'))
else:
client = None

return client

def _bytes_to_images(self, data):
images = pdf2image.convert_from_bytes(data)
images = convert_from_bytes(data)
for image in images:
buf = io.BytesIO()
buf = BytesIO()
image.save(buf, format='PNG')
yield buf.getvalue()
3 changes: 2 additions & 1 deletion ocrpy/parsers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ class AbstractTextOCR:
Abstract class for Text OCR backends.
"""
reader: Any = field()
credentials: str = field()

@abc.abstractproperty
@abc.abstractmethod
def parse(self):
return NotImplemented

Expand Down
73 changes: 46 additions & 27 deletions ocrpy/parsers/text/aws_text.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import os
import boto3
import time
from cloudpathlib import AnyPath
import boto3
from dotenv import load_dotenv
from attr import define, field
from typing import List, Dict, Any
from ...utils.errors import NotSupportedError
from cloudpathlib import AnyPath
from ...utils.exceptions import AttributeNotSupported
from typing import List, Dict, Any, Union, Generator, ByteString
from ..core import AbstractTextOCR, AbstractLineSegmenter, AbstractBlockSegmenter

__all__ = ['AwsTextOCR']

# TO-DO: Add logging and improve the error handling


def aws_region_extractor(block):
x1, x2 = block['Geometry']['BoundingBox']['Left'], block['Geometry']['BoundingBox']['Left'] + \
Expand All @@ -28,7 +30,7 @@ def aws_token_formator(token):
return token


def is_job_complete(client, job_id):
def _job_status(client, job_id):
time.sleep(1)
response = client.get_document_text_detection(JobId=job_id)
status = response["JobStatus"]
Expand All @@ -40,7 +42,7 @@ def is_job_complete(client, job_id):
return status


def get_job_results(client, job_id):
def _fetch_job_result(client, job_id):
pages = []
response = client.get_document_text_detection(JobId=job_id)
pages.append(response)
Expand Down Expand Up @@ -101,27 +103,45 @@ class AwsBlockSegmenter(AbstractBlockSegmenter):

@property
def blocks(self):
raise NotSupportedError(
raise AttributeNotSupported(
"AWS Backend does not support block segmentation yet.")


@define
class AwsTextOCR(AbstractTextOCR):
env_file = field(default=None)
textract = field(repr=False, init=False)
document = field(default=None, repr=False)
"""
AWS Textract OCR Engine
Attributes
----------
reader: Any
Reader object that can be used to read the document.
credentials : str
Path to credentials file.
Note: The credentials file must be in .env format.
"""
_client: boto3.client = field(repr=False, init=False)
_document: Union[Generator, ByteString] = field(
default=None, repr=False, init=False)

def __attrs_post_init__(self):
if self.env_file:
load_dotenv(self.env_file)
if self.credentials:
load_dotenv(self.credentials)
region = os.getenv('region_name')
access_key = os.getenv('aws_access_key_id')
secret_key = os.getenv('aws_secret_access_key')
self.textract = boto3.client('textract', region_name=region,
aws_access_key_id=access_key, aws_secret_access_key=secret_key)

@property
def parse(self):
self._client = boto3.client('textract', region_name=region,
aws_access_key_id=access_key, aws_secret_access_key=secret_key)

def parse(self) -> Dict[int, Dict]:
"""
Parses the document and returns the ocr data as a dictionary of pages along with additional metadata.
Returns
-------
parsed_data : dict
Dictionary of pages.
"""
return self._process_data()

def _process_data(self):
Expand All @@ -130,34 +150,33 @@ def _process_data(self):
if not isinstance(ocr, list):
ocr = [ocr]
for index, page in enumerate(ocr):
print("Processing page {}".format(index))
data = dict(text=self._get_text(page), lines=self._get_lines(
page), blocks=self._get_blocks(page), tokens=self._get_tokens(page))
result[index] = data
return result

def _get_ocr(self):
storage_type = self.reader.get_storage_type()
storage_type = self.reader.storage_type

if storage_type == 's3':
path = AnyPath(self.reader.file)

response = self.textract.start_document_text_detection(DocumentLocation={
response = self._client.start_document_text_detection(DocumentLocation={
'S3Object': {
'Bucket': path.bucket,
'Name': path.key
}})
job_id = response['JobId']
status = is_job_complete(self.textract, job_id)
ocr = get_job_results(self.textract, job_id)
status = _job_status(self.textract, job_id)
ocr = _fetch_job_result(self.textract, job_id)

else:
self.document = self.reader.read()
if isinstance(self.document, bytes):
self.document = [self.document]
self._document = self.reader.read()
if isinstance(self._document, bytes):
self._document = [self._document]
ocr = []
for document in self.document:
result = self.textract.detect_document_text(
for document in self._document:
result = self._client.detect_document_text(
Document={'Bytes': document})
ocr.append(result)
return ocr
Expand Down
48 changes: 33 additions & 15 deletions ocrpy/parsers/text/gcp_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from google.cloud import vision
from typing import List, Dict, Any
from google.oauth2 import service_account
from ...utils.errors import NotSupportedError
from ...utils.exceptions import AttributeNotSupported
from ..core import AbstractTextOCR, AbstractLineSegmenter, AbstractBlockSegmenter

__all__ = ['GcpTextOCR']

# TO-DO: Add logging and improve the error handling

def gcp_region_extractor(block):
x_points = [v.x for v in block]
Expand Down Expand Up @@ -74,35 +75,52 @@ class GCPLineSegmenter(AbstractLineSegmenter):

@property
def lines(self):
NotSupportedError("GCP Backend does not support line segmentation yet.")
AttributeNotSupported("GCP Backend does not support line segmentation yet.")

@define
class GcpTextOCR(AbstractTextOCR):
env_file = field(default=None)
client = field(repr=False, init=False)
document = field(default=None, repr=False)
"""
Google Cloud Vision OCR Engine
Attributes
----------
reader : Any
Reader object that can be used to read the document.
credentials : str
Path to credentials file.
Note: The credentials file must be in .json format.
"""
_client = field(repr=False, init=False)
_document = field(default=None, repr=False, init=False)

def __attrs_post_init__(self):
if self.env_file:
cred = service_account.Credentials.from_service_account_file(self.env_file)
self.client = vision.ImageAnnotatorClient(credentials=cred)
if self.credentials:
cred = service_account.Credentials.from_service_account_file(self.credentials)
self._client = vision.ImageAnnotatorClient(credentials=cred)
else:
self.client = vision.ImageAnnotatorClient()
self._client = vision.ImageAnnotatorClient()

self.document = self.reader.read()
self._document = self.reader.read()

@property
def parse(self):
"""
Parses the document and returns the ocr data as a dictionary of pages along with additional metadata.
Returns
-------
parsed_data : dict
Dictionary of pages.
"""
return self._process_data()

def _process_data(self):
is_image = False
if isinstance(self.document, bytes):
self.document = [self.document]
if isinstance(self._document, bytes):
self._document = [self._document]
is_image = True

result = {}
for index, document in enumerate(self.document):
for index, document in enumerate(self._document):

ocr = self._get_ocr(document)
blocks = self._get_blocks(ocr)
Expand Down Expand Up @@ -147,7 +165,7 @@ def _get_text(self, ocr):

def _get_ocr(self, image):
image = vision.types.Image(content=image)
ocr = self.client.document_text_detection(image=image).full_text_annotation
ocr = self._client.document_text_detection(image=image).full_text_annotation
return ocr


Loading

0 comments on commit 27a7b96

Please sign in to comment.