From c1e82bd77eb396ebc2fa85981d5f8403415c7970 Mon Sep 17 00:00:00 2001 From: Thuc Pham <51660321+thucpn@users.noreply.github.com> Date: Fri, 19 Jul 2024 12:18:50 +0700 Subject: [PATCH] fix: do not return false immediately with negative operators --- .../storage/vectorStore/SimpleVectorStore.ts | 28 ++- .../vectorStores/SimpleVectorStore.test.ts | 188 ++++++++++++++++-- 2 files changed, 197 insertions(+), 19 deletions(-) diff --git a/packages/llamaindex/src/storage/vectorStore/SimpleVectorStore.ts b/packages/llamaindex/src/storage/vectorStore/SimpleVectorStore.ts index 7daac5b276..99b4ad836d 100644 --- a/packages/llamaindex/src/storage/vectorStore/SimpleVectorStore.ts +++ b/packages/llamaindex/src/storage/vectorStore/SimpleVectorStore.ts @@ -62,23 +62,32 @@ const OPERATOR_TO_FILTER: { return parseArrayValue(value).every((v) => metadata[key].includes(v)); }, [FilterOperator.TEXT_MATCH]: ({ key, value }, metadata) => { - return metadata[key].includes(parsePrimitiveValue(value)); + if (typeof metadata[key] !== "string") return false; + return metadata[key].includes(parsePrimitiveValue(value).toString()); }, [FilterOperator.CONTAINS]: ({ key, value }, metadata) => { if (!Array.isArray(metadata[key])) return false; return !!parseArrayValue(metadata[key]).find((v) => v === value); }, [FilterOperator.GT]: ({ key, value }, metadata) => { - return metadata[key] > parsePrimitiveValue(value); + const val = metadata[key]; + if (typeof val !== "string" && typeof val !== "number") return false; + return val > parsePrimitiveValue(value); }, [FilterOperator.LT]: ({ key, value }, metadata) => { - return metadata[key] < parsePrimitiveValue(value); + const val = metadata[key]; + if (typeof val !== "string" && typeof val !== "number") return false; + return val < parsePrimitiveValue(value); }, [FilterOperator.GTE]: ({ key, value }, metadata) => { - return metadata[key] >= parsePrimitiveValue(value); + const val = metadata[key]; + if (typeof val !== "string" && typeof val !== "number") return false; + return val >= parsePrimitiveValue(value); }, [FilterOperator.LTE]: ({ key, value }, metadata) => { - return metadata[key] <= parsePrimitiveValue(value); + const val = metadata[key]; + if (typeof val !== "string" && typeof val !== "number") return false; + return val <= parsePrimitiveValue(value); }, }; @@ -94,7 +103,14 @@ const buildFilterFn = ( const queryCondition = condition || "and"; // default to and const itemFilterFn = (filter: MetadataFilter): boolean => { - if (metadata[filter.key] === undefined) return false; // always return false if the metadata key is not present + // for all operators except != and nin, if the metadata key is not present, return false + if ( + metadata[filter.key] === undefined && + filter.operator !== FilterOperator.NE && + filter.operator !== FilterOperator.NIN + ) { + return false; + } const metadataLookupFn = OPERATOR_TO_FILTER[filter.operator]; if (!metadataLookupFn) throw new Error(`Unsupported operator: ${filter.operator}`); diff --git a/packages/llamaindex/tests/vectorStores/SimpleVectorStore.test.ts b/packages/llamaindex/tests/vectorStores/SimpleVectorStore.test.ts index c27cef6027..07d0db0832 100644 --- a/packages/llamaindex/tests/vectorStores/SimpleVectorStore.test.ts +++ b/packages/llamaindex/tests/vectorStores/SimpleVectorStore.test.ts @@ -31,6 +31,7 @@ describe("SimpleVectorStore", () => { private: "true", weight: 1.2, type: ["husky", "puppy"], + height: 50, }, }), new TextNode({ @@ -87,19 +88,6 @@ describe("SimpleVectorStore", () => { title: "No filter", expected: 3, }, - { - title: "Filter with non-exist key", - filters: { - filters: [ - { - key: "non-exist-key", - value: "cat", - operator: "==", - }, - ], - }, - expected: 0, - }, { title: "Filter EQ", filters: { @@ -309,4 +297,178 @@ describe("SimpleVectorStore", () => { }); }); }); + + describe("[SimpleVectorStore] query nodes with optional key in metadata", () => { + const testcases: FilterTestCase[] = [ + { + title: "Filter EQ with an optional key", + filters: { + filters: [ + { + key: "height", + value: 50, + operator: "==", + }, + ], + }, + expected: 1, + }, + { + title: "Filter NE with an optional key", + filters: { + filters: [ + { + key: "height", + value: 50, + operator: "!=", + }, + ], + }, + expected: 2, + }, + { + title: "Filter GT with an optional key", + filters: { + filters: [ + { + key: "height", + value: 48, + operator: ">", + }, + ], + }, + expected: 1, + }, + { + title: "Filter GTE with an optional key", + filters: { + filters: [ + { + key: "height", + value: 50, + operator: ">=", + }, + ], + }, + expected: 1, + }, + { + title: "Filter LT with an optional key", + filters: { + filters: [ + { + key: "height", + value: 50, + operator: "<", + }, + ], + }, + expected: 0, + }, + { + title: "Filter LTE with an optional key", + filters: { + filters: [ + { + key: "height", + value: 50, + operator: "<=", + }, + ], + }, + expected: 1, + }, + { + title: "Filter IN with an optional key", + filters: { + filters: [ + { + key: "non-existing-key", + value: ["a", "b"], + operator: "in", + }, + ], + }, + expected: 0, + }, + { + title: "Filter NIN with an optional key", + filters: { + filters: [ + { + key: "non-existing-key", + value: ["a", "b"], + operator: "nin", + }, + ], + }, + expected: 3, + }, + { + title: "Filter ANY with an optional key", + filters: { + filters: [ + { + key: "non-existing-key", + value: ["a", "b"], + operator: "any", + }, + ], + }, + expected: 0, + }, + { + title: "Filter ALL with an optional key", + filters: { + filters: [ + { + key: "non-existing-key", + value: ["a", "b"], + operator: "all", + }, + ], + }, + expected: 0, + }, + { + title: "Filter CONTAINS with an optional key", + filters: { + filters: [ + { + key: "non-existing-key", + value: "a", + operator: "contains", + }, + ], + }, + expected: 0, + }, + { + title: "Filter TEXT_MATCH with an optional key", + filters: { + filters: [ + { + key: "non-existing-key", + value: "a", + operator: "text_match", + }, + ], + }, + expected: 0, + }, + ]; + + testcases.forEach((tc) => { + it(`[${tc.title}] should return ${tc.expected} nodes`, async () => { + await store.add(nodes); + const result = await store.query({ + queryEmbedding: [0.1, 0.2], + similarityTopK: 3, + mode: VectorStoreQueryMode.DEFAULT, + filters: tc.filters, + }); + expect(result.ids).length(tc.expected); + }); + }); + }); });