From 4770f12991868de9fa178938b3771b3962069b78 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Wed, 8 Jan 2025 10:06:49 -0800 Subject: [PATCH 01/10] Remove dependency on pyproject.toml in site package --- py/core/main/api/v3/users_router.py | 40 ++++++++++++------------ py/core/telemetry/telemetry_decorator.py | 9 ++---- py/r2r/__init__.py | 8 ++--- 3 files changed, 24 insertions(+), 33 deletions(-) diff --git a/py/core/main/api/v3/users_router.py b/py/core/main/api/v3/users_router.py index 0b24eb89c..5c4dc9d24 100644 --- a/py/core/main/api/v3/users_router.py +++ b/py/core/main/api/v3/users_router.py @@ -1537,9 +1537,6 @@ async def create_user_api_key( 403, ) - print("name =", name) - print("description =", description) - api_key = await self.services.auth.create_user_api_key( id, name=name, description=description ) @@ -1753,28 +1750,31 @@ async def get_user_limits( "x-codeSamples": [ { "lang": "Python", - "source": """ -from r2r import R2RClient + "source": textwrap.dedent( + """ + from r2r import R2RClient -client = R2RClient() -client.login(...) # Or some other auth flow + client = R2RClient() + client.login(...) # Or some other auth flow -metadata_update = { - "some_key": "some_value", - "old_key": "" -} -updated_user = client.users.patch_metadata("550e8400-e29b-41d4-a716-446655440000", metadata_update) -print(updated_user) - """, + metadata_update = { + "some_key": "some_value", + "old_key": "" + } + updated_user = client.users.patch_metadata("550e8400-e29b-41d4-a716-446655440000", metadata_update) + """, + ), }, { "lang": "cURL", - "source": """ -curl -X PATCH "https://api.example.com/v3/users/550e8400-e29b-41d4-a716-446655440000/metadata" \\ - -H "Authorization: Bearer YOUR_API_TOKEN" \\ - -H "Content-Type: application/json" \\ - -d '{"some_key":"some_value","old_key":""}' - """, + "source": textwrap.dedent( + """ + curl -X PATCH "https://api.example.com/v3/users/550e8400-e29b-41d4-a716-446655440000/metadata" \\ + -H "Authorization: Bearer YOUR_API_TOKEN" \\ + -H "Content-Type: application/json" \\ + -d '{"some_key":"some_value","old_key":""}' + """, + ), }, ] }, diff --git a/py/core/telemetry/telemetry_decorator.py b/py/core/telemetry/telemetry_decorator.py index 40f7216da..24630f9b6 100644 --- a/py/core/telemetry/telemetry_decorator.py +++ b/py/core/telemetry/telemetry_decorator.py @@ -4,11 +4,10 @@ import uuid from concurrent.futures import ThreadPoolExecutor from functools import wraps +from importlib import metadata from pathlib import Path from typing import Optional -import toml - from core.telemetry.events import ErrorEvent, FeatureUsageEvent from core.telemetry.posthog import telemetry_client @@ -25,11 +24,7 @@ class ProductTelemetryClient: def version(self) -> str: if self._version is None: try: - pyproject_path = ( - Path(__file__).parent.parent.parent / "pyproject.toml" - ) - pyproject_data = toml.load(pyproject_path) - self._version = pyproject_data["tool"]["poetry"]["version"] + self._version = metadata.version("r2r") except Exception as e: logger.error( f"Error reading version from pyproject.toml: {str(e)}" diff --git a/py/r2r/__init__.py b/py/r2r/__init__.py index e8a931905..2206e091e 100644 --- a/py/r2r/__init__.py +++ b/py/r2r/__init__.py @@ -1,7 +1,5 @@ import logging -from pathlib import Path - -import toml +from importlib import metadata from sdk.async_client import R2RAsyncClient from sdk.models import R2RException @@ -9,9 +7,7 @@ logger = logging.getLogger() -pyproject_path = Path(__file__).parent.parent / "pyproject.toml" -pyproject_data = toml.load(pyproject_path) -__version__ = pyproject_data["tool"]["poetry"]["version"] +__version__ = metadata.version("r2r") __all__ = [ "R2RAsyncClient", From 57acfce6ca0a858083eb812931876c0124081607 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Wed, 8 Jan 2025 10:13:34 -0800 Subject: [PATCH 02/10] Add package_data --- py/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/py/pyproject.toml b/py/pyproject.toml index 43f473d72..62e5948a2 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -25,6 +25,7 @@ packages = [ { include = "core", from = "." }, { include = "cli", from = "." }, ] +package_data = { "r2r" = ["*.yaml", "*.toml"] } [tool.poetry.dependencies] # Python Versions From 84c2a343700e3e9aa4844c705f0c435012f2344c Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Wed, 8 Jan 2025 10:16:15 -0800 Subject: [PATCH 03/10] Move to include, mapped by poetry --- py/pyproject.toml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/py/pyproject.toml b/py/pyproject.toml index 62e5948a2..8ff8626f7 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -16,7 +16,9 @@ include = [ "compose.full.yaml", "compose.full_with_replicas.yaml", "pyproject.toml", - "migrations/**/*" + "migrations/**/*", + "r2r/**/*.yaml", + "r2r/**/*.toml" ] packages = [ { include = "r2r" }, @@ -25,7 +27,6 @@ packages = [ { include = "core", from = "." }, { include = "cli", from = "." }, ] -package_data = { "r2r" = ["*.yaml", "*.toml"] } [tool.poetry.dependencies] # Python Versions From 68f9fd9676522a8b89a9bbcbfde1bff1b0c61b12 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Wed, 8 Jan 2025 10:18:52 -0800 Subject: [PATCH 04/10] Bump --- py/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/pyproject.toml b/py/pyproject.toml index 8ff8626f7..7d31c6b2b 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "r2r" readme = "README.md" -version = "3.3.22" +version = "3.3.23" description = "SciPhi R2R" authors = ["Owen Colegrove "] From e57f5ba26d038a9f741beba972990a1a98a98264 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Wed, 8 Jan 2025 10:48:58 -0800 Subject: [PATCH 05/10] Poetry enforces stricter package rules --- py/cli/utils/docker_utils.py | 6 +- py/pyproject.toml | 21 ++-- py/{ => r2r}/compose.full.yaml | 17 +-- py/{ => r2r}/compose.full_with_replicas.yaml | 0 py/r2r/compose.yaml | 124 +++++++++++++++++++ py/{ => r2r}/r2r.toml | 0 6 files changed, 145 insertions(+), 23 deletions(-) rename py/{ => r2r}/compose.full.yaml (98%) rename py/{ => r2r}/compose.full_with_replicas.yaml (100%) create mode 100644 py/r2r/compose.yaml rename py/{ => r2r}/r2r.toml (100%) diff --git a/py/cli/utils/docker_utils.py b/py/cli/utils/docker_utils.py index 4736445ff..3a3e4bbab 100644 --- a/py/cli/utils/docker_utils.py +++ b/py/cli/utils/docker_utils.py @@ -313,10 +313,10 @@ def get_compose_files(): "..", ) compose_files = { - "base": os.path.join(package_dir, "compose.yaml"), - "full": os.path.join(package_dir, "compose.full.yaml"), + "base": os.path.join(package_dir, "r2r", "compose.yaml"), + "full": os.path.join(package_dir, "r2r", "compose.full.yaml"), "full_scale": os.path.join( - package_dir, "compose.full_with_replicas.yaml" + package_dir, "r2r", "compose.full_with_replicas.yaml" ), } diff --git a/py/pyproject.toml b/py/pyproject.toml index 7d31c6b2b..315d4d963 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -5,28 +5,25 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "r2r" readme = "README.md" -version = "3.3.23" +version = "3.3.24" description = "SciPhi R2R" authors = ["Owen Colegrove "] license = "MIT" -include = [ - "r2r.toml", - "compose.yaml", - "compose.full.yaml", - "compose.full_with_replicas.yaml", - "pyproject.toml", - "migrations/**/*", - "r2r/**/*.yaml", - "r2r/**/*.toml" -] + packages = [ - { include = "r2r" }, + { include = "r2r", from = "." }, { include = "sdk", from = "." }, { include = "shared", from = "." }, { include = "core", from = "." }, { include = "cli", from = "." }, ] +include = [ + "migrations/**/*", + "pyproject.toml", + "r2r/**/*.yaml", + "r2r/**/*.toml" +] [tool.poetry.dependencies] # Python Versions diff --git a/py/compose.full.yaml b/py/r2r/compose.full.yaml similarity index 98% rename from py/compose.full.yaml rename to py/r2r/compose.full.yaml index 2711c62a0..647320a3d 100644 --- a/py/compose.full.yaml +++ b/py/r2r/compose.full.yaml @@ -281,14 +281,15 @@ services: retries: 5 r2r: - image: ${R2R_IMAGE:-ragtoriches/prod:latest} - build: - context: . - args: - PORT: ${R2R_PORT:-7272} - R2R_PORT: ${R2R_PORT:-7272} - HOST: ${R2R_HOST:-0.0.0.0} - R2R_HOST: ${R2R_HOST:-0.0.0.0} + # image: ${R2R_IMAGE:-ragtoriches/prod:latest} + image: r2r/test + # build: + # context: . + # args: + # PORT: ${R2R_PORT:-7272} + # R2R_PORT: ${R2R_PORT:-7272} + # HOST: ${R2R_HOST:-0.0.0.0} + # R2R_HOST: ${R2R_HOST:-0.0.0.0} ports: - "${R2R_PORT:-7272}:${R2R_PORT:-7272}" environment: diff --git a/py/compose.full_with_replicas.yaml b/py/r2r/compose.full_with_replicas.yaml similarity index 100% rename from py/compose.full_with_replicas.yaml rename to py/r2r/compose.full_with_replicas.yaml diff --git a/py/r2r/compose.yaml b/py/r2r/compose.yaml new file mode 100644 index 000000000..746bb13c5 --- /dev/null +++ b/py/r2r/compose.yaml @@ -0,0 +1,124 @@ +networks: + r2r-network: + driver: bridge + attachable: true + labels: + - "com.docker.compose.recreate=always" + +volumes: + postgres_data: + name: ${VOLUME_POSTGRES_DATA:-postgres_data} + +services: + postgres: + image: pgvector/pgvector:pg16 + profiles: [postgres] + environment: + - POSTGRES_USER=${R2R_POSTGRES_USER:-postgres} + - POSTGRES_PASSWORD=${R2R_POSTGRES_PASSWORD:-postgres} + - POSTGRES_HOST=${R2R_POSTGRES_HOST:-postgres} + - POSTGRES_PORT=${R2R_POSTGRES_PORT:-5432} + - POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-1024} + - PGPORT=${R2R_POSTGRES_PORT:-5432} + volumes: + - postgres_data:/var/lib/postgresql/data + networks: + - r2r-network + ports: + - "${R2R_POSTGRES_PORT:-5432}:${R2R_POSTGRES_PORT:-5432}" + healthcheck: + test: ["CMD-SHELL", "pg_isready -U ${R2R_POSTGRES_USER:-postgres}"] + interval: 10s + timeout: 5s + retries: 5 + restart: on-failure + command: > + postgres + -c max_connections=${R2R_POSTGRES_MAX_CONNECTIONS:-1024} + + r2r: + image: ${R2R_IMAGE:-ragtoriches/prod:latest} + build: + context: . + args: + PORT: ${R2R_PORT:-7272} + R2R_PORT: ${R2R_PORT:-7272} + HOST: ${R2R_HOST:-0.0.0.0} + R2R_HOST: ${R2R_HOST:-0.0.0.0} + ports: + - "${R2R_PORT:-7272}:${R2R_PORT:-7272}" + environment: + - PYTHONUNBUFFERED=1 + - R2R_PORT=${R2R_PORT:-7272} + - R2R_HOST=${R2R_HOST:-0.0.0.0} + + # R2R + - R2R_CONFIG_NAME=${R2R_CONFIG_NAME:-} + - R2R_CONFIG_PATH=${R2R_CONFIG_PATH:-} + - R2R_PROJECT_NAME=${R2R_PROJECT_NAME:-r2r_default} + - R2R_SECRET_KEY=${R2R_SECRET_KEY:-} + + # Postgres + - R2R_POSTGRES_USER=${R2R_POSTGRES_USER:-postgres} + - R2R_POSTGRES_PASSWORD=${R2R_POSTGRES_PASSWORD:-postgres} + - R2R_POSTGRES_HOST=${R2R_POSTGRES_HOST:-postgres} + - R2R_POSTGRES_PORT=${R2R_POSTGRES_PORT:-5432} + - R2R_POSTGRES_DBNAME=${R2R_POSTGRES_DBNAME:-postgres} + - R2R_POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-1024} + - R2R_POSTGRES_STATEMENT_CACHE_SIZE=${R2R_POSTGRES_STATEMENT_CACHE_SIZE:-100} + + # OpenAI + - OPENAI_API_KEY=${OPENAI_API_KEY:-} + - OPENAI_API_BASE=${OPENAI_API_BASE:-} + + # Anthropic + - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY:-} + + # Azure + - AZURE_API_KEY=${AZURE_API_KEY:-} + - AZURE_API_BASE=${AZURE_API_BASE:-} + - AZURE_API_VERSION=${AZURE_API_VERSION:-} + + # Google Vertex AI + - GOOGLE_APPLICATION_CREDENTIALS=${GOOGLE_APPLICATION_CREDENTIALS:-} + - VERTEX_PROJECT=${VERTEX_PROJECT:-} + - VERTEX_LOCATION=${VERTEX_LOCATION:-} + + # AWS Bedrock + - AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-} + - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-} + - AWS_REGION_NAME=${AWS_REGION_NAME:-} + + # Groq + - GROQ_API_KEY=${GROQ_API_KEY:-} + + # Cohere + - COHERE_API_KEY=${COHERE_API_KEY:-} + + # Anyscale + - ANYSCALE_API_KEY=${ANYSCALE_API_KEY:-} + + # Ollama + - OLLAMA_API_BASE=${OLLAMA_API_BASE:-http://host.docker.internal:11434} + + networks: + - r2r-network + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:${R2R_PORT:-7272}/v3/health"] + interval: 6s + timeout: 5s + retries: 5 + restart: on-failure + volumes: + - ${R2R_CONFIG_PATH:-/}:${R2R_CONFIG_PATH:-/app/config} + extra_hosts: + - host.docker.internal:host-gateway + + r2r-dashboard: + image: emrgntcmplxty/r2r-dashboard:latest + environment: + - NEXT_PUBLIC_R2R_DEPLOYMENT_URL=${R2R_DEPLOYMENT_URL:-http://localhost:7272} + networks: + - r2r-network + ports: + - "${R2R_DASHBOARD_PORT:-7273}:3000" diff --git a/py/r2r.toml b/py/r2r/r2r.toml similarity index 100% rename from py/r2r.toml rename to py/r2r/r2r.toml From afb429560604f54a5ca0c3bda15195244740f505 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Wed, 8 Jan 2025 10:51:10 -0800 Subject: [PATCH 06/10] Fix compose.full --- py/compose.yaml | 124 --- .../hatchet/ingestion_workflow.py | 768 ------------------ py/r2r/compose.full.yaml | 17 +- 3 files changed, 8 insertions(+), 901 deletions(-) delete mode 100644 py/compose.yaml delete mode 100644 py/core/main/orchestration/hatchet/ingestion_workflow.py diff --git a/py/compose.yaml b/py/compose.yaml deleted file mode 100644 index 746bb13c5..000000000 --- a/py/compose.yaml +++ /dev/null @@ -1,124 +0,0 @@ -networks: - r2r-network: - driver: bridge - attachable: true - labels: - - "com.docker.compose.recreate=always" - -volumes: - postgres_data: - name: ${VOLUME_POSTGRES_DATA:-postgres_data} - -services: - postgres: - image: pgvector/pgvector:pg16 - profiles: [postgres] - environment: - - POSTGRES_USER=${R2R_POSTGRES_USER:-postgres} - - POSTGRES_PASSWORD=${R2R_POSTGRES_PASSWORD:-postgres} - - POSTGRES_HOST=${R2R_POSTGRES_HOST:-postgres} - - POSTGRES_PORT=${R2R_POSTGRES_PORT:-5432} - - POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-1024} - - PGPORT=${R2R_POSTGRES_PORT:-5432} - volumes: - - postgres_data:/var/lib/postgresql/data - networks: - - r2r-network - ports: - - "${R2R_POSTGRES_PORT:-5432}:${R2R_POSTGRES_PORT:-5432}" - healthcheck: - test: ["CMD-SHELL", "pg_isready -U ${R2R_POSTGRES_USER:-postgres}"] - interval: 10s - timeout: 5s - retries: 5 - restart: on-failure - command: > - postgres - -c max_connections=${R2R_POSTGRES_MAX_CONNECTIONS:-1024} - - r2r: - image: ${R2R_IMAGE:-ragtoriches/prod:latest} - build: - context: . - args: - PORT: ${R2R_PORT:-7272} - R2R_PORT: ${R2R_PORT:-7272} - HOST: ${R2R_HOST:-0.0.0.0} - R2R_HOST: ${R2R_HOST:-0.0.0.0} - ports: - - "${R2R_PORT:-7272}:${R2R_PORT:-7272}" - environment: - - PYTHONUNBUFFERED=1 - - R2R_PORT=${R2R_PORT:-7272} - - R2R_HOST=${R2R_HOST:-0.0.0.0} - - # R2R - - R2R_CONFIG_NAME=${R2R_CONFIG_NAME:-} - - R2R_CONFIG_PATH=${R2R_CONFIG_PATH:-} - - R2R_PROJECT_NAME=${R2R_PROJECT_NAME:-r2r_default} - - R2R_SECRET_KEY=${R2R_SECRET_KEY:-} - - # Postgres - - R2R_POSTGRES_USER=${R2R_POSTGRES_USER:-postgres} - - R2R_POSTGRES_PASSWORD=${R2R_POSTGRES_PASSWORD:-postgres} - - R2R_POSTGRES_HOST=${R2R_POSTGRES_HOST:-postgres} - - R2R_POSTGRES_PORT=${R2R_POSTGRES_PORT:-5432} - - R2R_POSTGRES_DBNAME=${R2R_POSTGRES_DBNAME:-postgres} - - R2R_POSTGRES_MAX_CONNECTIONS=${R2R_POSTGRES_MAX_CONNECTIONS:-1024} - - R2R_POSTGRES_STATEMENT_CACHE_SIZE=${R2R_POSTGRES_STATEMENT_CACHE_SIZE:-100} - - # OpenAI - - OPENAI_API_KEY=${OPENAI_API_KEY:-} - - OPENAI_API_BASE=${OPENAI_API_BASE:-} - - # Anthropic - - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY:-} - - # Azure - - AZURE_API_KEY=${AZURE_API_KEY:-} - - AZURE_API_BASE=${AZURE_API_BASE:-} - - AZURE_API_VERSION=${AZURE_API_VERSION:-} - - # Google Vertex AI - - GOOGLE_APPLICATION_CREDENTIALS=${GOOGLE_APPLICATION_CREDENTIALS:-} - - VERTEX_PROJECT=${VERTEX_PROJECT:-} - - VERTEX_LOCATION=${VERTEX_LOCATION:-} - - # AWS Bedrock - - AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-} - - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-} - - AWS_REGION_NAME=${AWS_REGION_NAME:-} - - # Groq - - GROQ_API_KEY=${GROQ_API_KEY:-} - - # Cohere - - COHERE_API_KEY=${COHERE_API_KEY:-} - - # Anyscale - - ANYSCALE_API_KEY=${ANYSCALE_API_KEY:-} - - # Ollama - - OLLAMA_API_BASE=${OLLAMA_API_BASE:-http://host.docker.internal:11434} - - networks: - - r2r-network - healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:${R2R_PORT:-7272}/v3/health"] - interval: 6s - timeout: 5s - retries: 5 - restart: on-failure - volumes: - - ${R2R_CONFIG_PATH:-/}:${R2R_CONFIG_PATH:-/app/config} - extra_hosts: - - host.docker.internal:host-gateway - - r2r-dashboard: - image: emrgntcmplxty/r2r-dashboard:latest - environment: - - NEXT_PUBLIC_R2R_DEPLOYMENT_URL=${R2R_DEPLOYMENT_URL:-http://localhost:7272} - networks: - - r2r-network - ports: - - "${R2R_DASHBOARD_PORT:-7273}:3000" diff --git a/py/core/main/orchestration/hatchet/ingestion_workflow.py b/py/core/main/orchestration/hatchet/ingestion_workflow.py deleted file mode 100644 index f0800091d..000000000 --- a/py/core/main/orchestration/hatchet/ingestion_workflow.py +++ /dev/null @@ -1,768 +0,0 @@ -import asyncio -import logging -import uuid -from typing import TYPE_CHECKING -from uuid import UUID - -from fastapi import HTTPException -from hatchet_sdk import ConcurrencyLimitStrategy, Context -from litellm import AuthenticationError - -from core.base import ( - DocumentChunk, - IngestionStatus, - KGEnrichmentStatus, - OrchestrationProvider, - generate_extraction_id, - increment_version, -) -from core.base.abstractions import DocumentResponse, R2RException -from core.utils import ( - generate_default_user_collection_id, - update_settings_from_dict, -) - -from ...services import IngestionService, IngestionServiceAdapter - -if TYPE_CHECKING: - from hatchet_sdk import Hatchet - -logger = logging.getLogger() - - -def hatchet_ingestion_factory( - orchestration_provider: OrchestrationProvider, service: IngestionService -) -> dict[str, "Hatchet.Workflow"]: - @orchestration_provider.workflow( - name="ingest-files", - timeout="60m", - ) - class HatchetIngestFilesWorkflow: - def __init__(self, ingestion_service: IngestionService): - self.ingestion_service = ingestion_service - - @orchestration_provider.concurrency( # type: ignore - max_runs=orchestration_provider.config.ingestion_concurrency_limit, # type: ignore - limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, - ) - def concurrency(self, context: Context) -> str: - # TODO: Possible bug in hatchet, the job can't find context.workflow_input() when rerun - try: - input_data = context.workflow_input()["request"] - parsed_data = IngestionServiceAdapter.parse_ingest_file_input( - input_data - ) - return str(parsed_data["user"].id) - except Exception as e: - return str(uuid.uuid4()) - - @orchestration_provider.step(retries=0, timeout="60m") - async def parse(self, context: Context) -> dict: - try: - logger.info("Initiating ingestion workflow, step: parse") - input_data = context.workflow_input()["request"] - parsed_data = IngestionServiceAdapter.parse_ingest_file_input( - input_data - ) - - ingestion_result = ( - await self.ingestion_service.ingest_file_ingress( - **parsed_data - ) - ) - - document_info = ingestion_result["info"] - - await self.ingestion_service.update_document_status( - document_info, - status=IngestionStatus.PARSING, - ) - - ingestion_config = parsed_data["ingestion_config"] or {} - extractions_generator = ( - await self.ingestion_service.parse_file( - document_info, ingestion_config - ) - ) - - extractions = [] - async for extraction in extractions_generator: - extractions.append(extraction) - - await service.update_document_status( - document_info, status=IngestionStatus.AUGMENTING - ) - await service.augment_document_info( - document_info, - [extraction.to_dict() for extraction in extractions], - ) - - await self.ingestion_service.update_document_status( - document_info, - status=IngestionStatus.EMBEDDING, - ) - - # extractions = context.step_output("parse")["extractions"] - - embedding_generator = ( - await self.ingestion_service.embed_document( - [extraction.to_dict() for extraction in extractions] - ) - ) - - embeddings = [] - async for embedding in embedding_generator: - embeddings.append(embedding) - - await self.ingestion_service.update_document_status( - document_info, - status=IngestionStatus.STORING, - ) - - storage_generator = await self.ingestion_service.store_embeddings( # type: ignore - embeddings - ) - - async for _ in storage_generator: - pass - - await self.ingestion_service.finalize_ingestion(document_info) - - await self.ingestion_service.update_document_status( - document_info, - status=IngestionStatus.SUCCESS, - ) - - collection_ids = context.workflow_input()["request"].get( - "collection_ids" - ) - if not collection_ids: - # TODO: Move logic onto the `management service` - collection_id = generate_default_user_collection_id( - document_info.owner_id - ) - await service.providers.database.collections_handler.assign_document_to_collection_relational( - document_id=document_info.id, - collection_id=collection_id, - ) - await service.providers.database.chunks_handler.assign_document_chunks_to_collection( - document_id=document_info.id, - collection_id=collection_id, - ) - await service.providers.database.documents_handler.set_workflow_status( - id=collection_id, - status_type="graph_sync_status", - status=KGEnrichmentStatus.OUTDATED, - ) - await service.providers.database.documents_handler.set_workflow_status( - id=collection_id, - status_type="graph_cluster_status", # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still - status=KGEnrichmentStatus.OUTDATED, - ) - else: - for collection_id_str in collection_ids: - collection_id = UUID(collection_id_str) - try: - name = document_info.title or "N/A" - description = "" - await service.providers.database.collections_handler.create_collection( - owner_id=document_info.owner_id, - name=name, - description=description, - collection_id=collection_id, - ) - await self.providers.database.graphs_handler.create( - collection_id=collection_id, - name=name, - description=description, - graph_id=collection_id, - ) - - except Exception as e: - logger.warning( - f"Warning, could not create collection with error: {str(e)}" - ) - - await service.providers.database.collections_handler.assign_document_to_collection_relational( - document_id=document_info.id, - collection_id=collection_id, - ) - await service.providers.database.chunks_handler.assign_document_chunks_to_collection( - document_id=document_info.id, - collection_id=collection_id, - ) - await service.providers.database.documents_handler.set_workflow_status( - id=collection_id, - status_type="graph_sync_status", - status=KGEnrichmentStatus.OUTDATED, - ) - await service.providers.database.documents_handler.set_workflow_status( - id=collection_id, - status_type="graph_cluster_status", # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still - status=KGEnrichmentStatus.OUTDATED, - ) - # get server chunk enrichment settings and override parts of it if provided in the ingestion config - server_chunk_enrichment_settings = getattr( - service.providers.ingestion.config, - "chunk_enrichment_settings", - None, - ) - - if server_chunk_enrichment_settings: - chunk_enrichment_settings = update_settings_from_dict( - server_chunk_enrichment_settings, - ingestion_config.get("chunk_enrichment_settings", {}) - or {}, - ) - - if chunk_enrichment_settings.enable_chunk_enrichment: - logger.info("Enriching document with contextual chunks") - - # TODO: the status updating doesn't work because document_info doesn't contain information about collection IDs - # we don't update the document_info when we assign document_to_collection_relational and document_to_collection_vector - # hack: get document_info again from DB - document_info = ( - await self.ingestion_service.providers.database.documents_handler.get_documents_overview( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. - offset=0, - limit=100, - filter_user_ids=[document_info.user_id], - filter_document_ids=[document_info.id], - ) - )["results"][0] - - await self.ingestion_service.update_document_status( - document_info, - status=IngestionStatus.ENRICHING, - ) - - await self.ingestion_service.chunk_enrichment( - document_id=document_info.id, - chunk_enrichment_settings=chunk_enrichment_settings, - ) - - await self.ingestion_service.update_document_status( - document_info, - status=IngestionStatus.ENRICHED, - ) - - return { - "status": "Successfully finalized ingestion", - "document_info": document_info.to_dict(), - } - - except AuthenticationError as e: - raise R2RException( - status_code=401, - message="Authentication error: Invalid API key or credentials.", - ) - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Error during ingestion: {str(e)}", - ) - - @orchestration_provider.failure() - async def on_failure(self, context: Context) -> None: - request = context.workflow_input().get("request", {}) - document_id = request.get("document_id") - - if not document_id: - logger.error( - "No document id was found in workflow input to mark a failure." - ) - return - - try: - documents_overview = ( - await self.ingestion_service.providers.database.documents_handler.get_documents_overview( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. - offset=0, - limit=100, - filter_document_ids=[document_id], - ) - )["results"] - - if not documents_overview: - logger.error( - f"Document with id {document_id} not found in database to mark failure." - ) - return - - document_info = documents_overview[0] - - # Update the document status to FAILED - if document_info.ingestion_status not in [ - IngestionStatus.SUCCESS, - IngestionStatus.ENRICHED, - ]: - await self.ingestion_service.update_document_status( - document_info, - status=IngestionStatus.FAILED, - ) - - except Exception as e: - logger.error( - f"Failed to update document status for {document_id}: {e}" - ) - - # TODO: Implement a check to see if the file is actually changed before updating - @orchestration_provider.workflow(name="update-files", timeout="60m") - class HatchetUpdateFilesWorkflow: - def __init__(self, ingestion_service: IngestionService): - self.ingestion_service = ingestion_service - - @orchestration_provider.step(retries=0, timeout="60m") - async def update_files(self, context: Context) -> None: - data = context.workflow_input()["request"] - parsed_data = IngestionServiceAdapter.parse_update_files_input( - data - ) - - file_datas = parsed_data["file_datas"] - user = parsed_data["user"] - document_ids = parsed_data["document_ids"] - metadatas = parsed_data["metadatas"] - ingestion_config = parsed_data["ingestion_config"] - file_sizes_in_bytes = parsed_data["file_sizes_in_bytes"] - - if not file_datas: - raise R2RException( - status_code=400, message="No files provided for update." - ) - if len(document_ids) != len(file_datas): - raise R2RException( - status_code=400, - message="Number of ids does not match number of files.", - ) - - documents_overview = ( - await self.ingestion_service.providers.database.documents_handler.get_documents_overview( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. - offset=0, - limit=100, - filter_document_ids=document_ids, - filter_user_ids=None if user.is_superuser else [user.id], - ) - )["results"] - - if len(documents_overview) != len(document_ids): - raise R2RException( - status_code=404, - message="One or more documents not found.", - ) - - results = [] - - for idx, ( - file_data, - doc_id, - doc_info, - file_size_in_bytes, - ) in enumerate( - zip( - file_datas, - document_ids, - documents_overview, - file_sizes_in_bytes, - ) - ): - new_version = increment_version(doc_info.version) - - updated_metadata = ( - metadatas[idx] if metadatas else doc_info.metadata - ) - updated_metadata["title"] = ( - updated_metadata.get("title") - or file_data["filename"].split("/")[-1] - ) - - # Prepare input for ingest_file workflow - ingest_input = { - "file_data": file_data, - "user": data.get("user"), - "metadata": updated_metadata, - "document_id": str(doc_id), - "version": new_version, - "ingestion_config": ( - ingestion_config.model_dump_json() - if ingestion_config - else None - ), - "size_in_bytes": file_size_in_bytes, - } - - # Spawn ingest_file workflow as a child workflow - child_result = ( - await context.aio.spawn_workflow( - "ingest-files", - {"request": ingest_input}, - key=f"ingest_file_{doc_id}", - ) - ).result() - results.append(child_result) - - await asyncio.gather(*results) - - return None - - @orchestration_provider.workflow( - name="ingest-chunks", - timeout="60m", - ) - class HatchetIngestChunksWorkflow: - def __init__(self, ingestion_service: IngestionService): - self.ingestion_service = ingestion_service - - @orchestration_provider.step(timeout="60m") - async def ingest(self, context: Context) -> dict: - input_data = context.workflow_input()["request"] - parsed_data = IngestionServiceAdapter.parse_ingest_chunks_input( - input_data - ) - - document_info = await self.ingestion_service.ingest_chunks_ingress( - **parsed_data - ) - - await self.ingestion_service.update_document_status( - document_info, status=IngestionStatus.EMBEDDING - ) - document_id = document_info.id - - extractions = [ - DocumentChunk( - id=generate_extraction_id(document_id, i), - document_id=document_id, - collection_ids=[], - owner_id=document_info.owner_id, - data=chunk.text, - metadata=parsed_data["metadata"], - ).to_dict() - for i, chunk in enumerate(parsed_data["chunks"]) - ] - return { - "status": "Successfully ingested chunks", - "extractions": extractions, - "document_info": document_info.to_dict(), - } - - @orchestration_provider.step(parents=["ingest"], timeout="60m") - async def embed(self, context: Context) -> dict: - document_info_dict = context.step_output("ingest")["document_info"] - document_info = DocumentResponse(**document_info_dict) - - extractions = context.step_output("ingest")["extractions"] - - embedding_generator = await self.ingestion_service.embed_document( - extractions - ) - embeddings = [ - embedding.model_dump() - async for embedding in embedding_generator - ] - - await self.ingestion_service.update_document_status( - document_info, status=IngestionStatus.STORING - ) - - storage_generator = await self.ingestion_service.store_embeddings( - embeddings - ) - async for _ in storage_generator: - pass - - return { - "status": "Successfully embedded and stored chunks", - "document_info": document_info.to_dict(), - } - - @orchestration_provider.step(parents=["embed"], timeout="60m") - async def finalize(self, context: Context) -> dict: - document_info_dict = context.step_output("embed")["document_info"] - document_info = DocumentResponse(**document_info_dict) - - await self.ingestion_service.finalize_ingestion(document_info) - - await self.ingestion_service.update_document_status( - document_info, status=IngestionStatus.SUCCESS - ) - - try: - # TODO - Move logic onto the `management service` - collection_ids = context.workflow_input()["request"].get( - "collection_ids" - ) - if not collection_ids: - # TODO: Move logic onto the `management service` - collection_id = generate_default_user_collection_id( - document_info.owner_id - ) - await service.providers.database.collections_handler.assign_document_to_collection_relational( - document_id=document_info.id, - collection_id=collection_id, - ) - await service.providers.database.chunks_handler.assign_document_chunks_to_collection( - document_id=document_info.id, - collection_id=collection_id, - ) - await service.providers.database.documents_handler.set_workflow_status( - id=collection_id, - status_type="graph_sync_status", - status=KGEnrichmentStatus.OUTDATED, - ) - await service.providers.database.documents_handler.set_workflow_status( - id=collection_id, - status_type="graph_cluster_status", # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still - status=KGEnrichmentStatus.OUTDATED, - ) - else: - for collection_id_str in collection_ids: - collection_id = UUID(collection_id_str) - try: - name = document_info.title or "N/A" - description = "" - await service.providers.database.collections_handler.create_collection( - owner_id=document_info.owner_id, - name=name, - description=description, - collection_id=collection_id, - ) - await self.providers.database.graphs_handler.create( - collection_id=collection_id, - name=name, - description=description, - graph_id=collection_id, - ) - - except Exception as e: - logger.warning( - f"Warning, could not create collection with error: {str(e)}" - ) - - await service.providers.database.collections_handler.assign_document_to_collection_relational( - document_id=document_info.id, - collection_id=collection_id, - ) - - await service.providers.database.chunks_handler.assign_document_chunks_to_collection( - document_id=document_info.id, - collection_id=collection_id, - ) - - await service.providers.database.documents_handler.set_workflow_status( - id=collection_id, - status_type="graph_sync_status", - status=KGEnrichmentStatus.OUTDATED, - ) - - await service.providers.database.documents_handler.set_workflow_status( - id=collection_id, - status_type="graph_cluster_status", - status=KGEnrichmentStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still - ) - except Exception as e: - logger.error( - f"Error during assigning document to collection: {str(e)}" - ) - - return { - "status": "Successfully finalized ingestion", - "document_info": document_info.to_dict(), - } - - @orchestration_provider.failure() - async def on_failure(self, context: Context) -> None: - request = context.workflow_input().get("request", {}) - document_id = request.get("document_id") - - if not document_id: - logger.error( - "No document id was found in workflow input to mark a failure." - ) - return - - try: - documents_overview = ( - await self.ingestion_service.providers.database.documents_handler.get_documents_overview( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. - offset=0, - limit=100, - filter_document_ids=[document_id], - ) - )["results"] - - if not documents_overview: - logger.error( - f"Document with id {document_id} not found in database to mark failure." - ) - return - - document_info = documents_overview[0] - - if document_info.ingestion_status != IngestionStatus.SUCCESS: - await self.ingestion_service.update_document_status( - document_info, status=IngestionStatus.FAILED - ) - - except Exception as e: - logger.error( - f"Failed to update document status for {document_id}: {e}" - ) - - @orchestration_provider.workflow( - name="update-chunk", - timeout="60m", - ) - class HatchetUpdateChunkWorkflow: - def __init__(self, ingestion_service: IngestionService): - self.ingestion_service = ingestion_service - - @orchestration_provider.step(timeout="60m") - async def update_chunk(self, context: Context) -> dict: - try: - input_data = context.workflow_input()["request"] - parsed_data = IngestionServiceAdapter.parse_update_chunk_input( - input_data - ) - - document_uuid = ( - UUID(parsed_data["document_id"]) - if isinstance(parsed_data["document_id"], str) - else parsed_data["document_id"] - ) - extraction_uuid = ( - UUID(parsed_data["id"]) - if isinstance(parsed_data["id"], str) - else parsed_data["id"] - ) - - await self.ingestion_service.update_chunk_ingress( - document_id=document_uuid, - chunk_id=extraction_uuid, - text=parsed_data.get("text"), - user=parsed_data["user"], - metadata=parsed_data.get("metadata"), - collection_ids=parsed_data.get("collection_ids"), - ) - - return { - "message": "Chunk update completed successfully.", - "task_id": context.workflow_run_id(), - "document_ids": [str(document_uuid)], - } - - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Error during chunk update: {str(e)}", - ) - - @orchestration_provider.failure() - async def on_failure(self, context: Context) -> None: - # Handle failure case if necessary - pass - - @orchestration_provider.workflow( - name="create-vector-index", timeout="360m" - ) - class HatchetCreateVectorIndexWorkflow: - def __init__(self, ingestion_service: IngestionService): - self.ingestion_service = ingestion_service - - @orchestration_provider.step(timeout="60m") - async def create_vector_index(self, context: Context) -> dict: - input_data = context.workflow_input()["request"] - parsed_data = ( - IngestionServiceAdapter.parse_create_vector_index_input( - input_data - ) - ) - - await self.ingestion_service.providers.database.chunks_handler.create_index( - **parsed_data - ) - - return { - "status": "Vector index creation queued successfully.", - } - - @orchestration_provider.workflow(name="delete-vector-index", timeout="30m") - class HatchetDeleteVectorIndexWorkflow: - def __init__(self, ingestion_service: IngestionService): - self.ingestion_service = ingestion_service - - @orchestration_provider.step(timeout="10m") - async def delete_vector_index(self, context: Context) -> dict: - input_data = context.workflow_input()["request"] - parsed_data = ( - IngestionServiceAdapter.parse_delete_vector_index_input( - input_data - ) - ) - - await self.ingestion_service.providers.database.chunks_handler.delete_index( - **parsed_data - ) - - return {"status": "Vector index deleted successfully."} - - @orchestration_provider.workflow( - name="update-document-metadata", - timeout="30m", - ) - class HatchetUpdateDocumentMetadataWorkflow: - def __init__(self, ingestion_service: IngestionService): - self.ingestion_service = ingestion_service - - @orchestration_provider.step(timeout="30m") - async def update_document_metadata(self, context: Context) -> dict: - try: - input_data = context.workflow_input()["request"] - parsed_data = IngestionServiceAdapter.parse_update_document_metadata_input( - input_data - ) - - document_id = UUID(parsed_data["document_id"]) - metadata = parsed_data["metadata"] - user = parsed_data["user"] - - await self.ingestion_service.update_document_metadata( - document_id=document_id, - metadata=metadata, - user=user, - ) - - return { - "message": "Document metadata update completed successfully.", - "document_id": str(document_id), - "task_id": context.workflow_run_id(), - } - - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Error during document metadata update: {str(e)}", - ) - - @orchestration_provider.failure() - async def on_failure(self, context: Context) -> None: - # Handle failure case if necessary - pass - - # Add this to the workflows dictionary in hatchet_ingestion_factory - ingest_files_workflow = HatchetIngestFilesWorkflow(service) - update_files_workflow = HatchetUpdateFilesWorkflow(service) - ingest_chunks_workflow = HatchetIngestChunksWorkflow(service) - update_chunks_workflow = HatchetUpdateChunkWorkflow(service) - update_document_metadata_workflow = HatchetUpdateDocumentMetadataWorkflow( - service - ) - create_vector_index_workflow = HatchetCreateVectorIndexWorkflow(service) - delete_vector_index_workflow = HatchetDeleteVectorIndexWorkflow(service) - - return { - "ingest_files": ingest_files_workflow, - "update_files": update_files_workflow, - "ingest_chunks": ingest_chunks_workflow, - "update_chunk": update_chunks_workflow, - "update_document_metadata": update_document_metadata_workflow, - "create_vector_index": create_vector_index_workflow, - "delete_vector_index": delete_vector_index_workflow, - } diff --git a/py/r2r/compose.full.yaml b/py/r2r/compose.full.yaml index 647320a3d..2711c62a0 100644 --- a/py/r2r/compose.full.yaml +++ b/py/r2r/compose.full.yaml @@ -281,15 +281,14 @@ services: retries: 5 r2r: - # image: ${R2R_IMAGE:-ragtoriches/prod:latest} - image: r2r/test - # build: - # context: . - # args: - # PORT: ${R2R_PORT:-7272} - # R2R_PORT: ${R2R_PORT:-7272} - # HOST: ${R2R_HOST:-0.0.0.0} - # R2R_HOST: ${R2R_HOST:-0.0.0.0} + image: ${R2R_IMAGE:-ragtoriches/prod:latest} + build: + context: . + args: + PORT: ${R2R_PORT:-7272} + R2R_PORT: ${R2R_PORT:-7272} + HOST: ${R2R_HOST:-0.0.0.0} + R2R_HOST: ${R2R_HOST:-0.0.0.0} ports: - "${R2R_PORT:-7272}:${R2R_PORT:-7272}" environment: From 1c212da660f4ae8a670252c723b6814663f1b019 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Wed, 8 Jan 2025 10:54:45 -0800 Subject: [PATCH 07/10] Restore workflow --- .../hatchet/ingestion_workflow.py | 768 ++++++++++++++++++ 1 file changed, 768 insertions(+) create mode 100644 py/core/main/orchestration/hatchet/ingestion_workflow.py diff --git a/py/core/main/orchestration/hatchet/ingestion_workflow.py b/py/core/main/orchestration/hatchet/ingestion_workflow.py new file mode 100644 index 000000000..f0800091d --- /dev/null +++ b/py/core/main/orchestration/hatchet/ingestion_workflow.py @@ -0,0 +1,768 @@ +import asyncio +import logging +import uuid +from typing import TYPE_CHECKING +from uuid import UUID + +from fastapi import HTTPException +from hatchet_sdk import ConcurrencyLimitStrategy, Context +from litellm import AuthenticationError + +from core.base import ( + DocumentChunk, + IngestionStatus, + KGEnrichmentStatus, + OrchestrationProvider, + generate_extraction_id, + increment_version, +) +from core.base.abstractions import DocumentResponse, R2RException +from core.utils import ( + generate_default_user_collection_id, + update_settings_from_dict, +) + +from ...services import IngestionService, IngestionServiceAdapter + +if TYPE_CHECKING: + from hatchet_sdk import Hatchet + +logger = logging.getLogger() + + +def hatchet_ingestion_factory( + orchestration_provider: OrchestrationProvider, service: IngestionService +) -> dict[str, "Hatchet.Workflow"]: + @orchestration_provider.workflow( + name="ingest-files", + timeout="60m", + ) + class HatchetIngestFilesWorkflow: + def __init__(self, ingestion_service: IngestionService): + self.ingestion_service = ingestion_service + + @orchestration_provider.concurrency( # type: ignore + max_runs=orchestration_provider.config.ingestion_concurrency_limit, # type: ignore + limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, + ) + def concurrency(self, context: Context) -> str: + # TODO: Possible bug in hatchet, the job can't find context.workflow_input() when rerun + try: + input_data = context.workflow_input()["request"] + parsed_data = IngestionServiceAdapter.parse_ingest_file_input( + input_data + ) + return str(parsed_data["user"].id) + except Exception as e: + return str(uuid.uuid4()) + + @orchestration_provider.step(retries=0, timeout="60m") + async def parse(self, context: Context) -> dict: + try: + logger.info("Initiating ingestion workflow, step: parse") + input_data = context.workflow_input()["request"] + parsed_data = IngestionServiceAdapter.parse_ingest_file_input( + input_data + ) + + ingestion_result = ( + await self.ingestion_service.ingest_file_ingress( + **parsed_data + ) + ) + + document_info = ingestion_result["info"] + + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.PARSING, + ) + + ingestion_config = parsed_data["ingestion_config"] or {} + extractions_generator = ( + await self.ingestion_service.parse_file( + document_info, ingestion_config + ) + ) + + extractions = [] + async for extraction in extractions_generator: + extractions.append(extraction) + + await service.update_document_status( + document_info, status=IngestionStatus.AUGMENTING + ) + await service.augment_document_info( + document_info, + [extraction.to_dict() for extraction in extractions], + ) + + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.EMBEDDING, + ) + + # extractions = context.step_output("parse")["extractions"] + + embedding_generator = ( + await self.ingestion_service.embed_document( + [extraction.to_dict() for extraction in extractions] + ) + ) + + embeddings = [] + async for embedding in embedding_generator: + embeddings.append(embedding) + + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.STORING, + ) + + storage_generator = await self.ingestion_service.store_embeddings( # type: ignore + embeddings + ) + + async for _ in storage_generator: + pass + + await self.ingestion_service.finalize_ingestion(document_info) + + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.SUCCESS, + ) + + collection_ids = context.workflow_input()["request"].get( + "collection_ids" + ) + if not collection_ids: + # TODO: Move logic onto the `management service` + collection_id = generate_default_user_collection_id( + document_info.owner_id + ) + await service.providers.database.collections_handler.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=KGEnrichmentStatus.OUTDATED, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + status=KGEnrichmentStatus.OUTDATED, + ) + else: + for collection_id_str in collection_ids: + collection_id = UUID(collection_id_str) + try: + name = document_info.title or "N/A" + description = "" + await service.providers.database.collections_handler.create_collection( + owner_id=document_info.owner_id, + name=name, + description=description, + collection_id=collection_id, + ) + await self.providers.database.graphs_handler.create( + collection_id=collection_id, + name=name, + description=description, + graph_id=collection_id, + ) + + except Exception as e: + logger.warning( + f"Warning, could not create collection with error: {str(e)}" + ) + + await service.providers.database.collections_handler.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=KGEnrichmentStatus.OUTDATED, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + status=KGEnrichmentStatus.OUTDATED, + ) + # get server chunk enrichment settings and override parts of it if provided in the ingestion config + server_chunk_enrichment_settings = getattr( + service.providers.ingestion.config, + "chunk_enrichment_settings", + None, + ) + + if server_chunk_enrichment_settings: + chunk_enrichment_settings = update_settings_from_dict( + server_chunk_enrichment_settings, + ingestion_config.get("chunk_enrichment_settings", {}) + or {}, + ) + + if chunk_enrichment_settings.enable_chunk_enrichment: + logger.info("Enriching document with contextual chunks") + + # TODO: the status updating doesn't work because document_info doesn't contain information about collection IDs + # we don't update the document_info when we assign document_to_collection_relational and document_to_collection_vector + # hack: get document_info again from DB + document_info = ( + await self.ingestion_service.providers.database.documents_handler.get_documents_overview( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. + offset=0, + limit=100, + filter_user_ids=[document_info.user_id], + filter_document_ids=[document_info.id], + ) + )["results"][0] + + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.ENRICHING, + ) + + await self.ingestion_service.chunk_enrichment( + document_id=document_info.id, + chunk_enrichment_settings=chunk_enrichment_settings, + ) + + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.ENRICHED, + ) + + return { + "status": "Successfully finalized ingestion", + "document_info": document_info.to_dict(), + } + + except AuthenticationError as e: + raise R2RException( + status_code=401, + message="Authentication error: Invalid API key or credentials.", + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error during ingestion: {str(e)}", + ) + + @orchestration_provider.failure() + async def on_failure(self, context: Context) -> None: + request = context.workflow_input().get("request", {}) + document_id = request.get("document_id") + + if not document_id: + logger.error( + "No document id was found in workflow input to mark a failure." + ) + return + + try: + documents_overview = ( + await self.ingestion_service.providers.database.documents_handler.get_documents_overview( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. + offset=0, + limit=100, + filter_document_ids=[document_id], + ) + )["results"] + + if not documents_overview: + logger.error( + f"Document with id {document_id} not found in database to mark failure." + ) + return + + document_info = documents_overview[0] + + # Update the document status to FAILED + if document_info.ingestion_status not in [ + IngestionStatus.SUCCESS, + IngestionStatus.ENRICHED, + ]: + await self.ingestion_service.update_document_status( + document_info, + status=IngestionStatus.FAILED, + ) + + except Exception as e: + logger.error( + f"Failed to update document status for {document_id}: {e}" + ) + + # TODO: Implement a check to see if the file is actually changed before updating + @orchestration_provider.workflow(name="update-files", timeout="60m") + class HatchetUpdateFilesWorkflow: + def __init__(self, ingestion_service: IngestionService): + self.ingestion_service = ingestion_service + + @orchestration_provider.step(retries=0, timeout="60m") + async def update_files(self, context: Context) -> None: + data = context.workflow_input()["request"] + parsed_data = IngestionServiceAdapter.parse_update_files_input( + data + ) + + file_datas = parsed_data["file_datas"] + user = parsed_data["user"] + document_ids = parsed_data["document_ids"] + metadatas = parsed_data["metadatas"] + ingestion_config = parsed_data["ingestion_config"] + file_sizes_in_bytes = parsed_data["file_sizes_in_bytes"] + + if not file_datas: + raise R2RException( + status_code=400, message="No files provided for update." + ) + if len(document_ids) != len(file_datas): + raise R2RException( + status_code=400, + message="Number of ids does not match number of files.", + ) + + documents_overview = ( + await self.ingestion_service.providers.database.documents_handler.get_documents_overview( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. + offset=0, + limit=100, + filter_document_ids=document_ids, + filter_user_ids=None if user.is_superuser else [user.id], + ) + )["results"] + + if len(documents_overview) != len(document_ids): + raise R2RException( + status_code=404, + message="One or more documents not found.", + ) + + results = [] + + for idx, ( + file_data, + doc_id, + doc_info, + file_size_in_bytes, + ) in enumerate( + zip( + file_datas, + document_ids, + documents_overview, + file_sizes_in_bytes, + ) + ): + new_version = increment_version(doc_info.version) + + updated_metadata = ( + metadatas[idx] if metadatas else doc_info.metadata + ) + updated_metadata["title"] = ( + updated_metadata.get("title") + or file_data["filename"].split("/")[-1] + ) + + # Prepare input for ingest_file workflow + ingest_input = { + "file_data": file_data, + "user": data.get("user"), + "metadata": updated_metadata, + "document_id": str(doc_id), + "version": new_version, + "ingestion_config": ( + ingestion_config.model_dump_json() + if ingestion_config + else None + ), + "size_in_bytes": file_size_in_bytes, + } + + # Spawn ingest_file workflow as a child workflow + child_result = ( + await context.aio.spawn_workflow( + "ingest-files", + {"request": ingest_input}, + key=f"ingest_file_{doc_id}", + ) + ).result() + results.append(child_result) + + await asyncio.gather(*results) + + return None + + @orchestration_provider.workflow( + name="ingest-chunks", + timeout="60m", + ) + class HatchetIngestChunksWorkflow: + def __init__(self, ingestion_service: IngestionService): + self.ingestion_service = ingestion_service + + @orchestration_provider.step(timeout="60m") + async def ingest(self, context: Context) -> dict: + input_data = context.workflow_input()["request"] + parsed_data = IngestionServiceAdapter.parse_ingest_chunks_input( + input_data + ) + + document_info = await self.ingestion_service.ingest_chunks_ingress( + **parsed_data + ) + + await self.ingestion_service.update_document_status( + document_info, status=IngestionStatus.EMBEDDING + ) + document_id = document_info.id + + extractions = [ + DocumentChunk( + id=generate_extraction_id(document_id, i), + document_id=document_id, + collection_ids=[], + owner_id=document_info.owner_id, + data=chunk.text, + metadata=parsed_data["metadata"], + ).to_dict() + for i, chunk in enumerate(parsed_data["chunks"]) + ] + return { + "status": "Successfully ingested chunks", + "extractions": extractions, + "document_info": document_info.to_dict(), + } + + @orchestration_provider.step(parents=["ingest"], timeout="60m") + async def embed(self, context: Context) -> dict: + document_info_dict = context.step_output("ingest")["document_info"] + document_info = DocumentResponse(**document_info_dict) + + extractions = context.step_output("ingest")["extractions"] + + embedding_generator = await self.ingestion_service.embed_document( + extractions + ) + embeddings = [ + embedding.model_dump() + async for embedding in embedding_generator + ] + + await self.ingestion_service.update_document_status( + document_info, status=IngestionStatus.STORING + ) + + storage_generator = await self.ingestion_service.store_embeddings( + embeddings + ) + async for _ in storage_generator: + pass + + return { + "status": "Successfully embedded and stored chunks", + "document_info": document_info.to_dict(), + } + + @orchestration_provider.step(parents=["embed"], timeout="60m") + async def finalize(self, context: Context) -> dict: + document_info_dict = context.step_output("embed")["document_info"] + document_info = DocumentResponse(**document_info_dict) + + await self.ingestion_service.finalize_ingestion(document_info) + + await self.ingestion_service.update_document_status( + document_info, status=IngestionStatus.SUCCESS + ) + + try: + # TODO - Move logic onto the `management service` + collection_ids = context.workflow_input()["request"].get( + "collection_ids" + ) + if not collection_ids: + # TODO: Move logic onto the `management service` + collection_id = generate_default_user_collection_id( + document_info.owner_id + ) + await service.providers.database.collections_handler.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=KGEnrichmentStatus.OUTDATED, + ) + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + status=KGEnrichmentStatus.OUTDATED, + ) + else: + for collection_id_str in collection_ids: + collection_id = UUID(collection_id_str) + try: + name = document_info.title or "N/A" + description = "" + await service.providers.database.collections_handler.create_collection( + owner_id=document_info.owner_id, + name=name, + description=description, + collection_id=collection_id, + ) + await self.providers.database.graphs_handler.create( + collection_id=collection_id, + name=name, + description=description, + graph_id=collection_id, + ) + + except Exception as e: + logger.warning( + f"Warning, could not create collection with error: {str(e)}" + ) + + await service.providers.database.collections_handler.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + + await service.providers.database.chunks_handler.assign_document_chunks_to_collection( + document_id=document_info.id, + collection_id=collection_id, + ) + + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_sync_status", + status=KGEnrichmentStatus.OUTDATED, + ) + + await service.providers.database.documents_handler.set_workflow_status( + id=collection_id, + status_type="graph_cluster_status", + status=KGEnrichmentStatus.OUTDATED, # NOTE - we should actually check that cluster has been made first, if not it should be PENDING still + ) + except Exception as e: + logger.error( + f"Error during assigning document to collection: {str(e)}" + ) + + return { + "status": "Successfully finalized ingestion", + "document_info": document_info.to_dict(), + } + + @orchestration_provider.failure() + async def on_failure(self, context: Context) -> None: + request = context.workflow_input().get("request", {}) + document_id = request.get("document_id") + + if not document_id: + logger.error( + "No document id was found in workflow input to mark a failure." + ) + return + + try: + documents_overview = ( + await self.ingestion_service.providers.database.documents_handler.get_documents_overview( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. + offset=0, + limit=100, + filter_document_ids=[document_id], + ) + )["results"] + + if not documents_overview: + logger.error( + f"Document with id {document_id} not found in database to mark failure." + ) + return + + document_info = documents_overview[0] + + if document_info.ingestion_status != IngestionStatus.SUCCESS: + await self.ingestion_service.update_document_status( + document_info, status=IngestionStatus.FAILED + ) + + except Exception as e: + logger.error( + f"Failed to update document status for {document_id}: {e}" + ) + + @orchestration_provider.workflow( + name="update-chunk", + timeout="60m", + ) + class HatchetUpdateChunkWorkflow: + def __init__(self, ingestion_service: IngestionService): + self.ingestion_service = ingestion_service + + @orchestration_provider.step(timeout="60m") + async def update_chunk(self, context: Context) -> dict: + try: + input_data = context.workflow_input()["request"] + parsed_data = IngestionServiceAdapter.parse_update_chunk_input( + input_data + ) + + document_uuid = ( + UUID(parsed_data["document_id"]) + if isinstance(parsed_data["document_id"], str) + else parsed_data["document_id"] + ) + extraction_uuid = ( + UUID(parsed_data["id"]) + if isinstance(parsed_data["id"], str) + else parsed_data["id"] + ) + + await self.ingestion_service.update_chunk_ingress( + document_id=document_uuid, + chunk_id=extraction_uuid, + text=parsed_data.get("text"), + user=parsed_data["user"], + metadata=parsed_data.get("metadata"), + collection_ids=parsed_data.get("collection_ids"), + ) + + return { + "message": "Chunk update completed successfully.", + "task_id": context.workflow_run_id(), + "document_ids": [str(document_uuid)], + } + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error during chunk update: {str(e)}", + ) + + @orchestration_provider.failure() + async def on_failure(self, context: Context) -> None: + # Handle failure case if necessary + pass + + @orchestration_provider.workflow( + name="create-vector-index", timeout="360m" + ) + class HatchetCreateVectorIndexWorkflow: + def __init__(self, ingestion_service: IngestionService): + self.ingestion_service = ingestion_service + + @orchestration_provider.step(timeout="60m") + async def create_vector_index(self, context: Context) -> dict: + input_data = context.workflow_input()["request"] + parsed_data = ( + IngestionServiceAdapter.parse_create_vector_index_input( + input_data + ) + ) + + await self.ingestion_service.providers.database.chunks_handler.create_index( + **parsed_data + ) + + return { + "status": "Vector index creation queued successfully.", + } + + @orchestration_provider.workflow(name="delete-vector-index", timeout="30m") + class HatchetDeleteVectorIndexWorkflow: + def __init__(self, ingestion_service: IngestionService): + self.ingestion_service = ingestion_service + + @orchestration_provider.step(timeout="10m") + async def delete_vector_index(self, context: Context) -> dict: + input_data = context.workflow_input()["request"] + parsed_data = ( + IngestionServiceAdapter.parse_delete_vector_index_input( + input_data + ) + ) + + await self.ingestion_service.providers.database.chunks_handler.delete_index( + **parsed_data + ) + + return {"status": "Vector index deleted successfully."} + + @orchestration_provider.workflow( + name="update-document-metadata", + timeout="30m", + ) + class HatchetUpdateDocumentMetadataWorkflow: + def __init__(self, ingestion_service: IngestionService): + self.ingestion_service = ingestion_service + + @orchestration_provider.step(timeout="30m") + async def update_document_metadata(self, context: Context) -> dict: + try: + input_data = context.workflow_input()["request"] + parsed_data = IngestionServiceAdapter.parse_update_document_metadata_input( + input_data + ) + + document_id = UUID(parsed_data["document_id"]) + metadata = parsed_data["metadata"] + user = parsed_data["user"] + + await self.ingestion_service.update_document_metadata( + document_id=document_id, + metadata=metadata, + user=user, + ) + + return { + "message": "Document metadata update completed successfully.", + "document_id": str(document_id), + "task_id": context.workflow_run_id(), + } + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error during document metadata update: {str(e)}", + ) + + @orchestration_provider.failure() + async def on_failure(self, context: Context) -> None: + # Handle failure case if necessary + pass + + # Add this to the workflows dictionary in hatchet_ingestion_factory + ingest_files_workflow = HatchetIngestFilesWorkflow(service) + update_files_workflow = HatchetUpdateFilesWorkflow(service) + ingest_chunks_workflow = HatchetIngestChunksWorkflow(service) + update_chunks_workflow = HatchetUpdateChunkWorkflow(service) + update_document_metadata_workflow = HatchetUpdateDocumentMetadataWorkflow( + service + ) + create_vector_index_workflow = HatchetCreateVectorIndexWorkflow(service) + delete_vector_index_workflow = HatchetDeleteVectorIndexWorkflow(service) + + return { + "ingest_files": ingest_files_workflow, + "update_files": update_files_workflow, + "ingest_chunks": ingest_chunks_workflow, + "update_chunk": update_chunks_workflow, + "update_document_metadata": update_document_metadata_workflow, + "create_vector_index": create_vector_index_workflow, + "delete_vector_index": delete_vector_index_workflow, + } From fb3b3dc4271bb3152642286774188c603f01cc9e Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Wed, 8 Jan 2025 11:08:34 -0800 Subject: [PATCH 08/10] config location --- py/core/main/config.py | 2 +- py/tests/unit/test_config.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/py/core/main/config.py b/py/core/main/config.py index 2c371d8d2..158387d12 100644 --- a/py/core/main/config.py +++ b/py/core/main/config.py @@ -26,7 +26,7 @@ class R2RConfig: current_file_path = os.path.dirname(__file__) config_dir_root = os.path.join(current_file_path, "..", "configs") default_config_path = os.path.join( - current_file_path, "..", "..", "r2r.toml" + current_file_path, "..", "..", "r2r", "r2r.toml" ) CONFIG_OPTIONS: dict[str, Optional[str]] = {} diff --git a/py/tests/unit/test_config.py b/py/tests/unit/test_config.py index f61a5f28a..5242057ee 100644 --- a/py/tests/unit/test_config.py +++ b/py/tests/unit/test_config.py @@ -11,7 +11,7 @@ @pytest.fixture def base_config(): """Load the base r2r.toml config""" - config_path = Path(__file__).parent.parent.parent / "r2r.toml" + config_path = Path(__file__).parent.parent.parent / "r2r/r2r.toml" with open(config_path) as f: return toml.load(f) From 01503c84304123c09d03d459f57322f15a44cb53 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Wed, 8 Jan 2025 11:13:23 -0800 Subject: [PATCH 09/10] Test location --- py/tests/unit/test_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/tests/unit/test_config.py b/py/tests/unit/test_config.py index 5242057ee..36455d05f 100644 --- a/py/tests/unit/test_config.py +++ b/py/tests/unit/test_config.py @@ -140,7 +140,7 @@ def get_config_files(): def test_config_required_keys(config_file): """Test that all required keys are present in all config files""" if config_file == "r2r.toml": - file_path = Path(__file__).parent.parent.parent / "r2r.toml" + file_path = Path(__file__).parent.parent.parent / "r2r/r2r.toml" else: file_path = ( Path(__file__).parent.parent.parent From 231fae5ebab08775e52168216d0e76708ab028d2 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Wed, 8 Jan 2025 11:21:42 -0800 Subject: [PATCH 10/10] Include files AND packages in wheel --- py/pyproject.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/py/pyproject.toml b/py/pyproject.toml index 315d4d963..5a0def39c 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -19,10 +19,10 @@ packages = [ { include = "cli", from = "." }, ] include = [ - "migrations/**/*", - "pyproject.toml", - "r2r/**/*.yaml", - "r2r/**/*.toml" + { path = "migrations/**/*", format = ["sdist", "wheel"] }, + { path = "pyproject.toml", format = ["sdist", "wheel"] }, + { path = "r2r/**/*.yaml", format = ["sdist", "wheel"] }, + { path = "r2r/**/*.toml", format = ["sdist", "wheel"] } ] [tool.poetry.dependencies]