Skip to content

Commit

Permalink
Cleanup factory methods (#1482)
Browse files Browse the repository at this point in the history
* cleanup factory methods to have similar design pattern across codebase

* add semversioner file

* cleanup logging factory

* update developer guide

* add comment

* typo fix

* cleanup reporter terminology

* renmae reporter to logger

* fix comments

* update comment

* instantiate factory classes correctly and update index api callback parameter

---------

Co-authored-by: Alonso Guevara <[email protected]>
  • Loading branch information
jgbradley1 and AlonsoGuevara authored Dec 10, 2024
1 parent 0440580 commit 8233421
Show file tree
Hide file tree
Showing 51 changed files with 1,249 additions and 1,152 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241209064913440349.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "cleanup and refactor factory classes."
}
48 changes: 40 additions & 8 deletions DEVELOPING.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,56 @@
# Getting Started

## Install Dependencies

```sh
# Install Python dependencies.
```shell
# install python dependencies
poetry install
```

## Executing the Indexing Engine

```sh
## Execute the indexing engine
```shell
poetry run poe index <...args>
```

## Executing Queries
## Execute prompt tuning
```shell
poetry run poe prompt_tune <...args>
```

```sh
## Execute Queries
```shell
poetry run poe query <...args>
```

## Repository Structure
An overview of the repository's top-level folder structure is provided below, detailing the overall design and purpose.
We leverage a factory design pattern where possible, enabling a variety of implementations for each core component of graphrag.

```shell
graphrag
├── api # library API definitions
├── cache # cache module supporting several options
│   └─ factory.py # └─ main entrypoint to create a cache
├── callbacks # a collection of commonly used callback functions
├── cli # library CLI
│   └─ main.py # └─ primary CLI entrypoint
├── config # configuration management
├── index # indexing engine
| └─ run/run.py # main entrypoint to build an index
├── llm # generic llm interfaces
├── logger # logger module supporting several options
│   └─ factory.py # └─ main entrypoint to create a logger
├── model # data model definitions associated with the knowledge graph
├── prompt_tune # prompt tuning module
├── prompts # a collection of all the system prompts used by graphrag
├── query # query engine
├── storage # storage module supporting several options
│   └─ factory.py # └─ main entrypoint to create/load a storage endpoint
├── utils # helper functions used throughout the library
└── vector_stores # vector store module containing a few options
└─ factory.py # └─ main entrypoint to create a vector store
```
Where appropriate, the factories expose a registration method for users to provide their own custom implementations if desired.

## Versioning

We use [semversioner](https://github.com/raulgomis/semversioner) to automate and enforce semantic versioning in the release process. Our CI/CD pipeline checks that all PR's include a json file generated by semversioner. When submitting a PR, please run:
Expand Down
2 changes: 1 addition & 1 deletion docs/config/env_vars.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ This section controls the storage mechanism used by the pipeline used for export

| Parameter | Description | Type | Required or Optional | Default |
| ------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- | -------------------- | ------- |
| `GRAPHRAG_STORAGE_TYPE` | The type of reporter to use. Options are `file`, `memory`, or `blob` | `str` | optional | `file` |
| `GRAPHRAG_STORAGE_TYPE` | The type of storage to use. Options are `file`, `memory`, or `blob` | `str` | optional | `file` |
| `GRAPHRAG_STORAGE_STORAGE_ACCOUNT_BLOB_URL` | The Azure Storage blob endpoint to use when in `blob` mode and using managed identity. Will have the format `https://<storage_account_name>.blob.core.windows.net` | `str` | optional | None |
| `GRAPHRAG_STORAGE_CONNECTION_STRING` | The Azure Storage connection string to use when in `blob` mode. | `str` | optional | None |
| `GRAPHRAG_STORAGE_CONTAINER_NAME` | The Azure Storage container name to use when in `blob` mode. | `str` | optional | None |
Expand Down
24 changes: 12 additions & 12 deletions graphrag/api/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from graphrag.index.create_pipeline_config import create_pipeline_config
from graphrag.index.run import run_pipeline_with_config
from graphrag.index.typing import PipelineRunResult
from graphrag.logging.base import ProgressReporter
from graphrag.logger.base import ProgressLogger


