Skip to content

Commit

Permalink
Merge pull request #22 from karbasia/bug/serializing
Browse files Browse the repository at this point in the history
Resolve serialization issues
  • Loading branch information
zc277584121 authored Jun 8, 2024
2 parents eba5784 + 3b6df3a commit f0b8520
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 1 deletion.
31 changes: 30 additions & 1 deletion src/milvus_haystack/milvus_embedding_retriever.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, List, Optional

from haystack import Document, component
from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict

from milvus_haystack import MilvusDocumentStore

Expand All @@ -23,6 +23,35 @@ def __init__(self, document_store: MilvusDocumentStore, filters: Optional[Dict[s
self.top_k = top_k
self.document_store = document_store

def to_dict(self) -> Dict[str, Any]:
"""
Returns a dictionary representation of the retriever component.
:returns:
A dictionary representation of the retriever component.
"""
return default_to_dict(
self, document_store=self.document_store.to_dict(), filters=self.filters, top_k=self.top_k
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MilvusEmbeddingRetriever":
"""
Creates a new retriever from a dictionary.
:param data: The dictionary to use to create the retriever.
:return: A new retriever.
"""
init_params = data.get("init_parameters", {})
if "document_store" not in init_params:
err_msg = "Missing 'document_store' in serialization data"
raise DeserializationError(err_msg)

docstore = MilvusDocumentStore.from_dict(init_params["document_store"])
data["init_parameters"]["document_store"] = docstore

return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, query_embedding: List[float]) -> Dict[str, List[Document]]:
"""
Expand Down
78 changes: 78 additions & 0 deletions tests/test_embedding_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,81 @@ def test_run(self, document_store: MilvusDocumentStore):
query_embedding = [-10.0] * 128
res = retriever.run(query_embedding)
assert res["documents"] == documents

def test_to_dict(self, document_store: MilvusDocumentStore):
expected_dict = {
"type": "src.milvus_haystack.document_store.MilvusDocumentStore",
"init_parameters": {
"collection_name": "HaystackCollection",
"collection_description": "",
"collection_properties": None,
"connection_args": {"host": "localhost", "port": "19530", "user": "", "password": "", "secure": False},
"consistency_level": "Session",
"index_params": None,
"search_params": None,
"drop_old": True,
"primary_field": "id",
"text_field": "text",
"vector_field": "vector",
"partition_key_field": None,
"partition_names": None,
"replica_number": 1,
"timeout": None,
},
}
retriever = MilvusEmbeddingRetriever(document_store)
result = retriever.to_dict()

assert result["type"] == "src.milvus_haystack.milvus_embedding_retriever.MilvusEmbeddingRetriever"
assert result["init_parameters"]["document_store"] == expected_dict

def test_from_dict(self, document_store: MilvusDocumentStore):
retriever_dict = {
"type": "src.milvus_haystack.milvus_embedding_retriever.MilvusEmbeddingRetriever",
"init_parameters": {
"document_store": {
"type": "milvus_haystack.document_store.MilvusDocumentStore",
"init_parameters": {
"collection_name": "HaystackCollection",
"collection_description": "",
"collection_properties": None,
"connection_args": {
"host": "localhost",
"port": "19530",
"user": "",
"password": "",
"secure": False,
},
"consistency_level": "Session",
"index_params": None,
"search_params": None,
"drop_old": True,
"primary_field": "id",
"text_field": "text",
"vector_field": "vector",
"partition_key_field": None,
"partition_names": None,
"replica_number": 1,
"timeout": None,
},
},
"filters": None,
"top_k": 10,
},
}

retriever = MilvusEmbeddingRetriever(document_store)

reconstructed_retriever = MilvusEmbeddingRetriever.from_dict(retriever_dict)
for field in vars(reconstructed_retriever):
if field.startswith("__"):
continue
elif field == "document_store":
for doc_store_field in vars(document_store):
if doc_store_field.startswith("__"):
continue
assert getattr(reconstructed_retriever.document_store, doc_store_field) == getattr(
document_store, doc_store_field
)
else:
assert getattr(reconstructed_retriever, field) == getattr(retriever, field)

0 comments on commit f0b8520

Please sign in to comment.