Skip to content

Commit

Permalink
feat: add naive RAG
Browse files Browse the repository at this point in the history
  • Loading branch information
gusye1234 committed Sep 5, 2024
1 parent e5ec32b commit 4d0ae78
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 6 deletions.
39 changes: 35 additions & 4 deletions nano_graphrag/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
from typing import Union
from collections import Counter, defaultdict

from openai import AsyncOpenAI

from ._llm import gpt_4o_complete
from ._utils import (
logger,
clean_str,
Expand Down Expand Up @@ -992,4 +989,38 @@ async def global_query(
report_data=points_context, response_type=query_param.response_type
),
)
return response
return response


async def naive_query(
query,
chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
):
use_model_func = global_config["best_model_func"]
results = await chunks_vdb.query(query, top_k=query_param.top_k)
if not len(results):
return PROMPTS["fail_response"]
chunks_ids = [r["id"] for r in results]
chunks = await text_chunks_db.get_by_ids(chunks_ids)

maybe_trun_chunks = truncate_list_by_token_size(
chunks,
key=lambda x: x["content"],
max_token_size=query_param.naive_max_token_for_text_unit,
)
logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
section = "--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
if query_param.only_need_context:
return section
sys_prompt_temp = PROMPTS["naive_rag_response"]
sys_prompt = sys_prompt_temp.format(
content_data=section, response_type=query_param.response_type
)
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
return response
4 changes: 3 additions & 1 deletion nano_graphrag/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

@dataclass
class QueryParam:
mode: Literal["local", "global"] = "global"
mode: Literal["local", "global", "naive"] = "global"
only_need_context: bool = False
response_type: str = "Multiple Paragraphs"
level: int = 2
top_k: int = 20
# naive search
naive_max_token_for_text_unit = 12000
# local search
local_max_token_for_text_unit: int = 4000 # 12000 * 0.33
local_max_token_for_local_context: int = 4800 # 12000 * 0.4
Expand Down
34 changes: 33 additions & 1 deletion nano_graphrag/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
generate_community_report,
local_query,
global_query,
naive_query,
)
from ._storage import (
JsonKVStorage,
Expand Down Expand Up @@ -54,6 +55,7 @@ class GraphRAG:
)
# graph mode
enable_local: bool = True
enable_naive_rag: bool = False

# text chunking
chunk_token_size: int = 1200
Expand Down Expand Up @@ -151,11 +153,20 @@ def __post_init__(self):
namespace="entities",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"entity_name"}
meta_fields={"entity_name"},
)
if self.enable_local
else None
)
self.chunks_vdb = (
self.vector_db_storage_cls(
namespace="chunks",
global_config=asdict(self),
embedding_func=self.embedding_func,
)
if self.enable_naive_rag
else None
)

self.best_model_func = limit_async_func_call(self.best_model_max_async)(
partial(self.best_model_func, hashing_kv=self.llm_response_cache)
Expand All @@ -172,9 +183,15 @@ def query(self, query: str, param: QueryParam = QueryParam()):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery(query, param))

def eval(self, querys: list[str], contexts: list[str], answers: list[str]):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.aeval(querys, contexts, answers))

async def aquery(self, query: str, param: QueryParam = QueryParam()):
if param.mode == "local" and not self.enable_local:
raise ValueError("enable_local is False, cannot query in local mode")
if param.mode == "naive" and not self.enable_naive_rag:
raise ValueError("enable_naive_rag is False, cannot query in local mode")
if param.mode == "local":
response = await local_query(
query,
Expand All @@ -195,6 +212,14 @@ async def aquery(self, query: str, param: QueryParam = QueryParam()):
param,
asdict(self),
)
elif param.mode == "naive":
response = await naive_query(
query,
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
)
else:
raise ValueError(f"Unknown mode {param.mode}")
await self._query_done()
Expand Down Expand Up @@ -242,6 +267,9 @@ async def ainsert(self, string_or_strings):
logger.warning(f"All chunks are already in the storage")
return
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
if self.enable_naive_rag:
logger.info("Insert chunks for naive RAG")
await self.chunks_vdb.upsert(inserting_chunks)

# TODO: no incremental update for communities now, so just drop all
await self.community_reports.drop()
Expand Down Expand Up @@ -273,6 +301,9 @@ async def ainsert(self, string_or_strings):
finally:
await self._insert_done()

async def aeval(self, querys: list[str], contexts: list[str], answers: list[str]):
pass

async def _insert_done(self):
tasks = []
for storage_inst in [
Expand All @@ -281,6 +312,7 @@ async def _insert_done(self):
self.llm_response_cache,
self.community_reports,
self.entities_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
]:
if storage_inst is None:
Expand Down
14 changes: 14 additions & 0 deletions nano_graphrag/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,20 @@
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
"""

PROMPTS[
"naive_rag_response"
] = """You're a helpful assistant
Below are the knowledge you know:
{content_data}
---
If you don't know the answer or if the provided knowledge do not contain sufficient information to provide an answer, just say so. Do not make anything up.
Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
If you don't know the answer, just say so. Do not make anything up.
Do not include information where the supporting evidence for it is not provided.
---Target response length and format---
{response_type}
"""

PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."

PROMPTS["process_tickers"] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]

0 comments on commit 4d0ae78

Please sign in to comment.