Skip to content

Commit

Permalink
✨ feat(GraphRAG): enhance GraphRAG by graph community summary (#1801)
Browse files Browse the repository at this point in the history
Co-authored-by: Florian <[email protected]>
Co-authored-by: KingSkyLi <[email protected]>
Co-authored-by: aries_ckt <[email protected]>
Co-authored-by: Fangyin Cheng <[email protected]>
Co-authored-by: yvonneyx <[email protected]>
  • Loading branch information
6 people authored and csunny committed Sep 2, 2024
1 parent 4af8423 commit a90f81e
Show file tree
Hide file tree
Showing 59 changed files with 29,315 additions and 410 deletions.
9 changes: 7 additions & 2 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ EMBEDDING_MODEL=text2vec
#EMBEDDING_MODEL=bge-large-zh
KNOWLEDGE_CHUNK_SIZE=500
KNOWLEDGE_SEARCH_TOP_SIZE=5
KNOWLEDGE_GRAPH_SEARCH_TOP_SIZE=50
KNOWLEDGE_GRAPH_SEARCH_TOP_SIZE=200
## Maximum number of chunks to load at once, if your single document is too large,
## you can set this value to a higher value for better performance.
## if out of memory when load large document, you can set this value to a lower value.
Expand Down Expand Up @@ -157,6 +157,11 @@ EXECUTE_LOCAL_COMMANDS=False
#*******************************************************************#
VECTOR_STORE_TYPE=Chroma
GRAPH_STORE_TYPE=TuGraph
GRAPH_COMMUNITY_SUMMARY_ENABLED=True
KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE=5
KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE=0.3
KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE=20
KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_RECALL_SCORE=0.0

### Chroma vector db config
#CHROMA_PERSIST_PATH=/root/DB-GPT/pilot/data
Expand Down Expand Up @@ -187,7 +192,7 @@ ElasticSearch_PASSWORD={your_password}
#TUGRAPH_PASSWORD=73@TuGraph
#TUGRAPH_VERTEX_TYPE=entity
#TUGRAPH_EDGE_TYPE=relation
#TUGRAPH_EDGE_NAME_KEY=label
#TUGRAPH_PLUGIN_NAMES=leiden

#*******************************************************************#
#** WebServer Language Support **#
Expand Down
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ __pycache__/
*$py.class

# C extensions
*.so
message/
dbgpt/util/extensions/
.env*
Expand Down Expand Up @@ -185,4 +184,4 @@ thirdparty
/examples/**/*.gv
/examples/**/*.gv.pdf
/i18n/locales/**/**/*_ai_translated.po
/i18n/locales/**/**/*~
/i18n/locales/**/**/*~
3 changes: 3 additions & 0 deletions dbgpt/_private/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ def __init__(self) -> None:

# Vector Store Configuration
self.VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "Chroma")
self.GRAPH_COMMUNITY_SUMMARY_ENABLED = (
os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true"
)
self.MILVUS_URL = os.getenv("MILVUS_URL", "127.0.0.1")
self.MILVUS_PORT = os.getenv("MILVUS_PORT", "19530")
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
Expand Down
6 changes: 4 additions & 2 deletions dbgpt/app/knowledge/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,15 @@ def arguments(space_id: str):


@router.post("/knowledge/{space_name}/recall_test")
def recall_test(
async def recall_test(
space_name: str,
request: DocumentRecallTestRequest,
):
print(f"/knowledge/{space_name}/recall_test params:")
try:
return Result.succ(knowledge_space_service.recall_test(space_name, request))
return Result.succ(
await knowledge_space_service.recall_test(space_name, request)
)
except Exception as e:
return Result.failed(code="E000X", msg=f"{space_name} recall_test error {e}")

Expand Down
19 changes: 11 additions & 8 deletions dbgpt/app/knowledge/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def get_knowledge_space_by_ids(self, ids):
"""
return knowledge_space_dao.get_knowledge_space_by_ids(ids)

def recall_test(
async def recall_test(
self, space_name, doc_recall_test_request: DocumentRecallTestRequest
):
logger.info(f"recall_test {space_name}, {doc_recall_test_request}")
Expand Down Expand Up @@ -338,7 +338,7 @@ def recall_test(
knowledge_space_retriever = KnowledgeSpaceRetriever(
space_id=space.id, top_k=top_k
)
chunks = knowledge_space_retriever.retrieve_with_scores(
chunks = await knowledge_space_retriever.aretrieve_with_scores(
question, score_threshold
)
retrievers_end_time = timeit.default_timer()
Expand Down Expand Up @@ -646,13 +646,16 @@ def query_graph(self, space_name, limit):
graph = vector_store_connector.client.query_graph(limit=limit)
res = {"nodes": [], "edges": []}
for node in graph.vertices():
res["nodes"].append({"vid": node.vid})
for edge in graph.edges():
res["edges"].append(
res["nodes"].append(
{
"src": edge.sid,
"dst": edge.tid,
"label": edge.props[graph.edge_label],
"id": node.vid,
"communityId": node.get_prop("_community_id"),
"name": node.vid,
"type": "",
}
)
for edge in graph.edges():
res["edges"].append(
{"source": edge.sid, "target": edge.tid, "name": edge.name, "type": ""}
)
return res
28 changes: 21 additions & 7 deletions dbgpt/datasource/conn_tugraph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""TuGraph Connector."""

import json
from typing import Dict, List, cast
from typing import Dict, Generator, List, cast

from .base import BaseConnector

Expand All @@ -23,11 +23,16 @@ def __init__(self, driver, graph):
def create_graph(self, graph_name: str) -> None:
"""Create a new graph."""
# run the query to get vertex labels
with self._driver.session(database="default") as session:
graph_list = session.run("CALL dbms.graph.listGraphs()").data()
exists = any(item["graph_name"] == graph_name for item in graph_list)
if not exists:
session.run(f"CALL dbms.graph.createGraph('{graph_name}', '', 2048)")
try:
with self._driver.session(database="default") as session:
graph_list = session.run("CALL dbms.graph.listGraphs()").data()
exists = any(item["graph_name"] == graph_name for item in graph_list)
if not exists:
session.run(
f"CALL dbms.graph.createGraph('{graph_name}', '', 2048)"
)
except Exception as e:
raise Exception(f"Failed to create graph '{graph_name}': {str(e)}")

def delete_graph(self, graph_name: str) -> None:
"""Delete a graph."""
Expand Down Expand Up @@ -89,10 +94,19 @@ def close(self):
self._driver.close()

def run(self, query: str, fetch: str = "all") -> List:
"""Run query."""
with self._driver.session(database=self._graph) as session:
try:
result = session.run(query)
return list(result)
except Exception as e:
raise Exception(f"Query execution failed: {e}")

def run_stream(self, query: str) -> Generator:
"""Run GQL."""
with self._driver.session(database=self._graph) as session:
result = session.run(query)
return list(result)
yield from result

def get_columns(self, table_name: str, table_type: str = "vertex") -> List[Dict]:
"""Get fields about specified graph.
Expand Down
12 changes: 6 additions & 6 deletions dbgpt/rag/embedding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401

__ALL__ = [
"CrossEncoderRerankEmbeddings",
"DefaultEmbeddingFactory",
"EmbeddingFactory",
"Embeddings",
"HuggingFaceBgeEmbeddings",
"HuggingFaceEmbeddings",
"HuggingFaceInferenceAPIEmbeddings",
"HuggingFaceInstructEmbeddings",
"JinaEmbeddings",
"OpenAPIEmbeddings",
"OllamaEmbeddings",
"DefaultEmbeddingFactory",
"EmbeddingFactory",
"WrappedEmbeddingFactory",
"TongYiEmbeddings",
"CrossEncoderRerankEmbeddings",
"OpenAPIEmbeddings",
"OpenAPIRerankEmbeddings",
"QianFanEmbeddings",
"TongYiEmbeddings",
"WrappedEmbeddingFactory",
]
10 changes: 9 additions & 1 deletion dbgpt/rag/index/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def __init__(self, executor: Optional[Executor] = None):
"""Init index store."""
self._executor = executor or ThreadPoolExecutor()

@abstractmethod
def get_config(self) -> IndexStoreConfig:
"""Get the index store config."""

@abstractmethod
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""Load document in index database.
Expand Down Expand Up @@ -104,6 +108,10 @@ def delete_by_ids(self, ids: str) -> List[str]:
ids(str): The vector ids to delete, separated by comma.
"""

@abstractmethod
def truncate(self) -> List[str]:
"""Truncate data by name."""

@abstractmethod
def delete_vector_name(self, index_name: str):
"""Delete index by name.
Expand Down Expand Up @@ -188,7 +196,7 @@ def similar_search(
Return:
List[Chunk]: The similar documents.
"""
return self.similar_search_with_scores(text, topk, 1.0, filters)
return self.similar_search_with_scores(text, topk, 0.0, filters)

async def asimilar_search_with_scores(
self,
Expand Down
16 changes: 16 additions & 0 deletions dbgpt/rag/transformer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,27 @@
class TransformerBase:
"""Transformer base class."""

@abstractmethod
def truncate(self):
"""Truncate operation."""

@abstractmethod
def drop(self):
"""Clean operation."""


class EmbedderBase(TransformerBase, ABC):
"""Embedder base class."""


class SummarizerBase(TransformerBase, ABC):
"""Summarizer base class."""

@abstractmethod
async def summarize(self, **args) -> str:
"""Summarize result."""


class ExtractorBase(TransformerBase, ABC):
"""Extractor base class."""

Expand Down
Loading

0 comments on commit a90f81e

Please sign in to comment.