Skip to content

Commit

Permalink
Refactor[Requests]: Validate requests from endpoints using pydantic (#35
Browse files Browse the repository at this point in the history
)

* refactor[summary]: clean up code to remove deprecated processes

* fix leftovers

* fix leftovers

* fix leftovers

* refactor: code improvements

* remove extra print statement

* refactor: cleanup some potential bug

* fix: remove extra print statement
  • Loading branch information
ArslanSaleem authored and gventuri committed Oct 24, 2024
1 parent bb1b581 commit 60ce24b
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 48 deletions.
29 changes: 15 additions & 14 deletions backend/app/processing/file_preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from concurrent.futures import ThreadPoolExecutor
from app.requests.schemas import TextExtractionResponse
from sqlalchemy.orm.exc import ObjectDeletedError
from app.models.asset_content import AssetProcessingStatus
from app.database import SessionLocal
Expand Down Expand Up @@ -27,22 +28,22 @@ def process_segmentation(project_id: int, asset_id: int, asset_file_name: str):
with SessionLocal() as db:
asset_content = project_repository.get_asset_content(db, asset_id)

# segmentation = extract_file_segmentation(
# api_token=api_key, pdf_content=asset_content.content
# )

vectorstore = ChromaDB(f"panda-etl-{project_id}")
vectorstore.add_docs(
docs=asset_content.content["content"],
metadatas=[
{

docs = []
metadatas = []
for content in asset_content.content["content"]:
docs.append(content["text"])
metadatas.append({
"asset_id": asset_id,
"filename": asset_file_name,
"project_id": project_id,
"page_number": asset_content.content["page_number_data"][index],
}
for index, _ in enumerate(asset_content.content["content"])
],
**(content["metadata"] if content.get("metadata") else {"page_number": 1}), # Unpack all metadata or default to page_number: 1
})

vectorstore.add_docs(
docs=docs,
metadatas=metadatas
)

project_repository.update_asset_content_status(
Expand Down Expand Up @@ -88,7 +89,7 @@ def preprocess_file(asset_id: int):
while retries < settings.max_retries and not success:
try:
# Perform the expensive operation here, without holding the DB connection
pdf_content = extract_text_from_file(api_key, asset.path, asset.type)
pdf_content: TextExtractionResponse = extract_text_from_file(api_key, asset.path)

success = True

Expand All @@ -111,7 +112,7 @@ def preprocess_file(asset_id: int):
if success and pdf_content:
with SessionLocal() as db:
asset_content = project_repository.update_or_add_asset_content(
db, asset_id, pdf_content
db, asset_id, pdf_content.model_dump()
)
# Submit the segmentation task once the content is saved
file_segmentation_executor.submit(
Expand Down
23 changes: 14 additions & 9 deletions backend/app/processing/process_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def process_task(process_id: int):
if not all_process_steps_ready:
logger.info(f"Process id: [{process.id}] some steps preprocessing is missing moving to waiting queue")
process_execution_scheduler.add_process_to_queue(process.id)
db.commit()
# Skip status update since not all steps are ready
return

Expand Down Expand Up @@ -241,7 +242,7 @@ def extract_process(api_key, process, process_step, asset_content):

if not pdf_content:
pdf_content = (
"\n".join(asset_content.content["content"])
"\n".join(item["text"] for item in asset_content.content.get("content", []) if "text" in item)
if asset_content.content
else None
)
Expand All @@ -256,20 +257,22 @@ def extract_process(api_key, process, process_step, asset_content):
vectorstore = ChromaDB(f"panda-etl-{process.project_id}", similarity_threshold=3)
all_relevant_docs = []

for context in data["context"]:
for sources in context:
for references in data.references:
for reference in references:
page_numbers = []
for source_index, source in enumerate(sources["sources"]):
for source_index, source in enumerate(reference.sources):
if len(source) < 30:

best_match = find_best_match_for_short_reference(
source,
all_relevant_docs,
process_step.asset.id,
process.project_id
)
if best_match:
sources["sources"][source_index] = best_match["text"]
reference.sources[source_index] = best_match["text"]
page_numbers.append(best_match["page_number"])

else:
relevant_docs = vectorstore.get_relevant_docs(
source,
Expand Down Expand Up @@ -297,7 +300,7 @@ def extract_process(api_key, process, process_step, asset_content):
break

if not match and len(relevant_docs["documents"][0]) > 0:
sources["sources"][source_index] = relevant_docs["documents"][0][0]
reference.sources[source_index] = relevant_docs["documents"][0][0]
if relevant_docs["documents"][0]:
page_numbers.append(
relevant_docs["metadatas"][0][most_relevant_index]["page_number"]
Expand All @@ -311,11 +314,13 @@ def extract_process(api_key, process, process_step, asset_content):
)

if page_numbers:
sources["page_numbers"] = page_numbers
reference.page_numbers = page_numbers

data_dict = data.model_dump()

return {
"fields": data["fields"],
"context": data["context"],
"fields": data_dict["fields"],
"context": data_dict["references"],
}

def find_best_match_for_short_reference(source, all_relevant_docs, asset_id, project_id, threshold=0.8):
Expand Down
38 changes: 24 additions & 14 deletions backend/app/requests.py → backend/app/requests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
from app.exceptions import CreditLimitExceededException
from .schemas import ExtractFieldsResponse, TextExtractionResponse
import requests
from app.config import settings
from app.logger import Logger
Expand Down Expand Up @@ -30,31 +31,34 @@ def request_api_key(email: str):
return data.get("message", "No message in response")


def extract_text_from_file(api_token: str, file_path: str, type: str):
def extract_text_from_file(api_token: str, file_path: str, metadata: bool=True) -> TextExtractionResponse:
# Prepare the headers with the Bearer token
headers = {"x-authorization": f"Bearer {api_token}"}
files = {}
file = open(file_path, "rb")
files["file"] = (os.path.basename(file_path), file)

response = requests.post(
f"{settings.pandaetl_server_url}/v1/extract/file/content",
files=files,
headers=headers,
timeout=360,
)
with open(file_path, "rb") as file:
files["file"] = (os.path.basename(file_path), file)

response = requests.post(
f"{settings.pandaetl_server_url}/v1/parse",
files=files,
headers=headers,
timeout=360,
params={"metadata": metadata}
)

# Check the response status code
if response.status_code == 201 or response.status_code == 200:
return response.json()
data = response.json()
return TextExtractionResponse(**data)
else:
logger.error(
f"Unable to process file ${file_path} during text extraction. It returned {response.status_code} code: {response.text}"
)
raise Exception("Unable to process file!")


def extract_data(api_token, fields, file_path=None, pdf_content=None):
def extract_data(api_token, fields, file_path=None, pdf_content=None) -> ExtractFieldsResponse:
fields_data = fields if isinstance(fields, str) else json.dumps(fields)

# Prepare the headers with the Bearer token
Expand All @@ -68,8 +72,8 @@ def extract_data(api_token, fields, file_path=None, pdf_content=None):
if not os.path.isfile(file_path):
raise FileNotFoundError(f"The file at {file_path} does not exist.")

file = open(file_path, "rb")
files["file"] = (os.path.basename(file_path), file)
with open(file_path, "rb") as file:
files["file"] = (os.path.basename(file_path), file)

elif pdf_content:
data["pdf_content"] = pdf_content
Expand All @@ -81,11 +85,17 @@ def extract_data(api_token, fields, file_path=None, pdf_content=None):
data=data,
headers=headers,
timeout=360,
params={"references": True}
)

# Check the response status code
if response.status_code == 201 or response.status_code == 200:
return response.json()

data = response.json()

return ExtractFieldsResponse(
**data
)

elif response.status_code == 402:
raise CreditLimitExceededException(
Expand Down
26 changes: 26 additions & 0 deletions backend/app/requests/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Dict, List, Optional
from pydantic import BaseModel


class SentenceMetadata(BaseModel):
page_number: Optional[int] = None

class StructuredSentence(BaseModel):
text: str
metadata: Optional[SentenceMetadata] = None

class TextExtractionResponse(BaseModel):
content: List[StructuredSentence]
word_count: int
lang: str


class ReferenceData(BaseModel):
name: str
sources: List[str]
page_numbers: Optional[List[int]] = None


class ExtractFieldsResponse(BaseModel):
fields: List[Dict]
references: Optional[List[List[ReferenceData]]]
1 change: 0 additions & 1 deletion backend/app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def fetch_html_and_save(url, file_path):
),
"Referer": url,
"Accept-Language": "en-US,en;q=0.9",
"Accept-Encoding": "gzip, deflate, br",
"Connection": "keep-alive",
}

Expand Down
29 changes: 19 additions & 10 deletions backend/tests/processing/test_process_queue.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from app.requests.schemas import ExtractFieldsResponse
import pytest
from unittest.mock import Mock, patch
from app.processing.process_queue import (
Expand Down Expand Up @@ -47,10 +48,14 @@ def test_extract_process(mock_chroma, mock_extract_data):
}]],
"documents": [["Test document"]]
}
mock_extract_data.return_value = {
"fields": {"field1": "value1"},
"context": [[{"sources": ["source1"], "page_numbers": [1]}]]
}
mock_extract_data.return_value = ExtractFieldsResponse(fields=[{"field1": "value1"}],
references=[[{
"name": "ESG_Reporting_Assurance",
"sources": [
"Assurance"
]
}]]
)

