Skip to content

Commit

Permalink
Implement query api (#839)
Browse files Browse the repository at this point in the history
* initial API redesign

* typo fix

* update docstring

* update docsring

* remove artifacts caused by the merge from main

* minor typo updates

* add semversioner check

* switch API to async function calls

---------

Co-authored-by: Alonso Guevara <[email protected]>
  • Loading branch information
jgbradley1 and AlonsoGuevara authored Aug 12, 2024
1 parent 7fd23fa commit 4bcbfd1
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 115 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/minor-20240806062641863317.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Implement query engine API."
}
11 changes: 7 additions & 4 deletions graphrag/query/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ def __str__(self):


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
prog="python -m graphrag.query",
description="The graphrag query engine",
)

parser.add_argument(
"--config",
Expand All @@ -49,22 +52,22 @@ def __str__(self):

parser.add_argument(
"--method",
help="The method to run, one of: local or global",
help="The method to run",
required=True,
type=SearchType,
choices=list(SearchType),
)

parser.add_argument(
"--community_level",
help="Community level in the Leiden community hierarchy from which we will load the community reports higher value means we use reports on smaller communities",
help="Community level in the Leiden community hierarchy from which we will load the community reports higher value means we use reports on smaller communities. Default: 2",
type=int,
default=2,
)

parser.add_argument(
"--response_type",
help="Free form text describing the response type and format, can be anything, e.g. Multiple Paragraphs, Single Paragraph, Single Sentence, List of 3-7 Points, Single Page, Multi-Page Report",
help="Free form text describing the response type and format, can be anything, e.g. Multiple Paragraphs, Single Paragraph, Single Sentence, List of 3-7 Points, Single Page, Multi-Page Report. Default: Multiple Paragraphs",
type=str,
default="Multiple Paragraphs",
)
Expand Down
192 changes: 192 additions & 0 deletions graphrag/query/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""
Query Engine API.
This API provides access to the query engine of graphrag, allowing external applications
to hook into graphrag and run queries over a knowledge graph generated by graphrag.
WARNING: This API is under development and may undergo changes in future releases.
Backwards compatibility is not guaranteed at this time.
"""

from typing import Any

import pandas as pd
from pydantic import validate_call

from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.progress.types import PrintProgressReporter
from graphrag.model.entity import Entity
from graphrag.vector_stores.lancedb import LanceDBVectorStore
from graphrag.vector_stores.typing import VectorStoreFactory, VectorStoreType

from .factories import get_global_search_engine, get_local_search_engine
from .indexer_adapters import (
read_indexer_covariates,
read_indexer_entities,
read_indexer_relationships,
read_indexer_reports,
read_indexer_text_units,
)
from .input.loaders.dfs import store_entity_semantic_embeddings

reporter = PrintProgressReporter("")


def __get_embedding_description_store(
entities: list[Entity],
vector_store_type: str = VectorStoreType.LanceDB,
config_args: dict | None = None,
):
"""Get the embedding description store."""
if not config_args:
config_args = {}

collection_name = config_args.get(
"query_collection_name", "entity_description_embeddings"
)
config_args.update({"collection_name": collection_name})
description_embedding_store = VectorStoreFactory.get_vector_store(
vector_store_type=vector_store_type, kwargs=config_args
)

description_embedding_store.connect(**config_args)

if config_args.get("overwrite", True):
# this step assumes the embeddings were originally stored in a file rather
# than a vector database

# dump embeddings from the entities list to the description_embedding_store
store_entity_semantic_embeddings(
entities=entities, vectorstore=description_embedding_store
)
else:
# load description embeddings to an in-memory lancedb vectorstore
# and connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
collection_name=collection_name
)
description_embedding_store.connect(
db_uri=config_args.get("db_uri", "./lancedb")
)

# load data from an existing table
description_embedding_store.document_collection = (
description_embedding_store.db_connection.open_table(
description_embedding_store.collection_name
)
)

return description_embedding_store


@validate_call(config={"arbitrary_types_allowed": True})
async def global_search(
config: GraphRagConfig,
nodes: pd.DataFrame,
entities: pd.DataFrame,
community_reports: pd.DataFrame,
community_level: int,
response_type: str,
query: str,
) -> str | dict[str, Any] | list[dict[str, Any]]:
"""Perform a global search.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
- community_level (int): The community level to search at.
- response_type (str): The type of response to return.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
reports = read_indexer_reports(community_reports, nodes, community_level)
_entities = read_indexer_entities(nodes, entities, community_level)
search_engine = get_global_search_engine(
config,
reports=reports,
entities=_entities,
response_type=response_type,
)
result = await search_engine.asearch(query=query)
reporter.success(f"Global Search Response: {result.response}")
return result.response


@validate_call(config={"arbitrary_types_allowed": True})
async def local_search(
config: GraphRagConfig,
nodes: pd.DataFrame,
entities: pd.DataFrame,
community_reports: pd.DataFrame,
text_units: pd.DataFrame,
relationships: pd.DataFrame,
covariates: pd.DataFrame | None,
community_level: int,
response_type: str,
query: str,
) -> str | dict[str, Any] | list[dict[str, Any]]:
"""Perform a local search.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
- text_units (pd.DataFrame): A DataFrame containing the final text units (from create_final_text_units.parquet)
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from create_final_relationships.parquet)
- covariates (pd.DataFrame): A DataFrame containing the final covariates (from create_final_covariates.parquet)
- community_level (int): The community level to search at.
- response_type (str): The response type to return.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
vector_store_args = (
config.embeddings.vector_store if config.embeddings.vector_store else {}
)

reporter.info(f"Vector Store Args: {vector_store_args}")
vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)

_entities = read_indexer_entities(nodes, entities, community_level)
description_embedding_store = __get_embedding_description_store(
entities=_entities,
vector_store_type=vector_store_type,
config_args=vector_store_args,
)
_covariates = read_indexer_covariates(covariates) if covariates is not None else []

search_engine = get_local_search_engine(
config=config,
reports=read_indexer_reports(community_reports, nodes, community_level),
text_units=read_indexer_text_units(text_units),
entities=_entities,
relationships=read_indexer_relationships(relationships),
covariates={"claims": _covariates},
description_embedding_store=description_embedding_store,
response_type=response_type,
)

result = await search_engine.asearch(query=query)
reporter.success(f"Local Search Response: {result.response}")
return result.response
Loading

0 comments on commit 4bcbfd1

Please sign in to comment.