Skip to content

Commit

Permalink
Update to llama cloud v 0.1.11 (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
neerajprad authored Jan 28, 2025
1 parent e049c74 commit 90eb45b
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 34 deletions.
60 changes: 56 additions & 4 deletions llama_extract/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
ExtractConfig,
ExtractJob,
ExtractJobCreate,
ExtractResultset,
ExtractRun,
File,
ExtractMode,
StatusEnum,
Project,
ExtractTarget,
LlamaExtractSettings,
)
from llama_cloud.client import AsyncLlamaCloud
from llama_extract.utils import JSONObjectType, augment_async_errors
Expand All @@ -33,11 +34,10 @@
SchemaInput = Union[JSONObjectType, Type[BaseModel]]

DEFAULT_EXTRACT_CONFIG = ExtractConfig(
extraction_mode=ExtractMode.PER_DOC,
extraction_target=ExtractTarget.PER_DOC,
extraction_mode=ExtractMode.ACCURATE,
)

ExtractionResult = Tuple[ExtractJob, ExtractResultset]


class ExtractionAgent:
"""Class representing a single extraction agent with methods for extraction operations."""
Expand Down Expand Up @@ -192,6 +192,58 @@ def save(self) -> None:
)
)

async def _queue_extraction_test(
self,
files: Union[FileInput, List[FileInput]],
extract_settings: LlamaExtractSettings,
) -> Union[ExtractJob, List[ExtractJob]]:
if not isinstance(files, list):
files = [files]
single_file = True
else:
single_file = False

upload_tasks = [self._upload_file(file) for file in files]
with augment_async_errors():
uploaded_files = await run_jobs(
upload_tasks,
workers=self.num_workers,
desc="Uploading files",
show_progress=self.show_progress,
)

async def run_job(file: File) -> ExtractRun:
job_queued = await self._client.llama_extract.run_job_test_user(
job_create=ExtractJobCreate(
extraction_agent_id=self.id,
file_id=file.id,
data_schema_override=self.data_schema,
config_override=self.config,
),
extract_settings=extract_settings,
)
return await self._wait_for_job_result(job_queued.id)

job_tasks = [run_job(file) for file in uploaded_files]
with augment_async_errors():
extract_jobs = await run_jobs(
job_tasks,
workers=self.num_workers,
desc="Creating extraction jobs",
show_progress=self.show_progress,
)

if self._verbose:
for file, job in zip(files, extract_jobs):
file_repr = (
str(file) if isinstance(file, (str, Path)) else "<bytes/buffer>"
)
print(
f"Queued file extraction for file {file_repr} under job_id {job.id}"
)

return extract_jobs[0] if single_file else extract_jobs

async def queue_extraction(
self,
files: Union[FileInput, List[FileInput]],
Expand Down
40 changes: 20 additions & 20 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ packages = [{include = "llama_extract"}]
[tool.poetry.dependencies]
python = ">=3.9,<4.0"
llama-index-core = "^0.11.0"
llama-cloud = "0.1.10"
llama-cloud = "0.1.11"
python-dotenv = "^1.0.1"

[tool.poetry.group.dev.dependencies]
Expand Down
149 changes: 149 additions & 0 deletions tests/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import os
import pytest
from pathlib import Path

from llama_extract import LlamaExtract, ExtractionAgent
from dotenv import load_dotenv
from time import perf_counter
from collections import namedtuple
import json
import uuid
from llama_cloud.core.api_error import ApiError
from llama_cloud.types import (
ExtractConfig,
ExtractMode,
LlamaParseParameters,
LlamaExtractSettings,
)

load_dotenv(Path(__file__).parent.parent / ".env.dev", override=True)


TEST_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
# Get configuration from environment
LLAMA_CLOUD_API_KEY = os.getenv("LLAMA_CLOUD_API_KEY")
LLAMA_CLOUD_BASE_URL = os.getenv("LLAMA_CLOUD_BASE_URL")
LLAMA_CLOUD_PROJECT_ID = os.getenv("LLAMA_CLOUD_PROJECT_ID")

TestCase = namedtuple(
"TestCase", ["name", "schema_path", "config", "input_file", "expected_output"]
)


def get_test_cases():
"""Get all test cases from TEST_DIR.
Returns:
List[TestCase]: List of test cases
"""
test_cases = []

for data_type in os.listdir(TEST_DIR):
data_type_dir = os.path.join(TEST_DIR, data_type)
if not os.path.isdir(data_type_dir):
continue

schema_path = os.path.join(data_type_dir, "schema.json")
if not os.path.exists(schema_path):
continue

input_files = []

for file in os.listdir(data_type_dir):
file_path = os.path.join(data_type_dir, file)
if (
not os.path.isfile(file_path)
or file == "schema.json"
or file.endswith(".test.json")
):
continue

input_files.append(file_path)

settings = [
ExtractConfig(extraction_mode=ExtractMode.FAST),
ExtractConfig(extraction_mode=ExtractMode.ACCURATE),
]

for input_file in sorted(input_files):
base_name = os.path.splitext(os.path.basename(input_file))[0]
expected_output = os.path.join(data_type_dir, f"{base_name}.test.json")

if not os.path.exists(expected_output):
continue

test_name = f"{data_type}/{os.path.basename(input_file)}"
for setting in settings:
test_cases.append(
TestCase(
name=test_name,
schema_path=schema_path,
input_file=input_file,
config=setting,
expected_output=expected_output,
)
)

return test_cases


@pytest.fixture(scope="session")
def extractor():
"""Create a single LlamaExtract instance for all tests."""
extract = LlamaExtract(
api_key=LLAMA_CLOUD_API_KEY,
base_url=LLAMA_CLOUD_BASE_URL,
project_id=LLAMA_CLOUD_PROJECT_ID,
verbose=True,
)
yield extract
# Cleanup thread pool at end of session
extract._thread_pool.shutdown()


@pytest.fixture
def extraction_agent(test_case: TestCase, extractor: LlamaExtract):
"""Fixture to create and cleanup extraction agent for each test."""
# Create unique name with random UUID (important for CI to avoid conflicts)
unique_id = uuid.uuid4().hex[:8]
agent_name = f"{test_case.name}_{unique_id}"

with open(test_case.schema_path, "r") as f:
schema = json.load(f)

# Clean up any existing agents with this name
try:
agents = extractor.list_agents()
for agent in agents:
if agent.name == agent_name:
extractor.delete_agent(agent.id)
except Exception as e:
print(f"Warning: Failed to cleanup existing agent: {str(e)}")

# Create new agent
agent = extractor.create_agent(agent_name, schema, config=test_case.config)
yield agent


@pytest.mark.skipif(
"CI" in os.environ,
reason="CI environment is not suitable for benchmarking",
)
@pytest.mark.parametrize("test_case", get_test_cases(), ids=lambda x: x.name)
@pytest.mark.asyncio(loop_scope="session")
async def test_extraction(
test_case: TestCase, extraction_agent: ExtractionAgent
) -> None:
start = perf_counter()
result = await extraction_agent._queue_extraction_test(
test_case.input_file,
extract_settings=LlamaExtractSettings(
llama_parse_params=LlamaParseParameters(
invalidate_cache=True,
do_not_cache=True,
)
),
)
end = perf_counter()
print(f"Time taken: {end - start} seconds")
print(result)
Loading

0 comments on commit 90eb45b

Please sign in to comment.