Skip to content

Commit

Permalink
Reorganize python package structure (#1214)
Browse files Browse the repository at this point in the history
  • Loading branch information
jgbradley1 authored Oct 10, 2024
1 parent ce8749b commit d9a005c
Show file tree
Hide file tree
Showing 48 changed files with 370 additions and 305 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20240926032712236048.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Reorganized api,reporter,callback code into separate components. Defined debug profiles."
}
39 changes: 33 additions & 6 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -1,12 +1,39 @@
{
"_comment": "Use this file to configure the graphrag project for debugging. You may create other configuration profiles based on these or select one below to use.",
"version": "0.2.0",
"configurations": [
{
"name": "Attach to Node Functions",
"type": "node",
"request": "attach",
"port": 9229,
"preLaunchTask": "func: host start"
"name": "Indexer",
"type": "debugpy",
"request": "launch",
"module": "poetry",
"args": [
"poe", "index",
"--root", "<path_to_ragtest_root_demo>"
],
},
{
"name": "Query",
"type": "debugpy",
"request": "launch",
"module": "poetry",
"args": [
"poe", "query",
"--root", "<path_to_ragtest_root_demo>",
"--method", "global",
"What are the top themes in this story",
]
},
{
"name": "Prompt Tuning",
"type": "debugpy",
"request": "launch",
"module": "poetry",
"args": [
"poe", "prompt_tune",
"--config",
"<path_to_ragtest_root_demo>/settings.yaml",
]
}
]
}
}
1 change: 0 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
],
"python.defaultInterpreterPath": "python/services/.venv/bin/python",
"python.languageServer": "Pylance",
"python.analysis.typeCheckingMode": "basic",
"cSpell.customDictionaries": {
"project-words": {
"name": "project-words",
Expand Down
30 changes: 30 additions & 0 deletions graphrag/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""API for GraphRAG.
WARNING: This API is under development and may undergo changes in future releases.
Backwards compatibility is not guaranteed at this time.
"""

from .index_api import build_index
from .prompt_tune_api import DocSelectionType, generate_indexing_prompts
from .query_api import (
global_search,
global_search_streaming,
local_search,
local_search_streaming,
)

__all__ = [ # noqa: RUF022
# index API
"build_index",
# query API
"global_search",
"global_search_streaming",
"local_search",
"local_search_streaming",
# prompt tuning API
"DocSelectionType",
"generate_indexing_prompts",
]
15 changes: 6 additions & 9 deletions graphrag/index/api.py → graphrag/api/index_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,12 @@
"""

from graphrag.config import CacheType, GraphRagConfig

from .cache.noop_pipeline_cache import NoopPipelineCache
from .create_pipeline_config import create_pipeline_config
from .emit.types import TableEmitterType
from .progress import (
ProgressReporter,
)
from .run import run_pipeline_with_config
from .typing import PipelineRunResult
from graphrag.index.cache.noop_pipeline_cache import NoopPipelineCache
from graphrag.index.create_pipeline_config import create_pipeline_config
from graphrag.index.emit.types import TableEmitterType
from graphrag.index.run import run_pipeline_with_config
from graphrag.index.typing import PipelineRunResult
from graphrag.logging import ProgressReporter


async def build_index(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@

from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.llm import load_llm
from graphrag.index.progress import PrintProgressReporter

from .generator import (
from graphrag.logging import PrintProgressReporter
from graphrag.prompt_tune.generator import (
MAX_TOKEN_COUNT,
create_community_summarization_prompt,
create_entity_extraction_prompt,
Expand All @@ -31,11 +30,11 @@
generate_entity_types,
generate_persona,
)
from .loader import (
from graphrag.prompt_tune.loader import (
MIN_CHUNK_SIZE,
load_docs_in_chunks,
)
from .types import DocSelectionType
from graphrag.prompt_tune.types import DocSelectionType


@validate_call
Expand Down
15 changes: 7 additions & 8 deletions graphrag/query/api.py → graphrag/api/query_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,20 @@
from pydantic import validate_call

from graphrag.config import GraphRagConfig
from graphrag.index.progress.types import PrintProgressReporter
from graphrag.logging import PrintProgressReporter
from graphrag.model.entity import Entity
from graphrag.query.structured_search.base import SearchResult # noqa: TCH001
from graphrag.vector_stores.lancedb import LanceDBVectorStore
from graphrag.vector_stores.typing import VectorStoreFactory, VectorStoreType

from .factories import get_global_search_engine, get_local_search_engine
from .indexer_adapters import (
from graphrag.query.factories import get_global_search_engine, get_local_search_engine
from graphrag.query.indexer_adapters import (
read_indexer_covariates,
read_indexer_entities,
read_indexer_relationships,
read_indexer_reports,
read_indexer_text_units,
)
from .input.loaders.dfs import store_entity_semantic_embeddings
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
from graphrag.query.structured_search.base import SearchResult # noqa: TCH001
from graphrag.vector_stores.lancedb import LanceDBVectorStore
from graphrag.vector_stores.typing import VectorStoreFactory, VectorStoreType

reporter = PrintProgressReporter("")

Expand Down
4 changes: 4 additions & 0 deletions graphrag/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""A module containing callback implementations."""
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""A reporter that writes to a blob storage."""
"""A logger that emits updates from the indexing engine to a blob in Azure Storage."""

import json
from datetime import datetime, timezone
Expand All @@ -14,7 +14,7 @@


class BlobWorkflowCallbacks(NoopWorkflowCallbacks):
"""A reporter that writes to a blob storage."""
"""A logger that writes to a blob storage account."""

_blob_service_client: BlobServiceClient
_container_name: str
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Console-based reporter for the workflow engine."""
"""A logger that emits updates from the indexing engine to the console."""

from datashaper import NoopWorkflowCallbacks


class ConsoleWorkflowCallbacks(NoopWorkflowCallbacks):
"""A reporter that writes to a console."""
"""A logger that writes to a console."""

def on_error(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Load pipeline reporter method."""
"""Create a pipeline reporter."""

from pathlib import Path
from typing import cast
Expand All @@ -20,7 +20,7 @@
from .file_workflow_callbacks import FileWorkflowCallbacks


def load_pipeline_reporter(
def create_pipeline_reporter(
config: PipelineReportingConfig | None, root_dir: str | None
) -> WorkflowCallbacks:
"""Create a reporter for the given pipeline config."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""A reporter that writes to a file."""
"""A logger that emits updates from the indexing engine to a local file."""

import json
import logging
Expand All @@ -14,12 +14,12 @@


class FileWorkflowCallbacks(NoopWorkflowCallbacks):
"""A reporter that writes to a file."""
"""A logger that writes to a local file."""

_out_stream: TextIOWrapper

def __init__(self, directory: str):
"""Create a new file-based workflow reporter."""
"""Create a new file-based workflow logger."""
Path(directory).mkdir(parents=True, exist_ok=True)
self._out_stream = open( # noqa: PTH123, SIM115
Path(directory) / "logs.json", "a", encoding="utf-8", errors="strict"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

"""GlobalSearch LLM Callbacks."""

from graphrag.query.llm.base import BaseLLMCallback
from graphrag.query.structured_search.base import SearchResult

from .llm_callbacks import BaseLLMCallback


class GlobalSearchLLMCallback(BaseLLMCallback):
"""GlobalSearch LLM Callbacks."""
Expand Down
15 changes: 15 additions & 0 deletions graphrag/callbacks/llm_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""LLM Callbacks."""


class BaseLLMCallback:
"""Base class for LLM callbacks."""

def __init__(self):
self.response = []

def on_llm_new_token(self, token: str):
"""Handle when a new token is generated."""
self.response.append(token)
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""A workflow callback manager that emits updates to a ProgressReporter."""
"""A workflow callback manager that emits updates."""

from typing import Any

from datashaper import ExecutionNode, NoopWorkflowCallbacks, Progress, TableContainer

from graphrag.index.progress import ProgressReporter
from graphrag.logging import ProgressReporter


class ProgressWorkflowCallbacks(NoopWorkflowCallbacks):
Expand Down
2 changes: 1 addition & 1 deletion graphrag/index/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

import argparse

from graphrag.logging import ReporterType
from graphrag.utils.cli import dir_exist, file_exist

from .cli import index_cli
from .emit.types import TableEmitterType
from .progress.types import ReporterType

if __name__ == "__main__":
parser = argparse.ArgumentParser(
Expand Down
9 changes: 4 additions & 5 deletions graphrag/index/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,21 @@
import warnings
from pathlib import Path

import graphrag.api as api
from graphrag.config import (
CacheType,
enable_logging_with_config,
load_config,
resolve_paths,
)
from graphrag.logging import ProgressReporter, ReporterType, create_progress_reporter

from .api import build_index
from .emit.types import TableEmitterType
from .graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT
from .graph.extractors.community_reports.prompts import COMMUNITY_REPORT_PROMPT
from .graph.extractors.graph.prompts import GRAPH_EXTRACTION_PROMPT
from .graph.extractors.summarize.prompts import SUMMARIZE_PROMPT
from .init_content import INIT_DOTENV, INIT_YAML
from .progress import ProgressReporter, ReporterType
from .progress.load_progress_reporter import load_progress_reporter
from .validate_config import validate_config_names

# Ignore warnings from numba
Expand Down Expand Up @@ -118,7 +117,7 @@ def index_cli(
output_dir: str | None,
):
"""Run the pipeline with the given config."""
progress_reporter = load_progress_reporter(reporter)
progress_reporter = create_progress_reporter(reporter)
info, error, success = _logger(progress_reporter)
run_id = resume or update_index_id or time.strftime("%Y%m%d-%H%M%S")

Expand Down Expand Up @@ -161,7 +160,7 @@ def index_cli(
_register_signal_handlers(progress_reporter)

outputs = asyncio.run(
build_index(
api.build_index(
config=config,
run_id=run_id,
is_resume_run=bool(resume),
Expand Down
2 changes: 1 addition & 1 deletion graphrag/index/input/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
import pandas as pd

from graphrag.index.config import PipelineCSVInputConfig, PipelineInputConfig
from graphrag.index.progress import ProgressReporter
from graphrag.index.storage import PipelineStorage
from graphrag.index.utils import gen_md5_hash
from graphrag.logging import ProgressReporter

log = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion graphrag/index/input/load_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

from graphrag.config import InputConfig, InputType
from graphrag.index.config import PipelineInputConfig
from graphrag.index.progress import NullProgressReporter, ProgressReporter
from graphrag.index.storage import (
BlobPipelineStorage,
FilePipelineStorage,
)
from graphrag.logging import NullProgressReporter, ProgressReporter

from .csv import input_type as csv
from .csv import load as load_csv
Expand Down
2 changes: 1 addition & 1 deletion graphrag/index/input/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
import pandas as pd

from graphrag.index.config import PipelineInputConfig
from graphrag.index.progress import ProgressReporter
from graphrag.index.storage import PipelineStorage
from graphrag.index.utils import gen_md5_hash
from graphrag.logging import ProgressReporter

DEFAULT_FILE_PATTERN = re.compile(
r".*[\\/](?P<source>[^\\/]+)[\\/](?P<year>\d{4})-(?P<month>\d{2})-(?P<day>\d{2})_(?P<author>[^_]+)_\d+\.txt"
Expand Down
Loading

0 comments on commit d9a005c

Please sign in to comment.