Skip to content

Commit

Permalink
Merge pull request #5764 from darrellwarde/fix/vector-sort
Browse files Browse the repository at this point in the history
Fix issues with vector search sorting
  • Loading branch information
darrellwarde authored Nov 8, 2024
2 parents e1fc2cf + 09632d5 commit f7d01aa
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 34 deletions.
5 changes: 5 additions & 0 deletions .changeset/three-planes-train.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@neo4j/graphql": patch
---

Fix issues #5759 and #5760 to do with sorting vector search results
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import type { QueryASTContext } from "../QueryASTContext";
import type { QueryASTNode } from "../QueryASTNode";
import type { ScoreField } from "../fields/ScoreField";
import type { EntitySelection } from "../selection/EntitySelection";
import { ScoreSort } from "../sort/ScoreSort";
import { ConnectionReadOperation } from "./ConnectionReadOperation";

export type VectorOptions = {
Expand Down Expand Up @@ -77,7 +78,7 @@ export class VectorOperation extends ConnectionReadOperation {
edgeVar: Cypher.Variable,
edgesVar: Cypher.Variable
): Cypher.With {
if (this.scoreField && context.neo4jGraphQLContext.vector) {
if ((this.scoreField || this.hasScoreSort()) && context.neo4jGraphQLContext.vector) {
// No relationship, so we directly unwind node and score
return new Cypher.Unwind([edgesVar, edgeVar]).with(
[edgeVar.property("node"), context.target],
Expand All @@ -93,7 +94,7 @@ export class VectorOperation extends ConnectionReadOperation {
edgesVar: Cypher.Variable,
totalCount: Cypher.Variable
): Cypher.With {
if (this.scoreField && nestedContext.neo4jGraphQLContext.vector) {
if ((this.scoreField || this.hasScoreSort()) && nestedContext.neo4jGraphQLContext.vector) {
const nodeAndRelationshipMap = new Cypher.Map({
node: nestedContext.target,
});
Expand All @@ -102,10 +103,7 @@ export class VectorOperation extends ConnectionReadOperation {
nodeAndRelationshipMap.set("relationship", nestedContext.relationship);
}

const scoreProjection = this.scoreField.getProjectionField();
for (const [key, value] of Object.entries(scoreProjection)) {
nodeAndRelationshipMap.set(key, value);
}
nodeAndRelationshipMap.set("score", nestedContext.neo4jGraphQLContext.vector.scoreVariable);

return new Cypher.With([Cypher.collect(nodeAndRelationshipMap), edgesVar]).with(edgesVar, [
Cypher.size(edgesVar),
Expand All @@ -115,4 +113,8 @@ export class VectorOperation extends ConnectionReadOperation {
return super.getWithCollectEdgesAndTotalCount(nestedContext, edgesVar, totalCount);
}
}

private hasScoreSort(): boolean {
return this.sortFields.some(({ node }) => node.some((sort) => sort instanceof ScoreSort));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,14 @@
*/

import type { ResolveTree } from "graphql-parse-resolve-info";
import { SCORE_FIELD } from "../../../../constants";
import type { ConcreteEntityAdapter } from "../../../../schema-model/entity/model-adapters/ConcreteEntityAdapter";
import type { SortDirection } from "../../../../types";
import type { Neo4jGraphQLTranslationContext } from "../../../../types/neo4j-graphql-translation-context";
import { checkEntityAuthentication } from "../../../authorization/check-authentication";
import { ScoreField } from "../../ast/fields/ScoreField";
import { ScoreFilter } from "../../ast/filters/property-filters/ScoreFilter";
import type { VectorOptions } from "../../ast/operations/VectorOperation";
import { VectorOperation } from "../../ast/operations/VectorOperation";
import { VectorSelection } from "../../ast/selection/VectorSelection";
import { ScoreSort } from "../../ast/sort/ScoreSort";
import type { QueryASTFactory } from "../QueryASTFactory";
import { findFieldsByNameInFieldsByTypeNameField } from "../parsers/find-fields-by-name-in-fields-by-type-name-field";
import { getFieldsByTypeName } from "../parsers/get-fields-by-type-name";
Expand Down Expand Up @@ -90,8 +87,6 @@ export class VectorFactory {
whereArgs: resolveTreeWhere,
});

this.addScoreSort(operation, resolveTree, context);

this.queryASTFactory.operationsFactory.hydrateConnectionOperation({
target: entity,
resolveTree: resolveTree,
Expand All @@ -104,24 +99,6 @@ export class VectorFactory {
return operation;
}

private addScoreSort(
operation: VectorOperation,
resolveTree: ResolveTree,
context: Neo4jGraphQLTranslationContext
) {
const sortArguments: Record<string, SortDirection>[] = (resolveTree.args.sort ?? []) as any;

for (const sortArgument of sortArguments) {
if (sortArgument[SCORE_FIELD] && context?.vector) {
const scoreSort = new ScoreSort({
scoreVariable: context.vector.scoreVariable,
direction: sortArgument[SCORE_FIELD],
});
operation.addSort({ node: [scoreSort], edge: [] });
}
}
}

private addVectorScoreFilter({
operation,
whereArgs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ export class SortAndPaginationFactory {
context,
});

if (options[SCORE_FIELD] && context?.vector) {
const scoreSort = new ScoreSort({
scoreVariable: context.vector.scoreVariable,
direction: options[SCORE_FIELD],
});
nodeSortFields.push(scoreSort);
}

return {
edge: [],
node: nodeSortFields,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ describe("@vector directive - Query", () => {

Movie = testHelper.createUniqueType("Movie");

const typeDefs = `
type ${Movie.name} @vector(indexes: [{ indexName: "${Movie}Index", embeddingProperty: "embedding", queryName: "${queryName}" }]) @node {
const typeDefs = /* GraphQL */ `
type ${Movie.name} @vector(indexes: [{ indexName: "${Movie}Index", embeddingProperty: "embedding", queryName: "${queryName}" }]) @node {
title: String!
released: Int!
}`;
Expand Down Expand Up @@ -137,7 +137,7 @@ describe("@vector directive - Query", () => {
return;
}

const query = `
const query = /* GraphQL */ `
query($vector: [Float!]) {
${queryName}(vector: $vector, sort: {score: DESC} ) {
edges {
Expand Down Expand Up @@ -185,7 +185,7 @@ describe("@vector directive - Query", () => {
return;
}

const query = `
const query = /* GraphQL */ `
query($vector: [Float!]) {
${queryName}(vector: $vector, sort: {score: ASC} ) {
edges {
Expand Down Expand Up @@ -220,6 +220,51 @@ describe("@vector directive - Query", () => {
});
});

test("Retrieve nodes ordered by score DESC without score in selection set", async () => {
// Skip if vector not supported
if (!VECTOR_SUPPORT) {
console.log("VECTOR SUPPORT NOT AVAILABLE - SKIPPING");
return;
}

// Skip if multi-db not supported
if (!MULTIDB_SUPPORT) {
console.log("MULTIDB_SUPPORT NOT AVAILABLE - SKIPPING");
return;
}

const query = /* GraphQL */ `
query($vector: [Float!]) {
${queryName}(vector: $vector, sort: {score: DESC} ) {
edges {
node {
title
}
}
}
}
`;
const gqlResult = await testHelper.executeGraphQL(query, { variableValues: { vector: testVectors[0] } });

expect(gqlResult.errors).toBeFalsy();
expect(gqlResult.data).toEqual({
[queryName]: {
edges: [
{
node: {
title: "Some Title",
},
},
{
node: {
title: "Another Title",
},
},
],
},
});
});

test("Retrieve nodes ordered by node property", async () => {
// Skip if vector not supported
if (!VECTOR_SUPPORT) {
Expand All @@ -233,7 +278,7 @@ describe("@vector directive - Query", () => {
return;
}

const query = `
const query = /* GraphQL */ `
query($vector: [Float!]) {
${queryName}(vector: $vector, sort: {node: {title: ASC}} ) {
edges {
Expand Down Expand Up @@ -267,4 +312,100 @@ describe("@vector directive - Query", () => {
},
});
});

test("Retrieve nodes ordered by node property first and score second", async () => {
// Skip if vector not supported
if (!VECTOR_SUPPORT) {
console.log("VECTOR SUPPORT NOT AVAILABLE - SKIPPING");
return;
}

// Skip if multi-db not supported
if (!MULTIDB_SUPPORT) {
console.log("MULTIDB_SUPPORT NOT AVAILABLE - SKIPPING");
return;
}

const query = /* GraphQL */ `
query($vector: [Float!]) {
${queryName}(vector: $vector, sort: [{node: {title: DESC}}, { score: ASC }] ) {
edges {
score
node {
title
}
}
}
}
`;
const gqlResult = await testHelper.executeGraphQL(query, { variableValues: { vector: testVectors[0] } });

expect(gqlResult.errors).toBeFalsy();
expect(gqlResult.data).toEqual({
[queryName]: {
edges: [
{
node: {
title: "Some Title",
},
score: expect.closeTo(1),
},
{
node: {
title: "Another Title",
},
score: expect.closeTo(0.56),
},
],
},
});
});

test("Retrieve nodes ordered by score first and node property second", async () => {
// Skip if vector not supported
if (!VECTOR_SUPPORT) {
console.log("VECTOR SUPPORT NOT AVAILABLE - SKIPPING");
return;
}

// Skip if multi-db not supported
if (!MULTIDB_SUPPORT) {
console.log("MULTIDB_SUPPORT NOT AVAILABLE - SKIPPING");
return;
}

const query = /* GraphQL */ `
query($vector: [Float!]) {
${queryName}(vector: $vector, sort: [{ score: ASC }, {node: {title: DESC}}] ) {
edges {
score
node {
title
}
}
}
}
`;
const gqlResult = await testHelper.executeGraphQL(query, { variableValues: { vector: testVectors[0] } });

expect(gqlResult.errors).toBeFalsy();
expect(gqlResult.data).toEqual({
[queryName]: {
edges: [
{
node: {
title: "Another Title",
},
score: expect.closeTo(0.56),
},
{
node: {
title: "Some Title",
},
score: expect.closeTo(1),
},
],
},
});
});
});

0 comments on commit f7d01aa

Please sign in to comment.