-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathcassandra.py
144 lines (132 loc) · 4.56 KB
/
cassandra.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from typing import (
Any,
Iterable,
List,
Optional,
Type,
)
from cassandra.cluster import Session
from langchain_community.utilities.cassandra import SetupMode
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from .base import GraphStore, Node, nodes_to_documents
from .embedding_adapter import EmbeddingAdapter
from ragstack_knowledge_store import graph_store
class CassandraGraphStore(GraphStore):
def __init__(
self,
embedding: Embeddings,
*,
node_table: str = "graph_nodes",
targets_table: str = "graph_targets",
session: Optional[Session] = None,
keyspace: Optional[str] = None,
setup_mode: SetupMode = SetupMode.SYNC,
):
"""
Create the hybrid graph store.
Parameters configure the ways that edges should be added between
documents. Many take `Union[bool, Set[str]]`, with `False` disabling
inference, `True` enabling it globally between all documents, and a set
of metadata fields defining a scope in which to enable it. Specifically,
passing a set of metadata fields such as `source` only links documents
with the same `source` metadata value.
Args:
embedding: The embeddings to use for the document content.
setup_mode: Mode used to create the Cassandra table (SYNC,
ASYNC or OFF).
"""
self._embedding = embedding
_setup_mode = getattr(graph_store.SetupMode, setup_mode.name)
self.store = graph_store.GraphStore(
embedding=EmbeddingAdapter(embedding),
node_table=node_table,
targets_table=targets_table,
session=session,
keyspace=keyspace,
setup_mode=_setup_mode,
)
@property
def embeddings(self) -> Optional[Embeddings]:
return self._embedding
def add_nodes(
self,
nodes: Iterable[Node],
**kwargs: Any,
) -> Iterable[str]:
_nodes = []
for node in nodes:
_nodes.append(
graph_store.Node(id=node.id, text=node.text, mime_type=node.mime_type, mime_encoding=node.mime_encoding, metadata=node.metadata)
)
return self.store.add_nodes(_nodes)
@classmethod
def from_texts(
cls: Type["CassandraGraphStore"],
texts: Iterable[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[Iterable[str]] = None,
**kwargs: Any,
) -> "CassandraGraphStore":
"""Return CassandraGraphStore initialized from texts and embeddings."""
store = cls(embedding, **kwargs)
store.add_texts(texts, metadatas, ids=ids)
return store
@classmethod
def from_documents(
cls: Type["CassandraGraphStore"],
documents: Iterable[Document],
embedding: Embeddings,
ids: Optional[Iterable[str]] = None,
**kwargs: Any,
) -> "CassandraGraphStore":
"""Return CassandraGraphStore initialized from documents and embeddings."""
store = cls(embedding, **kwargs)
store.add_documents(documents, ids=ids)
return store
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
embedding_vector = self._embedding.embed_query(query)
return self.similarity_search_by_vector(
embedding_vector,
k=k,
)
def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Document]:
nodes = self.store.similarity_search(embedding, k=k)
return list(nodes_to_documents(nodes))
def traversal_search(
self,
query: str,
*,
k: int = 4,
depth: int = 1,
**kwargs: Any,
) -> Iterable[Document]:
nodes = self.store.traversal_search(query, k=k, depth=depth)
return nodes_to_documents(nodes)
def mmr_traversal_search(
self,
query: str,
*,
k: int = 4,
depth: int = 2,
fetch_k: int = 100,
adjacent_k: int = 10,
lambda_mult: float = 0.5,
score_threshold: float = float("-inf"),
**kwargs: Any,
) -> Iterable[Document]:
nodes = self.store.mmr_traversal_search(
query,
k=k,
depth=depth,
fetch_k=fetch_k,
adjacent_k=adjacent_k,
lambda_mult=lambda_mult,
score_threshold=score_threshold,
)
return nodes_to_documents(nodes)