Skip to content

Commit

Permalink
Slightly cleaned version without reflection signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
NumberChiffre committed Sep 8, 2024
1 parent 0bb2b74 commit e347ac1
Show file tree
Hide file tree
Showing 10 changed files with 420 additions and 338 deletions.
124 changes: 24 additions & 100 deletions examples/benchmarks/dspy_entity.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,19 @@
import dspy
import os
from dotenv import load_dotenv
from openai import AsyncOpenAI, OpenAI
from openai import AsyncOpenAI
import logging
import asyncio
from nano_graphrag._op import extract_entities, extract_entities_dspy
from nano_graphrag.entity_extraction.extract import extract_entities_dspy
from nano_graphrag._storage import NetworkXStorage, BaseKVStorage
from nano_graphrag._utils import compute_mdhash_id, compute_args_hash
from nano_graphrag.prompt import PROMPTS

WORKING_DIR = "./nano_graphrag_cache_dspy_entity"

load_dotenv()

logging.basicConfig(level=logging.WARNING)
logging.getLogger("nano-graphrag").setLevel(logging.DEBUG)


class DeepSeek(dspy.Module):
def __init__(self, model, api_key, **kwargs):
self.model = model
self.api_key = api_key
self.client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com")
self.provider = "default",
self.history = []
self.kwargs = {
"temperature": 0.2,
"max_tokens": 2048,
**kwargs
}

def basic_request(self, prompt, **kwargs):
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
],
stream=False,
**self.kwargs
)
self.history.append({"prompt": prompt, "response": response})
return response

def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs):
response = self.basic_request(prompt, **kwargs)
completions = [choice.message.content for choice in response.choices]
return completions

def inspect_history(self, n: int = 1):
if len(self.history) < n:
return self.history
return self.history[-n:]
logger = logging.getLogger("nano-graphrag")
logger.setLevel(logging.DEBUG)


async def deepseepk_model_if_cache(
Expand Down Expand Up @@ -88,64 +50,12 @@ async def deepseepk_model_if_cache(
return response.choices[0].message.content


class EntityTypeExtractionSignature(dspy.Signature):
input_text = dspy.InputField(desc="The text to extract entity types from")
entity_types = dspy.OutputField(desc="List of entity types present in the text")


class EntityExtractionSignature(dspy.Signature):
input_text = dspy.InputField(desc="The text to extract entities and relationships from")
entities = dspy.OutputField(desc="List of extracted entities with their types and descriptions")
relationships = dspy.OutputField(desc="List of relationships between entities, including descriptions and importance scores")
reasoning = dspy.OutputField(desc="Step-by-step reasoning for entity and relationship extraction")


class EntityExtractor(dspy.Module):
def __init__(self):
super().__init__()
self.type_extractor = dspy.ChainOfThought(EntityTypeExtractionSignature)
self.cot = dspy.ChainOfThought(EntityExtractionSignature)

def forward(self, input_text):
type_result = self.type_extractor(input_text=input_text)
entity_types = type_result.entity_types
prompt_template = PROMPTS["entity_extraction"]
formatted_prompt = prompt_template.format(
input_text=input_text,
entity_types=entity_types,
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"]
)
return self.cot(input_text=formatted_prompt)


# def extract_entities_dspy(text):
# dspy_extractor = EntityExtractor()
# dspy_result = dspy_extractor(input_text=text)

# print("DSPY Result:")
# print("\nReasoning:")
# print(dspy_result.reasoning)
# print("\nEntities:")
# entities = dspy_result.entities.split(PROMPTS["DEFAULT_RECORD_DELIMITER"])
# for entity in entities:
# if entity.strip():
# print(entity.strip())

# print("\nRelationships:")
# relationships = dspy_result.relationships.split(PROMPTS["DEFAULT_RECORD_DELIMITER"])
# for relationship in relationships:
# if relationship.strip():
# print(relationship.strip())


async def nano_entity_extraction(text):
async def nano_entity_extraction(text: str, system_prompt: str = None):
graph_storage = NetworkXStorage(namespace="test", global_config={
"working_dir": WORKING_DIR,
"entity_summary_to_max_tokens": 500,
"cheap_model_func": deepseepk_model_if_cache,
"best_model_func": deepseepk_model_if_cache,
"cheap_model_func": lambda *args, **kwargs: deepseepk_model_if_cache(*args, system_prompt=system_prompt, **kwargs),
"best_model_func": lambda *args, **kwargs: deepseepk_model_if_cache(*args, system_prompt=system_prompt, **kwargs),
"cheap_model_max_token_size": 4096,
"best_model_max_token_size": 4096,
"tiktoken_model_name": "gpt-4o",
Expand Down Expand Up @@ -176,11 +86,25 @@ async def nano_entity_extraction(text):


if __name__ == "__main__":
lm = DeepSeek(model="deepseek-chat", api_key=os.environ["DEEPSEEK_API_KEY"])
system_prompt = """
You are a world-class AI system, capable of complex reasoning and reflection.
Reason through the query, and then provide your final response.
If you detect that you made a mistake in your reasoning at any point, correct yourself.
Think carefully.
"""
lm = dspy.OpenAI(
model="deepseek-chat",
model_type="chat",
api_key=os.environ["DEEPSEEK_API_KEY"],
base_url=os.environ["DEEPSEEK_BASE_URL"],
system_prompt=system_prompt,
temperature=0.3,
top_p=1,
max_tokens=4096
)
dspy.settings.configure(lm=lm)

with open("./examples/data/test.txt", encoding="utf-8-sig") as f:
text = f.read()

asyncio.run(nano_entity_extraction(text))
# extract_entities_dspy(text)
asyncio.run(nano_entity_extraction(text, system_prompt))
149 changes: 149 additions & 0 deletions examples/using_dspy_entity_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import os
from openai import AsyncOpenAI
from dotenv import load_dotenv
import logging
import numpy as np
import dspy
from sentence_transformers import SentenceTransformer
from nano_graphrag import GraphRAG, QueryParam
from nano_graphrag._llm import gpt_4o_mini_complete
from nano_graphrag._storage import HNSWVectorStorage
from nano_graphrag.base import BaseKVStorage
from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs

logging.basicConfig(level=logging.WARNING)
logging.getLogger("nano-graphrag").setLevel(logging.DEBUG)

WORKING_DIR = "./nano_graphrag_cache_using_hnsw_as_vectorDB"

load_dotenv()


EMBED_MODEL = SentenceTransformer(
"sentence-transformers/all-MiniLM-L6-v2", cache_folder=WORKING_DIR, device="cpu"
)


@wrap_embedding_func_with_attrs(
embedding_dim=EMBED_MODEL.get_sentence_embedding_dimension(),
max_token_size=EMBED_MODEL.max_seq_length,
)
async def local_embedding(texts: list[str]) -> np.ndarray:
return EMBED_MODEL.encode(texts, normalize_embeddings=True)


async def deepseepk_model_if_cache(
prompt, model: str = "deepseek-chat", system_prompt=None, history_messages=[], **kwargs
) -> str:
openai_async_client = AsyncOpenAI(
api_key=os.environ.get("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com"
)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})

# Get the cached response if having-------------------
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# -----------------------------------------------------

response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)

# Cache the response if having-------------------
if hashing_kv is not None:
await hashing_kv.upsert(
{args_hash: {"return": response.choices[0].message.content, "model": model}}
)
# -----------------------------------------------------
return response.choices[0].message.content



def remove_if_exist(file):
if os.path.exists(file):
os.remove(file)


def insert():
from time import time

with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
# with open("./examples/data/test.txt", encoding="utf-8-sig") as f:
FAKE_TEXT = f.read()

remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
rag = GraphRAG(
working_dir=WORKING_DIR,
enable_llm_cache=True,
vector_db_storage_cls=HNSWVectorStorage,
vector_db_storage_cls_kwargs={"max_elements": 1000000, "ef_search": 200, "M": 50},
best_model_max_async=10,
cheap_model_max_async=10,
best_model_func=deepseepk_model_if_cache,
cheap_model_func=deepseepk_model_if_cache,
embedding_func=local_embedding,
)
start = time()
rag.insert(FAKE_TEXT)
print("indexing time:", time() - start)


def query():
rag = GraphRAG(
working_dir=WORKING_DIR,
enable_llm_cache=True,
vector_db_storage_cls=HNSWVectorStorage,
vector_db_storage_cls_kwargs={"max_elements": 1000000, "ef_search": 200, "M": 50},
best_model_max_token_size=8196,
cheap_model_max_token_size=8196,
best_model_max_async=4,
cheap_model_max_async=4,
best_model_func=gpt_4o_mini_complete,
cheap_model_func=gpt_4o_mini_complete,
embedding_func=local_embedding,

)
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="global")
)
)
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="local")
)
)


