Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
Merge pull request #48 from SciPhi-AI/Nolan/IntroduceRagChat
Browse files Browse the repository at this point in the history
Introduce Rag Chat
  • Loading branch information
NolanTrem authored Jul 27, 2024
2 parents bb3b376 + 178e32a commit af1d06b
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 5 deletions.
43 changes: 42 additions & 1 deletion __tests__/r2rClientIntegrationSuperUser.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,48 @@ describe("r2rClient Integration Tests", () => {

test("Generate RAG response", async () => {
await expect(client.rag({ query: "test" })).resolves.not.toThrow();
}, 10000);
}, 30000);

test("Generate RAG Chat response", async () => {
const messages = [
{ role: "system", content: "You are a helpful assistant." },
{ role: "user", content: "Tell me about Raskolnikov." },
];

await expect(client.ragChat({ messages })).resolves.not.toThrow();
}, 30000);

test("Generate RAG Chat response with streaming", async () => {
const messages = [
{ role: "system", content: "You are a helpful assistant." },
{ role: "user", content: "Tell me about Raskolnikov." },
];

const streamingConfig = {
messages,
rag_generation_config: { stream: true },
};

const stream = await client.ragChat(streamingConfig);

expect(stream).toBeDefined();
expect(stream instanceof ReadableStream).toBe(true);

let fullResponse = "";
const reader = stream.getReader();

while (true) {
const { done, value } = await reader.read();
if (done) {
break;
}

const chunk = new TextDecoder().decode(value);
fullResponse += chunk;
}

expect(fullResponse.length).toBeGreaterThan(0);
}, 30000);

test("Delete document", async () => {
await expect(
Expand Down
2 changes: 1 addition & 1 deletion __tests__/r2rClientIntegrationUser.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ describe("r2rClient Integration Tests", () => {

test("Generate RAG response", async () => {
await expect(client.rag({ query: "test" })).resolves.not.toThrow();
}, 10000);
}, 30000);

test("Delete document", async () => {
await expect(
Expand Down
2 changes: 1 addition & 1 deletion package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "r2r-js",
"version": "1.2.11",
"version": "1.2.12",
"description": "",
"main": "dist/index.js",
"browser": "dist/index.browser.js",
Expand Down
22 changes: 22 additions & 0 deletions src/models.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ export interface KGSearchSettings {
agent_generation_config?: Record<string, any>;
}

export interface Message {
role: string;
content: string;
}

export interface R2RUpdatePromptRequest {
name: string;
template?: string;
Expand Down Expand Up @@ -127,6 +132,23 @@ export interface R2RLogsRequest {
max_runs_requested: number;
}

export interface R2RRAGChatRequest {
messages: Message[];
vector_search_settings?: {
use_vector_search: boolean;
search_filters?: Record<string, any>;
search_limit: number;
do_hybrid_search: boolean;
};
kg_search_settings?: {
use_kg_search: boolean;
kg_search_generation_config?: Record<string, any>;
};
rag_generation_config?: GenerationConfig;
task_prompt_override?: string;
include_title_if_available?: boolean;
}

export interface R2RPrintRelationshipRequest {
limit: number;
}
73 changes: 72 additions & 1 deletion src/r2rClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ import { feature, initializeTelemetry } from "./feature";
import {
LoginResponse,
UserCreate,
Message,
RefreshTokenResponse,
R2RUpdatePromptRequest,
R2RIngestFilesRequest,
R2RSearchRequest,
R2RRAGChatRequest,
R2RRAGRequest,
R2RDeleteRequest,
R2RAnalyticsRequest,
Expand Down Expand Up @@ -52,7 +54,7 @@ function handleRequestError(response: AxiosResponse): void {
errorContent !== null &&
"detail" in errorContent
) {
const detail = errorContent.detail;
const { detail } = errorContent;
if (typeof detail === "object" && detail !== null) {
message = (detail as { message?: string }).message || response.statusText;
} else {
Expand Down Expand Up @@ -685,6 +687,7 @@ export class r2rClient {
return response;
}

@feature("deleteUser")
async deleteUser(password: string): Promise<any> {
this._ensureAuthenticated();
const response = await this._makeRequest("DELETE", "user", {
Expand All @@ -694,6 +697,74 @@ export class r2rClient {
this.refreshToken = null;
return response;
}

@feature("ragChat")
async ragChat(params: {
messages: Message[];
use_vector_search?: boolean;
search_filters?: Record<string, any>;
search_limit?: number;
do_hybrid_search?: boolean;
use_kg_search?: boolean;
kg_search_generation_config?: Record<string, any>;
rag_generation_config?: GenerationConfig;
task_prompt_override?: string;
include_title_if_available?: boolean;
}): Promise<any> {
this._ensureAuthenticated();

const {
messages,
use_vector_search = true,
search_filters = {},
search_limit = 10,
do_hybrid_search = false,
use_kg_search = false,
kg_search_generation_config,
rag_generation_config,
task_prompt_override,
include_title_if_available = true,
} = params;

const request: R2RRAGChatRequest = {
messages,
vector_search_settings: {
use_vector_search,
search_filters,
search_limit,
do_hybrid_search,
},
kg_search_settings: {
use_kg_search,
kg_search_generation_config,
},
rag_generation_config,
task_prompt_override,
include_title_if_available,
};

if (rag_generation_config && rag_generation_config.stream) {
return this.streamRagChat(request);
} else {
console.log("RAG Chat Request:", JSON.stringify(request, null, 2));
return await this._makeRequest("POST", "rag_chat", { data: request });
}
}

@feature("streamingRagChat")
private async streamRagChat(
request: R2RRAGChatRequest,
): Promise<ReadableStream<Uint8Array>> {
this._ensureAuthenticated();

return this._makeRequest<ReadableStream<Uint8Array>>("POST", "rag_chat", {
data: request,
headers: {
"Content-Type": "application/json",
},
responseType: "stream",
});
}
}

export default r2rClient;

0 comments on commit af1d06b

Please sign in to comment.