diff --git a/extract_thinker/__init__.py b/extract_thinker/__init__.py index c14ad4c..1ac0e48 100644 --- a/extract_thinker/__init__.py +++ b/extract_thinker/__init__.py @@ -3,6 +3,7 @@ from .document_loader.cached_document_loader import CachedDocumentLoader from .document_loader.document_loader_tesseract import DocumentLoaderTesseract from .document_loader.document_loader_spreadsheet import DocumentLoaderSpreadSheet +from .document_loader.document_loader_azure_document_intelligence import DocumentLoaderAzureForm from .document_loader.document_loader_pypdf import DocumentLoaderPyPdf from .document_loader.document_loader_text import DocumentLoaderText from .models import classification, classification_response @@ -18,6 +19,7 @@ 'DocumentLoader', 'CachedDocumentLoader', 'DocumentLoaderTesseract', + 'DocumentLoaderAzureForm', 'DocumentLoaderPyPdf', 'DocumentLoaderText', 'classification', diff --git a/extract_thinker/document_loader/document_loader_azure_document_intelligence.py b/extract_thinker/document_loader/document_loader_azure_document_intelligence.py index 5f46acd..6fdd0ee 100644 --- a/extract_thinker/document_loader/document_loader_azure_document_intelligence.py +++ b/extract_thinker/document_loader/document_loader_azure_document_intelligence.py @@ -10,7 +10,7 @@ class DocumentLoaderAzureForm(CachedDocumentLoader): - def __init__(self, subscription_key: str, endpoint: str, is_container: bool = False, content: Any = None, cache_ttl: int = 300): + def __init__(self, subscription_key: str, endpoint: str, content: Any = None, cache_ttl: int = 300): super().__init__(content, cache_ttl) self.subscription_key = subscription_key self.endpoint = endpoint @@ -42,33 +42,50 @@ def process_result(self, result: AnalyzeResult) -> List[dict]: for page in result.pages: paragraphs = [p.content for p in page.lines] tables = self.build_tables(result.tables) - words_with_locations = self.process_words(page) + # words_with_locations = self.process_words(page) + # Remove lines that are present in tables + paragraphs = self.remove_lines_present_in_tables(paragraphs, tables) output = { - "type": "pdf", - "content": result.content, + #"content": result.content, "paragraphs": paragraphs, - "words": words_with_locations, - "tables": tables + #"words": words_with_locations, + "tables": tables.get(page.page_number, []) } extract_results.append(output) - return extract_results + return {"pages": extract_results} + + def remove_lines_present_in_tables(self, paragraphs: List[str], tables: dict[int, List[List[str]]]) -> List[str]: + for table in tables.values(): + for row in table: + for cell in row: + if cell in paragraphs: + paragraphs.remove(cell) + return paragraphs + + def page_to_string(self, page: DocumentPage) -> str: + page_string = "" + for word in page.words: + for point in word.polygon: + page_string += f"({point.x}, {point.y}): {word.content}\n" + return page_string def process_words(self, page: DocumentPage) -> List[dict]: words_with_locations = [] - for line in page.lines: - for word in line.words: - word_info = { - "content": word.content, - "bounding_box": { - "points": self.build_points(word.bounding_box) - }, - "page_number": page.page_number - } - words_with_locations.append(word_info) + + for word in page.words: + word_info = { + "content": word.content, + "bounding_box": { + "points": word.polygon + }, + "page_number": page.page_number + } + words_with_locations.append(word_info) + return words_with_locations - def build_tables(self, tables: List[DocumentTable]) -> List[List[str]]: - table_data = [] + def build_tables(self, tables: List[DocumentTable]) -> dict[int, List[List[str]]]: + table_data = {} for table in tables: rows = [] for row_idx in range(table.row_count): @@ -77,7 +94,8 @@ def build_tables(self, tables: List[DocumentTable]) -> List[List[str]]: if cell.row_index == row_idx: row.append(cell.content) rows.append(row) - table_data.append(rows) + # Use the page number as the key for the dictionary + table_data[table.bounding_regions[0].page_number] = rows return table_data def build_points(self, bounding_box: List[Point]) -> List[dict]: diff --git a/pyproject.toml b/pyproject.toml index abf51c1..5040f4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "extract_thinker" -version = "0.0.3" +version = "0.0.4" description = "Library to extract data from files and documents agnositicaly using LLMs" authors = ["JĂșlio Almeida "] readme = "README.md" diff --git a/tests/document_loader_azure_document_intelligence.py b/tests/document_loader_azure_document_intelligence.py index 3a80ab9..05d6632 100644 --- a/tests/document_loader_azure_document_intelligence.py +++ b/tests/document_loader_azure_document_intelligence.py @@ -1,12 +1,5 @@ import os - -import sys -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) - -from io import BytesIO from dotenv import load_dotenv -import pytest -from azure.core.exceptions import AzureError from extract_thinker.document_loader.document_loader_azure_document_intelligence import DocumentLoaderAzureForm @@ -17,48 +10,16 @@ subscription_key = os.getenv("AZURE_SUBSCRIPTION_KEY") endpoint = os.getenv("AZURE_ENDPOINT") loader = DocumentLoaderAzureForm(subscription_key, endpoint) -test_file_path = os.path.join(cwd, "tests", "test_documents", "invoice.pdf") +test_file_path = os.path.join(cwd, "test_images", "invoice.png") def test_load_content_from_file(): # Act - try: - content = loader.load_content_from_file("C:\\Users\\Lopez\\Downloads\\LNKD_INVOICE_7894414780.pdf") - except AzureError as e: - pytest.fail(f"AzureError occurred: {e}") - - # Assert - assert content is not None - assert isinstance(content, list) - assert len(content) > 0 - - -def test_load_content_from_stream(): - with open(test_file_path, 'rb') as f: - test_document_stream = BytesIO(f.read()) - - # Act - try: - content = loader.load_content_from_stream(test_document_stream) - except AzureError as e: - pytest.fail(f"AzureError occurred: {e}") - - # Assert - assert content is not None - assert isinstance(content, list) - assert len(content) > 0 + content = loader.load_content_from_file(test_file_path) - -def test_cache_for_file(): - # Act - try: - content1 = loader.load_content_from_file(test_file_path) - content2 = loader.load_content_from_file(test_file_path) - except AzureError as e: - pytest.fail(f"AzureError occurred: {e}") + firstPage = content["pages"][0] # Assert - assert content1 is content2 - - -test_load_content_from_file() \ No newline at end of file + assert firstPage is not None + assert firstPage["paragraphs"][0] == "Invoice 0000001" + assert len(firstPage["tables"][0]) == 4 diff --git a/tests/extractor.py b/tests/extractor.py index d710f4f..4b1a3d0 100644 --- a/tests/extractor.py +++ b/tests/extractor.py @@ -4,6 +4,7 @@ from extract_thinker.extractor import Extractor from extract_thinker.document_loader.document_loader_tesseract import DocumentLoaderTesseract from tests.models.invoice import InvoiceContract +from extract_thinker.document_loader.document_loader_azure_document_intelligence import DocumentLoaderAzureForm load_dotenv() cwd = os.getcwd() @@ -28,3 +29,24 @@ def test_extract_with_tessaract_and_claude(): assert result is not None assert result.invoice_number == "0000001" assert result.invoice_date == "2014-05-07" + + +def test_extract_with_azure_di_and_claude(): + subscription_key = os.getenv("AZURE_SUBSCRIPTION_KEY") + endpoint = os.getenv("AZURE_ENDPOINT") + test_file_path = os.path.join(cwd, "test_images", "invoice.png") + + extractor = Extractor() + extractor.load_document_loader( + DocumentLoaderAzureForm(subscription_key, endpoint) + ) + extractor.load_llm("claude-3-haiku-20240307") + # Act + result = extractor.extract(test_file_path, InvoiceContract) + + # Assert + assert result is not None + assert result.lines[0].description == "Website Redesign" + assert result.lines[0].quantity == 1 + assert result.lines[0].unit_price == 2500 + assert result.lines[0].amount == 2500 diff --git a/tests/models/invoice.py b/tests/models/invoice.py index 5eed8d0..17d3832 100644 --- a/tests/models/invoice.py +++ b/tests/models/invoice.py @@ -1,6 +1,16 @@ +from typing import List from extract_thinker.models.contract import Contract +class LinesContract(Contract): + description: str + quantity: int + unit_price: int + amount: int + + class InvoiceContract(Contract): invoice_number: str invoice_date: str + lines: List[LinesContract] + total_amount: int