if __name__ == "__main__":
system_prompt = """
You are a world-class AI system, capable of complex reasoning and reflection.
Reason through the query, and then provide your final response.
If you detect that you made a mistake in your reasoning at any point, correct yourself.
Think carefully.
"""
lm = dspy.OpenAI(
model="deepseek-chat",
model_type="chat",
api_key=os.environ["DEEPSEEK_API_KEY"],
base_url=os.environ["DEEPSEEK_BASE_URL"],
system_prompt=system_prompt,
temperature=0.3,
top_p=1,
max_tokens=4096
)
dspy.settings.configure(lm=lm)
insert()
query()
43 changes: 0 additions & 43 deletions examples/using_hnsw_as_vectorDB.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def insert():
from time import time

with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
# with open("./examples/data/test.txt", encoding="utf-8-sig") as f:
FAKE_TEXT = f.read()

remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
Expand Down Expand Up @@ -109,49 +108,7 @@ def query():
)
)

import dspy
from openai import OpenAI


class DeepSeek(dspy.Module):
def __init__(self, model, api_key, **kwargs):
self.model = model
self.api_key = api_key
self.client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com")
self.provider = "default",
self.history = []
self.kwargs = {
"temperature": 0.2,
"max_tokens": 2048,
**kwargs
}

def basic_request(self, prompt, **kwargs):
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
],
stream=False,
**self.kwargs
)
self.history.append({"prompt": prompt, "response": response})
return response

def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs):
response = self.basic_request(prompt, **kwargs)
completions = [choice.message.content for choice in response.choices]
return completions

def inspect_history(self, n: int = 1):
if len(self.history) < n:
return self.history
return self.history[-n:]


if __name__ == "__main__":
lm = DeepSeek(model="deepseek-chat", api_key=os.environ["DEEPSEEK_API_KEY"])
dspy.settings.configure(lm=lm)
insert()
query()
4 changes: 2 additions & 2 deletions nano_graphrag/_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@retry(
stop=stop_after_attempt(3),
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
Expand Down Expand Up @@ -69,7 +69,7 @@ async def gpt_4o_mini_complete(

@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
Expand Down
Loading

0 comments on commit e347ac1

Please sign in to comment.