diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py index 02110a62e..0e391214c 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py @@ -367,6 +367,7 @@ def mmr_traversal_search( lambda_mult: float = 0.5, score_threshold: float = float("-inf"), metadata_filter: dict[str, Any] = {}, # noqa: B006 + tag_filter: set[tuple[str, str]] = {}, ) -> Iterable[Node]: """Retrieve documents from this graph store using MMR-traversal. @@ -398,6 +399,7 @@ def mmr_traversal_search( score_threshold: Only documents with a score greater than or equal this threshold will be chosen. Defaults to -infinity. metadata_filter: Optional metadata to filter the results. + tag_filter: Optional tags to filter graph edges to be traversed. """ query_embedding = self._embedding.embed_query(query) helper = MmrHelper( @@ -444,9 +446,14 @@ def fetch_neighborhood(neighborhood: Sequence[str]) -> None: new_candidates = {} for adjacent in adjacents: if adjacent.target_content_id not in outgoing_tags: - outgoing_tags[adjacent.target_content_id] = ( - adjacent.target_link_to_tags - ) + if len(tag_filter) == 0: + outgoing_tags[adjacent.target_content_id] = ( + adjacent.target_link_to_tags + ) + else: + outgoing_tags[adjacent.target_content_id] = ( + tag_filter.intersection(adjacent.target_link_to_tags) + ) new_candidates[adjacent.target_content_id] = ( adjacent.target_text_embedding @@ -474,7 +481,12 @@ def fetch_initial_candidates() -> None: for row in fetched: if row.content_id not in outgoing_tags: candidates[row.content_id] = row.text_embedding - outgoing_tags[row.content_id] = set(row.link_to_tags or []) + if len(tag_filter) == 0: + outgoing_tags[row.content_id] = set(row.link_to_tags or []) + else: + outgoing_tags[row.content_id] = tag_filter.intersection( + set(row.link_to_tags or []) + ) helper.add_candidates(candidates) if initial_roots: @@ -522,9 +534,14 @@ def fetch_initial_candidates() -> None: new_candidates = {} for adjacent in adjacents: if adjacent.target_content_id not in outgoing_tags: - outgoing_tags[adjacent.target_content_id] = ( - adjacent.target_link_to_tags - ) + if len(tag_filter) == 0: + outgoing_tags[adjacent.target_content_id] = ( + adjacent.target_link_to_tags + ) + else: + outgoing_tags[adjacent.target_content_id] = ( + tag_filter.intersection(adjacent.target_link_to_tags) + ) new_candidates[adjacent.target_content_id] = ( adjacent.target_text_embedding ) @@ -553,6 +570,7 @@ def traversal_search( k: int = 4, depth: int = 1, metadata_filter: dict[str, Any] = {}, # noqa: B006 + tag_filter: set[tuple[str, str]] = {}, ) -> Iterable[Node]: """Retrieve documents from this knowledge store. @@ -566,6 +584,7 @@ def traversal_search( Defaults to 4. depth: The maximum depth of edges to traverse. Defaults to 1. metadata_filter: Optional metadata to filter the results. + tag_filter: Optional tags to filter graph edges to be traversed. Returns: Collection of retrieved documents. @@ -630,7 +649,11 @@ def visit_nodes(d: int, nodes: Sequence[Any]) -> None: # given depth, so we don't fetch it again # (unless we find it an earlier depth) visited_tags[(kind, value)] = d - outgoing_tags.add((kind, value)) + if ( + len(tag_filter) == 0 + or (kind, value) in tag_filter + ): + outgoing_tags.add((kind, value)) if outgoing_tags: # If there are new tags to visit at the next depth, query for the diff --git a/libs/knowledge-store/tests/integration_tests/test_graph_store.py b/libs/knowledge-store/tests/integration_tests/test_graph_store.py index 17d5a7a77..6bb8ba506 100644 --- a/libs/knowledge-store/tests/integration_tests/test_graph_store.py +++ b/libs/knowledge-store/tests/integration_tests/test_graph_store.py @@ -211,6 +211,16 @@ def test_mmr_traversal( results = gs.mmr_traversal_search("0.0", fetch_k=2, k=4, initial_roots=["v0"]) assert _result_ids(results) == ["v1", "v3", "v2"] + results = gs.mmr_traversal_search( + "0.0", k=2, fetch_k=2, tag_filter={("explicit", "link")} + ) + assert _result_ids(results) == ["v0", "v2"] + + results = gs.mmr_traversal_search( + "0.0", k=2, fetch_k=2, tag_filter={("no", "match")} + ) + assert _result_ids(results) == [] + def test_write_retrieve_keywords( graph_store_factory: Callable[[MetadataIndexingType], GraphStore], @@ -282,6 +292,14 @@ def test_write_retrieve_keywords( results = gs.traversal_search("Earth", k=1, depth=1) assert set(_result_ids(results)) == {"doc2", "doc1", "greetings"} + results = gs.traversal_search( + "Earth", k=1, depth=1, tag_filter={("parent", "parent")} + ) + assert set(_result_ids(results)) == {"doc2", "greetings"} + + results = gs.traversal_search("Earth", k=1, depth=1, tag_filter={("no", "match")}) + assert _result_ids(results) == [] + def test_metadata( graph_store_factory: Callable[[MetadataIndexingType], GraphStore],