Skip to content

Commit

Permalink
Merge pull request #1 from gusye1234/main
Browse files Browse the repository at this point in the history
MERGE Master
  • Loading branch information
Ashes47 authored Sep 24, 2024
2 parents 4e3fb7d + b33b2b8 commit 7ac7908
Show file tree
Hide file tree
Showing 25 changed files with 17,315 additions and 917 deletions.
4 changes: 2 additions & 2 deletions docs/ROADMAP.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
## Next Version

- [ ] Add neo4j for better visualization
- [x] Add DSpy for prompt-tuning
- [ ] Add neo4j for better visualization @gusye1234
- [ ] Add DSpy for prompt-tuning to make small models(Qwen2 7B, Llama 3.1 8B...) can extract entities. @NumberChiffre



Expand Down
11 changes: 6 additions & 5 deletions examples/benchmarks/dspy_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,19 +114,20 @@ async def run_benchmark(text: str):
system_prompt_dspy = f"{system_prompt} Time: {time.time()}."
lm = dspy.OpenAI(
model="deepseek-chat",
model_type="chat",
model_type="chat",
api_provider="openai",
api_key=os.environ["DEEPSEEK_API_KEY"],
base_url=os.environ["DEEPSEEK_BASE_URL"],
system_prompt=system_prompt_dspy,
system_prompt=system_prompt,
temperature=1.0,
top_p=1,
max_tokens=4096
max_tokens=8192
)
dspy.settings.configure(lm=lm)
dspy.settings.configure(lm=lm, experimental=True)
graph_storage_with_dspy, time_with_dspy = await benchmark_entity_extraction(text, system_prompt_dspy, use_dspy=True)
print(f"Execution time with DSPy-AI: {time_with_dspy:.2f} seconds")
print_extraction_results(graph_storage_with_dspy)

import pdb; pdb.set_trace()
print("Running benchmark without DSPy-AI:")
system_prompt_no_dspy = f"{system_prompt} Time: {time.time()}."
graph_storage_without_dspy, time_without_dspy = await benchmark_entity_extraction(text, system_prompt_no_dspy, use_dspy=False)
Expand Down
14,525 changes: 14,146 additions & 379 deletions examples/finetune_entity_relationship_dspy.ipynb

Large diffs are not rendered by default.

2,062 changes: 2,062 additions & 0 deletions examples/generate_entity_relationship_dspy.ipynb

Large diffs are not rendered by default.

39 changes: 22 additions & 17 deletions examples/graphml_visualize.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import networkx as nx
import json
import os
import webbrowser
import http.server
import socketserver
import threading

# 读取GraphML文件并转换为JSON
# load GraphML file and transfer to JSON
def graphml_to_json(graphml_file):
G = nx.read_graphml(graphml_file)
data = nx.node_link_data(G)
return json.dumps(data)


