diff --git a/__tests__/r2rClientIntegrationSuperUser.test.ts b/__tests__/r2rClientIntegrationSuperUser.test.ts index eebb96a..0f0f9a0 100644 --- a/__tests__/r2rClientIntegrationSuperUser.test.ts +++ b/__tests__/r2rClientIntegrationSuperUser.test.ts @@ -59,7 +59,48 @@ describe("r2rClient Integration Tests", () => { test("Generate RAG response", async () => { await expect(client.rag({ query: "test" })).resolves.not.toThrow(); - }, 10000); + }, 30000); + + test("Generate RAG Chat response", async () => { + const messages = [ + { role: "system", content: "You are a helpful assistant." }, + { role: "user", content: "Tell me about Raskolnikov." }, + ]; + + await expect(client.ragChat({ messages })).resolves.not.toThrow(); + }, 30000); + + test("Generate RAG Chat response with streaming", async () => { + const messages = [ + { role: "system", content: "You are a helpful assistant." }, + { role: "user", content: "Tell me about Raskolnikov." }, + ]; + + const streamingConfig = { + messages, + rag_generation_config: { stream: true }, + }; + + const stream = await client.ragChat(streamingConfig); + + expect(stream).toBeDefined(); + expect(stream instanceof ReadableStream).toBe(true); + + let fullResponse = ""; + const reader = stream.getReader(); + + while (true) { + const { done, value } = await reader.read(); + if (done) { + break; + } + + const chunk = new TextDecoder().decode(value); + fullResponse += chunk; + } + + expect(fullResponse.length).toBeGreaterThan(0); + }, 30000); test("Delete document", async () => { await expect( diff --git a/__tests__/r2rClientIntegrationUser.test.ts b/__tests__/r2rClientIntegrationUser.test.ts index 978cfab..51324a4 100644 --- a/__tests__/r2rClientIntegrationUser.test.ts +++ b/__tests__/r2rClientIntegrationUser.test.ts @@ -70,7 +70,7 @@ describe("r2rClient Integration Tests", () => { test("Generate RAG response", async () => { await expect(client.rag({ query: "test" })).resolves.not.toThrow(); - }, 10000); + }, 30000); test("Delete document", async () => { await expect( diff --git a/package-lock.json b/package-lock.json index 437a1a7..199635a 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,6 +1,6 @@ { "name": "r2r-js", - "version": "1.2.11", + "version": "1.2.12", "lockfileVersion": 3, "requires": true, "packages": { diff --git a/package.json b/package.json index b7710e7..7d8b093 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "r2r-js", - "version": "1.2.11", + "version": "1.2.12", "description": "", "main": "dist/index.js", "browser": "dist/index.browser.js", diff --git a/src/models.tsx b/src/models.tsx index a43eea1..aa21638 100644 --- a/src/models.tsx +++ b/src/models.tsx @@ -59,6 +59,11 @@ export interface KGSearchSettings { agent_generation_config?: Record; } +export interface Message { + role: string; + content: string; +} + export interface R2RUpdatePromptRequest { name: string; template?: string; @@ -127,6 +132,23 @@ export interface R2RLogsRequest { max_runs_requested: number; } +export interface R2RRAGChatRequest { + messages: Message[]; + vector_search_settings?: { + use_vector_search: boolean; + search_filters?: Record; + search_limit: number; + do_hybrid_search: boolean; + }; + kg_search_settings?: { + use_kg_search: boolean; + kg_search_generation_config?: Record; + }; + rag_generation_config?: GenerationConfig; + task_prompt_override?: string; + include_title_if_available?: boolean; +} + export interface R2RPrintRelationshipRequest { limit: number; } diff --git a/src/r2rClient.ts b/src/r2rClient.ts index 899c0f8..481b29c 100644 --- a/src/r2rClient.ts +++ b/src/r2rClient.ts @@ -18,10 +18,12 @@ import { feature, initializeTelemetry } from "./feature"; import { LoginResponse, UserCreate, + Message, RefreshTokenResponse, R2RUpdatePromptRequest, R2RIngestFilesRequest, R2RSearchRequest, + R2RRAGChatRequest, R2RRAGRequest, R2RDeleteRequest, R2RAnalyticsRequest, @@ -52,7 +54,7 @@ function handleRequestError(response: AxiosResponse): void { errorContent !== null && "detail" in errorContent ) { - const detail = errorContent.detail; + const { detail } = errorContent; if (typeof detail === "object" && detail !== null) { message = (detail as { message?: string }).message || response.statusText; } else { @@ -685,6 +687,7 @@ export class r2rClient { return response; } + @feature("deleteUser") async deleteUser(password: string): Promise { this._ensureAuthenticated(); const response = await this._makeRequest("DELETE", "user", { @@ -694,6 +697,74 @@ export class r2rClient { this.refreshToken = null; return response; } + + @feature("ragChat") + async ragChat(params: { + messages: Message[]; + use_vector_search?: boolean; + search_filters?: Record; + search_limit?: number; + do_hybrid_search?: boolean; + use_kg_search?: boolean; + kg_search_generation_config?: Record; + rag_generation_config?: GenerationConfig; + task_prompt_override?: string; + include_title_if_available?: boolean; + }): Promise { + this._ensureAuthenticated(); + + const { + messages, + use_vector_search = true, + search_filters = {}, + search_limit = 10, + do_hybrid_search = false, + use_kg_search = false, + kg_search_generation_config, + rag_generation_config, + task_prompt_override, + include_title_if_available = true, + } = params; + + const request: R2RRAGChatRequest = { + messages, + vector_search_settings: { + use_vector_search, + search_filters, + search_limit, + do_hybrid_search, + }, + kg_search_settings: { + use_kg_search, + kg_search_generation_config, + }, + rag_generation_config, + task_prompt_override, + include_title_if_available, + }; + + if (rag_generation_config && rag_generation_config.stream) { + return this.streamRagChat(request); + } else { + console.log("RAG Chat Request:", JSON.stringify(request, null, 2)); + return await this._makeRequest("POST", "rag_chat", { data: request }); + } + } + + @feature("streamingRagChat") + private async streamRagChat( + request: R2RRAGChatRequest, + ): Promise> { + this._ensureAuthenticated(); + + return this._makeRequest>("POST", "rag_chat", { + data: request, + headers: { + "Content-Type": "application/json", + }, + responseType: "stream", + }); + } } export default r2rClient;