Skip to content

Commit

Permalink
Merge pull request #227 from enoch3712/226-reasoning-models-support--…
Browse files Browse the repository at this point in the history
…-dynamic-build

dynamic generation of the output
  • Loading branch information
enoch3712 authored Jan 27, 2025
2 parents a3500f9 + 7b7b567 commit 7236493
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 64 deletions.
72 changes: 64 additions & 8 deletions extract_thinker/llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
from typing import List, Dict, Any
from typing import List, Dict, Any, Optional
import instructor
import litellm
from litellm import Router
from extract_thinker.utils import add_classification_structure, extract_thinking_json

# Add these constants at the top of the file, after the imports
DYNAMIC_PROMPT_TEMPLATE = """Please provide your thinking process within <think> tags, followed by your JSON output.
JSON structure:
{prompt}
OUTPUT example:
<think>
Your step-by-step reasoning and analysis goes here...
</think>
##JSON OUTPUT
{{
...
}}
"""

class LLM:
TEMPERATURE = 0 # Always zero for deterministic outputs (IDP)
Expand All @@ -14,35 +32,73 @@ def __init__(self,
self.model = model
self.router = None
self.token_limit = token_limit

self.is_dynamic = False
def load_router(self, router: Router) -> None:
self.router = router

def request(self, messages: List[Dict[str, str]], response_model: str) -> Any:
def set_dynamic(self, is_dynamic: bool) -> None:
"""Set whether the LLM should handle dynamic content.
When dynamic is True, the LLM will attempt to parse and validate JSON responses.
This is useful for handling structured outputs like masking mappings.
Args:
is_dynamic (bool): Whether to enable dynamic content handling
"""
self.is_dynamic = is_dynamic

def request(
self,
messages: List[Dict[str, str]],
response_model: Optional[str] = None
) -> Any:
# Uncomment the following lines if you need to calculate max_tokens
# contents = map(lambda message: message['content'], messages)
# all_contents = ' '.join(contents)
# max_tokens = num_tokens_from_string(all_contents)

# if is sync, response model is None if dynamic true and used for dynamic parsing after llm request
request_model = None if self.is_dynamic else response_model

# Add model structure and prompt engineering if dynamic parsing is enabled
working_messages = messages.copy()
if self.is_dynamic and response_model:
structure = add_classification_structure(response_model)
prompt = DYNAMIC_PROMPT_TEMPLATE.format(prompt=structure)
working_messages.append({
"role": "system",
"content": prompt
})

if self.router:
response = self.router.completion(
model=self.model,
messages=messages,
response_model=response_model,
messages=working_messages,
response_model=request_model,
temperature=self.TEMPERATURE,
timeout=self.TIMEOUT,
)
else:
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
messages=working_messages,
temperature=self.TEMPERATURE,
response_model=response_model,
response_model=request_model,
max_retries=1,
max_tokens=self.token_limit,
timeout=self.TIMEOUT,
)

return response
# If response_model is provided, return the response directly
if self.is_dynamic == False:
return response

# Otherwise get content and handle dynamic parsing if enabled
content = response.choices[0].message.content
if self.is_dynamic:
return extract_thinking_json(content, response_model)

return content

def raw_completion(self, messages: List[Dict[str, str]]) -> str:
"""Make raw completion request without response model."""
Expand Down
65 changes: 64 additions & 1 deletion extract_thinker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,4 +472,67 @@ def check_mime_type(mime: str, supported_formats: List[str]) -> bool:
return True
elif mime == expected_mime:
return True
return False
return False

def extract_thinking_json(thinking_text: str, response_model: type[BaseModel]) -> Any:
"""
Extract JSON from thinking text and convert it to a Pydantic model.
Handles various formats including:
- JSON with or without backticks
- JSON with or without language identifier
- Thinking text before/after JSON
- Malformed thinking tags
Args:
thinking_text (str): Text containing thinking process and JSON output
response_model (type[BaseModel]): Pydantic model to parse the JSON into
Returns:
Any: Parsed Pydantic model
Raises:
ValueError: If no valid JSON is found or if JSON doesn't match expected structure
"""
try:
# Remove thinking tags if present
thinking_text = thinking_text.replace('<think>', '').replace('</think>', '')

# Try different JSON patterns in order of specificity
patterns = [
r'```json\s*({\s*.*?})\s*```', # JSON with language identifier
r'```\s*({\s*.*?})\s*```', # JSON with backticks only
r'({(?:[^{}]|{[^{}]*})*})', # Bare JSON object (most permissive)
]

json_str = None
for pattern in patterns:
match = re.search(pattern, thinking_text, re.DOTALL)
if match:
json_str = match.group(1)
break

if not json_str:
# If no JSON found in patterns, try to find the last occurrence of a JSON-like structure
possible_json = thinking_text.strip()
if possible_json.startswith('{') and possible_json.endswith('}'):
json_str = possible_json

if not json_str:
raise ValueError("No JSON structure found in thinking text")

# Clean up the JSON string
json_str = json_str.strip()

# Handle potential string formatting placeholders
json_str = json_str.replace('{content}', '') # Remove any {content} placeholders

# Parse JSON string
data = json.loads(json_str)

# Convert to Pydantic model
return response_model(**data)

except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format: {str(e)}\nJSON string was: {json_str}")
except Exception as e:
raise ValueError(f"Failed to parse thinking output: {str(e)}\nInput text was: {thinking_text[:200]}...")
2 changes: 1 addition & 1 deletion tests/models/invoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def convert_quantity_to_int(cls, v):
class InvoiceContract(Contract):
invoice_number: str
invoice_date: str
lines: List[InvoiceLine]
total_amount: float
lines: List[InvoiceLine]

