Skip to content

Commit

Permalink
fix: adding new documents should update the existing index within the…
Browse files Browse the repository at this point in the history
… file collection instead of creating new one Cinnamon#561
  • Loading branch information
varunsharma27 committed Jan 13, 2025
1 parent b454f37 commit 25f2f31
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 12 deletions.
25 changes: 25 additions & 0 deletions libs/ktem/ktem/index/file/graph/light_graph_index.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,40 @@
from typing import Any
from uuid import uuid4

from ktem.db.engine import engine
from sqlalchemy.orm import Session
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
from .graph_index import GraphRAGIndex
from .lightrag_pipelines import LightRAGIndexingPipeline, LightRAGRetrieverPipeline


class LightRAGIndex(GraphRAGIndex):
def __init__(self, app, id: int, name: str, config: dict):
super().__init__(app, id, name, config)
self._collection_graph_id = None

def _setup_indexing_cls(self):
self._indexing_pipeline_cls = LightRAGIndexingPipeline

def _setup_retriever_cls(self):
self._retriever_pipeline_cls = [LightRAGRetrieverPipeline]

def _get_or_create_collection_graph_id(self):
if not self._collection_graph_id:
# Try to find existing graph ID for this collection
with Session(engine) as session:
result = (
session.query(self._resources["Index"].target_id)
.filter(self._resources["Index"].relation_type == "graph")
.first()
)
if result:
self._collection_graph_id = result[0]
else:
self._collection_graph_id = str(uuid4())
return self._collection_graph_id


def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
pipeline = super().get_indexing_pipeline(settings, user_id)
# indexing settings
Expand All @@ -23,6 +46,8 @@ def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
}
# set the prompts
pipeline.prompts = striped_settings
# set collection graph id
pipeline.collection_graph_id = self._get_or_create_collection_graph_id()
return pipeline

def get_retriever_pipelines(
Expand Down
60 changes: 48 additions & 12 deletions libs/ktem/ktem/index/file/graph/lightrag_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,37 @@ class LightRAGIndexingPipeline(GraphRAGIndexingPipeline):
"""GraphRAG specific indexing pipeline"""

prompts: dict[str, str] = {}
collection_graph_id: str

def store_file_id_with_graph_id(self, file_ids: list[str | None]):
# Use the collection-wide graph ID for LightRAG
graph_id = self.collection_graph_id

# Record all files under this graph_id
with Session(engine) as session:
for file_id in file_ids:
if not file_id:
continue
# Check if mapping already exists
existing = (
session.query(self.Index)
.filter(
self.Index.source_id == file_id,
self.Index.target_id == graph_id,
self.Index.relation_type == "graph",
)
.first()
)
if not existing:
node = self.Index(
source_id=file_id,
target_id=graph_id,
relation_type="graph",
)
session.add(node)
session.commit()

return graph_id

@classmethod
def get_user_settings(cls) -> dict:
Expand Down Expand Up @@ -294,46 +325,51 @@ def call_graphrag_index(self, graph_id: str, docs: list[Document]):

yield Document(
channel="debug",
text="[GraphRAG] Creating index... This can take a long time.",
text="[GraphRAG] Creating/Updating index... This can take a long time.",
)

# remove all .json files in the input_path directory (previous cache)
json_files = glob.glob(f"{input_path}/*.json")
for json_file in json_files:
os.remove(json_file)
# Check if graph already exists
graph_file = input_path / "graph_chunk_entity_relation.graphml"
is_incremental = graph_file.exists()

# Only clear cache if it's a new graph
if not is_incremental:
json_files = glob.glob(f"{input_path}/*.json")
for json_file in json_files:
os.remove(json_file)

# indexing
# Initialize or load existing GraphRAG
graphrag_func = build_graphrag(
input_path,
llm_func=llm_func,
embedding_func=embedding_func,
)
# output must be contain: Loaded graph from
# ..input/graph_chunk_entity_relation.graphml with xxx nodes, xxx edges

total_docs = len(all_docs)
process_doc_count = 0
yield Document(
channel="debug",
text=f"[GraphRAG] Indexed {process_doc_count} / {total_docs} documents.",
text=f"[GraphRAG] {'Updating' if is_incremental else 'Creating'} index: {process_doc_count} / {total_docs} documents.",
)

for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE):
cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE]
combined_doc = "\n".join(cur_docs)

# Use insert for incremental updates
graphrag_func.insert(combined_doc)
process_doc_count += len(cur_docs)
yield Document(
channel="debug",
text=(
f"[GraphRAG] Indexed {process_doc_count} "
f"/ {total_docs} documents."
f"[GraphRAG] {'Updated' if is_incremental else 'Indexed'} "
f"{process_doc_count} / {total_docs} documents."
),
)

yield Document(
channel="debug",
text="[GraphRAG] Indexing finished.",
text=f"[GraphRAG] {'Update' if is_incremental else 'Indexing'} finished.",
)

def stream(
Expand Down

0 comments on commit 25f2f31

Please sign in to comment.