Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP #644

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft

WIP #644

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions libs/knowledge-store/ragstack_knowledge_store/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def mmr_traversal_search(
lambda_mult: float = 0.5,
score_threshold: float = float("-inf"),
metadata_filter: dict[str, Any] = {}, # noqa: B006
tag_filter: set[tuple[str, str]] = {},
) -> Iterable[Node]:
"""Retrieve documents from this graph store using MMR-traversal.

Expand Down Expand Up @@ -398,6 +399,7 @@ def mmr_traversal_search(
score_threshold: Only documents with a score greater than or equal
this threshold will be chosen. Defaults to -infinity.
metadata_filter: Optional metadata to filter the results.
tag_filter: Optional tags to filter graph edges to be traversed.
"""
query_embedding = self._embedding.embed_query(query)
helper = MmrHelper(
Expand Down Expand Up @@ -444,9 +446,14 @@ def fetch_neighborhood(neighborhood: Sequence[str]) -> None:
new_candidates = {}
for adjacent in adjacents:
if adjacent.target_content_id not in outgoing_tags:
outgoing_tags[adjacent.target_content_id] = (
adjacent.target_link_to_tags
)
if len(tag_filter) == 0:
outgoing_tags[adjacent.target_content_id] = (
adjacent.target_link_to_tags
)
else:
outgoing_tags[adjacent.target_content_id] = (
tag_filter.intersection(adjacent.target_link_to_tags)
)

new_candidates[adjacent.target_content_id] = (
adjacent.target_text_embedding
Expand Down Expand Up @@ -474,7 +481,12 @@ def fetch_initial_candidates() -> None:
for row in fetched:
if row.content_id not in outgoing_tags:
candidates[row.content_id] = row.text_embedding
outgoing_tags[row.content_id] = set(row.link_to_tags or [])
if len(tag_filter) == 0:
outgoing_tags[row.content_id] = set(row.link_to_tags or [])
else:
outgoing_tags[row.content_id] = tag_filter.intersection(
set(row.link_to_tags or [])
)
helper.add_candidates(candidates)

if initial_roots:
Expand Down Expand Up @@ -522,9 +534,14 @@ def fetch_initial_candidates() -> None:
new_candidates = {}
for adjacent in adjacents:
if adjacent.target_content_id not in outgoing_tags:
outgoing_tags[adjacent.target_content_id] = (
adjacent.target_link_to_tags
)
if len(tag_filter) == 0:
outgoing_tags[adjacent.target_content_id] = (
adjacent.target_link_to_tags
)
else:
outgoing_tags[adjacent.target_content_id] = (
tag_filter.intersection(adjacent.target_link_to_tags)
)
new_candidates[adjacent.target_content_id] = (
adjacent.target_text_embedding
)
Expand Down Expand Up @@ -553,6 +570,7 @@ def traversal_search(
k: int = 4,
depth: int = 1,
metadata_filter: dict[str, Any] = {}, # noqa: B006
tag_filter: set[tuple[str, str]] = {},
) -> Iterable[Node]:
"""Retrieve documents from this knowledge store.

Expand All @@ -566,6 +584,7 @@ def traversal_search(
Defaults to 4.
depth: The maximum depth of edges to traverse. Defaults to 1.
metadata_filter: Optional metadata to filter the results.
tag_filter: Optional tags to filter graph edges to be traversed.

Returns:
Collection of retrieved documents.
Expand Down Expand Up @@ -630,7 +649,11 @@ def visit_nodes(d: int, nodes: Sequence[Any]) -> None:
# given depth, so we don't fetch it again
# (unless we find it an earlier depth)
visited_tags[(kind, value)] = d
outgoing_tags.add((kind, value))
if (
len(tag_filter) == 0
or (kind, value) in tag_filter
):
outgoing_tags.add((kind, value))

if outgoing_tags:
# If there are new tags to visit at the next depth, query for the
Expand Down
18 changes: 18 additions & 0 deletions libs/knowledge-store/tests/integration_tests/test_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,16 @@ def test_mmr_traversal(
results = gs.mmr_traversal_search("0.0", fetch_k=2, k=4, initial_roots=["v0"])
assert _result_ids(results) == ["v1", "v3", "v2"]

results = gs.mmr_traversal_search(
"0.0", k=2, fetch_k=2, tag_filter={("explicit", "link")}
)
assert _result_ids(results) == ["v0", "v2"]

results = gs.mmr_traversal_search(
"0.0", k=2, fetch_k=2, tag_filter={("no", "match")}
)
assert _result_ids(results) == []


def test_write_retrieve_keywords(
graph_store_factory: Callable[[MetadataIndexingType], GraphStore],
Expand Down Expand Up @@ -282,6 +292,14 @@ def test_write_retrieve_keywords(
results = gs.traversal_search("Earth", k=1, depth=1)
assert set(_result_ids(results)) == {"doc2", "doc1", "greetings"}

results = gs.traversal_search(
"Earth", k=1, depth=1, tag_filter={("parent", "parent")}
)
assert set(_result_ids(results)) == {"doc2", "greetings"}

results = gs.traversal_search("Earth", k=1, depth=1, tag_filter={("no", "match")})
assert _result_ids(results) == []


def test_metadata(
graph_store_factory: Callable[[MetadataIndexingType], GraphStore],
Expand Down
Loading