async def build_index(
Expand All @@ -26,7 +26,7 @@ async def build_index(
is_resume_run: bool = False,
memory_profile: bool = False,
callbacks: list[WorkflowCallbacks] | None = None,
progress_reporter: ProgressReporter | None = None,
progress_logger: ProgressLogger | None = None,
) -> list[PipelineRunResult]:
"""Run the pipeline with the given configuration.
Expand All @@ -42,8 +42,8 @@ async def build_index(
Whether to enable memory profiling.
callbacks : list[WorkflowCallbacks] | None default=None
A list of callbacks to register.
progress_reporter : ProgressReporter | None default=None
The progress reporter.
progress_logger : ProgressLogger | None default=None
The progress logger.
Returns
-------
Expand All @@ -60,26 +60,26 @@ async def build_index(
pipeline_cache = (
NoopPipelineCache() if config.cache.type == CacheType.none is None else None
)
# create a pipeline reporter and add to any additional callbacks
# TODO: remove the type ignore once the new config engine has been refactored
callbacks = (
[create_pipeline_reporter(config.reporting, None)] if config.reporting else None # type: ignore
) # type: ignore
callbacks = callbacks or []
callbacks.append(create_pipeline_reporter(config.reporting, None)) # type: ignore
outputs: list[PipelineRunResult] = []
async for output in run_pipeline_with_config(
pipeline_config,
run_id=run_id,
memory_profile=memory_profile,
cache=pipeline_cache,
callbacks=callbacks,
progress_reporter=progress_reporter,
logger=progress_logger,
is_resume_run=is_resume_run,
is_update_run=is_update_run,
):
outputs.append(output)
if progress_reporter:
if progress_logger:
if output.errors and len(output.errors) > 0:
progress_reporter.error(output.workflow)
progress_logger.error(output.workflow)
else:
progress_reporter.success(output.workflow)
progress_reporter.info(str(output.result))
progress_logger.success(output.workflow)
progress_logger.info(str(output.result))
return outputs
28 changes: 14 additions & 14 deletions graphrag/api/prompt_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.llm.load_llm import load_llm
from graphrag.logging.print_progress import PrintProgressReporter
from graphrag.logger.print_progress import PrintProgressLogger
from graphrag.prompt_tune.defaults import MAX_TOKEN_COUNT
from graphrag.prompt_tune.generator.community_report_rating import (
generate_community_report_rating,
Expand Down Expand Up @@ -80,15 +80,15 @@ async def generate_indexing_prompts(
-------
tuple[str, str, str]: entity extraction prompt, entity summarization prompt, community summarization prompt
"""
reporter = PrintProgressReporter("")
logger = PrintProgressLogger("")

# Retrieve documents
doc_list = await load_docs_in_chunks(
root=root,
config=config,
limit=limit,
select_method=selection_method,
reporter=reporter,
logger=logger,
chunk_size=chunk_size,
n_subset_max=n_subset_max,
k=k,
Expand All @@ -103,25 +103,25 @@ async def generate_indexing_prompts(
)

if not domain:
reporter.info("Generating domain...")
logger.info("Generating domain...")
domain = await generate_domain(llm, doc_list)
reporter.info(f"Generated domain: {domain}")
logger.info(f"Generated domain: {domain}") # noqa

if not language:
reporter.info("Detecting language...")
logger.info("Detecting language...")
language = await detect_language(llm, doc_list)

reporter.info("Generating persona...")
logger.info("Generating persona...")
persona = await generate_persona(llm, domain)

reporter.info("Generating community report ranking description...")
logger.info("Generating community report ranking description...")
community_report_ranking = await generate_community_report_rating(
llm, domain=domain, persona=persona, docs=doc_list
)

entity_types = None
if discover_entity_types:
reporter.info("Generating entity types...")
logger.info("Generating entity types...")
entity_types = await generate_entity_types(
llm,
domain=domain,
Expand All @@ -130,7 +130,7 @@ async def generate_indexing_prompts(
json_mode=config.llm.model_supports_json or False,
)

reporter.info("Generating entity relationship examples...")
logger.info("Generating entity relationship examples...")
examples = await generate_entity_relationship_examples(
llm,
persona=persona,
Expand All @@ -140,7 +140,7 @@ async def generate_indexing_prompts(
json_mode=False, # config.llm.model_supports_json should be used, but these prompts are used in non-json mode by the index engine
)

reporter.info("Generating entity extraction prompt...")
logger.info("Generating entity extraction prompt...")
entity_extraction_prompt = create_entity_extraction_prompt(
entity_types=entity_types,
docs=doc_list,
Expand All @@ -152,18 +152,18 @@ async def generate_indexing_prompts(
min_examples_required=min_examples_required,
)

reporter.info("Generating entity summarization prompt...")
logger.info("Generating entity summarization prompt...")
entity_summarization_prompt = create_entity_summarization_prompt(
persona=persona,
language=language,
)

reporter.info("Generating community reporter role...")
logger.info("Generating community reporter role...")
community_reporter_role = await generate_community_reporter_role(
llm, domain=domain, persona=persona, docs=doc_list
)

reporter.info("Generating community summarization prompt...")
logger.info("Generating community summarization prompt...")
community_summarization_prompt = create_community_summarization_prompt(
persona=persona,
role=community_reporter_role,
Expand Down
18 changes: 10 additions & 8 deletions graphrag/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from collections.abc import AsyncGenerator
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any

import pandas as pd
from pydantic import validate_call
Expand All @@ -29,7 +29,7 @@
community_full_content_embedding,
entity_description_embedding,
)
from graphrag.logging.print_progress import PrintProgressReporter
from graphrag.logger.print_progress import PrintProgressLogger
from graphrag.query.factory import (
get_drift_search_engine,
get_global_search_engine,
Expand All @@ -44,13 +44,15 @@
read_indexer_reports,
read_indexer_text_units,
)
from graphrag.query.structured_search.base import SearchResult # noqa: TC001
from graphrag.utils.cli import redact
from graphrag.utils.embeddings import create_collection_name
from graphrag.vector_stores.base import BaseVectorStore
from graphrag.vector_stores.factory import VectorStoreFactory

reporter = PrintProgressReporter("")
if TYPE_CHECKING:
from graphrag.query.structured_search.base import SearchResult

logger = PrintProgressLogger("")


@validate_call(config={"arbitrary_types_allowed": True})
Expand Down Expand Up @@ -241,7 +243,7 @@ async def local_search(
TODO: Document any exceptions to expect.
"""
vector_store_args = config.embeddings.vector_store
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa

description_embedding_store = _get_embedding_store(
config_args=vector_store_args, # type: ignore
Expand Down Expand Up @@ -307,7 +309,7 @@ async def local_search_streaming(
TODO: Document any exceptions to expect.
"""
vector_store_args = config.embeddings.vector_store
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa

description_embedding_store = _get_embedding_store(
config_args=vector_store_args, # type: ignore
Expand Down Expand Up @@ -380,7 +382,7 @@ async def drift_search(
TODO: Document any exceptions to expect.
"""
vector_store_args = config.embeddings.vector_store
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa

description_embedding_store = _get_embedding_store(
config_args=vector_store_args, # type: ignore
Expand Down Expand Up @@ -430,7 +432,7 @@ def _get_embedding_store(
collection_name = create_collection_name(
config_args.get("container_name", "default"), embedding_name
)
embedding_store = VectorStoreFactory.get_vector_store(
embedding_store = VectorStoreFactory().create_vector_store(
vector_store_type=vector_store_type,
kwargs={**config_args, "collection_name": collection_name},
)
Expand Down
66 changes: 35 additions & 31 deletions graphrag/cache/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,49 +5,53 @@

from __future__ import annotations

from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, ClassVar

from graphrag.config.enums import CacheType
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
from graphrag.storage.file_pipeline_storage import FilePipelineStorage

if TYPE_CHECKING:
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.index.config.cache import (
PipelineBlobCacheConfig,
PipelineCacheConfig,
PipelineFileCacheConfig,
)

from graphrag.cache.json_pipeline_cache import JsonPipelineCache
from graphrag.cache.memory_pipeline_cache import InMemoryCache
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache


def create_cache(
config: PipelineCacheConfig | None, root_dir: str | None
) -> PipelineCache:
"""Create a cache from the given config."""
if config is None:
return NoopPipelineCache()
class CacheFactory:
"""A factory class for cache implementations.
match config.type:
case CacheType.none:
Includes a method for users to register a custom cache implementation.
"""

cache_types: ClassVar[dict[str, type]] = {}

@classmethod
def register(cls, cache_type: str, cache: type):
"""Register a custom cache implementation."""
cls.cache_types[cache_type] = cache

@classmethod
def create_cache(
cls, cache_type: CacheType | str | None, root_dir: str, kwargs: dict
) -> PipelineCache:
"""Create or get a cache from the provided type."""
if not cache_type:
return NoopPipelineCache()
case CacheType.memory:
return InMemoryCache()
case CacheType.file:
config = cast("PipelineFileCacheConfig", config)
storage = FilePipelineStorage(root_dir).child(config.base_dir)
return JsonPipelineCache(storage)
case CacheType.blob:
config = cast("PipelineBlobCacheConfig", config)
storage = BlobPipelineStorage(
config.connection_string,
config.container_name,
storage_account_blob_url=config.storage_account_blob_url,
).child(config.base_dir)
return JsonPipelineCache(storage)
case _:
msg = f"Unknown cache type: {config.type}"
raise ValueError(msg)
match cache_type:
case CacheType.none:
return NoopPipelineCache()
case CacheType.memory:
return InMemoryCache()
case CacheType.file:
return JsonPipelineCache(
FilePipelineStorage(root_dir=root_dir).child(kwargs["base_dir"])
)
case CacheType.blob:
return JsonPipelineCache(BlobPipelineStorage(**kwargs))
case _:
if cache_type in cls.cache_types:
return cls.cache_types[cache_type](**kwargs)
msg = f"Unknown cache type: {cache_type}"
raise ValueError(msg)
Loading

0 comments on commit 8233421

Please sign in to comment.