diff --git a/google/cloud/datastore/helpers.py b/google/cloud/datastore/helpers.py index 6eaa3b89..713d1c1c 100644 --- a/google/cloud/datastore/helpers.py +++ b/google/cloud/datastore/helpers.py @@ -29,6 +29,8 @@ from google.cloud.datastore_v1.types import entity as entity_pb2 from google.cloud.datastore.entity import Entity from google.cloud.datastore.key import Key +from google.cloud.datastore.vector import Vector +from google.cloud.datastore.vector import _VECTOR_VALUE from google.protobuf import timestamp_pb2 @@ -401,6 +403,8 @@ def _pb_attr_value(val): name, value = "array", val elif isinstance(val, GeoPoint): name, value = "geo_point", val.to_protobuf() + elif isinstance(val, Vector): + name, value = "vector", entity_pb2.Value(**val._to_dict()) elif val is None: name, value = "null", struct_pb2.NULL_VALUE else: @@ -457,6 +461,11 @@ def _get_value_from_value_pb(pb): result = [ _get_value_from_value_pb(item_value) for item_value in pb.array_value.values ] + # check for vector values + if pb.meaning == _VECTOR_VALUE and all( + isinstance(item, float) for item in result + ): + result = Vector(result, exclude_from_indexes=bool(pb.exclude_from_indexes)) elif value_type == "geo_point_value": result = GeoPoint( @@ -507,6 +516,8 @@ def _set_protobuf_value(value_pb, val): for item in val: i_pb = l_pb.add() _set_protobuf_value(i_pb, item) + elif attr == "vector_value": + value_pb.CopyFrom(val._pb) elif attr == "geo_point_value": value_pb.geo_point_value.CopyFrom(val) else: # scalar, just assign diff --git a/google/cloud/datastore/query.py b/google/cloud/datastore/query.py index 5ff27366..940c342d 100644 --- a/google/cloud/datastore/query.py +++ b/google/cloud/datastore/query.py @@ -183,6 +183,11 @@ class Query(object): this query. When set, explain_metrics will be available on the iterator returned by query.fetch(). + :type find_nearest: :class:`~google.cloud.datastore.vector.FindNearest` + :param find_nearest: (Optional) Options to perform a vector search for + entities in the query. When set, the query will return entities + sorted by distance from the query vector. + :raises: ValueError if ``project`` is not passed and no implicit default is set. """ @@ -211,6 +216,7 @@ def __init__( order=(), distinct_on=(), explain_options=None, + find_nearest=None, ): self._client = client self._kind = kind @@ -232,6 +238,7 @@ def __init__( self._explain_options = explain_options self._ancestor = ancestor self._filters = [] + self.find_nearest = find_nearest # Verify filters passed in. for filter in filters: @@ -944,6 +951,9 @@ def _pb_from_query(query): ref.name = distinct_on_name pb.distinct_on.append(ref) + if query.find_nearest: + pb.find_nearest = query_pb2.FindNearest(**query.find_nearest._to_dict()) + return pb diff --git a/google/cloud/datastore/vector.py b/google/cloud/datastore/vector.py new file mode 100644 index 00000000..b9e19e4c --- /dev/null +++ b/google/cloud/datastore/vector.py @@ -0,0 +1,134 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import collections + +from typing import Sequence +from dataclasses import dataclass +from enum import Enum + +_VECTOR_VALUE = 31 + + +class DistanceMeasure(Enum): + EUCLIDEAN = 1 + COSINE = 2 + DOT_PRODUCT = 3 + + +class Vector(collections.abc.Sequence): + """A class to represent a Vector for use in query.find_nearest. + Underlying object will be converted to a map representation in Firestore API. + """ + + def __init__(self, value: Sequence[float], *, exclude_from_indexes: bool = False): + self.exclude_from_indexes = exclude_from_indexes + self._value = tuple([float(v) for v in value]) + + def __getitem__(self, arg: int | slice): + if isinstance(arg, int): + return self._value[arg] + elif isinstance(arg, slice): + return Vector( + self._value[arg], exclude_from_indexes=self.exclude_from_indexes + ) + else: + raise NotImplementedError + + def __len__(self): + return len(self._value) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Vector): + raise NotImplementedError + return self._value == other._value + + def __repr__(self): + return f"Vector<{str(self._value)[1:-1]}>" + + def _to_dict(self): + return { + "array_value": {"values": [{"double_value": v} for v in self._value]}, + "meaning": _VECTOR_VALUE, + "exclude_from_indexes": self.exclude_from_indexes, + } + + +@dataclass +class FindNearest: + """ + Represents configuration for a find_nearest vector query. + + :type vector_field: str + :param vector_field: + An indexed vector property to search upon. + Only documents which contain vectors whose dimensionality match + the query_vector can be returned. + + :type query_vector: Union[Vector, Sequence[float]] + :param query_vector: + The query vector that we are searching on. + Must be a vector of no more than 2048 dimensions. + + :type limit: int + :param limit: + The number of nearest neighbors to return. + Must be a positive integer of no more than 100. + + :type distance_measure: DistanceMeasure + :param distance_measure: + The distance measure to use when comparing vectors. + + :type distance_result_property: Optional[str] + :param distance_result_property: + Optional name of the field to output the result of the vector distance + calculation. + + :type distance_threshold: Optional[float] + :param distance_threshold: + Threshold value for which no less similar documents will be returned. + The behavior of the specified ``distance_measure`` will affect the + meaning of the distance threshold: + For EUCLIDEAN, COSINE: WHERE distance <= distance_threshold + For DOT_PRODUCT: WHERE distance >= distance_threshold + + Optional threshold to apply to the distance measure. + If set, only documents whose distance measure is less than this value + will be returned. + """ + + vector_property: str + query_vector: Vector | Sequence[float] + limit: int + distance_measure: DistanceMeasure + distance_result_property: str | None = None + distance_threshold: float | None = None + + def __post_init__(self): + if not isinstance(self.query_vector, Vector): + self.query_vector = Vector(self.query_vector) + + def _to_dict(self): + output = { + "vector_property": {"name": self.vector_property}, + "query_vector": self.query_vector._to_dict(), + "distance_measure": self.distance_measure.value, + "limit": self.limit, + } + if self.distance_result_property is not None: + output["distance_result_property"] = self.distance_result_property + if self.distance_threshold is not None: + output["distance_threshold"] = float(self.distance_threshold) + return output diff --git a/tests/system/index.yaml b/tests/system/index.yaml index 1f27c246..0cabaf0e 100644 --- a/tests/system/index.yaml +++ b/tests/system/index.yaml @@ -45,3 +45,10 @@ indexes: - name: family - name: appearances +- kind: LargeCharacter + properties: + - name: __key__ + - name: vector + vectorConfig: + dimension: 10 + flat: {} \ No newline at end of file diff --git a/tests/system/test_put.py b/tests/system/test_put.py index 4cb5f6e8..161f62b3 100644 --- a/tests/system/test_put.py +++ b/tests/system/test_put.py @@ -176,3 +176,41 @@ def test_client_put_w_empty_array(datastore_client, entities_to_delete, database retrieved = local_client.get(entity.key) assert entity["children"] == retrieved["children"] + + +@pytest.mark.parametrize("data", [[0], (1.0, 2.0, 3.0), range(100)]) +@pytest.mark.parametrize("exclude_from_indexes", [True, False]) +@pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) +def test_client_put_w_vector( + datastore_client, entities_to_delete, database_id, data, exclude_from_indexes +): + local_client = _helpers.clone_client(datastore_client) + + key = local_client.key("VectorArray", 1234) + entity = datastore.Entity(key=key) + entity["vec"] = datastore.vector.Vector( + data, exclude_from_indexes=exclude_from_indexes + ) + local_client.put(entity) + entities_to_delete.append(entity) + + retrieved = local_client.get(entity.key) + + assert entity["vec"] == retrieved["vec"] + assert entity["vec"]._to_dict() == retrieved["vec"]._to_dict() + assert entity["vec"].exclude_from_indexes == exclude_from_indexes + assert retrieved["vec"].exclude_from_indexes == exclude_from_indexes + + +@pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) +def test_client_put_w_empty_vector(datastore_client, entities_to_delete, database_id): + from google.api_core.exceptions import BadRequest + + local_client = _helpers.clone_client(datastore_client) + + key = local_client.key("VectorArray", 1234) + entity = datastore.Entity(key=key) + entity["vec"] = datastore.vector.Vector([]) + with pytest.raises(BadRequest) as e: + local_client.put(entity) + assert "Cannot have a zero length vector" in str(e) diff --git a/tests/system/test_query.py b/tests/system/test_query.py index 99dce2ec..24fa7f28 100644 --- a/tests/system/test_query.py +++ b/tests/system/test_query.py @@ -22,6 +22,7 @@ from . import _helpers from google.cloud.datastore.query import PropertyFilter, And, Or +from google.cloud.datastore.vector import FindNearest, DistanceMeasure, Vector retry_503 = RetryErrors(exceptions.ServiceUnavailable) @@ -647,3 +648,95 @@ def test_query_explain_in_transaction(query_client, ancestor_key, database_id): # check for stats stats = iterator.explain_metrics assert isinstance(stats, ExplainMetrics) + + +@pytest.mark.parametrize( + "distance_measure", + [DistanceMeasure.EUCLIDEAN, DistanceMeasure.COSINE, DistanceMeasure.DOT_PRODUCT], +) +@pytest.mark.parametrize("limit", [5, 10, 20]) +@pytest.mark.parametrize("database_id", [_helpers.TEST_DATABASE], indirect=True) +def test_query_vector_find_nearest(query_client, database_id, limit, distance_measure): + q = query_client.query(kind="LargeCharacter", namespace="LargeCharacterEntity") + vector = [v / 10 for v in range(10)] + q.find_nearest = FindNearest( + vector_property="vector", + query_vector=vector, + limit=limit, + distance_measure=distance_measure, + distance_result_property="distance", + ) + iterator = q.fetch() + results = list(iterator) + # verify limit was applied + assert len(results) == limit + # verify distance property is present + assert all(r["distance"] for r in results) + distance_list = [r["distance"] for r in results] + assert all(isinstance(d, float) for d in distance_list) + # verify distances are sorted + if distance_measure == DistanceMeasure.DOT_PRODUCT: + # dot product sorts high to low + expected = sorted(distance_list, reverse=True) + else: + expected = sorted(distance_list) + assert expected == distance_list + + +@pytest.mark.parametrize("exclude_from_indexes", [True, False]) +@pytest.mark.parametrize("database_id", [_helpers.TEST_DATABASE], indirect=True) +def test_query_vector_find_nearest_w_vector_class( + query_client, database_id, exclude_from_indexes +): + """ + ensure passing Vector instance works as expected + + exclude_from_indexes field should be ignored + """ + q = query_client.query(kind="LargeCharacter", namespace="LargeCharacterEntity") + vector = Vector( + [v / 10 for v in range(10)], exclude_from_indexes=exclude_from_indexes + ) + q.find_nearest = FindNearest( + vector_property="vector", + query_vector=vector, + limit=5, + distance_measure=DistanceMeasure.EUCLIDEAN, + distance_result_property="distance", + ) + iterator = q.fetch() + results = list(iterator) + assert len(results) == 5 + + +@pytest.mark.parametrize("database_id", [_helpers.TEST_DATABASE], indirect=True) +def test_query_empty_find_nearest(query_client, database_id): + """ + vector search with empty query_vector should fail + """ + q = query_client.query(kind="LargeCharacter", namespace="LargeCharacterEntity") + q.find_nearest = FindNearest( + vector_property="vector", + query_vector=[], + limit=5, + distance_measure=DistanceMeasure.EUCLIDEAN, + ) + with pytest.raises(ValueError): + list(q.fetch()) + + +@pytest.mark.parametrize("database_id", [_helpers.TEST_DATABASE], indirect=True) +def test_query_find_nearest_wrong_size(query_client, database_id): + """ + vector search with mismatched vector size should fail + """ + q = query_client.query(kind="LargeCharacter", namespace="LargeCharacterEntity") + vector = [v / 10 for v in range(11)] + q.find_nearest = FindNearest( + vector_property="vector", + query_vector=vector, + limit=5, + distance_measure=DistanceMeasure.EUCLIDEAN, + ) + with pytest.raises(ValueError): + list(q.fetch()) diff --git a/tests/system/utils/populate_datastore.py b/tests/system/utils/populate_datastore.py index 0eea15fb..2a9631fc 100644 --- a/tests/system/utils/populate_datastore.py +++ b/tests/system/utils/populate_datastore.py @@ -22,6 +22,7 @@ import sys import time import uuid +import random from google.cloud import datastore @@ -104,6 +105,10 @@ def put_objects(count): task["name"] = "{0:05d}".format(i) task["family"] = "Stark" task["alive"] = False + random.seed(i) + task["vector"] = datastore.vector.Vector( + [random.random() for _ in range(10)] + ) for i in string.ascii_lowercase: task["space-{}".format(i)] = MAX_STRING diff --git a/tests/unit/test_helpers.py b/tests/unit/test_helpers.py index 38702dba..f473af81 100644 --- a/tests/unit/test_helpers.py +++ b/tests/unit/test_helpers.py @@ -815,6 +815,23 @@ def test__pb_attr_value_w_geo_point(): assert value == geo_pt_pb +@pytest.mark.parametrize("exclude", [True, False]) +def test__pb_attr_value_w_vector(exclude): + from google.cloud.datastore.vector import Vector + from google.cloud.datastore.helpers import _pb_attr_value + from google.cloud.datastore_v1.types import entity as entity_pb2 + + vector = Vector([1.0, 2.0, 3.0], exclude_from_indexes=exclude) + name, value = _pb_attr_value(vector) + assert name == "vector_value" + assert isinstance(value, entity_pb2.Value) + assert value.array_value.values[0].double_value == 1.0 + assert value.array_value.values[1].double_value == 2.0 + assert value.array_value.values[2].double_value == 3.0 + assert value.meaning == 31 + assert value.exclude_from_indexes == exclude + + def test__pb_attr_value_w_null(): from google.protobuf import struct_pb2 from google.cloud.datastore.helpers import _pb_attr_value @@ -949,6 +966,26 @@ def test__get_value_from_value_pb_w_geo_point(): assert result.longitude == lng +@pytest.mark.parametrize("exclude", [None, True, False]) +def test__get_value_from_value_pb_w_vector(exclude): + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore.helpers import _get_value_from_value_pb + from google.cloud.datastore.vector import Vector + + vector_pb = entity_pb2.Value() + vector_pb.array_value.values.append(entity_pb2.Value(double_value=1.0)) + vector_pb.array_value.values.append(entity_pb2.Value(double_value=2.0)) + vector_pb.array_value.values.append(entity_pb2.Value(double_value=3.0)) + vector_pb.meaning = 31 + if exclude is not None: + vector_pb.exclude_from_indexes = exclude + + result = _get_value_from_value_pb(vector_pb._pb) + assert isinstance(result, Vector) + assert result == Vector([1.0, 2.0, 3.0]) + assert result.exclude_from_indexes == bool(exclude) + + def test__get_value_from_value_pb_w_null(): from google.protobuf import struct_pb2 from google.cloud.datastore_v1.types import entity as entity_pb2 @@ -1133,6 +1170,21 @@ def test__set_protobuf_value_w_geo_point(): assert pb.geo_point_value == geo_pt_pb +@pytest.mark.parametrize("exclude", [True, False]) +def test__set_protobuf_value_w_vector(exclude): + from google.cloud.datastore.vector import Vector + from google.cloud.datastore.helpers import _set_protobuf_value + + pb = _make_empty_value_pb() + vector = Vector([1.0, 2.0, 3.0], exclude_from_indexes=exclude) + _set_protobuf_value(pb, vector) + assert pb.array_value.values[0].double_value == 1.0 + assert pb.array_value.values[1].double_value == 2.0 + assert pb.array_value.values[2].double_value == 3.0 + assert pb.meaning == 31 # Vector meaning + assert pb.exclude_from_indexes == exclude + + def test__get_meaning_w_no_meaning(): from google.cloud.datastore_v1.types import entity as entity_pb2 from google.cloud.datastore.helpers import _get_meaning diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py index 75fa31fa..54884f1b 100644 --- a/tests/unit/test_query.py +++ b/tests/unit/test_query.py @@ -24,6 +24,7 @@ Or, BaseCompositeFilter, ) +from google.cloud.datastore.vector import FindNearest, DistanceMeasure, Vector from google.cloud.datastore.helpers import set_database_id_to_request @@ -44,6 +45,8 @@ def test_query_ctor_defaults(database_id): assert query.projection == [] assert query.order == [] assert query.distinct_on == [] + assert query._explain_options is None + assert query.find_nearest is None @pytest.mark.parametrize( @@ -68,6 +71,8 @@ def test_query_ctor_explicit(filters, database_id): PROJECTION = ["foo", "bar", "baz"] ORDER = ["foo", "bar"] DISTINCT_ON = ["foo"] + explain_options = object() + find_nearest = object() query = _make_query( client, @@ -79,6 +84,8 @@ def test_query_ctor_explicit(filters, database_id): projection=PROJECTION, order=ORDER, distinct_on=DISTINCT_ON, + explain_options=explain_options, + find_nearest=find_nearest, ) assert query._client is client assert query._client.database == database_id @@ -90,6 +97,8 @@ def test_query_ctor_explicit(filters, database_id): assert query.projection == PROJECTION assert query.order == ORDER assert query.distinct_on == DISTINCT_ON + assert query._explain_options is explain_options + assert query.find_nearest is find_nearest @pytest.mark.parametrize("database_id", [None, "somedb"]) @@ -739,6 +748,20 @@ def test_query_transaction_begin_later(database_id): assert read_options.new_transaction == transaction._options +@pytest.mark.parametrize("database_id", [None, "somedb"]) +def test_query_find_nearest_setter(database_id): + client = _make_client(database=database_id) + query = _make_query(client) + find_nearest = FindNearest( + vector_property="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.COSINE, + limit=5, + ) + query.find_nearest = find_nearest + assert query.find_nearest == find_nearest + + def test_iterator_constructor_defaults(): query = object() client = object() @@ -1255,6 +1278,7 @@ def test_pb_from_query_empty(): assert pb.end_cursor == b"" assert pb._pb.limit.value == 0 assert pb.offset == 0 + assert pb.find_nearest == query_pb2.FindNearest() def test_pb_from_query_projection(): @@ -1271,6 +1295,33 @@ def test_pb_from_query_kind(): assert [item.name for item in pb.kind] == ["KIND"] +def test_pb_from_query_find_nearest(): + from google.cloud.datastore.query import _pb_from_query + from google.cloud.datastore.vector import FindNearest, DistanceMeasure, Vector + from google.cloud.datastore_v1.types import query as query_pb2 + + find_nearest = FindNearest( + vector_property="embedding", + query_vector=Vector([1, 2, 3]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=10, + distance_threshold=0.5, + ) + pb = _pb_from_query(_make_stub_query(find_nearest=find_nearest)) + + assert pb.find_nearest.vector_property.name == "embedding" + assert pb.find_nearest.query_vector.array_value.values[0].double_value == 1.0 + assert pb.find_nearest.query_vector.array_value.values[1].double_value == 2.0 + assert pb.find_nearest.query_vector.array_value.values[2].double_value == 3.0 + assert pb.find_nearest.query_vector.meaning == 31 + assert ( + pb.find_nearest.distance_measure + == query_pb2.FindNearest.DistanceMeasure.EUCLIDEAN + ) + assert pb.find_nearest.limit == 10 + assert pb.find_nearest.distance_threshold == 0.5 + + @pytest.mark.parametrize("database_id", [None, "somedb"]) def test_pb_from_query_ancestor(database_id): from google.cloud.datastore.key import Key @@ -1447,25 +1498,11 @@ def test_pb_from_query_distinct_on(): def _make_stub_query( client=object(), - kind=None, - project=None, - namespace=None, - ancestor=None, - filters=(), - projection=(), - order=(), - distinct_on=(), + **kwargs, ): query = Query( client, - kind=kind, - project=project, - namespace=namespace, - ancestor=ancestor, - filters=filters, - projection=projection, - order=order, - distinct_on=distinct_on, + **kwargs, ) return query diff --git a/tests/unit/test_vector.py b/tests/unit/test_vector.py new file mode 100644 index 00000000..49ee8e36 --- /dev/null +++ b/tests/unit/test_vector.py @@ -0,0 +1,221 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from google.cloud.datastore.vector import Vector, FindNearest, DistanceMeasure + + +class TestVector: + """ + tests for google.cloud.datastore.vector.Vector + """ + + def test_vector_ctor(self): + v = Vector([1.0, 2.0, 3.0]) + assert len(v) == 3 + assert v[0] == 1.0 + assert v[1] == 2.0 + assert v[2] == 3.0 + assert v.exclude_from_indexes is False + + def test_vector_ctor_w_ints(self): + v = Vector([1, 2, 3]) + assert len(v) == 3 + assert v[0] == 1.0 + assert v[1] == 2.0 + assert v[2] == 3.0 + + @pytest.mark.parametrize("exclude", [True, False]) + def test_vector_ctor_w_exclude_from_indexes(self, exclude): + v = Vector([1], exclude_from_indexes=exclude) + assert v.exclude_from_indexes == exclude + + def test_vector_empty_ctor(self): + v = Vector([]) + assert len(v) == 0 + assert v._value == () + + def test_vector_equality(self): + v1 = Vector([1.0, 2.0, 3.0]) + v2 = Vector([1.0, 2.0, 3.0]) + v3 = Vector([3.0, 2.0, 1.0]) + assert v1 == v2 + assert v1 != v3 + + def test_vector_representation(self): + v = Vector([1, 9.4, 3.1234]) + assert repr(v) == "Vector<1.0, 9.4, 3.1234>" + + @pytest.mark.parametrize("exclude", [True, False]) + def test_vector_to_dict(self, exclude): + v = Vector([1.0, 2.0, 3.0], exclude_from_indexes=exclude) + expected = { + "array_value": { + "values": [ + {"double_value": 1.0}, + {"double_value": 2.0}, + {"double_value": 3.0}, + ] + }, + "meaning": 31, + "exclude_from_indexes": exclude, + } + assert v._to_dict() == expected + + def test_vector_iteration(self): + v = Vector(range(10)) + assert v[0] == 0.0 + assert v[3] == 3.0 + assert v[-1] == 9.0 + for i, val in enumerate(v): + assert i == val + + @pytest.mark.parametrize("exclude", [True, False]) + def test_vector_slicing(self, exclude): + v = Vector(range(10), exclude_from_indexes=exclude) + assert v[1:3] == Vector([1.0, 2.0]) + assert v[:] == Vector([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]) + assert v[::-1] == Vector([9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0]) + assert v[3:7].exclude_from_indexes == exclude + + @pytest.mark.parametrize("exclude", [True, False]) + def test_vector_to_proto(self, exclude): + from google.cloud.datastore_v1.types import Value + + v = Vector([1.0, 2.0, 3.0], exclude_from_indexes=exclude) + proto = Value(**v._to_dict()) + assert proto.array_value.values[0].double_value == 1.0 + assert proto.array_value.values[1].double_value == 2.0 + assert proto.array_value.values[2].double_value == 3.0 + assert proto.meaning == 31 + assert proto.exclude_from_indexes == exclude + + def test_empty_vector_to_proto(self): + from google.cloud.datastore_v1.types import Value + + v = Vector([]) + proto = Value(**v._to_dict()) + assert proto.array_value.values == [] + assert proto.meaning == 31 + + +class TestFindNearest: + """ + tests for google.cloud.datastore.vector.FindNearest + """ + + def test_ctor_defaults(self): + expected_property = "embeddings" + expected_vector = [1.0, 2.0, 3.0] + expected_limit = 5 + expected_distance_measure = DistanceMeasure.DOT_PRODUCT + fn = FindNearest( + expected_property, + expected_vector, + expected_limit, + expected_distance_measure, + ) + assert fn.vector_property == expected_property + assert fn.query_vector == Vector(expected_vector) + assert fn.limit == expected_limit + assert fn.distance_measure == expected_distance_measure + assert fn.distance_result_property is None + assert fn.distance_threshold is None + + def test_ctor_explicit(self): + expected_property = "embeddings" + expected_vector = Vector([1.0, 2.0, 3.0]) + expected_limit = 10 + expected_distance_measure = DistanceMeasure.EUCLIDEAN + expected_distance_result_property = "distance" + expected_distance_threshold = 0.5 + fn = FindNearest( + expected_property, + expected_vector, + expected_limit, + expected_distance_measure, + expected_distance_result_property, + expected_distance_threshold, + ) + assert fn.vector_property == expected_property + assert fn.query_vector == expected_vector + assert fn.limit == expected_limit + assert fn.distance_measure == expected_distance_measure + assert fn.distance_result_property == expected_distance_result_property + assert fn.distance_threshold == expected_distance_threshold + + def test_find_nearest_to_dict(self): + fn = FindNearest( + vector_property="embeddings", + query_vector=[1.0, 2.0, 3.0], + limit=10, + distance_measure=DistanceMeasure.EUCLIDEAN, + distance_result_property="distance", + distance_threshold=0.5, + ) + expected = { + "vector_property": {"name": "embeddings"}, + "query_vector": { + "array_value": { + "values": [ + {"double_value": 1.0}, + {"double_value": 2.0}, + {"double_value": 3.0}, + ] + }, + "meaning": 31, + "exclude_from_indexes": False, + }, + "distance_measure": 1, + "limit": 10, + "distance_result_property": "distance", + "distance_threshold": 0.5, + } + assert fn._to_dict() == expected + + def test_limited_find_nearest_to_dict(self): + fn = FindNearest( + vector_property="embeddings", + query_vector=[3, 2, 1], + limit=99, + distance_measure=DistanceMeasure.DOT_PRODUCT, + ) + expected = { + "vector_property": {"name": "embeddings"}, + "query_vector": { + "array_value": { + "values": [ + {"double_value": 3.0}, + {"double_value": 2.0}, + {"double_value": 1.0}, + ] + }, + "meaning": 31, + "exclude_from_indexes": False, + }, + "distance_measure": 3, + "limit": 99, + } + assert fn._to_dict() == expected + + def test_find_nearest_representation(self): + fn = FindNearest( + vector_property="embeddings", + query_vector=[1.0, 2.0, 3.0], + limit=10, + distance_measure=DistanceMeasure.EUCLIDEAN, + ) + expected = "FindNearest(vector_property='embeddings', query_vector=Vector<1.0, 2.0, 3.0>, limit=10, distance_measure=, distance_result_property=None, distance_threshold=None)" + assert repr(fn) == expected