From 1ec1bd35a375d3f42e394fffe2d74f62bb4c8359 Mon Sep 17 00:00:00 2001 From: Darrell Warde Date: Thu, 7 Nov 2024 16:58:44 +0000 Subject: [PATCH] Fix issues with vector search --- .changeset/three-planes-train.md | 5 + .../ast/operations/VectorOperation.ts | 14 +- .../factory/Operations/VectorFactory.ts | 23 --- .../factory/SortAndPaginationFactory.ts | 8 + .../vector/vector-sorting.int.test.ts | 151 +++++++++++++++++- 5 files changed, 167 insertions(+), 34 deletions(-) create mode 100644 .changeset/three-planes-train.md diff --git a/.changeset/three-planes-train.md b/.changeset/three-planes-train.md new file mode 100644 index 0000000000..ba9568a75f --- /dev/null +++ b/.changeset/three-planes-train.md @@ -0,0 +1,5 @@ +--- +"@neo4j/graphql": patch +--- + +Fix issues #5759 and #5760 to do with sorting vector search results diff --git a/packages/graphql/src/translate/queryAST/ast/operations/VectorOperation.ts b/packages/graphql/src/translate/queryAST/ast/operations/VectorOperation.ts index f7e37d8828..141b5b5fe8 100644 --- a/packages/graphql/src/translate/queryAST/ast/operations/VectorOperation.ts +++ b/packages/graphql/src/translate/queryAST/ast/operations/VectorOperation.ts @@ -26,6 +26,7 @@ import type { QueryASTContext } from "../QueryASTContext"; import type { QueryASTNode } from "../QueryASTNode"; import type { ScoreField } from "../fields/ScoreField"; import type { EntitySelection } from "../selection/EntitySelection"; +import { ScoreSort } from "../sort/ScoreSort"; import { ConnectionReadOperation } from "./ConnectionReadOperation"; export type VectorOptions = { @@ -77,7 +78,7 @@ export class VectorOperation extends ConnectionReadOperation { edgeVar: Cypher.Variable, edgesVar: Cypher.Variable ): Cypher.With { - if (this.scoreField && context.neo4jGraphQLContext.vector) { + if ((this.scoreField || this.hasScoreSort()) && context.neo4jGraphQLContext.vector) { // No relationship, so we directly unwind node and score return new Cypher.Unwind([edgesVar, edgeVar]).with( [edgeVar.property("node"), context.target], @@ -93,7 +94,7 @@ export class VectorOperation extends ConnectionReadOperation { edgesVar: Cypher.Variable, totalCount: Cypher.Variable ): Cypher.With { - if (this.scoreField && nestedContext.neo4jGraphQLContext.vector) { + if ((this.scoreField || this.hasScoreSort()) && nestedContext.neo4jGraphQLContext.vector) { const nodeAndRelationshipMap = new Cypher.Map({ node: nestedContext.target, }); @@ -102,10 +103,7 @@ export class VectorOperation extends ConnectionReadOperation { nodeAndRelationshipMap.set("relationship", nestedContext.relationship); } - const scoreProjection = this.scoreField.getProjectionField(); - for (const [key, value] of Object.entries(scoreProjection)) { - nodeAndRelationshipMap.set(key, value); - } + nodeAndRelationshipMap.set("score", nestedContext.neo4jGraphQLContext.vector.scoreVariable); return new Cypher.With([Cypher.collect(nodeAndRelationshipMap), edgesVar]).with(edgesVar, [ Cypher.size(edgesVar), @@ -115,4 +113,8 @@ export class VectorOperation extends ConnectionReadOperation { return super.getWithCollectEdgesAndTotalCount(nestedContext, edgesVar, totalCount); } } + + private hasScoreSort(): boolean { + return this.sortFields.some(({ node }) => node.some((sort) => sort instanceof ScoreSort)); + } } diff --git a/packages/graphql/src/translate/queryAST/factory/Operations/VectorFactory.ts b/packages/graphql/src/translate/queryAST/factory/Operations/VectorFactory.ts index f2c057857d..0b79ced3d2 100644 --- a/packages/graphql/src/translate/queryAST/factory/Operations/VectorFactory.ts +++ b/packages/graphql/src/translate/queryAST/factory/Operations/VectorFactory.ts @@ -18,9 +18,7 @@ */ import type { ResolveTree } from "graphql-parse-resolve-info"; -import { SCORE_FIELD } from "../../../../constants"; import type { ConcreteEntityAdapter } from "../../../../schema-model/entity/model-adapters/ConcreteEntityAdapter"; -import type { SortDirection } from "../../../../types"; import type { Neo4jGraphQLTranslationContext } from "../../../../types/neo4j-graphql-translation-context"; import { checkEntityAuthentication } from "../../../authorization/check-authentication"; import { ScoreField } from "../../ast/fields/ScoreField"; @@ -28,7 +26,6 @@ import { ScoreFilter } from "../../ast/filters/property-filters/ScoreFilter"; import type { VectorOptions } from "../../ast/operations/VectorOperation"; import { VectorOperation } from "../../ast/operations/VectorOperation"; import { VectorSelection } from "../../ast/selection/VectorSelection"; -import { ScoreSort } from "../../ast/sort/ScoreSort"; import type { QueryASTFactory } from "../QueryASTFactory"; import { findFieldsByNameInFieldsByTypeNameField } from "../parsers/find-fields-by-name-in-fields-by-type-name-field"; import { getFieldsByTypeName } from "../parsers/get-fields-by-type-name"; @@ -90,8 +87,6 @@ export class VectorFactory { whereArgs: resolveTreeWhere, }); - this.addScoreSort(operation, resolveTree, context); - this.queryASTFactory.operationsFactory.hydrateConnectionOperation({ target: entity, resolveTree: resolveTree, @@ -104,24 +99,6 @@ export class VectorFactory { return operation; } - private addScoreSort( - operation: VectorOperation, - resolveTree: ResolveTree, - context: Neo4jGraphQLTranslationContext - ) { - const sortArguments: Record[] = (resolveTree.args.sort ?? []) as any; - - for (const sortArgument of sortArguments) { - if (sortArgument[SCORE_FIELD] && context?.vector) { - const scoreSort = new ScoreSort({ - scoreVariable: context.vector.scoreVariable, - direction: sortArgument[SCORE_FIELD], - }); - operation.addSort({ node: [scoreSort], edge: [] }); - } - } - } - private addVectorScoreFilter({ operation, whereArgs, diff --git a/packages/graphql/src/translate/queryAST/factory/SortAndPaginationFactory.ts b/packages/graphql/src/translate/queryAST/factory/SortAndPaginationFactory.ts index f328120287..490659e159 100644 --- a/packages/graphql/src/translate/queryAST/factory/SortAndPaginationFactory.ts +++ b/packages/graphql/src/translate/queryAST/factory/SortAndPaginationFactory.ts @@ -80,6 +80,14 @@ export class SortAndPaginationFactory { context, }); + if (options[SCORE_FIELD] && context?.vector) { + const scoreSort = new ScoreSort({ + scoreVariable: context.vector.scoreVariable, + direction: options[SCORE_FIELD], + }); + nodeSortFields.push(scoreSort); + } + return { edge: [], node: nodeSortFields, diff --git a/packages/graphql/tests/integration/directives/vector/vector-sorting.int.test.ts b/packages/graphql/tests/integration/directives/vector/vector-sorting.int.test.ts index b1e3d125fa..3b7d6025a9 100644 --- a/packages/graphql/tests/integration/directives/vector/vector-sorting.int.test.ts +++ b/packages/graphql/tests/integration/directives/vector/vector-sorting.int.test.ts @@ -77,8 +77,8 @@ describe("@vector directive - Query", () => { Movie = testHelper.createUniqueType("Movie"); - const typeDefs = ` - type ${Movie.name} @vector(indexes: [{ indexName: "${Movie}Index", embeddingProperty: "embedding", queryName: "${queryName}" }]) @node { + const typeDefs = /* GraphQL */ ` + type ${Movie.name} @vector(indexes: [{ indexName: "${Movie}Index", embeddingProperty: "embedding", queryName: "${queryName}" }]) @node { title: String! released: Int! }`; @@ -137,7 +137,7 @@ describe("@vector directive - Query", () => { return; } - const query = ` + const query = /* GraphQL */ ` query($vector: [Float!]) { ${queryName}(vector: $vector, sort: {score: DESC} ) { edges { @@ -185,7 +185,7 @@ describe("@vector directive - Query", () => { return; } - const query = ` + const query = /* GraphQL */ ` query($vector: [Float!]) { ${queryName}(vector: $vector, sort: {score: ASC} ) { edges { @@ -220,6 +220,51 @@ describe("@vector directive - Query", () => { }); }); + test("Retrieve nodes ordered by score DESC without score in selection set", async () => { + // Skip if vector not supported + if (!VECTOR_SUPPORT) { + console.log("VECTOR SUPPORT NOT AVAILABLE - SKIPPING"); + return; + } + + // Skip if multi-db not supported + if (!MULTIDB_SUPPORT) { + console.log("MULTIDB_SUPPORT NOT AVAILABLE - SKIPPING"); + return; + } + + const query = /* GraphQL */ ` + query($vector: [Float!]) { + ${queryName}(vector: $vector, sort: {score: DESC} ) { + edges { + node { + title + } + } + } + } + `; + const gqlResult = await testHelper.executeGraphQL(query, { variableValues: { vector: testVectors[0] } }); + + expect(gqlResult.errors).toBeFalsy(); + expect(gqlResult.data).toEqual({ + [queryName]: { + edges: [ + { + node: { + title: "Some Title", + }, + }, + { + node: { + title: "Another Title", + }, + }, + ], + }, + }); + }); + test("Retrieve nodes ordered by node property", async () => { // Skip if vector not supported if (!VECTOR_SUPPORT) { @@ -233,7 +278,7 @@ describe("@vector directive - Query", () => { return; } - const query = ` + const query = /* GraphQL */ ` query($vector: [Float!]) { ${queryName}(vector: $vector, sort: {node: {title: ASC}} ) { edges { @@ -267,4 +312,100 @@ describe("@vector directive - Query", () => { }, }); }); + + test("Retrieve nodes ordered by node property first and score second", async () => { + // Skip if vector not supported + if (!VECTOR_SUPPORT) { + console.log("VECTOR SUPPORT NOT AVAILABLE - SKIPPING"); + return; + } + + // Skip if multi-db not supported + if (!MULTIDB_SUPPORT) { + console.log("MULTIDB_SUPPORT NOT AVAILABLE - SKIPPING"); + return; + } + + const query = /* GraphQL */ ` + query($vector: [Float!]) { + ${queryName}(vector: $vector, sort: [{node: {title: DESC}}, { score: ASC }] ) { + edges { + score + node { + title + } + } + } + } + `; + const gqlResult = await testHelper.executeGraphQL(query, { variableValues: { vector: testVectors[0] } }); + + expect(gqlResult.errors).toBeFalsy(); + expect(gqlResult.data).toEqual({ + [queryName]: { + edges: [ + { + node: { + title: "Some Title", + }, + score: expect.closeTo(1), + }, + { + node: { + title: "Another Title", + }, + score: expect.closeTo(0.56), + }, + ], + }, + }); + }); + + test("Retrieve nodes ordered by score first and node property second", async () => { + // Skip if vector not supported + if (!VECTOR_SUPPORT) { + console.log("VECTOR SUPPORT NOT AVAILABLE - SKIPPING"); + return; + } + + // Skip if multi-db not supported + if (!MULTIDB_SUPPORT) { + console.log("MULTIDB_SUPPORT NOT AVAILABLE - SKIPPING"); + return; + } + + const query = /* GraphQL */ ` + query($vector: [Float!]) { + ${queryName}(vector: $vector, sort: [{ score: ASC }, {node: {title: DESC}}] ) { + edges { + score + node { + title + } + } + } + } + `; + const gqlResult = await testHelper.executeGraphQL(query, { variableValues: { vector: testVectors[0] } }); + + expect(gqlResult.errors).toBeFalsy(); + expect(gqlResult.data).toEqual({ + [queryName]: { + edges: [ + { + node: { + title: "Another Title", + }, + score: expect.closeTo(0.56), + }, + { + node: { + title: "Some Title", + }, + score: expect.closeTo(1), + }, + ], + }, + }); + }); });