Skip to content

Commit

Permalink
feat: add faiss storage and graphml visualization (#12)
Browse files Browse the repository at this point in the history
* 增加了faiss作为向量数据库

* 增加了faiss作为向量数据库

* 增加了faiss作为向量数据库

* 增加了faiss作为向量数据库

* 增加了简单的网络可视化

* 修改了faiss的id处理方式
  • Loading branch information
handsomecaoyu authored Sep 6, 2024
1 parent 2038a5a commit b985264
Show file tree
Hide file tree
Showing 4 changed files with 368 additions and 2 deletions.
97 changes: 97 additions & 0 deletions examples/using_faiss_as_vextorDB.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os
import asyncio
import numpy as np
from nano_graphrag.graphrag import GraphRAG, QueryParam
from nano_graphrag._utils import logger
from nano_graphrag.base import BaseVectorStorage
from dataclasses import dataclass
import faiss
import pickle
import logging
import xxhash
logging.getLogger('msal').setLevel(logging.WARNING)
logging.getLogger('azure').setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)

WORKING_DIR = "./nano_graphrag_cache_faiss_TEST"

@dataclass
class FAISSStorage(BaseVectorStorage):

def __post_init__(self):
self._index_file_name = os.path.join(
self.global_config["working_dir"], f"{self.namespace}_faiss.index"
)
self._metadata_file_name = os.path.join(
self.global_config["working_dir"], f"{self.namespace}_metadata.pkl"
)
self._max_batch_size = self.global_config["embedding_batch_num"]

if os.path.exists(self._index_file_name) and os.path.exists(self._metadata_file_name):
self._index = faiss.read_index(self._index_file_name)
with open(self._metadata_file_name, 'rb') as f:
self._metadata = pickle.load(f)
else:
self._index = faiss.IndexIDMap(faiss.IndexFlatIP(self.embedding_func.embedding_dim))
self._metadata = {}

async def upsert(self, data: dict[str, dict]):
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")

contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embeddings_list = await asyncio.gather(
*[self.embedding_func(batch) for batch in batches]
)
embeddings = np.concatenate(embeddings_list)

ids = []
for k, v in data.items():
id = xxhash.xxh32_intdigest(k.encode())
metadata = {k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}
metadata['id'] = k
self._metadata[id] = metadata
ids.append(id)

ids = np.array(ids, dtype=np.int64)
self._index.add_with_ids(embeddings, ids)


return len(data)

async def query(self, query, top_k=5):
embedding = await self.embedding_func([query])
distances, indices = self._index.search(embedding, top_k)

results = []
for _, (distance, id) in enumerate(zip(distances[0], indices[0])):
if id != -1: # FAISS returns -1 for empty slots
if id in self._metadata:
metadata = self._metadata[id]
results.append({**metadata, "distance": 1 - distance}) # Convert to cosine distance

return results

async def index_done_callback(self):
faiss.write_index(self._index, self._index_file_name)
with open(self._metadata_file_name, 'wb') as f:
pickle.dump(self._metadata, f)

if __name__ == "__main__":

graph_func = GraphRAG(
working_dir=WORKING_DIR,
enable_llm_cache=True,
vector_db_storage_cls=FAISSStorage,
)

with open(r"tests/mock_data.txt", encoding='utf-8') as f:
graph_func.insert(f.read()[:30000])

# Perform global graphrag search
print(graph_func.query("What are the top themes in this story?"))


270 changes: 270 additions & 0 deletions examples/visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
import networkx as nx
import json
import webbrowser
import os
import http.server
import socketserver
import threading

# 读取GraphML文件并转换为JSON
def graphml_to_json(graphml_file):
G = nx.read_graphml(graphml_file)
data = nx.node_link_data(G)
return json.dumps(data)

