diff --git a/llama_extract/extract.py b/llama_extract/extract.py index 3450e77..f60516e 100644 --- a/llama_extract/extract.py +++ b/llama_extract/extract.py @@ -12,12 +12,13 @@ ExtractConfig, ExtractJob, ExtractJobCreate, - ExtractResultset, ExtractRun, File, ExtractMode, StatusEnum, Project, + ExtractTarget, + LlamaExtractSettings, ) from llama_cloud.client import AsyncLlamaCloud from llama_extract.utils import JSONObjectType, augment_async_errors @@ -33,11 +34,10 @@ SchemaInput = Union[JSONObjectType, Type[BaseModel]] DEFAULT_EXTRACT_CONFIG = ExtractConfig( - extraction_mode=ExtractMode.PER_DOC, + extraction_target=ExtractTarget.PER_DOC, + extraction_mode=ExtractMode.ACCURATE, ) -ExtractionResult = Tuple[ExtractJob, ExtractResultset] - class ExtractionAgent: """Class representing a single extraction agent with methods for extraction operations.""" @@ -192,6 +192,58 @@ def save(self) -> None: ) ) + async def _queue_extraction_test( + self, + files: Union[FileInput, List[FileInput]], + extract_settings: LlamaExtractSettings, + ) -> Union[ExtractJob, List[ExtractJob]]: + if not isinstance(files, list): + files = [files] + single_file = True + else: + single_file = False + + upload_tasks = [self._upload_file(file) for file in files] + with augment_async_errors(): + uploaded_files = await run_jobs( + upload_tasks, + workers=self.num_workers, + desc="Uploading files", + show_progress=self.show_progress, + ) + + async def run_job(file: File) -> ExtractRun: + job_queued = await self._client.llama_extract.run_job_test_user( + job_create=ExtractJobCreate( + extraction_agent_id=self.id, + file_id=file.id, + data_schema_override=self.data_schema, + config_override=self.config, + ), + extract_settings=extract_settings, + ) + return await self._wait_for_job_result(job_queued.id) + + job_tasks = [run_job(file) for file in uploaded_files] + with augment_async_errors(): + extract_jobs = await run_jobs( + job_tasks, + workers=self.num_workers, + desc="Creating extraction jobs", + show_progress=self.show_progress, + ) + + if self._verbose: + for file, job in zip(files, extract_jobs): + file_repr = ( + str(file) if isinstance(file, (str, Path)) else "" + ) + print( + f"Queued file extraction for file {file_repr} under job_id {job.id}" + ) + + return extract_jobs[0] if single_file else extract_jobs + async def queue_extraction( self, files: Union[FileInput, List[FileInput]], diff --git a/poetry.lock b/poetry.lock index 56933fe..1b92eb0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -285,13 +285,13 @@ files = [ [[package]] name = "attrs" -version = "24.3.0" +version = "25.1.0" description = "Classes Without Boilerplate" optional = false python-versions = ">=3.8" files = [ - {file = "attrs-24.3.0-py3-none-any.whl", hash = "sha256:ac96cd038792094f438ad1f6ff80837353805ac950cd2aa0e0625ef19850c308"}, - {file = "attrs-24.3.0.tar.gz", hash = "sha256:8f5c07333d543103541ba7be0e2ce16eeee8130cb0b3f9238ab904ce1e85baff"}, + {file = "attrs-25.1.0-py3-none-any.whl", hash = "sha256:c75a69e28a550a7e93789579c22aa26b0f5b83b75dc4e08fe092980051e1090a"}, + {file = "attrs-25.1.0.tar.gz", hash = "sha256:1c97078a80c814273a76b2a298a932eb681c87415c11dee0a6921de7f1b02c3e"}, ] [package.extras] @@ -737,20 +737,20 @@ files = [ [[package]] name = "deprecated" -version = "1.2.15" +version = "1.2.18" description = "Python @deprecated decorator to deprecate old python classes, functions or methods." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" files = [ - {file = "Deprecated-1.2.15-py2.py3-none-any.whl", hash = "sha256:353bc4a8ac4bfc96800ddab349d89c25dec1079f65fd53acdcc1e0b975b21320"}, - {file = "deprecated-1.2.15.tar.gz", hash = "sha256:683e561a90de76239796e6b6feac66b99030d2dd3fcf61ef996330f14bbb9b0d"}, + {file = "Deprecated-1.2.18-py2.py3-none-any.whl", hash = "sha256:bd5011788200372a32418f888e326a09ff80d0214bd961147cfed01b5c018eec"}, + {file = "deprecated-1.2.18.tar.gz", hash = "sha256:422b6f6d859da6f2ef57857761bfb392480502a64c3028ca9bbe86085d72115d"}, ] [package.dependencies] wrapt = ">=1.10,<2" [package.extras] -dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "jinja2 (>=3.0.3,<3.1.0)", "setuptools", "sphinx (<2)", "tox"] +dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "setuptools", "tox"] [[package]] name = "dirtyjson" @@ -1784,13 +1784,13 @@ rapidfuzz = ">=3.9.0,<4.0.0" [[package]] name = "llama-cloud" -version = "0.1.10" +version = "0.1.11" description = "" optional = false python-versions = "<4,>=3.8" files = [ - {file = "llama_cloud-0.1.10-py3-none-any.whl", hash = "sha256:d91198ad92ea6c3a25757e5d6cb565b4bd6db385dc4fa596a725c0fb81a68f4e"}, - {file = "llama_cloud-0.1.10.tar.gz", hash = "sha256:56ffe8f2910c2047dd4eb1b13da31ee5f67321a000794eee559e0b56954d2f76"}, + {file = "llama_cloud-0.1.11-py3-none-any.whl", hash = "sha256:b703765d03783a5a0fc57a52adc9892f8b91b0c19bbecb85a54ad4e813342951"}, + {file = "llama_cloud-0.1.11.tar.gz", hash = "sha256:d4be5b48659fd9fe1698727be257269a22d7f2733a2ed11bce7065768eb94cbe"}, ] [package.dependencies] @@ -1905,13 +1905,13 @@ files = [ [[package]] name = "marshmallow" -version = "3.25.1" +version = "3.26.0" description = "A lightweight library for converting complex datatypes to and from native Python datatypes." optional = false python-versions = ">=3.9" files = [ - {file = "marshmallow-3.25.1-py3-none-any.whl", hash = "sha256:ec5d00d873ce473b7f2ffcb7104286a376c354cab0c2fa12f5573dab03e87210"}, - {file = "marshmallow-3.25.1.tar.gz", hash = "sha256:f4debda3bb11153d81ac34b0d582bf23053055ee11e791b54b4b35493468040a"}, + {file = "marshmallow-3.26.0-py3-none-any.whl", hash = "sha256:1287bca04e6a5f4094822ac153c03da5e214a0a60bcd557b140f3e66991b8ca1"}, + {file = "marshmallow-3.26.0.tar.gz", hash = "sha256:eb36762a1cc76d7abf831e18a3a1b26d3d481bbc74581b8e532a3d3a8115e1cb"}, ] [package.dependencies] @@ -2751,13 +2751,13 @@ files = [ [[package]] name = "pydantic" -version = "2.10.5" +version = "2.10.6" description = "Data validation using Python type hints" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic-2.10.5-py3-none-any.whl", hash = "sha256:4dd4e322dbe55472cb7ca7e73f4b63574eecccf2835ffa2af9021ce113c83c53"}, - {file = "pydantic-2.10.5.tar.gz", hash = "sha256:278b38dbbaec562011d659ee05f63346951b3a248a6f3642e1bc68894ea2b4ff"}, + {file = "pydantic-2.10.6-py3-none-any.whl", hash = "sha256:427d664bf0b8a2b34ff5dd0f5a18df00591adcee7198fbd71981054cef37b584"}, + {file = "pydantic-2.10.6.tar.gz", hash = "sha256:ca5daa827cce33de7a42be142548b0096bf05a7e7b365aebfa5f8eeec7128236"}, ] [package.dependencies] @@ -3307,13 +3307,13 @@ all = ["numpy"] [[package]] name = "referencing" -version = "0.36.1" +version = "0.36.2" description = "JSON Referencing + Python" optional = false python-versions = ">=3.9" files = [ - {file = "referencing-0.36.1-py3-none-any.whl", hash = "sha256:363d9c65f080d0d70bc41c721dce3c7f3e77fc09f269cd5c8813da18069a6794"}, - {file = "referencing-0.36.1.tar.gz", hash = "sha256:ca2e6492769e3602957e9b831b94211599d2aade9477f5d44110d2530cf9aade"}, + {file = "referencing-0.36.2-py3-none-any.whl", hash = "sha256:e8699adbbf8b5c7de96d8ffa0eb5c158b3beafce084968e2ea8bb08c6794dcd0"}, + {file = "referencing-0.36.2.tar.gz", hash = "sha256:df2e89862cd09deabbdba16944cc3f10feb6b3e6f18e902f7cc25609a34775aa"}, ] [package.dependencies] @@ -4317,4 +4317,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "90b87c91c45412dae185dac7a8eb80b8e6f37d9677990ff394590acd751c7565" +content-hash = "1ff53e863a18be137ee0eff10a8c4412e2db95ca7cdc7aedf22339325d3fc818" diff --git a/pyproject.toml b/pyproject.toml index 6f6e56b..9d84511 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ packages = [{include = "llama_extract"}] [tool.poetry.dependencies] python = ">=3.9,<4.0" llama-index-core = "^0.11.0" -llama-cloud = "0.1.10" +llama-cloud = "0.1.11" python-dotenv = "^1.0.1" [tool.poetry.group.dev.dependencies] diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py new file mode 100644 index 0000000..044bd16 --- /dev/null +++ b/tests/test_benchmark.py @@ -0,0 +1,149 @@ +import os +import pytest +from pathlib import Path + +from llama_extract import LlamaExtract, ExtractionAgent +from dotenv import load_dotenv +from time import perf_counter +from collections import namedtuple +import json +import uuid +from llama_cloud.core.api_error import ApiError +from llama_cloud.types import ( + ExtractConfig, + ExtractMode, + LlamaParseParameters, + LlamaExtractSettings, +) + +load_dotenv(Path(__file__).parent.parent / ".env.dev", override=True) + + +TEST_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") +# Get configuration from environment +LLAMA_CLOUD_API_KEY = os.getenv("LLAMA_CLOUD_API_KEY") +LLAMA_CLOUD_BASE_URL = os.getenv("LLAMA_CLOUD_BASE_URL") +LLAMA_CLOUD_PROJECT_ID = os.getenv("LLAMA_CLOUD_PROJECT_ID") + +TestCase = namedtuple( + "TestCase", ["name", "schema_path", "config", "input_file", "expected_output"] +) + + +def get_test_cases(): + """Get all test cases from TEST_DIR. + + Returns: + List[TestCase]: List of test cases + """ + test_cases = [] + + for data_type in os.listdir(TEST_DIR): + data_type_dir = os.path.join(TEST_DIR, data_type) + if not os.path.isdir(data_type_dir): + continue + + schema_path = os.path.join(data_type_dir, "schema.json") + if not os.path.exists(schema_path): + continue + + input_files = [] + + for file in os.listdir(data_type_dir): + file_path = os.path.join(data_type_dir, file) + if ( + not os.path.isfile(file_path) + or file == "schema.json" + or file.endswith(".test.json") + ): + continue + + input_files.append(file_path) + + settings = [ + ExtractConfig(extraction_mode=ExtractMode.FAST), + ExtractConfig(extraction_mode=ExtractMode.ACCURATE), + ] + + for input_file in sorted(input_files): + base_name = os.path.splitext(os.path.basename(input_file))[0] + expected_output = os.path.join(data_type_dir, f"{base_name}.test.json") + + if not os.path.exists(expected_output): + continue + + test_name = f"{data_type}/{os.path.basename(input_file)}" + for setting in settings: + test_cases.append( + TestCase( + name=test_name, + schema_path=schema_path, + input_file=input_file, + config=setting, + expected_output=expected_output, + ) + ) + + return test_cases + + +@pytest.fixture(scope="session") +def extractor(): + """Create a single LlamaExtract instance for all tests.""" + extract = LlamaExtract( + api_key=LLAMA_CLOUD_API_KEY, + base_url=LLAMA_CLOUD_BASE_URL, + project_id=LLAMA_CLOUD_PROJECT_ID, + verbose=True, + ) + yield extract + # Cleanup thread pool at end of session + extract._thread_pool.shutdown() + + +@pytest.fixture +def extraction_agent(test_case: TestCase, extractor: LlamaExtract): + """Fixture to create and cleanup extraction agent for each test.""" + # Create unique name with random UUID (important for CI to avoid conflicts) + unique_id = uuid.uuid4().hex[:8] + agent_name = f"{test_case.name}_{unique_id}" + + with open(test_case.schema_path, "r") as f: + schema = json.load(f) + + # Clean up any existing agents with this name + try: + agents = extractor.list_agents() + for agent in agents: + if agent.name == agent_name: + extractor.delete_agent(agent.id) + except Exception as e: + print(f"Warning: Failed to cleanup existing agent: {str(e)}") + + # Create new agent + agent = extractor.create_agent(agent_name, schema, config=test_case.config) + yield agent + + +@pytest.mark.skipif( + "CI" in os.environ, + reason="CI environment is not suitable for benchmarking", +) +@pytest.mark.parametrize("test_case", get_test_cases(), ids=lambda x: x.name) +@pytest.mark.asyncio(loop_scope="session") +async def test_extraction( + test_case: TestCase, extraction_agent: ExtractionAgent +) -> None: + start = perf_counter() + result = await extraction_agent._queue_extraction_test( + test_case.input_file, + extract_settings=LlamaExtractSettings( + llama_parse_params=LlamaParseParameters( + invalidate_cache=True, + do_not_cache=True, + ) + ), + ) + end = perf_counter() + print(f"Time taken: {end - start} seconds") + print(result) diff --git a/tests/test_extract_e2e.py b/tests/test_extract_e2e.py index 31d3d9c..6e54518 100644 --- a/tests/test_extract_e2e.py +++ b/tests/test_extract_e2e.py @@ -8,6 +8,7 @@ import json import uuid from llama_cloud.core.api_error import ApiError +from llama_cloud.types import ExtractConfig, ExtractMode, ExtractConfig from deepdiff import DeepDiff from tests.util import json_subset_match_score @@ -21,7 +22,7 @@ LLAMA_CLOUD_PROJECT_ID = os.getenv("LLAMA_CLOUD_PROJECT_ID") TestCase = namedtuple( - "TestCase", ["name", "schema_path", "input_file", "expected_output"] + "TestCase", ["name", "schema_path", "config", "input_file", "expected_output"] ) @@ -55,6 +56,11 @@ def get_test_cases(): input_files.append(file_path) + settings = [ + ExtractConfig(extraction_mode=ExtractMode.FAST), + ExtractConfig(extraction_mode=ExtractMode.ACCURATE), + ] + for input_file in sorted(input_files): base_name = os.path.splitext(os.path.basename(input_file))[0] expected_output = os.path.join(data_type_dir, f"{base_name}.test.json") @@ -63,14 +69,16 @@ def get_test_cases(): continue test_name = f"{data_type}/{os.path.basename(input_file)}" - test_cases.append( - TestCase( - name=test_name, - schema_path=schema_path, - input_file=input_file, - expected_output=expected_output, + for setting in settings: + test_cases.append( + TestCase( + name=test_name, + schema_path=schema_path, + input_file=input_file, + config=setting, + expected_output=expected_output, + ) ) - ) return test_cases @@ -109,7 +117,7 @@ def extraction_agent(test_case: TestCase, extractor: LlamaExtract): print(f"Warning: Failed to cleanup existing agent: {str(e)}") # Create new agent - agent = extractor.create_agent(agent_name, schema) + agent = extractor.create_agent(agent_name, schema, config=test_case.config) yield agent # Cleanup after test