Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: introduce SparseEmbedding #7382

Merged
merged 3 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/pydoc/config/data_classess_api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../../../haystack/dataclasses]
modules:
["answer", "byte_stream", "chat_message", "document", "streaming_chunk"]
["answer", "byte_stream", "chat_message", "document", "streaming_chunk", "sparse_embedding"]
ignore_when_discovered: ["__init__"]
processors:
- type: filter
Expand Down
2 changes: 2 additions & 0 deletions haystack/dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from haystack.dataclasses.byte_stream import ByteStream
from haystack.dataclasses.chat_message import ChatMessage, ChatRole
from haystack.dataclasses.document import Document
from haystack.dataclasses.sparse_embedding import SparseEmbedding
from haystack.dataclasses.streaming_chunk import StreamingChunk

__all__ = [
Expand All @@ -13,4 +14,5 @@
"ChatMessage",
"ChatRole",
"StreamingChunk",
"SparseEmbedding",
]
16 changes: 14 additions & 2 deletions haystack/dataclasses/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from haystack import logging
from haystack.dataclasses.byte_stream import ByteStream
from haystack.dataclasses.sparse_embedding import SparseEmbedding

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -57,7 +58,8 @@ class Document(metaclass=_BackwardCompatible):
:param blob: Binary data associated with the document, if the document has any binary data associated with it.
:param meta: Additional custom metadata for the document. Must be JSON-serializable.
:param score: Score of the document. Used for ranking, usually assigned by retrievers.
:param embedding: Vector representation of the document.
:param embedding: dense vector representation of the document.
:param sparse_embedding: sparse vector representation of the document.
"""

id: str = field(default="")
Expand All @@ -67,6 +69,7 @@ class Document(metaclass=_BackwardCompatible):
meta: Dict[str, Any] = field(default_factory=dict)
score: Optional[float] = field(default=None)
embedding: Optional[List[float]] = field(default=None)
sparse_embedding: Optional[SparseEmbedding] = field(default=None)

def __repr__(self):
fields = []
Expand All @@ -84,6 +87,8 @@ def __repr__(self):
fields.append(f"score: {self.score}")
if self.embedding is not None:
fields.append(f"embedding: vector of size {len(self.embedding)}")
if self.sparse_embedding is not None:
fields.append(f"sparse_embedding: vector with {len(self.sparse_embedding.indices)} non-zero elements")
fields_str = ", ".join(fields)
return f"{self.__class__.__name__}(id={self.id}, {fields_str})"

Expand Down Expand Up @@ -114,7 +119,8 @@ def _create_id(self):
mime_type = self.blob.mime_type if self.blob is not None else None
meta = self.meta or {}
embedding = self.embedding if self.embedding is not None else None
data = f"{text}{dataframe}{blob}{mime_type}{meta}{embedding}"
sparse_embedding = self.sparse_embedding.to_dict() if self.sparse_embedding is not None else ""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This differs a bit from the other ones to not alter the id of existing Documents.
I can change it if you think it's better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! This approach looks good to me 👍

data = f"{text}{dataframe}{blob}{mime_type}{meta}{embedding}{sparse_embedding}"
return hashlib.sha256(data.encode("utf-8")).hexdigest()

def to_dict(self, flatten=True) -> Dict[str, Any]:
Expand All @@ -132,6 +138,9 @@ def to_dict(self, flatten=True) -> Dict[str, Any]:
if (blob := data.get("blob")) is not None:
data["blob"] = {"data": list(blob["data"]), "mime_type": blob["mime_type"]}

if (sparse_embedding := data.get("sparse_embedding")) is not None:
data["sparse_embedding"] = sparse_embedding.to_dict()

if flatten:
meta = data.pop("meta")
return {**data, **meta}
Expand All @@ -149,6 +158,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "Document":
data["dataframe"] = read_json(io.StringIO(dataframe))
if blob := data.get("blob"):
data["blob"] = ByteStream(data=bytes(blob["data"]), mime_type=blob["mime_type"])
if sparse_embedding := data.get("sparse_embedding"):
data["sparse_embedding"] = SparseEmbedding.from_dict(sparse_embedding)

# Store metadata for a moment while we try un-flattening allegedly flatten metadata.
# We don't expect both a `meta=` keyword and flatten metadata keys so we'll raise a
# ValueError later if this is the case.
Expand Down
26 changes: 26 additions & 0 deletions haystack/dataclasses/sparse_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import List


class SparseEmbedding:
"""
Class representing a sparse embedding.
"""

def __init__(self, indices: List[int], values: List[float]):
"""
:param indices: List of indices of non-zero elements in the embedding.
:param values: List of values of non-zero elements in the embedding.
:raises ValueError: If the indices and values lists are not of the same length.
"""
if len(indices) != len(values):
raise ValueError("Length of indices and values must be the same.")
self.indices = indices
self.values = values

def to_dict(self):
return {"indices": self.indices, "values": self.values}

@classmethod
def from_dict(cls, sparse_embedding_dict):
return cls(indices=sparse_embedding_dict["indices"], values=sparse_embedding_dict["values"])
7 changes: 7 additions & 0 deletions releasenotes/notes/sparse-embedding-fd55b670437492be.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
features:
- |
Introduce a new `SparseEmbedding` class which can be used to store a sparse
vector representation of a Document.
It will be instrumental to support Sparse Embedding Retrieval with
the subsequent introduction of Sparse Embedders and Sparse Embedding Retrievers.
19 changes: 18 additions & 1 deletion test/dataclasses/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from haystack import Document
from haystack.dataclasses.byte_stream import ByteStream
from haystack.dataclasses.sparse_embedding import SparseEmbedding


@pytest.mark.parametrize(
Expand Down Expand Up @@ -37,6 +38,7 @@ def test_init():
assert doc.meta == {}
assert doc.score == None
assert doc.embedding == None
assert doc.sparse_embedding == None


def test_init_with_wrong_parameters():
Expand All @@ -46,15 +48,17 @@ def test_init_with_wrong_parameters():

def test_init_with_parameters():
blob_data = b"some bytes"
sparse_embedding = SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3])
doc = Document(
content="test text",
dataframe=pd.DataFrame([0]),
blob=ByteStream(data=blob_data, mime_type="text/markdown"),
meta={"text": "test text"},
score=0.812,
embedding=[0.1, 0.2, 0.3],
sparse_embedding=sparse_embedding,
)
assert doc.id == "ec92455f3f4576d40031163c89b1b4210b34ea1426ee0ff68ebed86cb7ba13f8"
assert doc.id == "967b7bd4a21861ad9e863f638cefcbdd6bf6306bebdd30aa3fedf0c26bc636ed"
assert doc.content == "test text"
assert doc.dataframe is not None
assert doc.dataframe.equals(pd.DataFrame([0]))
Expand All @@ -63,6 +67,7 @@ def test_init_with_parameters():
assert doc.meta == {"text": "test text"}
assert doc.score == 0.812
assert doc.embedding == [0.1, 0.2, 0.3]
assert doc.sparse_embedding == sparse_embedding


def test_init_with_legacy_fields():
Expand All @@ -76,6 +81,7 @@ def test_init_with_legacy_fields():
assert doc.meta == {}
assert doc.score == 0.812
assert doc.embedding == [0.1, 0.2, 0.3]
assert doc.sparse_embedding == None


def test_init_with_legacy_field():
Expand All @@ -93,6 +99,7 @@ def test_init_with_legacy_field():
assert doc.meta == {"date": "10-10-2023", "type": "article"}
assert doc.score == 0.812
assert doc.embedding == [0.1, 0.2, 0.3]
assert doc.sparse_embedding == None


def test_basic_equality_type_mismatch():
Expand Down Expand Up @@ -121,6 +128,7 @@ def test_to_dict():
"blob": None,
"score": None,
"embedding": None,
"sparse_embedding": None,
}


Expand All @@ -134,6 +142,7 @@ def test_to_dict_without_flattening():
"meta": {},
"score": None,
"embedding": None,
"sparse_embedding": None,
}


Expand All @@ -145,6 +154,7 @@ def test_to_dict_with_custom_parameters():
meta={"some": "values", "test": 10},
score=0.99,
embedding=[10.0, 10.0],
sparse_embedding=SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3]),
)

assert doc.to_dict() == {
Expand All @@ -156,6 +166,7 @@ def test_to_dict_with_custom_parameters():
"test": 10,
"score": 0.99,
"embedding": [10.0, 10.0],
"sparse_embedding": {"indices": [0, 2, 4], "values": [0.1, 0.2, 0.3]},
}


Expand All @@ -167,6 +178,7 @@ def test_to_dict_with_custom_parameters_without_flattening():
meta={"some": "values", "test": 10},
score=0.99,
embedding=[10.0, 10.0],
sparse_embedding=SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3]),
)

assert doc.to_dict(flatten=False) == {
Expand All @@ -177,6 +189,7 @@ def test_to_dict_with_custom_parameters_without_flattening():
"meta": {"some": "values", "test": 10},
"score": 0.99,
"embedding": [10, 10],
"sparse_embedding": {"indices": [0, 2, 4], "values": [0.1, 0.2, 0.3]},
}


Expand All @@ -194,6 +207,7 @@ def from_from_dict_with_parameters():
"meta": {"text": "test text"},
"score": 0.812,
"embedding": [0.1, 0.2, 0.3],
"sparse_embedding": {"indices": [0, 2, 4], "values": [0.1, 0.2, 0.3]},
}
) == Document(
content="test text",
Expand All @@ -202,6 +216,7 @@ def from_from_dict_with_parameters():
meta={"text": "test text"},
score=0.812,
embedding=[0.1, 0.2, 0.3],
sparse_embedding=SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3]),
)


Expand Down Expand Up @@ -249,6 +264,7 @@ def test_from_dict_with_flat_meta():
"blob": {"data": list(blob_data), "mime_type": "text/markdown"},
"score": 0.812,
"embedding": [0.1, 0.2, 0.3],
"sparse_embedding": {"indices": [0, 2, 4], "values": [0.1, 0.2, 0.3]},
"date": "10-10-2023",
"type": "article",
}
Expand All @@ -258,6 +274,7 @@ def test_from_dict_with_flat_meta():
blob=ByteStream(blob_data, mime_type="text/markdown"),
score=0.812,
embedding=[0.1, 0.2, 0.3],
sparse_embedding=SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3]),
meta={"date": "10-10-2023", "type": "article"},
)

Expand Down
23 changes: 23 additions & 0 deletions test/dataclasses/test_sparse_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

from haystack.dataclasses.sparse_embedding import SparseEmbedding


class TestSparseEmbedding:
def test_init(self):
se = SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3])
assert se.indices == [0, 2, 4]
assert se.values == [0.1, 0.2, 0.3]

def test_init_with_wrong_parameters(self):
with pytest.raises(ValueError):
SparseEmbedding(indices=[0, 2], values=[0.1, 0.2, 0.3, 0.4])

def test_to_dict(self):
se = SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3])
assert se.to_dict() == {"indices": [0, 2, 4], "values": [0.1, 0.2, 0.3]}

def test_from_dict(self):
se = SparseEmbedding.from_dict({"indices": [0, 2, 4], "values": [0.1, 0.2, 0.3]})
assert se.indices == [0, 2, 4]
assert se.values == [0.1, 0.2, 0.3]
6 changes: 3 additions & 3 deletions test/tracing/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ class TestTypeCoercion:
(NonSerializableClass(), "NonSerializableClass"),
(
Document(id="1", content="text"),
'{"id": "1", "content": "text", "dataframe": null, "blob": null, "score": null, "embedding": null}',
'{"id": "1", "content": "text", "dataframe": null, "blob": null, "score": null, "embedding": null, "sparse_embedding": null}',
),
(
[Document(id="1", content="text")],
'[{"id": "1", "content": "text", "dataframe": null, "blob": null, "score": null, "embedding": null}]',
'[{"id": "1", "content": "text", "dataframe": null, "blob": null, "score": null, "embedding": null, "sparse_embedding": null}]',
),
(
{"key": Document(id="1", content="text")},
'{"key": {"id": "1", "content": "text", "dataframe": null, "blob": null, "score": null, "embedding": null}}',
'{"key": {"id": "1", "content": "text", "dataframe": null, "blob": null, "score": null, "embedding": null, "sparse_embedding": null}}',
),
],
)
Expand Down