# 创建HTML文件
def create_html(json_data, html_path):
json_data = json_data.replace('\\"', '')
html_content = '''
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Graph Visualization</title>
<script src="https://d3js.org/d3.v7.min.js"></script>
<style>
body, html {
margin: 0;
padding: 0;
width: 100%;
height: 100%;
overflow: hidden;
}
svg {
width: 100%;
height: 100%;
}
.links line {
stroke: #999;
stroke-opacity: 0.6;
}
.nodes circle {
stroke: #fff;
stroke-width: 1.5px;
}
.node-label {
font-size: 12px;
pointer-events: none;
}
.link-label {
font-size: 10px;
fill: #666;
pointer-events: none;
opacity: 0;
transition: opacity 0.3s;
}
.link:hover .link-label {
opacity: 1;
}
.tooltip {
position: absolute;
text-align: left;
padding: 10px;
font: 12px sans-serif;
background: lightsteelblue;
border: 0px;
border-radius: 8px;
pointer-events: none;
opacity: 0;
transition: opacity 0.3s;
max-width: 300px;
}
.legend {
position: absolute;
top: 10px;
right: 10px;
background-color: rgba(255, 255, 255, 0.8);
padding: 10px;
border-radius: 5px;
}
.legend-item {
margin: 5px 0;
}
.legend-color {
display: inline-block;
width: 20px;
height: 20px;
margin-right: 5px;
vertical-align: middle;
}
</style>
</head>
<body>
<svg></svg>
<div class="tooltip"></div>
<div class="legend"></div>
<script>
const graphData = JSON.parse('{json_data}');
const svg = d3.select("svg"),
width = window.innerWidth,
height = window.innerHeight;
svg.attr("viewBox", [0, 0, width, height]);
const g = svg.append("g");
const entityTypes = [...new Set(graphData.nodes.map(d => d.entity_type))];
const color = d3.scaleOrdinal(d3.schemeCategory10).domain(entityTypes);
const simulation = d3.forceSimulation(graphData.nodes)
.force("link", d3.forceLink(graphData.links).id(d => d.id).distance(150))
.force("charge", d3.forceManyBody().strength(-300))
.force("center", d3.forceCenter(width / 2, height / 2))
.force("collide", d3.forceCollide().radius(30));
const linkGroup = g.append("g")
.attr("class", "links")
.selectAll("g")
.data(graphData.links)
.enter().append("g")
.attr("class", "link");
const link = linkGroup.append("line")
.attr("stroke-width", d => Math.sqrt(d.value));
const linkLabel = linkGroup.append("text")
.attr("class", "link-label")
.text(d => d.description || "");
const node = g.append("g")
.attr("class", "nodes")
.selectAll("circle")
.data(graphData.nodes)
.enter().append("circle")
.attr("r", 5)
.attr("fill", d => color(d.entity_type))
.call(d3.drag()
.on("start", dragstarted)
.on("drag", dragged)
.on("end", dragended));
const nodeLabel = g.append("g")
.attr("class", "node-labels")
.selectAll("text")
.data(graphData.nodes)
.enter().append("text")
.attr("class", "node-label")
.text(d => d.id);
const tooltip = d3.select(".tooltip");
node.on("mouseover", function(event, d) {
tooltip.transition()
.duration(200)
.style("opacity", .9);
tooltip.html(`<strong>${d.id}</strong><br>Entity Type: ${d.entity_type}<br>Description: ${d.description || "N/A"}`)
.style("left", (event.pageX + 10) + "px")
.style("top", (event.pageY - 28) + "px");
})
.on("mouseout", function(d) {
tooltip.transition()
.duration(500)
.style("opacity", 0);
});
const legend = d3.select(".legend");
entityTypes.forEach(type => {
legend.append("div")
.attr("class", "legend-item")
.html(`<span class="legend-color" style="background-color: ${color(type)}"></span>${type}`);
});
simulation
.nodes(graphData.nodes)
.on("tick", ticked);
simulation.force("link")
.links(graphData.links);
function ticked() {
link
.attr("x1", d => d.source.x)
.attr("y1", d => d.source.y)
.attr("x2", d => d.target.x)
.attr("y2", d => d.target.y);
linkLabel
.attr("x", d => (d.source.x + d.target.x) / 2)
.attr("y", d => (d.source.y + d.target.y) / 2)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "middle");
node
.attr("cx", d => d.x)
.attr("cy", d => d.y);
nodeLabel
.attr("x", d => d.x + 8)
.attr("y", d => d.y + 3);
}
function dragstarted(event) {
if (!event.active) simulation.alphaTarget(0.3).restart();
event.subject.fx = event.subject.x;
event.subject.fy = event.subject.y;
}
function dragged(event) {
event.subject.fx = event.x;
event.subject.fy = event.y;
}
function dragended(event) {
if (!event.active) simulation.alphaTarget(0);
event.subject.fx = null;
event.subject.fy = null;
}
const zoom = d3.zoom()
.scaleExtent([0.1, 10])
.on("zoom", zoomed);
svg.call(zoom);
function zoomed(event) {
g.attr("transform", event.transform);
}
</script>
</body>
</html>
'''.replace("{json_data}", json_data.replace("'", "\\'").replace("\n", ""))

with open(html_path, 'w', encoding='utf-8') as f:
f.write(html_content)

# 启动简单的HTTP服务器
def start_server():
handler = http.server.SimpleHTTPRequestHandler
with socketserver.TCPServer(("", 8000), handler) as httpd:
print("Server started at http://localhost:8000")
httpd.serve_forever()

# 主函数
def visualize_graphml(graphml_file, html_path):
json_data = graphml_to_json(graphml_file)
create_html(json_data, html_path)

# 在后台启动服务器
server_thread = threading.Thread(target=start_server)
server_thread.daemon = True
server_thread.start()

# 打开默认浏览器
webbrowser.open('http://localhost:8000/graph_visualization.html')

print("Visualization is ready. Press Ctrl+C to exit.")
try:
# 保持主线程运行
while True:
pass
except KeyboardInterrupt:
print("Shutting down...")

# 使用示例
if __name__ == "__main__":
graphml_file = r"nano_graphrag_cache_azure_openai_TEST\graph_chunk_entity_relation.graphml" # 替换为您的GraphML文件路径
html_path = "graph_visualization.html"
visualize_graphml(graphml_file, html_path)
1 change: 0 additions & 1 deletion nano_graphrag/_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
)
from .prompt import GRAPH_FIELD_SEP


@dataclass
class JsonKVStorage(BaseKVStorage):
def __post_init__(self):
Expand Down
2 changes: 1 addition & 1 deletion nano_graphrag/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def compute_mdhash_id(content, prefix: str = ""):


def write_json(json_obj, file_name):
with open(file_name, "w") as f:
with open(file_name, "w", encoding='utf-8') as f:
json.dump(json_obj, f, indent=2, ensure_ascii=False)


Expand Down

0 comments on commit b985264

Please sign in to comment.