diff --git a/.readme-partials.yaml b/.readme-partials.yaml index 3e8d3a465..ced21e49f 100644 --- a/.readme-partials.yaml +++ b/.readme-partials.yaml @@ -110,33 +110,28 @@ custom_content: | ------- In this feature launch, the [Java Datastore client](https://github.com/googleapis/java-datastore) now offers gRPC as a transport layer option with experimental support. Using [gRPC connection pooling](https://grpc.io/docs/guides/performance/) enables distributing RPCs over multiple connections which may improve performance. - #### Installation Instructions - The client can be built from the `grpc-experimental` branch on GitHub. For private preview, you can also download the artifact with the instructions provided below. - - 1. Download the datastore private preview package with dependencies: - ``` - curl -o https://datastore-sdk-feature-release.web.app/google-cloud-datastore-2.20.0-grpc-experimental-1-SNAPSHOT-jar-with-dependencies.jar - ``` - 2. Run the following commands to install JDK locally: - ``` - mvn install:install-file -Dfile= -DgroupId=com.google.cloud -DartifactId=google-cloud-datastore -Dversion=2.20.0-grpc - ``` - 3. Edit your pom.xml to add above package to `` section: - ```xml - + #### Download Instructions + Instructions: + 1. Clone the grpc-experimental branch from GitHub: + ```python + git clone -b grpc-experimental https://github.com/googleapis/java-datastore.git + ``` + 2. Run the following commands to build the library: + ```python + # Go to the directory the code was downloaded to + cd java-datastore/ + + # Build the library + mvn clean install -DskipTests=true + ``` + 3. Add the following dependency to your project: + ```xml + com.google.cloud google-cloud-datastore 2.20.0-grpc-experimental-1-SNAPSHOT - - ``` - - And if you have not yet, add below to `` section: - ```xml - - local-repo - file://${user.home}/.m2/repository - - ``` + + ``` #### How to Use To opt-in to the gRPC transport behavior, simply add the below line of code (`setTransportOptions`) to your Datastore client instantiation. @@ -186,44 +181,11 @@ custom_content: | #### New Features There are new gRPC specific features available to use in this update. - ##### Connection Pool - A connection pool, also known as a channel pool, is a cache of database connections that are shared and reused to improve connection latency and performance. With this update, now you will be able to configure the channel pool to improve application performance. This section guides you in determining the optimal connection pool size and configuring it within the Java datastore client. - To customize the number of channels your client uses, you can update the channel provider in the DatastoreOptions. - ###### Determine the best connection pool size - The default connection pool size is right for most applications, and in most cases there's no need to change it. - - However sometimes you may want to change your connection pool size due to high throughput or buffered requests. Ideally, to leave room for traffic fluctuations, a connection pool has about twice the number of connections it takes for maximum saturation. Because a connection can handle a maximum of 100 concurrent requests, between 10 and 50 outstanding requests per connection is optimal. The limit of 100 concurrent streams per gRPC connection is enforced in Google's middleware layer, and you are not able to reconfigure this number. - - The following steps help you calculate the optimal number of connections in your channel pool using estimate per-client QPS and average latency numbers. - - To calculate the optimal connections, gather the following information: - - 1. The maximum number of queries per second (QPS) per client when your application is running a typical workload. - 2. The average latency (the response time for a single request) in ms. - 3. Determine the number of requests that you can send serially per second by dividing 1,000 by the average latency value. - 4. Divide the QPS in seconds by the number of serial requests per second. - 5. Divide the result by 50 requests per channel to determine the minimum optimal channel pool size. (If your calculation is less than 2, use at least 2 channels anyway, to ensure redundancy.) - 6. Divide the same result by 10 requests per channel to determine the maximum optimal channel pool size. - - These steps are expressed in the following equations: - ```java - (QPS ÷ (1,000 ÷ latency ms)) ÷ 50 streams = Minimum optimal number of connections - (QPS ÷ (1,000 ÷ latency ms)) ÷ 10 streams = Maximum optimal number of connections - ``` - - ###### Example - Your application typically sends 50,000 requests per second, and the average latency is 10 ms. Divide 1,000 by 10 ms to determine that you can send 100 requests serially per second. - Divide that number into 50,000 to get the parallelism needed to send 50,000 QPS: 500. Each channel can have at most 100 requests out concurrently, and your target channel utilization - is between 10 and 50 concurrent streams. Therefore, to calculate the minimum, divide 500 by 50 to get 10. To find the maximum, divide 500 by 10 to get 50. This means that your channel - pool size for this example should be between 10 and 50 connections. - - It is also important to monitor your traffic after making changes and adjust the number of connections in your pool if necessary. - - ###### Set the pool size - The following code sample demonstrates how to configure the channel pool in the client libraries using `DatastoreOptions`. + ##### Channel Pooling + To customize the number of channels your client uses, you can update the channel provider in the DatastoreOptions. See [ChannelPoolSettings](https://cloud.google.com/java/docs/reference/gax/latest/com.google.api.gax.grpc.ChannelPoolSettings) and [Performance Best Practices](https://grpc.io/docs/guides/performance/) for more information on channel pools and best practices for performance. - Code Example + Example: ```java InstantiatingGrpcChannelProvider channelProvider = DatastoreSettings.defaultGrpcTransportProviderBuilder() diff --git a/README.md b/README.md index 2fa4ffd04..4efd4bec7 100644 --- a/README.md +++ b/README.md @@ -208,33 +208,28 @@ gRPC Java Datastore Client User Guide ------- In this feature launch, the [Java Datastore client](https://github.com/googleapis/java-datastore) now offers gRPC as a transport layer option with experimental support. Using [gRPC connection pooling](https://grpc.io/docs/guides/performance/) enables distributing RPCs over multiple connections which may improve performance. -#### Installation Instructions -The client can be built from the `grpc-experimental` branch on GitHub. For private preview, you can also download the artifact with the instructions provided below. - -1. Download the datastore private preview package with dependencies: - ``` - curl -o https://datastore-sdk-feature-release.web.app/google-cloud-datastore-2.20.0-grpc-experimental-1-SNAPSHOT-jar-with-dependencies.jar - ``` -2. Run the following commands to install JDK locally: - ``` - mvn install:install-file -Dfile= -DgroupId=com.google.cloud -DartifactId=google-cloud-datastore -Dversion=2.20.0-grpc - ``` -3. Edit your pom.xml to add above package to `` section: - ```xml - +#### Download Instructions +Instructions: +1. Clone the grpc-experimental branch from GitHub: +```python +git clone -b grpc-experimental https://github.com/googleapis/java-datastore.git +``` +2. Run the following commands to build the library: +```python +# Go to the directory the code was downloaded to +cd java-datastore/ + +# Build the library +mvn clean install -DskipTests=true +``` +3. Add the following dependency to your project: +```xml + com.google.cloud google-cloud-datastore 2.20.0-grpc-experimental-1-SNAPSHOT - - ``` - -And if you have not yet, add below to `` section: - ```xml - - local-repo - file://${user.home}/.m2/repository - - ``` + +``` #### How to Use To opt-in to the gRPC transport behavior, simply add the below line of code (`setTransportOptions`) to your Datastore client instantiation. @@ -284,44 +279,11 @@ boolean isHTTP = datastore.getOptions().getTransportOptions() instanceof HTTPTra #### New Features There are new gRPC specific features available to use in this update. -##### Connection Pool -A connection pool, also known as a channel pool, is a cache of database connections that are shared and reused to improve connection latency and performance. With this update, now you will be able to configure the channel pool to improve application performance. This section guides you in determining the optimal connection pool size and configuring it within the Java datastore client. -To customize the number of channels your client uses, you can update the channel provider in the DatastoreOptions. -###### Determine the best connection pool size -The default connection pool size is right for most applications, and in most cases there's no need to change it. - -However sometimes you may want to change your connection pool size due to high throughput or buffered requests. Ideally, to leave room for traffic fluctuations, a connection pool has about twice the number of connections it takes for maximum saturation. Because a connection can handle a maximum of 100 concurrent requests, between 10 and 50 outstanding requests per connection is optimal. The limit of 100 concurrent streams per gRPC connection is enforced in Google's middleware layer, and you are not able to reconfigure this number. - -The following steps help you calculate the optimal number of connections in your channel pool using estimate per-client QPS and average latency numbers. - -To calculate the optimal connections, gather the following information: - -1. The maximum number of queries per second (QPS) per client when your application is running a typical workload. -2. The average latency (the response time for a single request) in ms. -3. Determine the number of requests that you can send serially per second by dividing 1,000 by the average latency value. -4. Divide the QPS in seconds by the number of serial requests per second. -5. Divide the result by 50 requests per channel to determine the minimum optimal channel pool size. (If your calculation is less than 2, use at least 2 channels anyway, to ensure redundancy.) -6. Divide the same result by 10 requests per channel to determine the maximum optimal channel pool size. - -These steps are expressed in the following equations: -```java -(QPS ÷ (1,000 ÷ latency ms)) ÷ 50 streams = Minimum optimal number of connections -(QPS ÷ (1,000 ÷ latency ms)) ÷ 10 streams = Maximum optimal number of connections -``` - -###### Example -Your application typically sends 50,000 requests per second, and the average latency is 10 ms. Divide 1,000 by 10 ms to determine that you can send 100 requests serially per second. -Divide that number into 50,000 to get the parallelism needed to send 50,000 QPS: 500. Each channel can have at most 100 requests out concurrently, and your target channel utilization -is between 10 and 50 concurrent streams. Therefore, to calculate the minimum, divide 500 by 50 to get 10. To find the maximum, divide 500 by 10 to get 50. This means that your channel -pool size for this example should be between 10 and 50 connections. - -It is also important to monitor your traffic after making changes and adjust the number of connections in your pool if necessary. - -###### Set the pool size -The following code sample demonstrates how to configure the channel pool in the client libraries using `DatastoreOptions`. +##### Channel Pooling +To customize the number of channels your client uses, you can update the channel provider in the DatastoreOptions. See [ChannelPoolSettings](https://cloud.google.com/java/docs/reference/gax/latest/com.google.api.gax.grpc.ChannelPoolSettings) and [Performance Best Practices](https://grpc.io/docs/guides/performance/) for more information on channel pools and best practices for performance. -Code Example +Example: ```java InstantiatingGrpcChannelProvider channelProvider = DatastoreSettings.defaultGrpcTransportProviderBuilder() @@ -413,6 +375,13 @@ Samples are in the [`samples/`](https://github.com/googleapis/java-datastore/tre | Query Profile Explain Aggregation | [source code](https://github.com/googleapis/java-datastore/blob/main/samples/snippets/src/main/java/com/example/datastore/queryprofile/QueryProfileExplainAggregation.java) | [![Open in Cloud Shell][shell_img]](https://console.cloud.google.com/cloudshell/open?git_repo=https://github.com/googleapis/java-datastore&page=editor&open_in_editor=samples/snippets/src/main/java/com/example/datastore/queryprofile/QueryProfileExplainAggregation.java) | | Query Profile Explain Analyze | [source code](https://github.com/googleapis/java-datastore/blob/main/samples/snippets/src/main/java/com/example/datastore/queryprofile/QueryProfileExplainAnalyze.java) | [![Open in Cloud Shell][shell_img]](https://console.cloud.google.com/cloudshell/open?git_repo=https://github.com/googleapis/java-datastore&page=editor&open_in_editor=samples/snippets/src/main/java/com/example/datastore/queryprofile/QueryProfileExplainAnalyze.java) | | Query Profile Explain Analyze Aggregation | [source code](https://github.com/googleapis/java-datastore/blob/main/samples/snippets/src/main/java/com/example/datastore/queryprofile/QueryProfileExplainAnalyzeAggregation.java) | [![Open in Cloud Shell][shell_img]](https://console.cloud.google.com/cloudshell/open?git_repo=https://github.com/googleapis/java-datastore&page=editor&open_in_editor=samples/snippets/src/main/java/com/example/datastore/queryprofile/QueryProfileExplainAnalyzeAggregation.java) | +| Store Vectors | [source code](https://github.com/googleapis/java-datastore/blob/main/samples/snippets/src/main/java/com/example/datastore/vectorsearch/StoreVectors.java) | [![Open in Cloud Shell][shell_img]](https://console.cloud.google.com/cloudshell/open?git_repo=https://github.com/googleapis/java-datastore&page=editor&open_in_editor=samples/snippets/src/main/java/com/example/datastore/vectorsearch/StoreVectors.java) | +| Vector Search Basic | [source code](https://github.com/googleapis/java-datastore/blob/main/samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchBasic.java) | [![Open in Cloud Shell][shell_img]](https://console.cloud.google.com/cloudshell/open?git_repo=https://github.com/googleapis/java-datastore&page=editor&open_in_editor=samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchBasic.java) | +| Vector Search Distance Result Property | [source code](https://github.com/googleapis/java-datastore/blob/main/samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchDistanceResultProperty.java) | [![Open in Cloud Shell][shell_img]](https://console.cloud.google.com/cloudshell/open?git_repo=https://github.com/googleapis/java-datastore&page=editor&open_in_editor=samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchDistanceResultProperty.java) | +| Vector Search Distance Result Property Projection | [source code](https://github.com/googleapis/java-datastore/blob/main/samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchDistanceResultPropertyProjection.java) | [![Open in Cloud Shell][shell_img]](https://console.cloud.google.com/cloudshell/open?git_repo=https://github.com/googleapis/java-datastore&page=editor&open_in_editor=samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchDistanceResultPropertyProjection.java) | +| Vector Search Distance Threshold | [source code](https://github.com/googleapis/java-datastore/blob/main/samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchDistanceThreshold.java) | [![Open in Cloud Shell][shell_img]](https://console.cloud.google.com/cloudshell/open?git_repo=https://github.com/googleapis/java-datastore&page=editor&open_in_editor=samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchDistanceThreshold.java) | +| Vector Search Large Response | [source code](https://github.com/googleapis/java-datastore/blob/main/samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchLargeResponse.java) | [![Open in Cloud Shell][shell_img]](https://console.cloud.google.com/cloudshell/open?git_repo=https://github.com/googleapis/java-datastore&page=editor&open_in_editor=samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchLargeResponse.java) | +| Vector Search Prefilter | [source code](https://github.com/googleapis/java-datastore/blob/main/samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchPrefilter.java) | [![Open in Cloud Shell][shell_img]](https://console.cloud.google.com/cloudshell/open?git_repo=https://github.com/googleapis/java-datastore&page=editor&open_in_editor=samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchPrefilter.java) | | Task List | [source code](https://github.com/googleapis/java-datastore/blob/main/samples/snippets/src/main/java/com/google/datastore/snippets/TaskList.java) | [![Open in Cloud Shell][shell_img]](https://console.cloud.google.com/cloudshell/open?git_repo=https://github.com/googleapis/java-datastore&page=editor&open_in_editor=samples/snippets/src/main/java/com/google/datastore/snippets/TaskList.java) | diff --git a/samples/snippets/src/test/java/com/google/datastore/snippets/ConceptsTest.java b/com/google/datastore/snippets/ConceptsTest.java similarity index 98% rename from samples/snippets/src/test/java/com/google/datastore/snippets/ConceptsTest.java rename to com/google/datastore/snippets/ConceptsTest.java index 33aa63ab4..51dcd4b3f 100644 --- a/samples/snippets/src/test/java/com/google/datastore/snippets/ConceptsTest.java +++ b/com/google/datastore/snippets/ConceptsTest.java @@ -30,6 +30,7 @@ import com.google.cloud.datastore.DatastoreOptions; import com.google.cloud.datastore.Entity; import com.google.cloud.datastore.EntityQuery; +import com.google.cloud.datastore.FindNearest; import com.google.cloud.datastore.FullEntity; import com.google.cloud.datastore.IncompleteKey; import com.google.cloud.datastore.Key; @@ -47,6 +48,7 @@ import com.google.cloud.datastore.StructuredQuery.OrderBy; import com.google.cloud.datastore.StructuredQuery.PropertyFilter; import com.google.cloud.datastore.Transaction; +import com.google.cloud.datastore.VectorValue; import com.google.cloud.datastore.testing.LocalDatastoreHelper; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -407,6 +409,7 @@ private void setUpQueryTests() { "description", StringValue.newBuilder("Learn Cloud Datastore").setExcludeFromIndexes(true).build()) .set("tag", "fun", "l", "programming", "learn") + .set("vector_property", VectorValue.newBuilder(3.0, 1.0, 2.0).build()) .build()); } @@ -1192,4 +1195,18 @@ public void testStaleReads() throws InterruptedException { // [END datastore_stale_read] assertValidQueryRealBackend(query); } + + @Test + public void testVectorSearch() { + setUpQueryTests(); + // [START datastore_vector_search] + VectorValue vectorValue = VectorValue.newBuilder(1.78, 2.56, 3.88).build(); + FindNearest vectorQuery = + new FindNearest( + "vector_property", vectorValue, FindNearest.DistanceMeasure.COSINE, 1, "distance"); + + Query query = Query.newEntityQueryBuilder().setFindNearest(vectorQuery).build(); + // [END datastore_vector_search] + assertValidQuery(query); + } } diff --git a/google-cloud-datastore/clirr-ignored-differences.xml b/google-cloud-datastore/clirr-ignored-differences.xml index 6275f05c6..e3cc028d8 100644 --- a/google-cloud-datastore/clirr-ignored-differences.xml +++ b/google-cloud-datastore/clirr-ignored-differences.xml @@ -14,7 +14,7 @@ com/google/cloud/datastore/DatastoreReader - com.google.cloud.datastore.AggregationResults runAggregation(com.google.cloud.datastore.AggregationQuery, com.google.cloud.datastore.models.ExplainOptions) + com.google.cloud.datastore.AggregationResults runAggregation(com.google.cloud.datastore.AggregationQuery, com.google.cloud.datastore.models.ExplainOptions) 7012 @@ -27,8 +27,13 @@ com.google.cloud.datastore.QueryResults run(com.google.cloud.datastore.Query, com.google.cloud.datastore.models.ExplainOptions) 7012 + + com/google/cloud/datastore/StructuredQuery* + com.google.cloud.datastore.StructuredQuery* setFindNearest(com.google.cloud.datastore.FindNearest) + 7012 + - + com/google/cloud/datastore/ReadOption$QueryConfig com.google.cloud.datastore.ReadOption$QueryConfig create(com.google.cloud.datastore.Query, java.util.List) diff --git a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/BaseEntity.java b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/BaseEntity.java index 608dc7187..d50770215 100644 --- a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/BaseEntity.java +++ b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/BaseEntity.java @@ -615,6 +615,17 @@ public > List getList(String name) { return (List) getValue(name).get(); } + /** + * Returns the property value as a vector. + * + * @throws DatastoreException if no such property + * @throws ClassCastException if value is not a vector + */ + @SuppressWarnings("unchecked") + public List getVector(String name) { + return (List) getValue(name).get(); + } + /** * Returns the property value as a blob. * diff --git a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/FindNearest.java b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/FindNearest.java new file mode 100644 index 000000000..d8c2176f0 --- /dev/null +++ b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/FindNearest.java @@ -0,0 +1,207 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.datastore; + +import com.google.common.base.MoreObjects; +import com.google.common.base.MoreObjects.ToStringHelper; +import com.google.protobuf.DoubleValue; +import com.google.protobuf.Int32Value; +import java.io.Serializable; +import java.util.Objects; +import javax.annotation.Nullable; + +/** + * A query that finds the entities whose vector fields are closest to a certain query vector. Create + * an instance of `FindNearest` with {@link Query}. + */ +public final class FindNearest implements Serializable { + + /** An indexed vector property to search upon. */ + private final String vectorProperty; + /** The query vector that we are searching on. */ + private final VectorValue queryVector; + /** The Distance Measure to use, required. */ + private final DistanceMeasure measure; + /** The number of nearest neighbors to return. Must be a positive integer of no more than 100. */ + private final int limit; + + /** + * Optional. Optional name of the field to output the result of the vector distance calculation. + */ + private final @Nullable String distanceResultField; + + /** + * Optional. Option to specify a threshold for which no less similar documents will be returned. + * The behavior of the specified `distance_measure` will affect the meaning of the distance + * threshold. + */ + private final @Nullable Double distanceThreshold; + + private static final long serialVersionUID = 4688656124180403551L; + + /** Creates a FindNearest query. */ + public FindNearest( + String vectorProperty, + VectorValue queryVector, + DistanceMeasure measure, + int limit, + @Nullable String distanceResultField, + @Nullable Double distanceThreshold) { + this.vectorProperty = vectorProperty; + this.queryVector = queryVector; + this.measure = measure; + this.limit = limit; + this.distanceResultField = distanceResultField; + this.distanceThreshold = distanceThreshold; + } + + public FindNearest( + String vectorProperty, VectorValue queryVector, DistanceMeasure measure, int limit) { + this(vectorProperty, queryVector, measure, limit, null, null); + } + + public FindNearest( + String vectorProperty, + VectorValue queryVector, + DistanceMeasure measure, + int limit, + @Nullable String distanceResultField) { + this(vectorProperty, queryVector, measure, limit, distanceResultField, null); + } + + public FindNearest( + String vectorProperty, + VectorValue queryVector, + DistanceMeasure measure, + int limit, + @Nullable Double distanceThreshold) { + this(vectorProperty, queryVector, measure, limit, null, distanceThreshold); + } + + @Override + public int hashCode() { + return Objects.hash( + vectorProperty, queryVector, measure, limit, distanceResultField, distanceThreshold); + } + + /** + * Returns true if this FindNearest query is equal to the provided object. + * + * @param obj The object to compare against. + * @return Whether this FindNearest query is equal to the provided object. + */ + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || !(obj instanceof FindNearest)) { + return false; + } + FindNearest otherQuery = (FindNearest) obj; + return Objects.equals(vectorProperty, otherQuery.vectorProperty) + && Objects.equals(queryVector, otherQuery.queryVector) + && Objects.equals(distanceResultField, otherQuery.distanceResultField) + && Objects.equals(distanceThreshold, otherQuery.distanceThreshold) + && limit == otherQuery.limit + && measure == otherQuery.measure; + } + + @Override + public String toString() { + ToStringHelper toStringHelper = MoreObjects.toStringHelper(this); + toStringHelper.add("vectorProperty", vectorProperty); + toStringHelper.add("queryVector", queryVector); + toStringHelper.add("measure", measure); + toStringHelper.add("limit", limit); + toStringHelper.add("distanceResultField", distanceResultField); + toStringHelper.add("distanceThreshold", distanceThreshold); + return toStringHelper.toString(); + } + + static FindNearest fromPb(com.google.datastore.v1.FindNearest findNearestPb) { + String vectorProperty = findNearestPb.getVectorProperty().getName(); + VectorValue queryVector = + VectorValue.MARSHALLER.fromProto(findNearestPb.getQueryVector()).build(); + DistanceMeasure distanceMeasure = + DistanceMeasure.valueOf(findNearestPb.getDistanceMeasure().toString()); + int limit = findNearestPb.getLimit().getValue(); + String distanceResultField = + findNearestPb.getDistanceResultProperty() == null + || findNearestPb.getDistanceResultProperty().isEmpty() + ? null + : findNearestPb.getDistanceResultProperty(); + Double distanceThreshold = + findNearestPb.getDistanceThreshold() == null + || findNearestPb.getDistanceThreshold() == DoubleValue.getDefaultInstance() + ? null + : findNearestPb.getDistanceThreshold().getValue(); + return new FindNearest( + vectorProperty, + queryVector, + distanceMeasure, + limit, + distanceResultField, + distanceThreshold); + } + + com.google.datastore.v1.FindNearest toPb() { + com.google.datastore.v1.FindNearest.Builder findNearestPb = + com.google.datastore.v1.FindNearest.newBuilder(); + findNearestPb.getVectorPropertyBuilder().setName(vectorProperty); + findNearestPb.setQueryVector(queryVector.toPb()); + findNearestPb.setDistanceMeasure(toProto(measure)); + findNearestPb.setLimit(Int32Value.of(limit)); + if (distanceResultField != null) { + findNearestPb.setDistanceResultProperty(distanceResultField); + } + if (distanceThreshold != null) { + findNearestPb.setDistanceThreshold(DoubleValue.of(distanceThreshold)); + } + return findNearestPb.build(); + } + + protected static com.google.datastore.v1.FindNearest.DistanceMeasure toProto( + DistanceMeasure distanceMeasure) { + switch (distanceMeasure) { + case COSINE: + return com.google.datastore.v1.FindNearest.DistanceMeasure.COSINE; + case EUCLIDEAN: + return com.google.datastore.v1.FindNearest.DistanceMeasure.EUCLIDEAN; + case DOT_PRODUCT: + return com.google.datastore.v1.FindNearest.DistanceMeasure.DOT_PRODUCT; + default: + return com.google.datastore.v1.FindNearest.DistanceMeasure.UNRECOGNIZED; + } + } + + /** The distance measure to use when comparing vectors in a {@link FindNearest query}. */ + public enum DistanceMeasure { + DISTANCE_MEASURE_UNSPECIFIED, + /** + * COSINE distance compares vectors based on the angle between them, which allows you to measure + * similarity that isn't based on the vectors' magnitude. We recommend using DOT_PRODUCT with + * unit normalized vectors instead of COSINE distance, which is mathematically equivalent with + * better performance. + */ + COSINE, + /** Measures the EUCLIDEAN distance between the vectors. */ + EUCLIDEAN, + /** Similar to cosine but is affected by the magnitude of the vectors. */ + DOT_PRODUCT + } +} diff --git a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/StructuredQuery.java b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/StructuredQuery.java index 30cd05759..bd6b9f222 100644 --- a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/StructuredQuery.java +++ b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/StructuredQuery.java @@ -101,6 +101,7 @@ public abstract class StructuredQuery extends Query implements RecordQuery private final Cursor endCursor; private final int offset; private final Integer limit; + private final FindNearest findNearest; private final ResultType resultType; @@ -731,6 +732,9 @@ public interface Builder { /** Adds settings to the existing order by clause. */ Builder addOrderBy(OrderBy orderBy, OrderBy... others); + /** Sets the find_nearest for the query. */ + Builder setFindNearest(FindNearest findNearest); + StructuredQuery build(); } @@ -753,6 +757,7 @@ abstract static class BuilderImpl> implements Bui private Cursor endCursor; private int offset; private Integer limit; + private FindNearest findNearest; BuilderImpl(ResultType resultType) { this.resultType = resultType; @@ -770,6 +775,7 @@ abstract static class BuilderImpl> implements Bui endCursor = query.endCursor; offset = query.offset; limit = query.limit; + findNearest = query.findNearest; } @SuppressWarnings("unchecked") @@ -841,6 +847,13 @@ public B addOrderBy(OrderBy orderBy, OrderBy... others) { return self(); } + @Override + public B setFindNearest(FindNearest findNearest) { + Preconditions.checkArgument(findNearest != null, "vector query must not be null"); + this.findNearest = findNearest; + return self(); + } + B clearProjection() { projection.clear(); return self(); @@ -904,6 +917,10 @@ B mergeFrom(com.google.datastore.v1.Query queryPb) { for (com.google.datastore.v1.PropertyReference distinctOnPb : queryPb.getDistinctOnList()) { addDistinctOn(distinctOnPb.getName()); } + if (queryPb.getFindNearest() != null + && queryPb.getFindNearest() != com.google.datastore.v1.FindNearest.getDefaultInstance()) { + setFindNearest(FindNearest.fromPb(queryPb.getFindNearest())); + } return self(); } } @@ -920,6 +937,7 @@ B mergeFrom(com.google.datastore.v1.Query queryPb) { endCursor = builder.endCursor; offset = builder.offset; limit = builder.limit; + findNearest = builder.findNearest; } @Override @@ -935,6 +953,7 @@ public String toString() { .add("orderBy", orderBy) .add("projection", projection) .add("distinctOn", distinctOn) + .add("findNearest", findNearest) .toString(); } @@ -950,7 +969,8 @@ public int hashCode() { filter, orderBy, projection, - distinctOn); + distinctOn, + findNearest); } @Override @@ -971,7 +991,8 @@ public boolean equals(Object obj) { && Objects.equals(filter, other.filter) && Objects.equals(orderBy, other.orderBy) && Objects.equals(projection, other.projection) - && Objects.equals(distinctOn, other.distinctOn); + && Objects.equals(distinctOn, other.distinctOn) + && Objects.equals(findNearest, other.findNearest); } /** Returns the kind for this query. */ @@ -1023,6 +1044,11 @@ public Integer getLimit() { return limit; } + /** Returns the vector query for this query. */ + public FindNearest getFindNearest() { + return findNearest; + } + public abstract Builder toBuilder(); @InternalApi diff --git a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/StructuredQueryProtoPreparer.java b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/StructuredQueryProtoPreparer.java index fda6f8f4a..c7e39f3d4 100644 --- a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/StructuredQueryProtoPreparer.java +++ b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/StructuredQueryProtoPreparer.java @@ -60,6 +60,9 @@ public Query prepare(StructuredQuery query) { .build(); queryPb.addProjection(expressionPb); } + if (query.getFindNearest() != null) { + queryPb.setFindNearest(query.getFindNearest().toPb()); + } return queryPb.build(); } diff --git a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/Value.java b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/Value.java index 4bd0a5133..40b4e59a4 100644 --- a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/Value.java +++ b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/Value.java @@ -16,6 +16,7 @@ package com.google.cloud.datastore; +import static com.google.cloud.datastore.VectorValue.VECTOR_MEANING; import static com.google.common.base.Preconditions.checkNotNull; import com.google.cloud.GcpLaunchStage; @@ -214,8 +215,19 @@ com.google.datastore.v1.Value toPb() { public static Value fromPb(com.google.datastore.v1.Value proto) { ValueTypeCase descriptorId = proto.getValueTypeCase(); ValueType valueType = ValueType.getByDescriptorId(descriptorId.getNumber()); - return valueType == null - ? RawValue.MARSHALLER.fromProto(proto).build() - : valueType.getMarshaller().fromProto(proto).build(); + if (valueType == null) return RawValue.MARSHALLER.fromProto(proto).build(); + + Value returnValue = valueType.getMarshaller().fromProto(proto).build(); + // If the proto is a list of doubles with a meaning of 31, use the VectorValue marshaller. + if (valueType == ValueType.LIST && proto.getMeaning() == VECTOR_MEANING) { + for (com.google.datastore.v1.Value item : proto.getArrayValue().getValuesList()) { + if (item.getValueTypeCase() != ValueTypeCase.DOUBLE_VALUE) { + return returnValue; + } + } + returnValue = VectorValue.MARSHALLER.fromProto(proto).build(); + } + + return returnValue; } } diff --git a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/ValueType.java b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/ValueType.java index 13e3c7af6..d52b43236 100644 --- a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/ValueType.java +++ b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/ValueType.java @@ -19,7 +19,7 @@ import com.google.common.collect.ImmutableMap; /** - * The type of a Datastore property. + * The type of Datastore property. * * @see Google @@ -61,7 +61,10 @@ public enum ValueType { RAW_VALUE(RawValue.MARSHALLER), /** Represents a {@link LatLng} value. */ - LAT_LNG(LatLngValue.MARSHALLER); + LAT_LNG(LatLngValue.MARSHALLER), + + /** Represents a {@link VectorValue} value. */ + VECTOR(VectorValue.MARSHALLER); private static final ImmutableMap DESCRIPTOR_TO_TYPE_MAP; diff --git a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/VectorValue.java b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/VectorValue.java new file mode 100644 index 000000000..ce27018c4 --- /dev/null +++ b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/VectorValue.java @@ -0,0 +1,154 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.datastore; + +import com.google.common.collect.ImmutableList; +import java.util.ArrayList; +import java.util.List; + +/** + * A Google Cloud Datastore Vector value. A Vector value is a list of Double {@link Value} objects. + */ +public final class VectorValue extends Value>> { + + private static final long serialVersionUID = -5121887228607148859L; + + public static final int VECTOR_MEANING = 31; + + static final BaseMarshaller>, VectorValue, Builder> MARSHALLER = + new BaseMarshaller>, VectorValue, Builder>() { + private static final long serialVersionUID = 7720473855548179943L; + + @Override + public int getProtoFieldId() { + return -1; + } + + @Override + public Builder newBuilder(List> values) { + return VectorValue.newBuilder().set(values); + } + + @Override + protected List> getValue(com.google.datastore.v1.Value from) { + List> properties = new ArrayList<>(from.getArrayValue().getValuesCount()); + for (com.google.datastore.v1.Value valuePb : from.getArrayValue().getValuesList()) { + properties.add((Value) Value.fromPb(valuePb)); + } + return properties; + } + + @Override + protected void setValue(VectorValue from, com.google.datastore.v1.Value.Builder to) { + List propertiesPb = new ArrayList<>(); + for (Value property : from.get()) { + propertiesPb.add(property.toPb()); + } + to.setArrayValue( + com.google.datastore.v1.ArrayValue.newBuilder().addAllValues(propertiesPb)); + } + }; + + public static final class Builder + extends Value.BaseBuilder>, VectorValue, Builder> { + private ImmutableList.Builder> vectorBuilder = ImmutableList.builder(); + + private Builder() { + super(ValueType.VECTOR); + } + + /** Adds the provided double values to the {@code VectorValue} builder. */ + public VectorValue.Builder addValue(Value first, Value... other) { + vectorBuilder.add(first); + for (Value value : other) { + vectorBuilder.add(value); + } + return this; + } + + public VectorValue.Builder addValue(double first, double... other) { + vectorBuilder.add(DoubleValue.of(first)); + for (double value : other) { + vectorBuilder.add(DoubleValue.of(value)); + } + return this; + } + + /** + * Sets the list of values of this {@code VectorValue} builder to {@code values}. The provided + * list is copied. + * + * @see com.google.cloud.datastore.Value.BaseBuilder#set(java.lang.Object) + */ + @Override + public Builder set(List> values) { + vectorBuilder = ImmutableList.builder(); + for (Value value : values) { + addValue(value); + } + return this; + } + + @Override + public List> get() { + return vectorBuilder.build(); + } + + /** Creates a {@code VectorValue} object. */ + @Override + public VectorValue build() { + return new VectorValue(this); + } + } + + public VectorValue(List> values) { + this(newBuilder().set(values)); + } + + private VectorValue(Builder builder) { + super(builder); + } + + /** Returns a builder for the vector value object. */ + @Override + public Builder toBuilder() { + return new Builder().mergeFrom(this); + } + + /** Creates a {@code VectorValue} object given a number of double values. */ + public static VectorValue of(double first, double... other) { + return newBuilder().addValue(first, other).build(); + } + + /** Creates a {@code VectorValue} object given a list of {@code Value} objects. */ + public static VectorValue of(List> values) { + return new VectorValue(values); + } + + /** Returns a builder for {@code VectorValue} objects. */ + public static Builder newBuilder() { + Builder builder = new VectorValue.Builder(); + builder.setMeaning(VECTOR_MEANING); + return builder; + } + + public static Builder newBuilder(double first, double... other) { + VectorValue.Builder builder = new VectorValue.Builder(); + builder.setMeaning(VECTOR_MEANING); + return builder.addValue(first, other); + } +} diff --git a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/BaseEntityTest.java b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/BaseEntityTest.java index 1b5380ab9..3e01999aa 100644 --- a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/BaseEntityTest.java +++ b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/BaseEntityTest.java @@ -16,6 +16,7 @@ package com.google.cloud.datastore; +import static com.google.cloud.datastore.VectorValue.VECTOR_MEANING; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -35,6 +36,8 @@ public class BaseEntityTest { private static final Blob BLOB = Blob.copyFrom(new byte[] {1, 2}); private static final Timestamp TIMESTAMP = Timestamp.now(); private static final LatLng LAT_LNG = new LatLng(37.422035, -122.084124); + private static final VectorValue VECTOR = + VectorValue.newBuilder(1.78, 2.56, 3.88).setMeaning(VECTOR_MEANING).build(); private static final Key KEY = Key.newBuilder("ds1", "k1", "n1").build(); private static final Entity ENTITY = Entity.newBuilder(KEY).set("name", "foo").build(); private static final IncompleteKey INCOMPLETE_KEY = IncompleteKey.newBuilder("ds1", "k1").build(); @@ -76,6 +79,7 @@ public void setUp() { builder.set("stringList", "s1", "s2", "s3"); builder.set("longList", 1, 23, 456); builder.set("latLngList", LAT_LNG, LAT_LNG); + builder.set("vector", VECTOR); } @Test @@ -182,6 +186,16 @@ public void testGetEntity() { assertEquals(PARTIAL_ENTITY, entity.getEntity("entity")); } + @Test + public void testGetVector() { + BaseEntity entity = builder.build(); + List vectorList = entity.getVector("vector"); + assertEquals(3, vectorList.size()); + assertEquals(Double.valueOf(1.78), vectorList.get(0).get()); + assertEquals(Double.valueOf(2.56), vectorList.get(1).get()); + assertEquals(Double.valueOf(3.88), vectorList.get(2).get()); + } + @Test public void testGetList() { BaseEntity entity = builder.build(); @@ -229,7 +243,7 @@ public void testNames() { .add("entity", "partialEntity", "null", "timestamp", "blob", "key", "blobList") .add( "booleanList", "timestampList", "doubleList", "keyList", "entityList", "stringList") - .add("longList", "latLng", "latLngList") + .add("longList", "latLng", "latLngList", "vector") .build(); BaseEntity entity = builder.build(); assertEquals(names, entity.getNames()); diff --git a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/ProtoTestData.java b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/ProtoTestData.java index 8e2ba890a..57f71039c 100644 --- a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/ProtoTestData.java +++ b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/ProtoTestData.java @@ -17,6 +17,7 @@ import static com.google.datastore.v1.PropertyOrder.Direction.ASCENDING; +import com.google.cloud.datastore.FindNearest.DistanceMeasure; import com.google.datastore.v1.AggregationQuery.Aggregation; import com.google.datastore.v1.AggregationQuery.Aggregation.Count; import com.google.datastore.v1.Filter; @@ -27,6 +28,9 @@ import com.google.datastore.v1.PropertyOrder; import com.google.datastore.v1.PropertyReference; import com.google.datastore.v1.Value; +import com.google.protobuf.DoubleValue; +import com.google.protobuf.Int32Value; +import javax.annotation.Nullable; public class ProtoTestData { @@ -83,4 +87,33 @@ public static PropertyOrder propertyOrder(String value) { public static Projection projection(String value) { return Projection.newBuilder().setProperty(propertyReference(value)).build(); } + + public static com.google.datastore.v1.FindNearest FindNearest( + String vectorProperty, VectorValue queryVector, DistanceMeasure measure, int limit) { + return FindNearest(vectorProperty, queryVector, measure, limit, null, null); + } + + public static com.google.datastore.v1.FindNearest FindNearest( + String vectorProperty, + VectorValue queryVector, + DistanceMeasure measure, + int limit, + @Nullable String distanceResultField, + @Nullable Double distanceThreshold) { + com.google.datastore.v1.FindNearest.Builder builder = + com.google.datastore.v1.FindNearest.newBuilder() + .setVectorProperty(propertyReference(vectorProperty)) + .setQueryVector(queryVector.toPb()) + .setDistanceMeasure(FindNearest.toProto(measure)) + .setLimit(Int32Value.of(limit)); + + if (distanceResultField != null) { + builder.setDistanceResultProperty(distanceResultField); + } + if (distanceThreshold != null) { + builder.setDistanceThreshold(DoubleValue.of(distanceThreshold)); + } + + return builder.build(); + } } diff --git a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/StructuredQueryProtoPreparerTest.java b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/StructuredQueryProtoPreparerTest.java index 60937fc28..549a8876b 100644 --- a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/StructuredQueryProtoPreparerTest.java +++ b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/StructuredQueryProtoPreparerTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2022 Google LLC + * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ */ package com.google.cloud.datastore; +import static com.google.cloud.datastore.ProtoTestData.FindNearest; import static com.google.cloud.datastore.ProtoTestData.booleanValue; import static com.google.cloud.datastore.ProtoTestData.projection; import static com.google.cloud.datastore.ProtoTestData.propertyFilter; @@ -86,6 +87,18 @@ public void testFilter() { assertThat(queryProto.getFilter()).isEqualTo(propertyFilter("done", EQUAL, booleanValue(true))); } + @Test + public void testFindNearest() { + VectorValue VECTOR_VALUE = VectorValue.newBuilder(1.78, 2.56, 3.88).build(); + FindNearest FIND_NEAREST = + new FindNearest("vector_property", VECTOR_VALUE, FindNearest.DistanceMeasure.COSINE, 1); + Query queryProto = + protoPreparer.prepare(newEntityQueryBuilder().setFindNearest(FIND_NEAREST).build()); + assertThat(queryProto.getFindNearest()) + .isEqualTo( + FindNearest("vector_property", VECTOR_VALUE, FindNearest.DistanceMeasure.COSINE, 1)); + } + @Test public void testOrderBy() { Query queryProto = diff --git a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/StructuredQueryTest.java b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/StructuredQueryTest.java index c59337586..41b346766 100644 --- a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/StructuredQueryTest.java +++ b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/StructuredQueryTest.java @@ -50,6 +50,9 @@ public class StructuredQueryTest { private static final String DISTINCT_ON1 = "p6"; private static final String DISTINCT_ON2 = "p7"; private static final List DISTINCT_ON = ImmutableList.of(DISTINCT_ON1, DISTINCT_ON2); + private static final VectorValue VECTOR_VALUE = VectorValue.newBuilder(1.78, 2.56, 3.88).build(); + private static final FindNearest FIND_NEAREST = + new FindNearest("vector_property", VECTOR_VALUE, FindNearest.DistanceMeasure.COSINE, 1); private static final EntityQuery ENTITY_QUERY = Query.newEntityQueryBuilder() .setNamespace(NAMESPACE) @@ -60,6 +63,7 @@ public class StructuredQueryTest { .setLimit(LIMIT) .setFilter(AND_FILTER) .setOrderBy(ORDER_BY_1, ORDER_BY_2) + .setFindNearest(FIND_NEAREST) .build(); private static final KeyQuery KEY_QUERY = Query.newKeyQueryBuilder() @@ -71,6 +75,7 @@ public class StructuredQueryTest { .setLimit(LIMIT) .setFilter(OR_FILTER) .setOrderBy(ORDER_BY_1, ORDER_BY_2) + .setFindNearest(FIND_NEAREST) .build(); private static final ProjectionEntityQuery PROJECTION_QUERY = Query.newProjectionEntityQueryBuilder() @@ -82,6 +87,7 @@ public class StructuredQueryTest { .setLimit(LIMIT) .setFilter(AND_FILTER) .setOrderBy(ORDER_BY_1, ORDER_BY_2) + .setFindNearest(FIND_NEAREST) .setProjection(PROJECTION1, PROJECTION2) .setDistinctOn(DISTINCT_ON1, DISTINCT_ON2) .build(); @@ -123,6 +129,7 @@ private void compareBaseBuilderFields(StructuredQuery query) { assertEquals(LIMIT, query.getLimit()); assertEquals(AND_FILTER, query.getFilter()); assertEquals(ORDER_BY, query.getOrderBy()); + assertEquals(FIND_NEAREST, query.getFindNearest()); } @Test diff --git a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/ValueTest.java b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/ValueTest.java index 8d53dc736..773746281 100644 --- a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/ValueTest.java +++ b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/ValueTest.java @@ -42,6 +42,7 @@ public class ValueTest { private static final RawValue RAW_VALUE = RawValue.of(STRING_VALUE.toPb()); private static final LatLngValue LAT_LNG_VALUE = LatLngValue.of(new LatLng(37.422035, -122.084124)); + private static final VectorValue VECTOR_VALUE = VectorValue.newBuilder(1.78, 2.56, 3.88).build(); private static final ImmutableMap TYPES = ImmutableMap.builder() .put(ValueType.NULL, new Object[] {NullValue.class, NULL_VALUE.get()}) @@ -57,6 +58,7 @@ public class ValueTest { .put(ValueType.LONG, new Object[] {LongValue.class, 123L}) .put(ValueType.RAW_VALUE, new Object[] {RawValue.class, RAW_VALUE.get()}) .put(ValueType.LAT_LNG, new Object[] {LatLngValue.class, LAT_LNG_VALUE.get()}) + .put(ValueType.VECTOR, new Object[] {VectorValue.class, VECTOR_VALUE.get()}) .put(ValueType.STRING, new Object[] {StringValue.class, STRING_VALUE.get()}) .buildOrThrow(); diff --git a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/VectorValueTest.java b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/VectorValueTest.java new file mode 100644 index 000000000..69330195d --- /dev/null +++ b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/VectorValueTest.java @@ -0,0 +1,59 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.datastore; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import org.junit.Test; + +public class VectorValueTest { + private static final List> vectorList = + ImmutableList.of(DoubleValue.of(1.2), DoubleValue.of(3.6)); + + @Test + public void testToBuilder() { + VectorValue value = VectorValue.of(0.3, 4.2, 3.7); + assertEquals(value, value.toBuilder().build()); + } + + @Test + public void testOf() { + VectorValue value = VectorValue.of(0.3, 4.2, 3.7); + assertEquals( + ImmutableList.of(DoubleValue.of(0.3), DoubleValue.of(4.2), DoubleValue.of(3.7)), + value.get()); + assertEquals(31, value.getMeaning()); + VectorValue vectorListValue = VectorValue.of(vectorList); + assertEquals(vectorList, vectorListValue.get()); + assertEquals(31, vectorListValue.getMeaning()); + } + + @SuppressWarnings("deprecation") + @Test + public void testBuilder() { + VectorValue.Builder builder = VectorValue.newBuilder(0.3, 4.2, 3.7); + VectorValue value = builder.setExcludeFromIndexes(true).build(); + assertEquals( + ImmutableList.of(DoubleValue.of(0.3), DoubleValue.of(4.2), DoubleValue.of(3.7)), + value.get()); + assertEquals(31, value.getMeaning()); + assertTrue(value.excludeFromIndexes()); + } +} diff --git a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/it/ITDatastoreConceptsTest.java b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/it/ITDatastoreConceptsTest.java index 770065778..684730365 100644 --- a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/it/ITDatastoreConceptsTest.java +++ b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/it/ITDatastoreConceptsTest.java @@ -23,29 +23,10 @@ import static org.junit.Assert.fail; import com.google.cloud.Timestamp; -import com.google.cloud.datastore.Cursor; -import com.google.cloud.datastore.Datastore; -import com.google.cloud.datastore.DatastoreException; -import com.google.cloud.datastore.DatastoreOptions; -import com.google.cloud.datastore.Entity; -import com.google.cloud.datastore.EntityQuery; -import com.google.cloud.datastore.FullEntity; -import com.google.cloud.datastore.IncompleteKey; -import com.google.cloud.datastore.Key; -import com.google.cloud.datastore.KeyFactory; -import com.google.cloud.datastore.KeyQuery; -import com.google.cloud.datastore.ListValue; -import com.google.cloud.datastore.PathElement; -import com.google.cloud.datastore.ProjectionEntity; -import com.google.cloud.datastore.Query; -import com.google.cloud.datastore.QueryResults; -import com.google.cloud.datastore.ReadOption; -import com.google.cloud.datastore.StringValue; -import com.google.cloud.datastore.StructuredQuery; +import com.google.cloud.datastore.*; import com.google.cloud.datastore.StructuredQuery.CompositeFilter; import com.google.cloud.datastore.StructuredQuery.OrderBy; import com.google.cloud.datastore.StructuredQuery.PropertyFilter; -import com.google.cloud.datastore.Transaction; import com.google.cloud.datastore.testing.RemoteDatastoreHelper; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -175,6 +156,7 @@ private void setUpQueryTests() { "description", StringValue.newBuilder("Learn Cloud Datastore").setExcludeFromIndexes(true).build()) .set("tag", "fun", "l", "programming", "learn") + .set("embedding_field", VectorValue.newBuilder(3.0, 1.0, 2.0).build()) .build()); } @@ -591,6 +573,38 @@ public void testEqualAndInequalityRange() { assertValidQuery(query); } + @Test + public void testVectorSearch() { + VectorValue vectorValue = VectorValue.newBuilder(1.78, 2.56, 3.88).build(); + FindNearest vectorQuery = + new FindNearest( + "embedding_field", vectorValue, FindNearest.DistanceMeasure.COSINE, 1, "distance"); + + Query query = + Query.newEntityQueryBuilder().setKind(TASK_CONCEPTS).setFindNearest(vectorQuery).build(); + assertValidQuery(query); + } + + @Test + public void testVectorSearchWithEmptyVector() { + VectorValue emptyVector = VectorValue.newBuilder().build(); + FindNearest vectorQuery = + new FindNearest("embedding_field", emptyVector, FindNearest.DistanceMeasure.EUCLIDEAN, 1); + Query query = + Query.newEntityQueryBuilder().setKind(TASK_CONCEPTS).setFindNearest(vectorQuery).build(); + assertInvalidQuery(query); + } + + @Test + public void testVectorSearchWithUnmatchedVectorSize() { + VectorValue vectorValue = VectorValue.newBuilder(1.78, 2.56, 3.88, 4.33).build(); + FindNearest vectorQuery = + new FindNearest("embedding_field", vectorValue, FindNearest.DistanceMeasure.DOT_PRODUCT, 1); + Query query = + Query.newEntityQueryBuilder().setKind(TASK_CONCEPTS).setFindNearest(vectorQuery).build(); + assertInvalidQuery(query); + } + @Test public void testInequalitySort() { Query query = diff --git a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/it/ITDatastoreTest.java b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/it/ITDatastoreTest.java index bf0c20dce..c869b52a6 100644 --- a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/it/ITDatastoreTest.java +++ b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/it/ITDatastoreTest.java @@ -32,44 +32,12 @@ import com.google.cloud.Timestamp; import com.google.cloud.Tuple; -import com.google.cloud.datastore.AggregationQuery; -import com.google.cloud.datastore.AggregationResult; -import com.google.cloud.datastore.AggregationResults; -import com.google.cloud.datastore.Batch; -import com.google.cloud.datastore.BooleanValue; -import com.google.cloud.datastore.Cursor; -import com.google.cloud.datastore.Datastore; +import com.google.cloud.datastore.*; import com.google.cloud.datastore.Datastore.TransactionCallable; -import com.google.cloud.datastore.DatastoreException; -import com.google.cloud.datastore.DatastoreOptions; -import com.google.cloud.datastore.DatastoreReaderWriter; -import com.google.cloud.datastore.Entity; -import com.google.cloud.datastore.EntityQuery; -import com.google.cloud.datastore.EntityValue; -import com.google.cloud.datastore.FullEntity; -import com.google.cloud.datastore.GqlQuery; -import com.google.cloud.datastore.IncompleteKey; -import com.google.cloud.datastore.Key; -import com.google.cloud.datastore.KeyFactory; -import com.google.cloud.datastore.KeyValue; -import com.google.cloud.datastore.LatLng; -import com.google.cloud.datastore.LatLngValue; -import com.google.cloud.datastore.ListValue; -import com.google.cloud.datastore.NullValue; -import com.google.cloud.datastore.PathElement; -import com.google.cloud.datastore.ProjectionEntity; -import com.google.cloud.datastore.Query; import com.google.cloud.datastore.Query.ResultType; -import com.google.cloud.datastore.QueryResults; -import com.google.cloud.datastore.ReadOption; -import com.google.cloud.datastore.StringValue; -import com.google.cloud.datastore.StructuredQuery; import com.google.cloud.datastore.StructuredQuery.CompositeFilter; import com.google.cloud.datastore.StructuredQuery.OrderBy; import com.google.cloud.datastore.StructuredQuery.PropertyFilter; -import com.google.cloud.datastore.TimestampValue; -import com.google.cloud.datastore.Transaction; -import com.google.cloud.datastore.ValueType; import com.google.cloud.datastore.models.ExecutionStats; import com.google.cloud.datastore.models.ExplainMetrics; import com.google.cloud.datastore.models.ExplainOptions; @@ -125,6 +93,7 @@ public class ITDatastoreTest { private static final String KIND1 = "kind1"; private static final String KIND2 = "kind2"; private static final String KIND3 = "kind3"; + private static final String VECTOR_KIND = "CoffeeBean"; private static final NullValue NULL_VALUE = NullValue.of(); private static final StringValue STR_VALUE = StringValue.of("str"); private static final BooleanValue BOOL_VALUE = @@ -145,6 +114,7 @@ public class ITDatastoreTest { private static Key KEY4; private static Key KEY5; private static Key KEY6; + private static Key VECTORKEY; private static final String MARKS_KIND = "Marks"; private static FullEntity PARTIAL_ENTITY1; private static FullEntity PARTIAL_ENTITY2; @@ -156,6 +126,9 @@ public class ITDatastoreTest { private static Entity AGGREGATION_ENTITY_1; private static Entity AGGREGATION_ENTITY_2; private static Entity AGGREGATION_ENTITY_3; + private static Entity VECTOR_ENTITY_1; + private static Entity VECTOR_ENTITY_2; + private static Entity VECTOR_ENTITY_3; @Rule public Timeout globalTimeout = Timeout.seconds(100); @@ -176,6 +149,13 @@ public ITDatastoreTest( PROJECT_ID = this.options.getProjectId(); NAMESPACE = this.options.getNamespace(); + System.out.println( + "Project: " + + PROJECT_ID + + ", Namespace: " + + NAMESPACE + + ", db: " + + options.getDatabaseId()); ROOT_KEY = Key.newBuilder(PROJECT_ID, "rootkey", "default", options.getDatabaseId()) @@ -198,6 +178,10 @@ public ITDatastoreTest( Key.newBuilder(options.getProjectId(), KIND2, 100, options.getDatabaseId()) .setNamespace(NAMESPACE) .build(); + VECTORKEY = + Key.newBuilder(PROJECT_ID, VECTOR_KIND, "bean1", options.getDatabaseId()) + .setNamespace(NAMESPACE) + .build(); LIST_VALUE2 = ListValue.of(Collections.singletonList(KeyValue.of(KEY1))); @@ -247,6 +231,22 @@ public ITDatastoreTest( .set("partial1", PARTIAL_ENTITY2) .set("partial2", ENTITY2) .build(); + VECTOR_ENTITY_1 = + Entity.newBuilder(VECTORKEY) + .set("name", "Arabica") + .set("embedding_field", VectorValue.newBuilder(1.0, 7.0, 11.1).build()) + .build(); + VECTOR_ENTITY_2 = + Entity.newBuilder(Key.newBuilder(VECTORKEY).setName("bean2").build()) + .set("name", "Robusta") + .set("embedding_field", VectorValue.newBuilder(1.0, 9.0, 11.1).build()) + .set("vector_distance", 0) + .build(); + VECTOR_ENTITY_3 = + Entity.newBuilder(Key.newBuilder(VECTORKEY).setName("bean3").build()) + .set("name", "Excelsa") + .set("embedding_field", VectorValue.newBuilder(4.0, 9.0, 11.1).build()) + .build(); Key aggregationKey1 = datastore.newKeyFactory().setKind(MARKS_KIND).newKey(1); Key aggregationKey2 = datastore.newKeyFactory().setKind(MARKS_KIND).newKey(2); @@ -2115,6 +2115,93 @@ public void testQueryWithStartCursor() { datastore.delete(entity1.getKey(), entity2.getKey(), entity3.getKey()); } + @Test + public void testVectorSearchQueryWithLimit() { + datastore.put(VECTOR_ENTITY_1, VECTOR_ENTITY_2, VECTOR_ENTITY_3); + // Test FindNearest query with limit + FindNearest findNearestQueryWithLimit = + new FindNearest( + "embedding_field", + VectorValue.newBuilder(1, 9, 11.1).build(), + FindNearest.DistanceMeasure.DOT_PRODUCT, + 3); + Query queryWithLimit = + Query.newEntityQueryBuilder() + .setKind(VECTOR_KIND) + .setFindNearest(findNearestQueryWithLimit) + .build(); + + QueryResults resultWithLimit = datastore.run(queryWithLimit); + + List resultsCopyWithLimit = makeResultsCopy(resultWithLimit); + + // Verify limit was applied + assertEquals(3, resultsCopyWithLimit.size()); + } + + @Test + public void testVectorSearchQueryWithDistanceThreshold() { + datastore.put(VECTOR_ENTITY_1, VECTOR_ENTITY_2, VECTOR_ENTITY_3); + + VectorValue vectorValue = VectorValue.newBuilder(1.78, 2.56, 3.88).build(); + FindNearest vectorQuery = + new FindNearest( + "embedding_field", vectorValue, FindNearest.DistanceMeasure.COSINE, 1, "distance"); + + Query query = Query.newEntityQueryBuilder().setFindNearest(vectorQuery).build(); + + // Test FindNearest query with distanceThreshold + FindNearest findNearestQueryWithThreshold = + new FindNearest( + "embedding_field", + VectorValue.newBuilder(1, 9, 11.1).build(), + FindNearest.DistanceMeasure.EUCLIDEAN, + 3, + "vector_distance", + 2.0); + Query queryWithWithThreshold = + Query.newEntityQueryBuilder() + .setKind(VECTOR_KIND) + .setFindNearest(findNearestQueryWithThreshold) + .build(); + QueryResults resultWithThreshold = datastore.run(queryWithWithThreshold); + List resultsCopyWithThreshold = makeResultsCopy(resultWithThreshold); + // Verify threshold was applied regardless of limit + assertEquals(2, resultsCopyWithThreshold.size()); + // Verify qualified EUCLIDEAN distance: d((1, 9, 11.1), (1, 9, 11.1)) = 0.0, d((1, 9, 11.1), (1, + // 7, 11.1)) = 2.0 + assertEquals(DoubleValue.of(0.0), resultsCopyWithThreshold.get(0).getValue("vector_distance")); + assertEquals(DoubleValue.of(2.0), resultsCopyWithThreshold.get(1).getValue("vector_distance")); + } + + @Test + public void testQueryWithVectorSearchWithDistanceField() { + datastore.put(VECTOR_ENTITY_1, VECTOR_ENTITY_2, VECTOR_ENTITY_3); + // Test FindNearest query with distanceField + FindNearest findNearestQueryWithDistanceField = + new FindNearest( + "embedding_field", + VectorValue.newBuilder(1, 9, 11.1).build(), + FindNearest.DistanceMeasure.DOT_PRODUCT, + 3, + "vector_distance", + 0.0); + Query queryWithWithDistanceField = + Query.newEntityQueryBuilder() + .setKind(VECTOR_KIND) + .setFindNearest(findNearestQueryWithDistanceField) + .build(); + QueryResults resultWithDistanceField = datastore.run(queryWithWithDistanceField); + List resultsCopyWithDistanceField = makeResultsCopy(resultWithDistanceField); + // Verify results count + assertEquals(3, resultsCopyWithDistanceField.size()); + for (int i = 0; i < resultsCopyWithDistanceField.size(); i++) { + // Verify distance field was not 0 + assertNotEquals( + DoubleValue.of(0.0), resultsCopyWithDistanceField.get(i).getValue("vector_distance")); + } + } + @Test public void testQueryWithReadTime() throws InterruptedException { Entity entity1 = @@ -2164,7 +2251,7 @@ public void testQueryWithReadTime() throws InterruptedException { assertEquals(entity2, withReadTime.next()); assertFalse(withReadTime.hasNext()); } finally { - datastore.delete(entity1.getKey(), entity2.getKey(), entity3.getKey()); + // datastore.delete(entity1.getKey(), entity2.getKey(), entity3.getKey()); } } diff --git a/google-cloud-datastore/src/test/resources/index.yaml b/google-cloud-datastore/src/test/resources/index.yaml index ff1b08626..54f47b764 100644 --- a/google-cloud-datastore/src/test/resources/index.yaml +++ b/google-cloud-datastore/src/test/resources/index.yaml @@ -45,4 +45,18 @@ indexes: properties: - name: done - name: priority - direction: desc \ No newline at end of file + direction: desc + - kind: TaskConcepts + properties: + - name: __key__ + - name: embedding_field + vectorConfig: + dimension: 3 + flat: { } + - kind: CoffeeBean + properties: + - name: __key__ + - name: embedding_field + vectorConfig: + dimension: 3 + flat: { } \ No newline at end of file diff --git a/samples/snippets/src/main/java/com/example/datastore/filters/OrderFieldsQuery.java b/samples/snippets/src/main/java/com/example/datastore/filters/OrderFieldsQuery.java index 24fc7901c..f55f2f7f0 100644 --- a/samples/snippets/src/main/java/com/example/datastore/filters/OrderFieldsQuery.java +++ b/samples/snippets/src/main/java/com/example/datastore/filters/OrderFieldsQuery.java @@ -17,7 +17,7 @@ package com.example.datastore.filters; // sample-metadata: -// title: Queries with order fileds +// title: Queries with order fields // description: The following query order properties // in the decreasing order of query constraint selectivity. diff --git a/samples/snippets/src/main/java/com/example/datastore/vectorsearch/StoreVectors.java b/samples/snippets/src/main/java/com/example/datastore/vectorsearch/StoreVectors.java new file mode 100644 index 000000000..a1db1ca80 --- /dev/null +++ b/samples/snippets/src/main/java/com/example/datastore/vectorsearch/StoreVectors.java @@ -0,0 +1,55 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.datastore.vectorsearch; + +// [START datastore_store_vectors] + +import com.google.cloud.datastore.Datastore; +import com.google.cloud.datastore.DatastoreOptions; +import com.google.cloud.datastore.Entity; +import com.google.cloud.datastore.Key; +import com.google.cloud.datastore.VectorValue; + +public class StoreVectors { + public static void invoke() throws Exception { + // Instantiates a client + Datastore datastore = DatastoreOptions.getDefaultInstance().getService(); + + // The Cloud Datastore key for the new entity + Key key = datastore.newKeyFactory().setKind("CoffeeBean").newKey("Kahawa"); + + // Prepares the entity with a vector embedding + Entity entity = + Entity.newBuilder(key) + .set("name", "Kahawa") + .set("description", "Information about the Kahawa coffee beans.") + .set("roast", "dark") + .set("embedding_field", VectorValue.newBuilder(1.0, 7.0, 11.1).build()) + .build(); + + // Saves the entity + datastore.put(entity); + System.out.printf("Saved %s: %s%n", entity.getKey().getName(), entity.getString("description")); + + // Retrieve entity + Entity retrieved = datastore.get(key); + System.out.printf( + "Retrieved %s with embedding_field: %s%n", + key.getName(), retrieved.getVector("embedding_field")); + } +} +// [END datastore_store_vectors] diff --git a/samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchBasic.java b/samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchBasic.java new file mode 100644 index 000000000..a0ca8b953 --- /dev/null +++ b/samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchBasic.java @@ -0,0 +1,59 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.datastore.vectorsearch; + +// [START datastore_vector_search_basic] + +import com.google.cloud.datastore.Datastore; +import com.google.cloud.datastore.DatastoreOptions; +import com.google.cloud.datastore.Entity; +import com.google.cloud.datastore.FindNearest; +import com.google.cloud.datastore.Query; +import com.google.cloud.datastore.QueryResults; +import com.google.cloud.datastore.VectorValue; + +public class VectorSearchBasic { + public static void invoke() throws Exception { + // Instantiates a client + Datastore datastore = DatastoreOptions.getDefaultInstance().getService(); + + // Create vector search query + Query vectorSearchQuery = + Query.newEntityQueryBuilder() + .setKind("CoffeeBean") + .setFindNearest( + new FindNearest( + "embedding_field", + VectorValue.newBuilder(1, 9, 11.1).build(), + FindNearest.DistanceMeasure.EUCLIDEAN, + 1)) + .build(); + + // Execute vector search query + QueryResults results = datastore.run(vectorSearchQuery); + + if (!results.hasNext()) { + throw new Exception("query yielded no results"); + } + + while (results.hasNext()) { + Entity entity = results.next(); + System.out.printf("Entity: %s%n", entity.getKey().getName()); + } + } +} +// [END datastore_vector_search_basic] diff --git a/samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchDistanceResultProperty.java b/samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchDistanceResultProperty.java new file mode 100644 index 000000000..ca308828b --- /dev/null +++ b/samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchDistanceResultProperty.java @@ -0,0 +1,62 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.datastore.vectorsearch; + +// [START datastore_vector_search_distance_result_property] + +import com.google.cloud.datastore.Datastore; +import com.google.cloud.datastore.DatastoreOptions; +import com.google.cloud.datastore.Entity; +import com.google.cloud.datastore.FindNearest; +import com.google.cloud.datastore.Query; +import com.google.cloud.datastore.QueryResults; +import com.google.cloud.datastore.VectorValue; + +public class VectorSearchDistanceResultProperty { + public static void invoke() throws Exception { + // Instantiates a client + Datastore datastore = DatastoreOptions.getDefaultInstance().getService(); + + // Create vector search query with distance result property + Query vectorSearchQuery = + Query.newEntityQueryBuilder() + .setKind("CoffeeBean") + .setFindNearest( + new FindNearest( + "embedding_field", + VectorValue.newBuilder(1, 9, 11.1).build(), + FindNearest.DistanceMeasure.DOT_PRODUCT, + 3, + "vector_distance")) + .build(); + + // Execute vector search query + QueryResults results = datastore.run(vectorSearchQuery); + + if (!results.hasNext()) { + throw new Exception("query yielded no results"); + } + + while (results.hasNext()) { + Entity entity = results.next(); + System.out.printf( + "Entity: %s, Distance: %s%n", + entity.getKey().getName(), entity.getDouble("vector_distance")); + } + } +} +// [END datastore_vector_search_distance_result_property] diff --git a/samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchDistanceResultPropertyProjection.java b/samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchDistanceResultPropertyProjection.java new file mode 100644 index 000000000..96ce2ccbb --- /dev/null +++ b/samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchDistanceResultPropertyProjection.java @@ -0,0 +1,63 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.datastore.vectorsearch; + +// [START datastore_vector_search_distance_result_property_projection] + +import com.google.cloud.datastore.Datastore; +import com.google.cloud.datastore.DatastoreOptions; +import com.google.cloud.datastore.FindNearest; +import com.google.cloud.datastore.ProjectionEntity; +import com.google.cloud.datastore.Query; +import com.google.cloud.datastore.QueryResults; +import com.google.cloud.datastore.VectorValue; + +public class VectorSearchDistanceResultPropertyProjection { + public static void invoke() throws Exception { + // Instantiates a client + Datastore datastore = DatastoreOptions.getDefaultInstance().getService(); + + // Create vector search query with projection + Query vectorSearchQuery = + Query.newProjectionEntityQueryBuilder() + .setKind("CoffeeBean") + .setFindNearest( + new FindNearest( + "embedding_field", + VectorValue.newBuilder(1, 9, 11.1).build(), + FindNearest.DistanceMeasure.EUCLIDEAN, + 3, + "vector_distance")) + .setProjection("roast") + .build(); + + // Execute vector search query + QueryResults results = datastore.run(vectorSearchQuery); + + if (!results.hasNext()) { + throw new Exception("query yielded no results"); + } + + while (results.hasNext()) { + ProjectionEntity entity = results.next(); + System.out.printf( + "Entity: %s, Distance: %s%n", + entity.getKey().getName(), entity.getDouble("vector_distance")); + } + } +} +// [END datastore_vector_search_distance_result_property_projection] diff --git a/samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchDistanceThreshold.java b/samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchDistanceThreshold.java new file mode 100644 index 000000000..b9e4c658d --- /dev/null +++ b/samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchDistanceThreshold.java @@ -0,0 +1,63 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.datastore.vectorsearch; + +// [START datastore_vector_search_distance_threshold] + +import com.google.cloud.datastore.Datastore; +import com.google.cloud.datastore.DatastoreOptions; +import com.google.cloud.datastore.Entity; +import com.google.cloud.datastore.FindNearest; +import com.google.cloud.datastore.Query; +import com.google.cloud.datastore.QueryResults; +import com.google.cloud.datastore.VectorValue; + +public class VectorSearchDistanceThreshold { + public static void invoke() throws Exception { + // Instantiates a client + Datastore datastore = DatastoreOptions.getDefaultInstance().getService(); + + // Create vector search query with distance threshold + Query vectorSearchQuery = + Query.newEntityQueryBuilder() + .setKind("CoffeeBean") + .setFindNearest( + new FindNearest( + "embedding_field", + VectorValue.newBuilder(1, 9, 11.1).build(), + FindNearest.DistanceMeasure.EUCLIDEAN, + 3, + "vector_distance", + 2.0)) + .build(); + + // Execute vector search query + QueryResults results = datastore.run(vectorSearchQuery); + + if (!results.hasNext()) { + throw new Exception("query yielded no results"); + } + + while (results.hasNext()) { + Entity entity = results.next(); + System.out.printf( + "Entity: %s, Distance: %s%n", + entity.getKey().getName(), entity.getDouble("vector_distance")); + } + } +} +// [END datastore_vector_search_distance_threshold] diff --git a/samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchLargeResponse.java b/samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchLargeResponse.java new file mode 100644 index 000000000..0744ca6cb --- /dev/null +++ b/samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchLargeResponse.java @@ -0,0 +1,76 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.datastore.vectorsearch; + +// [START datastore_vector_search_large_response] + +import com.google.cloud.datastore.Datastore; +import com.google.cloud.datastore.DatastoreOptions; +import com.google.cloud.datastore.Entity; +import com.google.cloud.datastore.FindNearest; +import com.google.cloud.datastore.Key; +import com.google.cloud.datastore.ProjectionEntity; +import com.google.cloud.datastore.Query; +import com.google.cloud.datastore.QueryResults; +import com.google.cloud.datastore.StructuredQuery; +import com.google.cloud.datastore.VectorValue; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterators; +import java.util.Iterator; + +public class VectorSearchLargeResponse { + public static void invoke() throws Exception { + // Instantiates a client + Datastore datastore = DatastoreOptions.getDefaultInstance().getService(); + + // Create a keys-only vector search query + StructuredQuery keyOnlyVectorQuery = + Query.newProjectionEntityQueryBuilder() + .setKind("CoffeeBean") + .setProjection("__key__") + .setFindNearest( + new FindNearest( + "embedding_field", + VectorValue.newBuilder(1, 9, 11.1).build(), + FindNearest.DistanceMeasure.EUCLIDEAN, + 3, + "vector_distance", + 2.0)) + .build(); + + QueryResults keyOnlyResults = datastore.run(keyOnlyVectorQuery); + ProjectionEntity[] keyEntities = Iterators.toArray(keyOnlyResults, ProjectionEntity.class); + Key[] keys = + ImmutableList.copyOf(keyEntities).stream().map(e -> e.getKey()).toArray(Key[]::new); + System.out.printf("Key query result size: %s%n", keys.length); + + // Lookup the full entities using the result of the keys only query. + Iterator entities = datastore.get(keys); + Entity[] entitiesArray = Iterators.toArray(entities, Entity.class); + System.out.printf("Entity query result size: %s%n", entitiesArray.length); + + // Combine and print results + for (int i = 0; i < keyEntities.length; i++) { + System.out.printf( + "Entity: %s, Distance: %s, Roast: %s%n", + keyEntities[i].getKey().getName(), + keyEntities[i].getDouble("vector_distance"), + entitiesArray[i].getString("roast")); + } + } +} +// [END datastore_vector_search_large_response] diff --git a/samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchPrefilter.java b/samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchPrefilter.java new file mode 100644 index 000000000..c00fc5f0f --- /dev/null +++ b/samples/snippets/src/main/java/com/example/datastore/vectorsearch/VectorSearchPrefilter.java @@ -0,0 +1,65 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.datastore.vectorsearch; + +// [START datastore_vector_search_prefilter] + +import com.google.cloud.datastore.Datastore; +import com.google.cloud.datastore.DatastoreOptions; +import com.google.cloud.datastore.Entity; +import com.google.cloud.datastore.FindNearest; +import com.google.cloud.datastore.Query; +import com.google.cloud.datastore.QueryResults; +import com.google.cloud.datastore.StructuredQuery.PropertyFilter; +import com.google.cloud.datastore.VectorValue; + +public class VectorSearchPrefilter { + public static void invoke() throws Exception { + // Instantiates a client + Datastore datastore = DatastoreOptions.getDefaultInstance().getService(); + + // Create vector search query with property filter + Query vectorSearchQuery = + Query.newEntityQueryBuilder() + .setKind("CoffeeBean") + .setFilter(PropertyFilter.eq("roast", "dark")) + .setFindNearest( + new FindNearest( + "embedding_field", + VectorValue.newBuilder(1, 9, 11.1).build(), + FindNearest.DistanceMeasure.EUCLIDEAN, + 3, + "vector_distance", + 3.0)) + .build(); + + // Execute vector search query + QueryResults results = datastore.run(vectorSearchQuery); + + if (!results.hasNext()) { + throw new Exception("query yielded no results"); + } + + while (results.hasNext()) { + Entity entity = results.next(); + System.out.printf( + "Entity: %s, Distance: %s%n", + entity.getKey().getName(), entity.getDouble("vector_distance")); + } + } +} +// [END datastore_vector_search_prefilter] diff --git a/samples/snippets/src/test/java/com/example/datastore/vectorsearch/VectorSearchSampleIT.java b/samples/snippets/src/test/java/com/example/datastore/vectorsearch/VectorSearchSampleIT.java new file mode 100644 index 000000000..792ab3c56 --- /dev/null +++ b/samples/snippets/src/test/java/com/example/datastore/vectorsearch/VectorSearchSampleIT.java @@ -0,0 +1,150 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.datastore.vectorsearch; + +import com.google.cloud.datastore.Datastore; +import com.google.cloud.datastore.DatastoreOptions; +import com.google.cloud.datastore.Entity; +import com.google.cloud.datastore.Key; +import com.google.cloud.datastore.VectorValue; +import com.rule.SystemsOutRule; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +@SuppressWarnings("checkstyle:abbreviationaswordinname") +public class VectorSearchSampleIT { + private final Datastore datastore = DatastoreOptions.getDefaultInstance().getService(); + + private Key coffeeBeanKey1; + private Key coffeeBeanKey2; + private Key coffeeBeanKey3; + + @Rule public final SystemsOutRule systemsOutRule = new SystemsOutRule(); + + @Before + public void setUp() { + // DatastoreOptions.getDefaultHttpTransportOptions() + coffeeBeanKey1 = datastore.newKeyFactory().setKind("CoffeeBean").newKey("Kahawa"); + // Prepares the entity with a vector embedding + Entity entity1 = + Entity.newBuilder(coffeeBeanKey1) + .set("name", "Arabica") + .set("description", "Information about the Arabica coffee beans.") + .set("roast", "dark") + .set("embedding_field", VectorValue.newBuilder(1.0, 7.0, 11.1).build()) + .build(); + + coffeeBeanKey2 = datastore.newKeyFactory().setKind("CoffeeBean").newKey("Robusta"); + Entity entity2 = + Entity.newBuilder(coffeeBeanKey2) + .set("name", "Robusta") + .set("description", "Information about the Robusta coffee beans.") + .set("roast", "light") + .set("embedding_field", VectorValue.newBuilder(1.0, 9.0, 11.1).build()) + .build(); + + coffeeBeanKey3 = datastore.newKeyFactory().setKind("CoffeeBean").newKey("Excelsa"); + Entity entity3 = + Entity.newBuilder(coffeeBeanKey3) + .set("name", "Excelsa") + .set("description", "Information about the Excelsa coffee beans.") + .set("roast", "dark") + .set("embedding_field", VectorValue.newBuilder(4.0, 9.0, 11.1).build()) + .build(); + + datastore.put(entity1); + datastore.put(entity2); + datastore.put(entity3); + } + + @After + public void tearDown() { + datastore.delete(coffeeBeanKey1); + datastore.delete(coffeeBeanKey2); + datastore.delete(coffeeBeanKey3); + } + + @Test + public void testStoreVectors() throws Exception { + // Act + StoreVectors.invoke(); + // Assert + systemsOutRule.assertContains("Retrieved Kahawa with embedding_field"); + } + + @Test + public void testVectorSearchBasic() throws Exception { + // Act + VectorSearchBasic.invoke(); + // Assert + systemsOutRule.assertContains("Entity: Robusta"); + } + + @Test + public void testVectorSearchDistanceResultProperty() throws Exception { + // Act + VectorSearchDistanceResultProperty.invoke(); + // Assert + systemsOutRule.assertContains("Entity: Excelsa, Distance: 208"); + systemsOutRule.assertContains("Entity: Robusta, Distance: 205"); + systemsOutRule.assertContains("Entity: Kahawa, Distance: 187"); + } + + @Test + public void testVectorSearchDistanceResultPropertyProjection() throws Exception { + // Act + VectorSearchDistanceResultPropertyProjection.invoke(); + // Assert + systemsOutRule.assertContains("Entity: Robusta, Distance: 0.0"); + systemsOutRule.assertContains("Entity: Kahawa, Distance: 2.0"); + systemsOutRule.assertContains("Entity: Excelsa, Distance: 3.0"); + } + + @Test + public void testVectorSearchDistanceThreshold() throws Exception { + // Act + VectorSearchDistanceThreshold.invoke(); + // Assert + systemsOutRule.assertContains("Entity: Robusta, Distance: 0.0"); + systemsOutRule.assertContains("Entity: Kahawa, Distance: 2.0"); + } + + @Test + public void testVectorSearchLargeResponse() throws Exception { + // Act + VectorSearchLargeResponse.invoke(); + // Assert + systemsOutRule.assertContains("Key query result size: 2"); + systemsOutRule.assertContains("Entity query result size: 2"); + systemsOutRule.assertContains("Entity: Robusta, Distance: 0.0, Roast: dark"); + systemsOutRule.assertContains("Entity: Kahawa, Distance: 2.0, Roast: light"); + } + + @Test + public void testVectorSearchPrefilter() throws Exception { + // Act + VectorSearchPrefilter.invoke(); + // Assert + systemsOutRule.assertContains("Entity: Kahawa, Distance: 2.0"); + systemsOutRule.assertContains("Entity: Excelsa, Distance: 3.0"); + } +} diff --git a/samples/snippets/src/test/resources/index.yaml b/samples/snippets/src/test/resources/index.yaml index 5f2f0c74a..8c9967e33 100644 --- a/samples/snippets/src/test/resources/index.yaml +++ b/samples/snippets/src/test/resources/index.yaml @@ -26,4 +26,12 @@ indexes: - kind: employees properties: - name: salary - - name: experience \ No newline at end of file + - name: experience +- kind: CoffeeBean + properties: + - name: roast + - name: __key__ + - name: embedding_field + vectorConfig: + dimension: 3 + flat: {} \ No newline at end of file