diff --git a/protos/protos.d.ts b/protos/protos.d.ts index 4ca4f6b53..ed07f67f1 100644 --- a/protos/protos.d.ts +++ b/protos/protos.d.ts @@ -5105,7 +5105,7 @@ export namespace google { * Converts this FindNearest to JSON. * @returns JSON object */ - public toJSON(): { [k: string]: any }; + public static toJSON(): { [k: string]: any }; /** * Gets the default type url for FindNearest diff --git a/src/entity.ts b/src/entity.ts index 0437e2397..b52b2abf3 100644 --- a/src/entity.ts +++ b/src/entity.ts @@ -22,6 +22,7 @@ import {protobuf as Protobuf} from 'google-gax'; import * as path from 'path'; import {google} from '../protos/protos'; import {and, PropertyFilter} from './filter'; +import {Vector} from './vector'; // eslint-disable-next-line @typescript-eslint/no-namespace export namespace entity { @@ -1268,6 +1269,30 @@ export namespace entity { queryProto.filter = and(allFilters).toProto(); } + if (query.vectorSearch && query.vectorOptions) { + function queryVectorToArray( + queryVector: Vector | Array | undefined + ): google.datastore.v1.IValue | undefined { + if (queryVector instanceof Vector) { + return queryVector.value as google.datastore.v1.IValue; + } else { + return queryVector as google.datastore.v1.IValue; + } + } + + const vectorProto: google.datastore.v1.FindNearest = { + vectorProperty: {name: query.vectorOptions.vectorProperty}, + queryVector: queryVectorToArray(query.vectorOptions.queryVector), + distanceMeasure: query.vectorOptions + .distanceMeasure as google.datastore.v1.FindNearest.DistanceMeasure, + limit: {value: query.vectorOptions.limit}, + distanceResultProperty: query.vectorOptions.distanceResultProperty, + distanceThreshold: {value: query.vectorOptions.distanceThreshold}, + }; + + queryProto.findNearest = vectorProto; + } + return queryProto; } diff --git a/src/index.ts b/src/index.ts index 20cad99c4..1d4510845 100644 --- a/src/index.ts +++ b/src/index.ts @@ -39,7 +39,7 @@ import { import * as is from 'is'; import {Transform, pipeline} from 'stream'; -import {entity, Entities, Entity, EntityProto, ValueProto} from './entity'; +import {entity, Entities, Entity, ValueProto} from './entity'; import {AggregateField} from './aggregate'; import Key = entity.Key; export {Entity, Key, AggregateField}; @@ -489,6 +489,8 @@ class Datastore extends DatastoreRequest { options.projectId = options.projectId || process.env.DATASTORE_PROJECT_ID; + // prod: datastore.googleapis.com + // nightly: nightly-datastore.sandbox.googleapis.com this.defaultBaseUrl_ = 'datastore.googleapis.com'; this.determineBaseUrl_(options.apiEndpoint); diff --git a/src/query.ts b/src/query.ts index 557ae12a7..c2d5aff59 100644 --- a/src/query.ts +++ b/src/query.ts @@ -24,6 +24,7 @@ import {CallOptions} from 'google-gax'; import {RunQueryStreamOptions} from '../src/request'; import * as gaxInstance from 'google-gax'; import {google} from '../protos/protos'; +import {VectorQueryOptions} from './vector'; export type Operator = | '=' @@ -76,6 +77,7 @@ export interface Filter { class Query { scope?: Datastore | Transaction; namespace?: string | null; + vectorOptions?: VectorQueryOptions; kinds: string[]; filters: Filter[]; entityFilters: EntityFilter[]; @@ -86,6 +88,7 @@ class Query { endVal: string | Buffer | null; limitVal: number; offsetVal: number; + vectorSearch = false; constructor(scope?: Datastore | Transaction, kinds?: string[] | null); constructor( @@ -256,6 +259,40 @@ class Query { return this; } + /** + * Returns a query that can perform vector distance (similarity) search with given parameters. + * + * The returned query, when executed, performs a distance (similarity) search on the specified + * `vectorField` against the given `queryVector` and returns the top documents that are closest + * to the `queryVector`. + * + * @example + * ``` + * // Returns the closest 10 documents whose Euclidean distance from their 'embedding' fields are closed to [41, 42]. + * const vectorQuery = query.findNearest({vectorfield: 'embedding', queryVector: [41, 42], limit: 10, distanceMeasure: 'EUCLIDEAN'}); + * + * const querySnapshot = await vectorQuery.get(); + * ``` + * + * @param {VectorQueryOptions} options - Options control the vector query. `limit` specifies the upper bound of documents to return, must + * be a positive integer with a maximum value of 1000. `distanceMeasure` specifies what type of distance is calculated + * when performing the query. + * + */ + findNearest(options: VectorQueryOptions): Query { + if (options.limit && options.limit <= 0) { + throw new Error('limit should be a positive limit number'); + } + + if (options.queryVector && options.queryVector.length === 0) { + throw new Error('vector size must be larger than 0'); + } + + this.vectorOptions = options; + this.vectorSearch = true; + return this; + } + /** * Filter a query by ancestors. * @@ -584,6 +621,7 @@ export interface QueryProto { limit?: {}; offset?: number; filter?: {}; + findNearest?: {}; } /** diff --git a/src/vector.ts b/src/vector.ts new file mode 100644 index 000000000..d62e356ac --- /dev/null +++ b/src/vector.ts @@ -0,0 +1,121 @@ +/*! + * Copyright 2024 Google LLC. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +const VECTOR_VALUE = 31; + +export enum DistanceMeasure { + DISTANCE_MEASURE_UNSPECIFIED = 0, + EUCLIDEAN = 1, + COSINE = 2, + DOT_PRODUCT = 3, +} + +interface VectorDict { + array_value: {values: {double_value: number}[]}; + meaning: number; + exclude_from_indexes: boolean; +} + +/*A class to represent a Vector for use in query.findNearest. + *Underlying object will be converted to a map representation in Firestore API. + */ +export class Vector { + value: number[]; + + constructor(value: number[]) { + this.value = value.map(v => parseFloat(v.toString())); + } + + get(index: number): number { + return this.value[index]; + } + + slice(start?: number, end?: number): Vector { + return new Vector(this.value.slice(start, end)); + } + + get length(): number { + return this.value.length; + } + + equals(other: Vector): boolean { + if (!(other instanceof Vector)) { + throw new Error('Cannot compare Vector to a non-Vector object.'); + } + return ( + this.value.length === other.value.length && + this.value.every((v, i) => v === other.value[i]) + ); + } + + toString(): string { + return `Vector<${this.value.join(', ')}>`; + } + + _toDict(): VectorDict { + return { + array_value: { + values: this.value.map(v => ({double_value: v})), + }, + meaning: VECTOR_VALUE, + exclude_from_indexes: true, + }; + } +} + +/** + * Specifies the behavior of the a Vector Search Query generated by a call to {@link Query.findNearest}. + */ +export interface VectorQueryOptions { + /** + * A string specifying the vector field to search on. + */ + vectorProperty?: string; + + /** + * The value used to measure the distance from `vectorProperty` values in the documents. + */ + queryVector?: Vector | Array; + + /** + * Specifies what type of distance is calculated when performing the query. + */ + distanceMeasure: DistanceMeasure; + + /** + * Specifies the upper bound of documents to return, must be a positive integer with a maximum value of 1000. + */ + limit?: number; + + /** + * Optionally specifies the name of a property that will be set on each returned DocumentSnapshot, + * which will contain the computed distance for the document. + */ + distanceResultProperty: string; + + /** + * Specifies a threshold for which no less similar documents will be returned. The behavior + * of the specified `distanceMeasure` will affect the meaning of the distance threshold. + * + * - For `distanceMeasure: "EUCLIDEAN"`, the meaning of `distanceThreshold` is: + * SELECT docs WHERE euclidean_distance <= distanceThreshold + * - For `distanceMeasure: "COSINE"`, the meaning of `distanceThreshold` is: + * SELECT docs WHERE cosine_distance <= distanceThreshold + * - For `distanceMeasure: "DOT_PRODUCT"`, the meaning of `distanceThreshold` is: + * SELECT docs WHERE dot_product_distance >= distanceThreshold + */ + distanceThreshold?: number; +} diff --git a/system-test/datastore.ts b/system-test/datastore.ts index c8eeecebb..3bbe422a9 100644 --- a/system-test/datastore.ts +++ b/system-test/datastore.ts @@ -26,6 +26,8 @@ import {Entities, entity, Entity} from '../src/entity'; import {Query, RunQueryInfo, ExecutionStats} from '../src/query'; import KEY_SYMBOL = entity.KEY_SYMBOL; import {transactionExpiredError} from '../src/request'; +import {DistanceMeasure, Vector, VectorQueryOptions} from '../src/vector'; +import {startServer} from '../mock-server/datastore-server'; const async = require('async'); @@ -3296,5 +3298,34 @@ async.each( }); }); }); + + describe('vector search query', () => { + it.only('should complete a request successfully with vector search options', async () => { + startServer(async () => { + const customDatastore = new Datastore({ + namespace: `${Date.now()}`, + apiEndpoint: 'localhost:50051', + }); + + const vectorOptions: VectorQueryOptions = { + vectorProperty: 'embedding', + queryVector: [1.0, 2.0, 3.0], + limit: 2, + distanceMeasure: DistanceMeasure.EUCLIDEAN, + distanceResultProperty: 'distance', + distanceThreshold: 0.5, + }; + + const query = customDatastore + .createQuery('Kind') + .findNearest(vectorOptions); + + const [entities] = await customDatastore.runQuery(query); + + console.log(entities); + assert.deepEqual(entities, new Vector([1.0, 2.0, 3.0])); + }); + }); + }); } ); diff --git a/test/entity.ts b/test/entity.ts index e4b738cac..8a8337e1a 100644 --- a/test/entity.ts +++ b/test/entity.ts @@ -19,11 +19,12 @@ import * as sinon from 'sinon'; import {Datastore} from '../src'; import {Entity, entity} from '../src/entity'; import {IntegerTypeCastOptions} from '../src/query'; -import {PropertyFilter, EntityFilter, and} from '../src/filter'; +import {PropertyFilter, and} from '../src/filter'; import { entityObject, expectedEntityProto, } from './fixtures/entityObjectAndProto'; +import {DistanceMeasure} from '../src/vector'; export function outOfBoundsError(opts: { propertyName?: string; @@ -1502,12 +1503,26 @@ describe('entity', () => { op: 'AND', }, }, + findNearest: { + distanceMeasure: 1, + distanceResultField: 'vector_distance', + limit: 3, + queryVector: [1, 2, 3], + vectorField: 'embedding_field', + }, }; it('should support all configurations of a query', () => { const ancestorKey = new entity.Key({ path: ['Kind2', 'somename'], }); + const vectorOptions = { + vectorProperty: 'embedding_property', + queryVector: [1.0, 2.0, 3.0], + limit: 3, + distanceMeasure: DistanceMeasure.EUCLIDEAN, + distanceResultProperty: 'vector_distance', + }; const ds = new Datastore({projectId: 'project-id'}); @@ -1521,7 +1536,8 @@ describe('entity', () => { .select('name') .limit(1) .offset(1) - .hasAncestor(ancestorKey); + .hasAncestor(ancestorKey) + .findNearest(vectorOptions); assert.deepStrictEqual(testEntity.queryToQueryProto(query), queryProto); }); @@ -1589,6 +1605,13 @@ describe('entity', () => { const ancestorKey = new entity.Key({ path: ['Kind2', 'somename'], }); + const vectorOptions = { + vectorProperty: 'embedding_property', + queryVector: [1.0, 2.0, 3.0], + limit: 3, + distanceMeasure: DistanceMeasure.EUCLIDEAN, + distanceResultProperty: 'vector_distance', + }; const ds = new Datastore({projectId: 'project-id'}); @@ -1602,7 +1625,8 @@ describe('entity', () => { .select('name') .limit(1) .offset(1) - .hasAncestor(ancestorKey); + .hasAncestor(ancestorKey) + .findNearest(vectorOptions); assert.deepStrictEqual(testEntity.queryToQueryProto(query), queryProto); }); diff --git a/test/query.ts b/test/query.ts index a6d32f696..b1bafa902 100644 --- a/test/query.ts +++ b/test/query.ts @@ -19,9 +19,10 @@ const {Query} = require('../src/query'); // eslint-disable-next-line @typescript-eslint/no-var-requires import {Datastore} from '../src'; import {AggregateField, AggregateQuery} from '../src/aggregate'; -import {PropertyFilter, EntityFilter, or} from '../src/filter'; +import {PropertyFilter, or} from '../src/filter'; import {entity} from '../src/entity'; import {SECOND_DATABASE_ID} from './index'; +import {google} from '../protos/protos'; describe('Query', () => { const SCOPE = {} as Datastore; @@ -327,6 +328,7 @@ describe('Query', () => { done(); }); }); + describe('filter with Filter class', () => { it('should support filter with Filter', () => { const now = new Date(); @@ -371,6 +373,35 @@ describe('Query', () => { }); }); + describe('findNearest', () => { + it('should successfully build a vector search query', () => { + const VectorOptions = { + vectorField: 'embedding', + queryVector: [1.0, 2.0, 3.0], + limit: 2, + distanceMeasure: + google.datastore.v1.FindNearest.DistanceMeasure.EUCLIDEAN, + distanceResultField: 'distance', + distanceThreshold: 0.5, + }; + + const vectorQuery = new Query(['kind1']).findNearest({ + vectorField: 'embedding', + queryVector: [1.0, 2.0, 3.0], + limit: { + value: 2, + }, + distanceMeasure: + google.datastore.v1.FindNearest.DistanceMeasure.EUCLIDEAN, + distanceResultField: 'distance', + distanceThreshold: 0.5, + }); + + assert.ok(vectorQuery.vectorSearch); + assert.deepEqual(VectorOptions, vectorQuery.vectorOptions); + }); + }); + describe('hasAncestor', () => { it('should support ancestor filtering', () => { const query = new Query(['kind1']).hasAncestor(['kind2', 123]);