-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial API redesign * typo fix * update docstring * update docsring * remove artifacts caused by the merge from main * minor typo updates * add semversioner check * switch API to async function calls --------- Co-authored-by: Alonso Guevara <[email protected]>
- Loading branch information
1 parent
7fd23fa
commit 4bcbfd1
Showing
4 changed files
with
237 additions
and
115 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
{ | ||
"type": "minor", | ||
"description": "Implement query engine API." | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
# Copyright (c) 2024 Microsoft Corporation. | ||
# Licensed under the MIT License | ||
|
||
""" | ||
Query Engine API. | ||
This API provides access to the query engine of graphrag, allowing external applications | ||
to hook into graphrag and run queries over a knowledge graph generated by graphrag. | ||
WARNING: This API is under development and may undergo changes in future releases. | ||
Backwards compatibility is not guaranteed at this time. | ||
""" | ||
|
||
from typing import Any | ||
|
||
import pandas as pd | ||
from pydantic import validate_call | ||
|
||
from graphrag.config.models.graph_rag_config import GraphRagConfig | ||
from graphrag.index.progress.types import PrintProgressReporter | ||
from graphrag.model.entity import Entity | ||
from graphrag.vector_stores.lancedb import LanceDBVectorStore | ||
from graphrag.vector_stores.typing import VectorStoreFactory, VectorStoreType | ||
|
||
from .factories import get_global_search_engine, get_local_search_engine | ||
from .indexer_adapters import ( | ||
read_indexer_covariates, | ||
read_indexer_entities, | ||
read_indexer_relationships, | ||
read_indexer_reports, | ||
read_indexer_text_units, | ||
) | ||
from .input.loaders.dfs import store_entity_semantic_embeddings | ||
|
||
reporter = PrintProgressReporter("") | ||
|
||
|
||
def __get_embedding_description_store( | ||
entities: list[Entity], | ||
vector_store_type: str = VectorStoreType.LanceDB, | ||
config_args: dict | None = None, | ||
): | ||
"""Get the embedding description store.""" | ||
if not config_args: | ||
config_args = {} | ||
|
||
collection_name = config_args.get( | ||
"query_collection_name", "entity_description_embeddings" | ||
) | ||
config_args.update({"collection_name": collection_name}) | ||
description_embedding_store = VectorStoreFactory.get_vector_store( | ||
vector_store_type=vector_store_type, kwargs=config_args | ||
) | ||
|
||
description_embedding_store.connect(**config_args) | ||
|
||
if config_args.get("overwrite", True): | ||
# this step assumes the embeddings were originally stored in a file rather | ||
# than a vector database | ||
|
||
# dump embeddings from the entities list to the description_embedding_store | ||
store_entity_semantic_embeddings( | ||
entities=entities, vectorstore=description_embedding_store | ||
) | ||
else: | ||
# load description embeddings to an in-memory lancedb vectorstore | ||
# and connect to a remote db, specify url and port values. | ||
description_embedding_store = LanceDBVectorStore( | ||
collection_name=collection_name | ||
) | ||
description_embedding_store.connect( | ||
db_uri=config_args.get("db_uri", "./lancedb") | ||
) | ||
|
||
# load data from an existing table | ||
description_embedding_store.document_collection = ( | ||
description_embedding_store.db_connection.open_table( | ||
description_embedding_store.collection_name | ||
) | ||
) | ||
|
||
return description_embedding_store | ||
|
||
|
||
@validate_call(config={"arbitrary_types_allowed": True}) | ||
async def global_search( | ||
config: GraphRagConfig, | ||
nodes: pd.DataFrame, | ||
entities: pd.DataFrame, | ||
community_reports: pd.DataFrame, | ||
community_level: int, | ||
response_type: str, | ||
query: str, | ||
) -> str | dict[str, Any] | list[dict[str, Any]]: | ||
"""Perform a global search. | ||
Parameters | ||
---------- | ||
- config (GraphRagConfig): A graphrag configuration (from settings.yaml) | ||
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet) | ||
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet) | ||
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet) | ||
- community_level (int): The community level to search at. | ||
- response_type (str): The type of response to return. | ||
- query (str): The user query to search for. | ||
Returns | ||
------- | ||
TODO: Document the search response type and format. | ||
Raises | ||
------ | ||
TODO: Document any exceptions to expect. | ||
""" | ||
reports = read_indexer_reports(community_reports, nodes, community_level) | ||
_entities = read_indexer_entities(nodes, entities, community_level) | ||
search_engine = get_global_search_engine( | ||
config, | ||
reports=reports, | ||
entities=_entities, | ||
response_type=response_type, | ||
) | ||
result = await search_engine.asearch(query=query) | ||
reporter.success(f"Global Search Response: {result.response}") | ||
return result.response | ||
|
||
|
||
@validate_call(config={"arbitrary_types_allowed": True}) | ||
async def local_search( | ||
config: GraphRagConfig, | ||
nodes: pd.DataFrame, | ||
entities: pd.DataFrame, | ||
community_reports: pd.DataFrame, | ||
text_units: pd.DataFrame, | ||
relationships: pd.DataFrame, | ||
covariates: pd.DataFrame | None, | ||
community_level: int, | ||
response_type: str, | ||
query: str, | ||
) -> str | dict[str, Any] | list[dict[str, Any]]: | ||
"""Perform a local search. | ||
Parameters | ||
---------- | ||
- config (GraphRagConfig): A graphrag configuration (from settings.yaml) | ||
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet) | ||
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet) | ||
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet) | ||
- text_units (pd.DataFrame): A DataFrame containing the final text units (from create_final_text_units.parquet) | ||
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from create_final_relationships.parquet) | ||
- covariates (pd.DataFrame): A DataFrame containing the final covariates (from create_final_covariates.parquet) | ||
- community_level (int): The community level to search at. | ||
- response_type (str): The response type to return. | ||
- query (str): The user query to search for. | ||
Returns | ||
------- | ||
TODO: Document the search response type and format. | ||
Raises | ||
------ | ||
TODO: Document any exceptions to expect. | ||
""" | ||
vector_store_args = ( | ||
config.embeddings.vector_store if config.embeddings.vector_store else {} | ||
) | ||
|
||
reporter.info(f"Vector Store Args: {vector_store_args}") | ||
vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB) | ||
|
||
_entities = read_indexer_entities(nodes, entities, community_level) | ||
description_embedding_store = __get_embedding_description_store( | ||
entities=_entities, | ||
vector_store_type=vector_store_type, | ||
config_args=vector_store_args, | ||
) | ||
_covariates = read_indexer_covariates(covariates) if covariates is not None else [] | ||
|
||
search_engine = get_local_search_engine( | ||
config=config, | ||
reports=read_indexer_reports(community_reports, nodes, community_level), | ||
text_units=read_indexer_text_units(text_units), | ||
entities=_entities, | ||
relationships=read_indexer_relationships(relationships), | ||
covariates={"claims": _covariates}, | ||
description_embedding_store=description_embedding_store, | ||
response_type=response_type, | ||
) | ||
|
||
result = await search_engine.asearch(query=query) | ||
reporter.success(f"Local Search Response: {result.response}") | ||
return result.response |
Oops, something went wrong.