process = Mock(id=1, project_id=1, details={"fields": [{"key": "field1"}]})
process_step = Mock(id=1, asset=Mock(id=1))
Expand All @@ -60,8 +65,8 @@ def test_extract_process(mock_chroma, mock_extract_data):

assert "fields" in result
assert "context" in result
assert result["fields"] == {"field1": "value1"}
assert result["context"][0][0]["page_numbers"] == [1]
assert result["fields"] == [{"field1": "value1"}]
assert result["context"] == [[{'name': 'ESG_Reporting_Assurance', 'sources': ['Assurance'], 'page_numbers': None}]]
mock_extract_data.assert_called_once()
mock_chroma_instance.get_relevant_docs.assert_called()

Expand Down Expand Up @@ -158,10 +163,14 @@ def test_find_best_match_for_short_reference_parametrized(mock_findall, short_re
def test_chroma_db_initialization(mock_extract_data, mock_chroma):
mock_chroma_instance = Mock()
mock_chroma.return_value = mock_chroma_instance
mock_extract_data.return_value = {
"fields": {"field1": "value1"},
"context": [[{"sources": ["source1"], "page_numbers": [1]}]]
}
mock_extract_data.return_value = ExtractFieldsResponse(fields=[{"field1": "value1"}],
references=[[{
"name": "ESG_Reporting_Assurance",
"sources": [
"Assurance"
]
}]]
)

process = Mock(id=1, project_id=1, details={"fields": [{"key": "field1"}]})
process_step = Mock(id=1, asset=Mock(id=1))
Expand Down

0 comments on commit 60ce24b

Please sign in to comment.