Skip to content

Commit

Permalink
got the retrieval working now too :D
Browse files Browse the repository at this point in the history
Signed-off-by: Francisco Javier Arceo <[email protected]>
  • Loading branch information
franciscojavierarceo committed Dec 19, 2024
1 parent bcc69cf commit e006129
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 51 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ test-python-universal-milvus-online:
FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.milvus_online_store.milvus_repo_configuration \
PYTEST_PLUGINS=sdk.python.tests.integration.feature_repos.universal.online_store.milvus \
python -m pytest -n 8 --integration \
-k "test_retrieve_online_documents2" \
-k "test_retrieve_online_milvus_ocuments" \
sdk/python/tests --ignore=sdk/python/tests/integration/offline_store/test_dqm_validation.py

test-python-universal-singlestore-online:
Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1843,6 +1843,7 @@ def retrieve_online_documents(
document_feature_vals = [feature[4] for feature in document_features]
document_feature_distance_vals = [feature[5] for feature in document_features]
online_features_response = GetOnlineFeaturesResponse(results=[])
requested_feature = requested_feature or requested_features[0]
utils._populate_result_rows_from_columnar(
online_features_response=online_features_response,
data={
Expand Down
17 changes: 16 additions & 1 deletion sdk/python/feast/infra/key_encoding_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import struct
from typing import List, Tuple
from typing import List, Tuple, Union

from google.protobuf.internal.containers import RepeatedScalarFieldContainer

from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
Expand Down Expand Up @@ -163,3 +165,16 @@ def get_list_val_str(val):
if val.HasField(accept_type):
return str(getattr(val, accept_type).val)
return None


def serialize_f32(
vector: Union[RepeatedScalarFieldContainer[float], List[float]], vector_length: int
) -> bytes:
"""serializes a list of floats into a compact "raw bytes" format"""
return struct.pack(f"{vector_length}f", *vector)


def deserialize_f32(byte_vector: bytes, vector_length: int) -> List[float]:
"""deserializes a list of floats from a compact "raw bytes" format"""
num_floats = vector_length // 4 # 4 bytes per float
return list(struct.unpack(f"{num_floats}f", byte_vector))
51 changes: 22 additions & 29 deletions sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from feast.feature_view import FeatureView
from feast.infra.infra_object import InfraObject
from feast.infra.key_encoding_utils import (
deserialize_entity_key,
serialize_entity_key,
)
from feast.infra.online_stores.online_store import OnlineStore
Expand All @@ -34,6 +33,7 @@
)
from feast.utils import (
_build_retrieve_online_document_record,
_serialize_vector_to_float_list,
to_naive_utc,
)

Expand Down Expand Up @@ -317,13 +317,19 @@ def retrieve_online_documents(
output_fields = (
[composite_key_name] + requested_features + ["created_ts", "event_ts"]
)
assert all(field for field in output_fields if field in [f.name for f in collection.schema.fields]), \
f"field(s) [{[field for field in output_fields if field not in [f.name for f in collection.schema.fields]]}'] not found in collection schema"
assert all(
field
for field in output_fields
if field in [f.name for f in collection.schema.fields]
), f"field(s) [{[field for field in output_fields if field not in [f.name for f in collection.schema.fields]]}'] not found in collection schema"

# Note we choose the first vector field as the field to search on. Not ideal but it's something.
ann_search_field = None
for field in collection.schema.fields:
if field.dtype in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
if (
field.dtype in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]
and field.name in output_fields
):
ann_search_field = field.name
break

Expand All @@ -342,36 +348,23 @@ def retrieve_online_documents(
for hit in hits:
single_record = {}
for field in output_fields:
val = hit.entity.get(field)
if field == composite_key_name:
val = deserialize_entity_key(
bytes.fromhex(val),
config.entity_key_serialization_version,
)
entity_key_proto = val
single_record[field] = val

single_record[field] = hit.entity.get(field)

entity_key_str = hit.entity.get(composite_key_name)
val_bin = hit.entity.get("value")
val = ValueProto()
val.ParseFromString(val_bin)
entity_key_bytes = bytes.fromhex(hit.entity.get(composite_key_name))
embedding = hit.entity.get(ann_search_field)
serialized_embedding = _serialize_vector_to_float_list(embedding)
distance = hit.distance
event_ts = datetime.fromtimestamp(hit.entity.get("event_ts") / 1e6)
entity_key = deserialize_entity_key(
bytes.fromhex(entity_key_str),
prepared_result = _build_retrieve_online_document_record(
entity_key_bytes,
# This may have a bug
serialized_embedding.SerializeToString(),
embedding,
distance,
event_ts,
config.entity_key_serialization_version,
)
result_list.append(
_build_retrieve_online_document_record(
entity_key_proto,
val.SerializeToString(),
embedding,
distance,
event_ts,
config.entity_key_serialization_version,
)
)
result_list.append(prepared_result)
return result_list


Expand Down
4 changes: 4 additions & 0 deletions sdk/python/feast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,6 +1192,10 @@ def _utc_now() -> datetime:
return datetime.now(tz=timezone.utc)


def _serialize_vector_to_float_list(vector: List[float]) -> FloatListProto:
return ValueProto(float_list_val=FloatListProto(val=vector))


def _build_retrieve_online_document_record(
entity_key: Union[str, bytes],
feature_value: Union[str, bytes],
Expand Down
24 changes: 4 additions & 20 deletions sdk/python/tests/integration/online_store/test_universal_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,7 +895,7 @@ def test_retrieve_online_documents(vectordb_environment, fake_document_data):

@pytest.mark.integration
@pytest.mark.universal_online_stores(only=["milvus"])
def test_retrieve_online_documents2(environment, fake_document_data):
def test_retrieve_online_milvus_ocuments(environment, fake_document_data):
print(environment.online_store)
fs = environment.feature_store
df, data_source = fake_document_data
Expand All @@ -914,22 +914,6 @@ def test_retrieve_online_documents2(environment, fake_document_data):
distance_metric="L2",
).to_dict()
assert len(documents["embedding_float"]) == 2
#
# # assert returned the entity_id
# assert len(documents["item_id"]) == 2
#
# documents = fs.retrieve_online_documents(
# feature="item_embeddings:embedding_float",
# query=[1.0, 2.0],
# top_k=2,
# distance_metric="L1",
# ).to_dict()
# assert len(documents["embedding_float"]) == 2
#
# with pytest.raises(ValueError):
# fs.retrieve_online_documents(
# feature="item_embeddings:embedding_float",
# query=[1.0, 2.0],
# top_k=2,
# distance_metric="wrong",
# ).to_dict()

assert len(documents["item_id"]) == 2
assert documents["item_id"] == [2, 3]

0 comments on commit e006129

Please sign in to comment.