diff --git a/Makefile b/Makefile index 170733caeb..3d000848d0 100644 --- a/Makefile +++ b/Makefile @@ -358,7 +358,7 @@ test-python-universal-milvus-online: FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.milvus_online_store.milvus_repo_configuration \ PYTEST_PLUGINS=sdk.python.tests.integration.feature_repos.universal.online_store.milvus \ python -m pytest -n 8 --integration \ - -k "test_retrieve_online_documents2" \ + -k "test_retrieve_online_milvus_ocuments" \ sdk/python/tests --ignore=sdk/python/tests/integration/offline_store/test_dqm_validation.py test-python-universal-singlestore-online: diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 6ad3313d95..b5a836fc4f 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -1843,6 +1843,7 @@ def retrieve_online_documents( document_feature_vals = [feature[4] for feature in document_features] document_feature_distance_vals = [feature[5] for feature in document_features] online_features_response = GetOnlineFeaturesResponse(results=[]) + requested_feature = requested_feature or requested_features[0] utils._populate_result_rows_from_columnar( online_features_response=online_features_response, data={ diff --git a/sdk/python/feast/infra/key_encoding_utils.py b/sdk/python/feast/infra/key_encoding_utils.py index 1f9ffeef14..18127896bd 100644 --- a/sdk/python/feast/infra/key_encoding_utils.py +++ b/sdk/python/feast/infra/key_encoding_utils.py @@ -1,5 +1,7 @@ import struct -from typing import List, Tuple +from typing import List, Tuple, Union + +from google.protobuf.internal.containers import RepeatedScalarFieldContainer from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto @@ -163,3 +165,16 @@ def get_list_val_str(val): if val.HasField(accept_type): return str(getattr(val, accept_type).val) return None + + +def serialize_f32( + vector: Union[RepeatedScalarFieldContainer[float], List[float]], vector_length: int +) -> bytes: + """serializes a list of floats into a compact "raw bytes" format""" + return struct.pack(f"{vector_length}f", *vector) + + +def deserialize_f32(byte_vector: bytes, vector_length: int) -> List[float]: + """deserializes a list of floats from a compact "raw bytes" format""" + num_floats = vector_length // 4 # 4 bytes per float + return list(struct.unpack(f"{num_floats}f", byte_vector)) diff --git a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py index ab41072c84..a0a0ab56f6 100644 --- a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py +++ b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py @@ -14,7 +14,6 @@ from feast.feature_view import FeatureView from feast.infra.infra_object import InfraObject from feast.infra.key_encoding_utils import ( - deserialize_entity_key, serialize_entity_key, ) from feast.infra.online_stores.online_store import OnlineStore @@ -34,6 +33,7 @@ ) from feast.utils import ( _build_retrieve_online_document_record, + _serialize_vector_to_float_list, to_naive_utc, ) @@ -317,13 +317,19 @@ def retrieve_online_documents( output_fields = ( [composite_key_name] + requested_features + ["created_ts", "event_ts"] ) - assert all(field for field in output_fields if field in [f.name for f in collection.schema.fields]), \ - f"field(s) [{[field for field in output_fields if field not in [f.name for f in collection.schema.fields]]}'] not found in collection schema" + assert all( + field + for field in output_fields + if field in [f.name for f in collection.schema.fields] + ), f"field(s) [{[field for field in output_fields if field not in [f.name for f in collection.schema.fields]]}'] not found in collection schema" # Note we choose the first vector field as the field to search on. Not ideal but it's something. ann_search_field = None for field in collection.schema.fields: - if field.dtype in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: + if ( + field.dtype in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR] + and field.name in output_fields + ): ann_search_field = field.name break @@ -342,36 +348,23 @@ def retrieve_online_documents( for hit in hits: single_record = {} for field in output_fields: - val = hit.entity.get(field) - if field == composite_key_name: - val = deserialize_entity_key( - bytes.fromhex(val), - config.entity_key_serialization_version, - ) - entity_key_proto = val - single_record[field] = val - + single_record[field] = hit.entity.get(field) - entity_key_str = hit.entity.get(composite_key_name) - val_bin = hit.entity.get("value") - val = ValueProto() - val.ParseFromString(val_bin) + entity_key_bytes = bytes.fromhex(hit.entity.get(composite_key_name)) + embedding = hit.entity.get(ann_search_field) + serialized_embedding = _serialize_vector_to_float_list(embedding) distance = hit.distance event_ts = datetime.fromtimestamp(hit.entity.get("event_ts") / 1e6) - entity_key = deserialize_entity_key( - bytes.fromhex(entity_key_str), + prepared_result = _build_retrieve_online_document_record( + entity_key_bytes, + # This may have a bug + serialized_embedding.SerializeToString(), + embedding, + distance, + event_ts, config.entity_key_serialization_version, ) - result_list.append( - _build_retrieve_online_document_record( - entity_key_proto, - val.SerializeToString(), - embedding, - distance, - event_ts, - config.entity_key_serialization_version, - ) - ) + result_list.append(prepared_result) return result_list diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index 51d4bf4f2c..8e90291056 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -1192,6 +1192,10 @@ def _utc_now() -> datetime: return datetime.now(tz=timezone.utc) +def _serialize_vector_to_float_list(vector: List[float]) -> FloatListProto: + return ValueProto(float_list_val=FloatListProto(val=vector)) + + def _build_retrieve_online_document_record( entity_key: Union[str, bytes], feature_value: Union[str, bytes], diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index 5de41062f1..f113e25555 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -895,7 +895,7 @@ def test_retrieve_online_documents(vectordb_environment, fake_document_data): @pytest.mark.integration @pytest.mark.universal_online_stores(only=["milvus"]) -def test_retrieve_online_documents2(environment, fake_document_data): +def test_retrieve_online_milvus_ocuments(environment, fake_document_data): print(environment.online_store) fs = environment.feature_store df, data_source = fake_document_data @@ -914,22 +914,6 @@ def test_retrieve_online_documents2(environment, fake_document_data): distance_metric="L2", ).to_dict() assert len(documents["embedding_float"]) == 2 - # - # # assert returned the entity_id - # assert len(documents["item_id"]) == 2 - # - # documents = fs.retrieve_online_documents( - # feature="item_embeddings:embedding_float", - # query=[1.0, 2.0], - # top_k=2, - # distance_metric="L1", - # ).to_dict() - # assert len(documents["embedding_float"]) == 2 - # - # with pytest.raises(ValueError): - # fs.retrieve_online_documents( - # feature="item_embeddings:embedding_float", - # query=[1.0, 2.0], - # top_k=2, - # distance_metric="wrong", - # ).to_dict() + + assert len(documents["item_id"]) == 2 + assert documents["item_id"] == [2, 3]