Skip to content

Commit

Permalink
Add parallelization to azuredocint_parser
Browse files Browse the repository at this point in the history
  • Loading branch information
DL committed Sep 11, 2024
1 parent eb3d440 commit 2ed455c
Showing 1 changed file with 84 additions and 26 deletions.
110 changes: 84 additions & 26 deletions src/llmsearch/parsers/tables/azuredocint_parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from collections import Counter
from concurrent.futures import ThreadPoolExecutor, as_completed
from tenacity import retry, wait_exponential, stop_after_attempt, before_log, after_log
import os
from functools import cached_property
from pathlib import Path
Expand Down Expand Up @@ -34,12 +36,21 @@
raise ValueError


def log_attempt_number(retry_state):
error_message = str(retry_state.outcome.exception())
logger.error(
f"API call attempt {retry_state.attempt_number} failed with error: {error_message}. Retrying..."
)
# logger.error(f"API call attempt failed. Retrying: {retry_state.attempt_number}...")


class AzureParsedTable(GenericParsedTable):
def __init__(self, table, page_mapping: dict, default_dpi: int = 72):

page_number, bbox = self.extract_page_and_bbox(table, dpi=default_dpi)

page_number = page_mapping[page_number - 1]

logger.info(f"Get bounding box for the table: {bbox}, page: {page_number}")
super().__init__(page_number, bbox)
self.table = table
Expand Down Expand Up @@ -76,7 +87,7 @@ def df(self) -> Optional[pd.DataFrame]:
df_temp = df_temp.reset_index(drop=True)

# Rename duplicate columns ,if present
df_temp = df_temp.rename(columns = ColumnRenamer(separator='_'))
df_temp = df_temp.rename(columns=ColumnRenamer(separator="_"))
return df_temp

def clean_content(self, content: str) -> str:
Expand Down Expand Up @@ -120,7 +131,7 @@ def xml(self) -> List[str]:
return []
return XMLConverter.convert(self.df)


class ColumnRenamer:
def __init__(self, separator=None):
self.counter = Counter()
Expand All @@ -130,47 +141,96 @@ def __call__(self, x):
index = self.counter[x] # Counter returns 0 for missing elements
self.counter[x] = index + 1 # Uses something like `setdefault`
return f'{x}{self.sep if self.sep and index else ""}{index if index else ""}'



class AzureDocIntelligenceTableParser:
def __init__(self, fn: Path, cache_folder: Path):
self.fn = fn

# Initialie document intelligence client
self.document_intelligence_client = DocumentIntelligenceClient(
endpoint=doc_intelligence_endpoint,
credential=AzureKeyCredential(doc_intelligence_key),
)

self.table_pages_extractor = PDFTablePagesExtractor(fn, cache_folder / "azuredoc_temp")
self.table_pages_extractor = PDFTablePagesExtractor(
fn, cache_folder / "azuredoc_temp"
)
self._parsed_tables: Optional[List[AzureParsedTable]] = (
None # Cache for parsed tables
)

def detect_and_parse_tables(self) -> List[AzureParsedTable]:
tables = self.table_pages_extractor.extract_table_pages()
# def detect_and_parse_tables(self) -> List[AzureParsedTable]:
# tables = self.table_pages_extractor.extract_table_pages()

all_tables = []
# all_tables = []

for fn, page_mapping in tables:
with open(fn, "rb") as f:
poller = self.document_intelligence_client.begin_analyze_document(
"prebuilt-layout",
analyze_request=f,
content_type="application/octet-stream",
output_content_format=ContentFormat.MARKDOWN,
)
# for fn, page_mapping in tables:
# with open(fn, "rb") as f:
# poller = self.document_intelligence_client.begin_analyze_document(
# "prebuilt-layout",
# analyze_request=f,
# content_type="application/octet-stream",
# output_content_format=ContentFormat.MARKDOWN,
# )

# logger.info(f"Calling AzureDocument Intelligence for {fn}")
# result = poller.result()

# out = []
# if result.tables:
# logger.info(f"\tGot {len(result.tables)} table, extracting...")
# out = [AzureParsedTable(table, page_mapping) for table in result.tables]
# all_tables += out

logger.info(f"Calling AzureDocument Intelligence for {fn}")
result = poller.result()
# return all_tables

out = []
if result.tables:
logger.info(f"\tGot {len(result.tables)} table, extracting...")
out = [AzureParsedTable(table, page_mapping) for table in result.tables]
all_tables += out
def detect_and_parse_tables(self) -> List[AzureParsedTable]:
tables = self.table_pages_extractor.extract_table_pages()

with ThreadPoolExecutor(max_workers=min(10, len(tables))) as executor:
future_to_fn = {
executor.submit(self._analyze_document, fn, page_mapping): (
fn,
page_mapping,
)
for fn, page_mapping in tables
}
all_tables = []
for future in as_completed(future_to_fn):
fn, page_mapping = future_to_fn[future]
try:
result = future.result()
if result is not None:
all_tables += result
except Exception as e:
logger.error(
"Exception occurred while analyzing document", exc_info=e
)

return all_tables

@retry(
wait=wait_exponential(multiplier=1, min=4, max=60),
stop=stop_after_attempt(3),
after=log_attempt_number,
)
def _analyze_document(self, fn: Path, page_mapping):
with open(fn, "rb") as f:
poller = self.document_intelligence_client.begin_analyze_document(
"prebuilt-layout",
analyze_request=f,
content_type="application/octet-stream",
output_content_format=ContentFormat.MARKDOWN,
)
logger.info(f"Calling AzureDocument Intelligence for {fn}")
result = poller.result()

out = []
if result.tables:
logger.info(f"\tGot {len(result.tables)} table, extracting...")
out = [AzureParsedTable(table, page_mapping) for table in result.tables]
return out

@property
def parsed_tables(self) -> List[AzureParsedTable]:
"""Lazy-loads the parsed tables when requested.
Expand Down Expand Up @@ -293,9 +353,7 @@ def extract_save_pages(self, output_filename: Path, pages: List[int]):
if __name__ == "__main__":

path = Path("/home/snexus/Downloads/Table_Example2.pdf")
parser = AzureDocIntelligenceTableParser(
fn=path, cache_folder=Path(".")
)
parser = AzureDocIntelligenceTableParser(fn=path, cache_folder=Path("."))
# ex = PDFTablePagesExtractor(fn = path, temp_folder = Path("./azuredoc_temp"))

tables = parser.parsed_tables
Expand Down

0 comments on commit 2ed455c

Please sign in to comment.