class CreditNoteContract(Contract):
credit_note_number: str
Expand Down
136 changes: 82 additions & 54 deletions tests/test_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,60 +159,6 @@ def test_extract_with_invalid_file_path():

assert "Failed to extract from source" in str(exc_info.value.args[0])

def test_batch_extraction_single_source():
# Arrange
load_dotenv()
tesseract_path = os.getenv("TESSERACT_PATH")
test_file_path = os.path.join(os.getcwd(), "tests", "test_images", "invoice.png")

extractor = Extractor()
extractor.load_document_loader(DocumentLoaderTesseract(tesseract_path))
extractor.load_llm("gpt-4o-mini")

# Act
batch_job = extractor.extract_batch(test_file_path, InvoiceContract)

# Assert batch status
status = asyncio.run(batch_job.get_status())
assert status in ["queued", "processing", "completed"]
print(f"Batch status: {status}")

result = asyncio.run(batch_job.get_result())

# Get results and verify
assert result.invoice_number == "0000001"
assert result.invoice_date == "2014-05-07"

def test_cancel_batch_extraction():
# Arrange
tesseract_path = os.getenv("TESSERACT_PATH")
test_file_path = os.path.join(os.getcwd(), "tests", "test_images", "invoice.png")
batch_file_path = os.path.join(os.getcwd(), "tests", "batch_input.jsonl")
output_file_path = os.path.join(os.getcwd(), "tests", "batch_output.jsonl")

extractor = Extractor()
extractor.load_document_loader(DocumentLoaderTesseract(tesseract_path))
extractor.load_llm("gpt-4o-mini")

# Act
batch_job = extractor.extract_batch(
test_file_path,
InvoiceContract,
batch_file_path=batch_file_path,
output_file_path=output_file_path
)

# Cancel the batch job
cancel_success = asyncio.run(batch_job.cancel())
assert cancel_success, "Batch job cancellation failed"

# Add a small delay to ensure cleanup has time to complete
time.sleep(1)

# Check if files were removed
assert not os.path.exists(batch_job.file_path), f"Batch input file was not removed: {batch_job.file_path}"
assert not os.path.exists(batch_job.output_path), f"Batch output file was not removed: {batch_job.output_path}"

def test_forbidden_strategy_with_token_limit():
test_file_path = os.path.join(os.getcwd(), "tests", "test_images", "eu_tax_chart.png")
tesseract_path = os.getenv("TESSERACT_PATH")
Expand Down Expand Up @@ -390,3 +336,85 @@ def test_llm_timeout():
# Verify normal operation works after reset
result = extractor.extract(test_file_path, InvoiceContract)
assert result is not None

def test_dynamic_json_parsing():
"""Test dynamic JSON parsing with local Ollama model."""
# Initialize components
llm = LLM(model="ollama/deepseek-r1:1.5b")
llm.set_dynamic(True) # Enable dynamic JSON parsing

document_loader = DocumentLoaderPyPdf()
extractor = Extractor(document_loader=document_loader, llm=llm)

# Test content that should produce JSON response
test_file_path = os.path.join(cwd, "tests", "files", "invoice.pdf")

# Extract with dynamic parsing
try:
result = extractor.extract(test_file_path, InvoiceContract)

# Verify the result is an InvoiceContract instance
assert isinstance(result, InvoiceContract)

# Verify invoice fields
assert result.invoice_number is not None
assert result.invoice_date is not None
assert result.total_amount is not None
assert isinstance(result.lines, list)

except Exception as e:
pytest.fail(f"Dynamic JSON parsing test failed: {str(e)}")

def test_batch_extraction_single_source():
# Arrange
load_dotenv()
tesseract_path = os.getenv("TESSERACT_PATH")
test_file_path = os.path.join(os.getcwd(), "tests", "test_images", "invoice.png")

extractor = Extractor()
extractor.load_document_loader(DocumentLoaderTesseract(tesseract_path))
extractor.load_llm("gpt-4o-mini")

# Act
batch_job = extractor.extract_batch(test_file_path, InvoiceContract)

# Assert batch status
status = asyncio.run(batch_job.get_status())
assert status in ["queued", "processing", "completed"]
print(f"Batch status: {status}")

result = asyncio.run(batch_job.get_result())

# Get results and verify
assert result.invoice_number == "0000001"
assert result.invoice_date == "2014-05-07"

def test_cancel_batch_extraction():
# Arrange
tesseract_path = os.getenv("TESSERACT_PATH")
test_file_path = os.path.join(os.getcwd(), "tests", "test_images", "invoice.png")
batch_file_path = os.path.join(os.getcwd(), "tests", "batch_input.jsonl")
output_file_path = os.path.join(os.getcwd(), "tests", "batch_output.jsonl")

extractor = Extractor()
extractor.load_document_loader(DocumentLoaderTesseract(tesseract_path))
extractor.load_llm("gpt-4o-mini")

# Act
batch_job = extractor.extract_batch(
test_file_path,
InvoiceContract,
batch_file_path=batch_file_path,
output_file_path=output_file_path
)

# Cancel the batch job
cancel_success = asyncio.run(batch_job.cancel())
assert cancel_success, "Batch job cancellation failed"

# Add a small delay to ensure cleanup has time to complete
time.sleep(1)

# Check if files were removed
assert not os.path.exists(batch_job.file_path), f"Batch input file was not removed: {batch_job.file_path}"
assert not os.path.exists(batch_job.output_path), f"Batch output file was not removed: {batch_job.output_path}"

0 comments on commit 7236493

Please sign in to comment.