Skip to content

[DRAFT] feat: vector search #568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions google/cloud/datastore/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from google.cloud.datastore_v1.types import entity as entity_pb2
from google.cloud.datastore.entity import Entity
from google.cloud.datastore.key import Key
from google.cloud.datastore.vector import Vector
from google.cloud.datastore.vector import _VECTOR_VALUE
from google.protobuf import timestamp_pb2


Expand Down Expand Up @@ -401,6 +403,8 @@ def _pb_attr_value(val):
name, value = "array", val
elif isinstance(val, GeoPoint):
name, value = "geo_point", val.to_protobuf()
elif isinstance(val, Vector):
name, value = "vector", entity_pb2.Value(**val._to_dict())
elif val is None:
name, value = "null", struct_pb2.NULL_VALUE
else:
Expand Down Expand Up @@ -457,6 +461,11 @@ def _get_value_from_value_pb(pb):
result = [
_get_value_from_value_pb(item_value) for item_value in pb.array_value.values
]
# check for vector values
if pb.meaning == _VECTOR_VALUE and all(
isinstance(item, float) for item in result
):
result = Vector(result, exclude_from_indexes=bool(pb.exclude_from_indexes))

elif value_type == "geo_point_value":
result = GeoPoint(
Expand Down Expand Up @@ -507,6 +516,8 @@ def _set_protobuf_value(value_pb, val):
for item in val:
i_pb = l_pb.add()
_set_protobuf_value(i_pb, item)
elif attr == "vector_value":
value_pb.CopyFrom(val._pb)
elif attr == "geo_point_value":
value_pb.geo_point_value.CopyFrom(val)
else: # scalar, just assign
Expand Down
10 changes: 10 additions & 0 deletions google/cloud/datastore/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ class Query(object):
this query. When set, explain_metrics will be available on the iterator
returned by query.fetch().

:type find_nearest: :class:`~google.cloud.datastore.vector.FindNearest`
:param find_nearest: (Optional) Options to perform a vector search for
entities in the query. When set, the query will return entities
sorted by distance from the query vector.

:raises: ValueError if ``project`` is not passed and no implicit
default is set.
"""
Expand Down Expand Up @@ -211,6 +216,7 @@ def __init__(
order=(),
distinct_on=(),
explain_options=None,
find_nearest=None,
):
self._client = client
self._kind = kind
Expand All @@ -232,6 +238,7 @@ def __init__(
self._explain_options = explain_options
self._ancestor = ancestor
self._filters = []
self.find_nearest = find_nearest

# Verify filters passed in.
for filter in filters:
Expand Down Expand Up @@ -944,6 +951,9 @@ def _pb_from_query(query):
ref.name = distinct_on_name
pb.distinct_on.append(ref)

if query.find_nearest:
pb.find_nearest = query_pb2.FindNearest(**query.find_nearest._to_dict())

return pb


Expand Down
134 changes: 134 additions & 0 deletions google/cloud/datastore/vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import collections

from typing import Sequence
from dataclasses import dataclass
from enum import Enum

_VECTOR_VALUE = 31


class DistanceMeasure(Enum):
EUCLIDEAN = 1
COSINE = 2
DOT_PRODUCT = 3


class Vector(collections.abc.Sequence):
"""A class to represent a Vector for use in query.find_nearest.
Underlying object will be converted to a map representation in Firestore API.
"""

def __init__(self, value: Sequence[float], *, exclude_from_indexes: bool = False):
self.exclude_from_indexes = exclude_from_indexes
self._value = tuple([float(v) for v in value])

def __getitem__(self, arg: int | slice):
if isinstance(arg, int):
return self._value[arg]
elif isinstance(arg, slice):
return Vector(
self._value[arg], exclude_from_indexes=self.exclude_from_indexes
)
else:
raise NotImplementedError

def __len__(self):
return len(self._value)

def __eq__(self, other: object) -> bool:
if not isinstance(other, Vector):
raise NotImplementedError
return self._value == other._value

def __repr__(self):
return f"Vector<{str(self._value)[1:-1]}>"

def _to_dict(self):
return {
"array_value": {"values": [{"double_value": v} for v in self._value]},
"meaning": _VECTOR_VALUE,
"exclude_from_indexes": self.exclude_from_indexes,
}


@dataclass
class FindNearest:
"""
Represents configuration for a find_nearest vector query.

:type vector_field: str
:param vector_field:
An indexed vector property to search upon.
Only documents which contain vectors whose dimensionality match
the query_vector can be returned.

:type query_vector: Union[Vector, Sequence[float]]
:param query_vector:
The query vector that we are searching on.
Must be a vector of no more than 2048 dimensions.

:type limit: int
:param limit:
The number of nearest neighbors to return.
Must be a positive integer of no more than 100.

:type distance_measure: DistanceMeasure
:param distance_measure:
The distance measure to use when comparing vectors.

:type distance_result_property: Optional[str]
:param distance_result_property:
Optional name of the field to output the result of the vector distance
calculation.

:type distance_threshold: Optional[float]
:param distance_threshold:
Threshold value for which no less similar documents will be returned.
The behavior of the specified ``distance_measure`` will affect the
meaning of the distance threshold:
For EUCLIDEAN, COSINE: WHERE distance <= distance_threshold
For DOT_PRODUCT: WHERE distance >= distance_threshold

Optional threshold to apply to the distance measure.
If set, only documents whose distance measure is less than this value
will be returned.
"""

vector_property: str
query_vector: Vector | Sequence[float]
limit: int
distance_measure: DistanceMeasure
distance_result_property: str | None = None
distance_threshold: float | None = None

def __post_init__(self):
if not isinstance(self.query_vector, Vector):
self.query_vector = Vector(self.query_vector)

def _to_dict(self):
output = {
"vector_property": {"name": self.vector_property},
"query_vector": self.query_vector._to_dict(),
"distance_measure": self.distance_measure.value,
"limit": self.limit,
}
if self.distance_result_property is not None:
output["distance_result_property"] = self.distance_result_property
if self.distance_threshold is not None:
output["distance_threshold"] = float(self.distance_threshold)
return output
7 changes: 7 additions & 0 deletions tests/system/index.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,10 @@ indexes:
- name: family
- name: appearances

- kind: LargeCharacter
properties:
- name: __key__
- name: vector
vectorConfig:
dimension: 10
flat: {}
38 changes: 38 additions & 0 deletions tests/system/test_put.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,41 @@ def test_client_put_w_empty_array(datastore_client, entities_to_delete, database
retrieved = local_client.get(entity.key)

assert entity["children"] == retrieved["children"]


@pytest.mark.parametrize("data", [[0], (1.0, 2.0, 3.0), range(100)])
@pytest.mark.parametrize("exclude_from_indexes", [True, False])
@pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True)
def test_client_put_w_vector(
datastore_client, entities_to_delete, database_id, data, exclude_from_indexes
):
local_client = _helpers.clone_client(datastore_client)

key = local_client.key("VectorArray", 1234)
entity = datastore.Entity(key=key)
entity["vec"] = datastore.vector.Vector(
data, exclude_from_indexes=exclude_from_indexes
)
local_client.put(entity)
entities_to_delete.append(entity)

retrieved = local_client.get(entity.key)

assert entity["vec"] == retrieved["vec"]
assert entity["vec"]._to_dict() == retrieved["vec"]._to_dict()
assert entity["vec"].exclude_from_indexes == exclude_from_indexes
assert retrieved["vec"].exclude_from_indexes == exclude_from_indexes


@pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True)
def test_client_put_w_empty_vector(datastore_client, entities_to_delete, database_id):
from google.api_core.exceptions import BadRequest

local_client = _helpers.clone_client(datastore_client)

key = local_client.key("VectorArray", 1234)
entity = datastore.Entity(key=key)
entity["vec"] = datastore.vector.Vector([])
with pytest.raises(BadRequest) as e:
local_client.put(entity)
assert "Cannot have a zero length vector" in str(e)
93 changes: 93 additions & 0 deletions tests/system/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from . import _helpers

from google.cloud.datastore.query import PropertyFilter, And, Or
from google.cloud.datastore.vector import FindNearest, DistanceMeasure, Vector


retry_503 = RetryErrors(exceptions.ServiceUnavailable)
Expand Down Expand Up @@ -647,3 +648,95 @@ def test_query_explain_in_transaction(query_client, ancestor_key, database_id):
# check for stats
stats = iterator.explain_metrics
assert isinstance(stats, ExplainMetrics)


@pytest.mark.parametrize(
"distance_measure",
[DistanceMeasure.EUCLIDEAN, DistanceMeasure.COSINE, DistanceMeasure.DOT_PRODUCT],
)
@pytest.mark.parametrize("limit", [5, 10, 20])
@pytest.mark.parametrize("database_id", [_helpers.TEST_DATABASE], indirect=True)
def test_query_vector_find_nearest(query_client, database_id, limit, distance_measure):
q = query_client.query(kind="LargeCharacter", namespace="LargeCharacterEntity")
vector = [v / 10 for v in range(10)]
q.find_nearest = FindNearest(
vector_property="vector",
query_vector=vector,
limit=limit,
distance_measure=distance_measure,
distance_result_property="distance",
)
iterator = q.fetch()
results = list(iterator)
# verify limit was applied
assert len(results) == limit
# verify distance property is present
assert all(r["distance"] for r in results)
distance_list = [r["distance"] for r in results]
assert all(isinstance(d, float) for d in distance_list)
# verify distances are sorted
if distance_measure == DistanceMeasure.DOT_PRODUCT:
# dot product sorts high to low
expected = sorted(distance_list, reverse=True)
else:
expected = sorted(distance_list)
assert expected == distance_list


@pytest.mark.parametrize("exclude_from_indexes", [True, False])
@pytest.mark.parametrize("database_id", [_helpers.TEST_DATABASE], indirect=True)
def test_query_vector_find_nearest_w_vector_class(
query_client, database_id, exclude_from_indexes
):
"""
ensure passing Vector instance works as expected

exclude_from_indexes field should be ignored
"""
q = query_client.query(kind="LargeCharacter", namespace="LargeCharacterEntity")
vector = Vector(
[v / 10 for v in range(10)], exclude_from_indexes=exclude_from_indexes
)
q.find_nearest = FindNearest(
vector_property="vector",
query_vector=vector,
limit=5,
distance_measure=DistanceMeasure.EUCLIDEAN,
distance_result_property="distance",
)
iterator = q.fetch()
results = list(iterator)
assert len(results) == 5


@pytest.mark.parametrize("database_id", [_helpers.TEST_DATABASE], indirect=True)
def test_query_empty_find_nearest(query_client, database_id):
"""
vector search with empty query_vector should fail
"""
q = query_client.query(kind="LargeCharacter", namespace="LargeCharacterEntity")
q.find_nearest = FindNearest(
vector_property="vector",
query_vector=[],
limit=5,
distance_measure=DistanceMeasure.EUCLIDEAN,
)
with pytest.raises(ValueError):
list(q.fetch())


@pytest.mark.parametrize("database_id", [_helpers.TEST_DATABASE], indirect=True)
def test_query_find_nearest_wrong_size(query_client, database_id):
"""
vector search with mismatched vector size should fail
"""
q = query_client.query(kind="LargeCharacter", namespace="LargeCharacterEntity")
vector = [v / 10 for v in range(11)]
q.find_nearest = FindNearest(
vector_property="vector",
query_vector=vector,
limit=5,
distance_measure=DistanceMeasure.EUCLIDEAN,
)
with pytest.raises(ValueError):
list(q.fetch())
5 changes: 5 additions & 0 deletions tests/system/utils/populate_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import sys
import time
import uuid
import random

from google.cloud import datastore

Expand Down Expand Up @@ -104,6 +105,10 @@ def put_objects(count):
task["name"] = "{0:05d}".format(i)
task["family"] = "Stark"
task["alive"] = False
random.seed(i)
task["vector"] = datastore.vector.Vector(
[random.random() for _ in range(10)]
)

for i in string.ascii_lowercase:
task["space-{}".format(i)] = MAX_STRING
Expand Down
Loading
Loading