# 创建HTML文件
# create HTML file
def create_html(html_path):
html_content = '''
<!DOCTYPE html>
Expand Down Expand Up @@ -242,36 +243,40 @@ def create_json(json_data, json_path):
f.write(json_data)


# 启动简单的HTTP服务器
def start_server():
# start simple HTTP server
def start_server(port):
handler = http.server.SimpleHTTPRequestHandler
with socketserver.TCPServer(("", 8000), handler) as httpd:
print("Server started at http://localhost:8000")
with socketserver.TCPServer(("", port), handler) as httpd:
print(f"Server started at http://localhost:{port}")
httpd.serve_forever()

# 主函数
def visualize_graphml(graphml_file, html_path):
# main function
def visualize_graphml(graphml_file, html_path, port=8000):
json_data = graphml_to_json(graphml_file)
create_json(json_data, 'graph_json.js')
html_dir = os.path.dirname(html_path)
if not os.path.exists(html_dir):
os.makedirs(html_dir)
json_path = os.path.join(html_dir, 'graph_json.js')
create_json(json_data, json_path)
create_html(html_path)
# 在后台启动服务器
server_thread = threading.Thread(target=start_server)
# start server in background
server_thread = threading.Thread(target=start_server(port))
server_thread.daemon = True
server_thread.start()

# 打开默认浏览器
webbrowser.open('http://localhost:8000/graph_visualization.html')
# open default browser
webbrowser.open(f'http://localhost:{port}/{html_path}')

print("Visualization is ready. Press Ctrl+C to exit.")
try:
# 保持主线程运行
# keep main thread running
while True:
pass
except KeyboardInterrupt:
print("Shutting down...")

# 使用示例
# usage
if __name__ == "__main__":
graphml_file = r"nano_graphrag_cache_azure_openai_TEST\graph_chunk_entity_relation.graphml" # 替换为您的GraphML文件路径
graphml_file = r"nano_graphrag_cache_azure_openai_TEST\graph_chunk_entity_relation.graphml" # replace with your GraphML file path
html_path = "graph_visualization.html"
visualize_graphml(graphml_file, html_path)
visualize_graphml(graphml_file, html_path, 11236)
3 changes: 3 additions & 0 deletions examples/no_openai_key_at_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ async def local_embedding(texts: list[str]) -> np.ndarray:
async def ollama_model_if_cache(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
# remove kwargs that are not supported by ollama
kwargs.pop("max_tokens", None)

ollama_client = ollama.AsyncClient()
messages = []
if system_prompt:
Expand Down
81 changes: 28 additions & 53 deletions examples/using_custom_chunking_method.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,43 @@


from nano_graphrag._utils import encode_string_by_tiktoken
from nano_graphrag.base import QueryParam
from nano_graphrag.graphrag import GraphRAG
from nano_graphrag._op import chunking_by_seperators


def chunking_by_specific_separators(
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o",
def chunking_by_token_size(
tokens_list: list[list[int]], # nano-graphrag may pass a batch of docs' tokens
doc_keys: list[str], # nano-graphrag may pass a batch of docs' key ids
tiktoken_model, # a titoken model
overlap_token_size=128,
max_token_size=1024,
):
from langchain_text_splitters import RecursiveCharacterTextSplitter


text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=max_token_size,
chunk_overlap=overlap_token_size,
# length_function=lambda x: len(encode_string_by_tiktoken(x)),
model_name=tiktoken_model,
is_separator_regex=False,
separators=[
# Paragraph separators
"\n\n",
"\r\n\r\n",
# Line breaks
"\n",
"\r\n",
# Sentence ending punctuation
"。", # Chinese period
".", # Full-width dot
".", # English period
"!", # Chinese exclamation mark
"!", # English exclamation mark
"?", # Chinese question mark
"?", # English question mark
# Whitespace characters
" ", # Space
"\t", # Tab
"\u3000", # Full-width space
# Special characters
"\u200b", # Zero-width space (used in some Asian languages)
# Final fallback
"",
])
texts = text_splitter.split_text(content)

results = []
for index, chunk_content in enumerate(texts):

results.append(
{
# "tokens": None,
"content": chunk_content.strip(),
"chunk_order_index": index,
}
)
for index, tokens in enumerate(tokens_list):
chunk_token = []
lengths = []
for start in range(0, len(tokens), max_token_size - overlap_token_size):

chunk_token.append(tokens[start : start + max_token_size])
lengths.append(min(max_token_size, len(tokens) - start))

chunk_token = tiktoken_model.decode_batch(chunk_token)
for i, chunk in enumerate(chunk_token):

results.append(
{
"tokens": lengths[i],
"content": chunk.strip(),
"chunk_order_index": i,
"full_doc_id": doc_keys[index],
}
)

return results


WORKING_DIR = "./nano_graphrag_cache_local_embedding_TEST"
rag = GraphRAG(
working_dir=WORKING_DIR,
chunk_func=chunking_by_specific_separators,
chunk_func=chunking_by_seperators,
)

with open("../tests/mock_data.txt", encoding="utf-8-sig") as f:
FAKE_TEXT = f.read()

# rag.insert(FAKE_TEXT)
print(rag.query("What the main theme of this story?", param=QueryParam(mode="local")))
8 changes: 4 additions & 4 deletions examples/using_dspy_entity_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,14 @@ def query():
"""
lm = dspy.OpenAI(
model="deepseek-chat",
model_type="chat",
model_type="chat",
api_provider="openai",
api_key=os.environ["DEEPSEEK_API_KEY"],
base_url=os.environ["DEEPSEEK_BASE_URL"],
system_prompt=system_prompt,
temperature=1.0,
top_p=1,
max_tokens=4096
max_tokens=8192
)
dspy.settings.configure(lm=lm)
dspy.settings.configure(lm=lm, experimental=True)
insert()
query()
122 changes: 122 additions & 0 deletions examples/using_llm_api_as_llm+ollama_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import os
import logging
import ollama
import numpy as np
from openai import AsyncOpenAI
from nano_graphrag import GraphRAG, QueryParam
from nano_graphrag import GraphRAG, QueryParam
from nano_graphrag.base import BaseKVStorage
from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs

logging.basicConfig(level=logging.WARNING)
logging.getLogger("nano-graphrag").setLevel(logging.INFO)

# Assumed llm model settings
LLM_BASE_URL = "https://your.api.url"
LLM_API_KEY = "your_api_key"
MODEL = "your_model_name"

# Assumed embedding model settings
EMBEDDING_MODEL = "nomic-embed-text"
EMBEDDING_MODEL_DIM = 768
EMBEDDING_MODEL_MAX_TOKENS = 8192


async def llm_model_if_cache(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
openai_async_client = AsyncOpenAI(
api_key=LLM_API_KEY, base_url=LLM_BASE_URL
)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})

# Get the cached response if having-------------------
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(MODEL, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# -----------------------------------------------------

response = await openai_async_client.chat.completions.create(
model=MODEL, messages=messages, **kwargs
)

# Cache the response if having-------------------
if hashing_kv is not None:
await hashing_kv.upsert(
{args_hash: {"return": response.choices[0].message.content, "model": MODEL}}
)
# -----------------------------------------------------
return response.choices[0].message.content


def remove_if_exist(file):
if os.path.exists(file):
os.remove(file)


WORKING_DIR = "./nano_graphrag_cache_llm_TEST"


def query():
rag = GraphRAG(
working_dir=WORKING_DIR,
best_model_func=llm_model_if_cache,
cheap_model_func=llm_model_if_cache,
embedding_func=ollama_embedding,
)
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="global")
)
)


def insert():
from time import time

with open("./tests/mock_data.txt", encoding="utf-8-sig") as f:
FAKE_TEXT = f.read()

remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")

rag = GraphRAG(
working_dir=WORKING_DIR,
enable_llm_cache=True,
best_model_func=llm_model_if_cache,
cheap_model_func=llm_model_if_cache,
embedding_func=ollama_embedding,
)
start = time()
rag.insert(FAKE_TEXT)
print("indexing time:", time() - start)
# rag = GraphRAG(working_dir=WORKING_DIR, enable_llm_cache=True)
# rag.insert(FAKE_TEXT[half_len:])

# We're using Ollama to generate embeddings for the BGE model
@wrap_embedding_func_with_attrs(
embedding_dim= EMBEDDING_MODEL_DIM,
max_token_size= EMBEDDING_MODEL_MAX_TOKENS,
)

async def ollama_embedding(texts :list[str]) -> np.ndarray:
embed_text = []
for text in texts:
data = ollama.embeddings(model=EMBEDDING_MODEL, prompt=text)
embed_text.append(data["embedding"])

return embed_text

if __name__ == "__main__":
insert()
query()
3 changes: 3 additions & 0 deletions examples/using_ollama_as_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
async def ollama_model_if_cache(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
# remove kwargs that are not supported by ollama
kwargs.pop("max_tokens", None)

ollama_client = ollama.AsyncClient()
messages = []
if system_prompt:
Expand Down
Loading

0 comments on commit 7ac7908

Please sign in to comment.