Skip to content

Commit

Permalink
feat: add queryMode option
Browse files Browse the repository at this point in the history
  • Loading branch information
CahidArda committed Dec 31, 2024
1 parent 5e4c3de commit 55184cc
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 5 deletions.
Binary file modified bun.lockb
Binary file not shown.
3 changes: 1 addition & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
"@langchain/community": "^0.3.4",
"@langchain/core": "^0.2.9",
"@langchain/mistralai": "^0.0.28",
"@upstash/vector": "^1.1.3",
"@upstash/vector": "v1.2.0-canary-hybrid-3",
"ai": "^3.1.1",
"cheerio": "^1.0.0-rc.12",
"d3-dsv": "^3.0.1",
Expand All @@ -89,7 +89,6 @@
"@langchain/openai": "^0.2.8",
"@upstash/ratelimit": "^1 || ^2",
"@upstash/redis": "^1.34.0",
"@upstash/vector": "^1.1.5",
"react": "^18 || ^19",
"react-dom": "^18 || ^19"
}
Expand Down
1 change: 1 addition & 0 deletions src/context-service/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ export class ContextService {
topK: optionsWithDefault.topK,
namespace: optionsWithDefault.namespace,
contextFilter: optionsWithDefault.contextFilter,
queryMode: optionsWithDefault.queryMode,
});

// Log the result, which will be captured by the outer traceable
Expand Down
5 changes: 4 additions & 1 deletion src/database.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import type { WebBaseLoaderParams } from "@langchain/community/document_loaders/web/cheerio";
import type { Index } from "@upstash/vector";
import type { Index, QueryMode } from "@upstash/vector";
import type { RecursiveCharacterTextSplitterParams } from "langchain/text_splitter";
import { nanoid } from "nanoid";
import { DEFAULT_SIMILARITY_THRESHOLD, DEFAULT_TOP_K } from "./constants";
Expand Down Expand Up @@ -74,6 +74,7 @@ export type VectorPayload = {
topK?: number;
namespace?: string;
contextFilter?: string;
queryMode?: QueryMode;
};

export type ResetOptions = {
Expand Down Expand Up @@ -108,6 +109,7 @@ export class Database {
topK = DEFAULT_TOP_K,
namespace,
contextFilter,
queryMode,
}: VectorPayload): Promise<{ data: string; id: string; metadata: TMetadata }[]> {
const index = this.index;
const result = await index.query<Record<string, string>>(
Expand All @@ -117,6 +119,7 @@ export class Database {
includeData: true,
includeMetadata: true,
...(typeof contextFilter === "string" && { filter: contextFilter }),
queryMode,
},
{ namespace }
);
Expand Down
44 changes: 43 additions & 1 deletion src/rag-chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { ChatOpenAI } from "@langchain/openai";
import { openai } from "@ai-sdk/openai";
import { Ratelimit } from "@upstash/ratelimit";
import { Redis } from "@upstash/redis";
import { Index } from "@upstash/vector";
import { Index, QueryMode } from "@upstash/vector";
import { LangChainAdapter, StreamingTextResponse } from "ai";
import {
afterAll,
Expand Down Expand Up @@ -49,6 +49,11 @@ describe("RAG Chat with advance configs and direct instances", () => {
url: process.env.UPSTASH_REDIS_REST_URL!,
});

const hybridVector = new Index({
url: process.env.HYBRID_EMBEDDING_UPSTASH_VECTOR_REST_URL!,
token: process.env.HYBRID_EMBEDDING_UPSTASH_VECTOR_REST_TOKEN!,
});

const ragChat = new RAGChat({
model: upstashOpenai("gpt-3.5-turbo"),
vector,
Expand All @@ -69,6 +74,7 @@ describe("RAG Chat with advance configs and direct instances", () => {
await vector.reset({ namespace });
await vector.deleteNamespace(namespace);
await redis.flushdb();
await hybridVector.reset({ all: true });
});

test("should get result without streaming", async () => {
Expand All @@ -89,6 +95,42 @@ describe("RAG Chat with advance configs and direct instances", () => {
"Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall."
);
});

test("should retrieve with query mode", async () => {
const ragChat = new RAGChat({
vector: hybridVector,
streaming: true,
model: upstash("meta-llama/Meta-Llama-3-8B-Instruct"),
});

await ragChat.context.add({
type: "text",
data: "foo is bar",
});
await ragChat.context.add({
type: "text",
data: "foo is zed",
});
await awaitUntilIndexed(hybridVector);

const result = await ragChat.chat<{ unit: string }>("what is foo or bar?", {
topK: 1,
similarityThreshold: 0,
queryMode: QueryMode.SPARSE,
onContextFetched(context) {
expect(context.length).toBe(1);
return context;
},
});

expect(result.context).toEqual([
{
data: "foo is bar",
id: expect.any(String) as string,
metadata: undefined,
},
]);
});
});

describe("RAG Chat with ratelimit", () => {
Expand Down
1 change: 1 addition & 0 deletions src/rag-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ export class RAGChat {
? DEFAULT_PROMPT_WITHOUT_RAG
: (options?.promptFn ?? this.config.prompt),
contextFilter: options?.contextFilter ?? undefined,
queryMode: options?.queryMode ?? undefined,
};
}
}
10 changes: 9 additions & 1 deletion src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import type { ChatOpenAI } from "@langchain/openai";
import type { openai } from "@ai-sdk/openai";
import type { Ratelimit } from "@upstash/ratelimit";
import type { Redis } from "@upstash/redis";
import type { Index } from "@upstash/vector";
import type { Index, QueryMode } from "@upstash/vector";
import type { CustomPrompt } from "./rag-chat";
import type { ChatMistralAI } from "@langchain/mistralai";
import type { ChatAnthropic } from "@langchain/anthropic";
Expand Down Expand Up @@ -92,6 +92,14 @@ export type ChatOptions = {
* https://upstash.com/docs/vector/features/filtering#metadata-filtering
*/
contextFilter?: string;

/**
* Query mode to use when querying a hybrid index.
*
* This is useful if your index is a hybrid index and you want to query the
* sparse or dense part when you pass `data`.
*/
queryMode?: QueryMode;
} & CommonChatAndRAGOptions;

export type PrepareChatResult = { data: string; id: string; metadata: unknown }[];
Expand Down

0 comments on commit 55184cc

Please sign in to comment.