From aac129398f92e2e9b639919f97949b2abe8ca249 Mon Sep 17 00:00:00 2001 From: Joao Paulo Nobrega Date: Tue, 20 Feb 2024 16:29:06 -0300 Subject: [PATCH] - adjust README - add new prompt scope - add prompt user - add CHAT Hisotry in prompt --- README.md | 2 +- src/agent.ts | 60 ++++++++--------- src/services/chain/index.ts | 15 +++++ src/services/chain/openapi-base-chain.ts | 57 ++++++++++++---- src/services/chain/sql-database-chain.ts | 13 ++-- src/services/chat-history/index.ts | 30 +++++++-- .../chat-history/memory-chat-history.ts | 65 +++++++++++++++++++ .../chat-history/redis-chat-history.ts | 55 ++++++++++++++-- 8 files changed, 238 insertions(+), 59 deletions(-) create mode 100644 src/services/chat-history/memory-chat-history.ts diff --git a/README.md b/README.md index f499ba3..73e5df9 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ The documents found are used for the context of the Agent. systemMesssage: '', chatConfig: { temperature: 0, - } + }, llmConfig: { type: '', // Check availability at model: '', diff --git a/src/agent.ts b/src/agent.ts index b26c9ad..b17b5a5 100644 --- a/src/agent.ts +++ b/src/agent.ts @@ -14,9 +14,10 @@ import { import { nanoid } from 'ai'; import { interpolate } from './helpers/string.helpers'; import { ChainService, IChainService } from './services/chain'; -import ChatHistoryFactory from './services/chat-history'; +import { ChatHistoryFactory, IChatHistory } from './services/chat-history'; import LLMFactory from './services/llm'; import VectorStoreFactory from './services/vector-store'; +import { BaseMessage } from 'langchain/schema'; const EVENTS_NAME = { onMessage: 'onMessage', @@ -33,6 +34,9 @@ class Agent extends AgentBaseCommand implements IAgent { private _vectorService: VectorStore; private _chainService: IChainService; + + + private _chatHistory: IChatHistory; private _bufferMemory: BufferMemory; private _logger: Console; private _settings: IAgentConfig; @@ -61,28 +65,15 @@ class Agent extends AgentBaseCommand implements IAgent { private async buildHistory( userSessionId: string, settings: IDatabaseConfig - ): Promise { - if (this._bufferMemory && !settings) return this._bufferMemory; - - if (!this._bufferMemory && !settings) { - this._bufferMemory = new BufferMemory({ - returnMessages: true, - memoryKey: 'chat_history', - }); - - return this._bufferMemory; - } + ): Promise { + if (this._chatHistory) return this._chatHistory; - this._bufferMemory = new BufferMemory({ - returnMessages: true, - memoryKey: 'chat_history', - chatHistory: await ChatHistoryFactory.create({ - ...settings, - sessionId: userSessionId || nanoid(), // TODO - }), - }); + this._chatHistory = await ChatHistoryFactory.create({ + ...settings, + sessionId: userSessionId || nanoid(), // TODO + }) - return this._bufferMemory; + return this._chatHistory; } private async buildRelevantDocs( @@ -115,13 +106,11 @@ class Agent extends AgentBaseCommand implements IAgent { const { question, chatThreadID } = args; try { - const memoryChat = await this.buildHistory( + const chatHistory = await this.buildHistory( chatThreadID, this._settings.dbHistoryConfig ); - memoryChat.chatHistory?.addUserMessage(question); - const { relevantDocs, referenciesDocs } = await this.buildRelevantDocs( args, this._settings.vectorStoreConfig @@ -130,21 +119,23 @@ class Agent extends AgentBaseCommand implements IAgent { const chain = await this._chainService.build( this._llm, question, - memoryChat + chatHistory.getBufferMemory(), ); - const chat_history = await memoryChat.chatHistory?.getMessages(); + + const chatMessages = await chatHistory.getMessages(); const result = await chain.call({ referencies: referenciesDocs, input_documents: relevantDocs, query: question, question: question, - chat_history: chat_history?.slice( - -(this._settings?.dbHistoryConfig?.limit || 5) - ), + chat_history: chatMessages, + format_chat_messages: await chatHistory.getFormatedMessages(), + user_prompt: this._settings.systemMesssage, }); - await memoryChat.chatHistory?.addAIChatMessage(result?.text); + await chatHistory.addUserMessage(question); + await chatHistory.addAIChatMessage(result?.text); this.emit(EVENTS_NAME.onMessage, result?.text); @@ -157,6 +148,15 @@ class Agent extends AgentBaseCommand implements IAgent { } } + getMessageFormat(messages: BaseMessage[]): string { + const cut = messages + .slice(-(this._settings?.dbHistoryConfig?.limit || 5)); + + const formated = cut.map((message) => `${message._getType().toUpperCase()}: ${message.content}`).join('\n'); + + return formated; + } + execute(args: any): Promise { throw new Error(args); } diff --git a/src/services/chain/index.ts b/src/services/chain/index.ts index a6938f5..eb8afdb 100644 --- a/src/services/chain/index.ts +++ b/src/services/chain/index.ts @@ -58,6 +58,19 @@ class ChainService { builtMessage += '\n'; builtMessage += ` + Given the user prompt and conversation log, the document context, the API output, and the following database output, formulate a response from a knowledge base.\n + You must follow the following rules and priorities when generating and responding:\n + - Always prioritize user prompt over conversation record.\n + - Ignore any conversation logs that are not directly related to the user prompt.\n + - Only try to answer if a question is asked.\n + - The question must be a single sentence.\n + - You must remove any punctuation from the question.\n + - You must remove any words that are not relevant to the question.\n + - If you are unable to formulate a question, respond in a friendly manner so the user can rephrase the question.\n\n + + USER PROMPT: {user_prompt}\n + -------------------------------------- + CHAT HISTORY: {format_chat_messages}\n -------------------------------------- Context found in documents: {summaries}\n -------------------------------------- @@ -132,6 +145,8 @@ class ChainService { 'input_documents', 'question', 'chat_history', + 'format_chat_messages', + 'user_prompt' ], verbose: this._settings.debug || false, memory: memoryChat, diff --git a/src/services/chain/openapi-base-chain.ts b/src/services/chain/openapi-base-chain.ts index 32b0424..cc1b685 100644 --- a/src/services/chain/openapi-base-chain.ts +++ b/src/services/chain/openapi-base-chain.ts @@ -39,14 +39,16 @@ export class OpenApiBaseChain extends BaseChain { private getOpenApiPrompt(): string { return `You are an AI with expertise in OpenAPI and Swagger.\n - Always answer the question in the language in which the question was asked.\n - - Always respond with the URL;\n - - Never put information or explanations in the answer;\n - ${this._input.customizeSystemMessage || ''} + You should follow the following rules when generating and answer:\n + - Only execute the request on the service if the question is not in CHAT HISTORY, if the question has already been answered, use the same answer and do not make a request on the service. + - Only attempt to answer if a question was posed.\n + - Always answer the question in the language in which the question was asked.\n\n + ------------------------------------------- + USER PROMPT: ${this._input.customizeSystemMessage || ''} -------------------------------------------\n SCHEMA: {schema}\n -------------------------------------------\n - CHAT HISTORY: {chat_history}\n + CHAT HISTORY: {format_chat_messages}\n -------------------------------------------\n QUESTION: {question}\n ------------------------------------------\n @@ -66,6 +68,22 @@ export class OpenApiBaseChain extends BaseChain { return CHAT_COMBINE_PROMPT; } + private tryParseText(text: string): string { + if (text.includes('No function_call in message')) { + try { + const txtSplitJson = text.split('No function_call in message ')[1]; + const txtJson = JSON.parse(txtSplitJson); + + return txtJson[0]?.text; + } catch (error) { + return `Sorry, I could not find the answer to your question.`; + } + } + + return text; + } + + async _call( values: ChainValues, runManager?: CallbackManagerForChainRun @@ -83,15 +101,26 @@ export class OpenApiBaseChain extends BaseChain { verbose: true, }); - const answer = await chain.invoke({ - question, - schema, - chat_history: values?.chat_history, - }); - - console.log('OPENAPI Resposta: ', answer); - - return { [this.outputKey]: answer?.response }; + let answer:string = ''; + + try { + const rs = await chain.invoke({ + question, + schema, + chat_history: values?.chat_history, + format_chat_messages: values?.format_chat_messages, + }); + + console.log('OPENAPI Resposta: ', answer); + + answer = rs?.response; + } catch (error) { + console.error('OPENAPI Error: ', error); + + answer = this.tryParseText(error?.message); + } finally { + return { [this.outputKey]: answer }; + } } _chainType(): string { diff --git a/src/services/chain/sql-database-chain.ts b/src/services/chain/sql-database-chain.ts index ff903e8..39e1fd3 100644 --- a/src/services/chain/sql-database-chain.ts +++ b/src/services/chain/sql-database-chain.ts @@ -91,11 +91,12 @@ export default class SqlDatabaseChain extends BaseChain { Your response must only be a valid SQL query, based on the schema provided.\n -------------------------------------------\n Here are some important observations for generating the query:\n + - Only execute the request on the service if the question is not in CHAT HISTORY, if the question has already been answered, use the same answer and do not make a query on the database. ${this.customMessage}\n -------------------------------------------\n SCHEMA: {schema}\n -------------------------------------------\n - CHAT HISTORY: {chat_history}\n + CHAT HISTORY: {format_chat_messages}\n -------------------------------------------\n QUESTION: {question}\n ------------------------------------------\n @@ -123,8 +124,7 @@ export default class SqlDatabaseChain extends BaseChain { return sqlBlock; } - - throw new Error(MESSAGES_ERRORS.dataEmpty); + return; } // TODO: check implementation for big data @@ -173,6 +173,7 @@ export default class SqlDatabaseChain extends BaseChain { schema: () => table_schema, question: (input: { question: string }) => input.question, chat_history: () => values?.chat_history, + format_chat_messages: () => values?.format_chat_messages, }, this.buildPromptTemplate(this.getSQLPrompt()), this.llm.bind({ stop: ['\nSQLResult:'] }), @@ -190,12 +191,12 @@ export default class SqlDatabaseChain extends BaseChain { question: (input) => input.question, query: (input) => input.query, response: async (input) => { - const sql = input.query.content; + const text = input.query.content; try { - const sqlParserd = this.parserSQL(sql); + const sqlParserd = this.parserSQL(text); - if (!sqlParserd) return null; + if (!sqlParserd) return text; console.log(`SQL`, sqlParserd); diff --git a/src/services/chat-history/index.ts b/src/services/chat-history/index.ts index 434aa3d..9817c96 100644 --- a/src/services/chat-history/index.ts +++ b/src/services/chat-history/index.ts @@ -1,17 +1,39 @@ import { IDatabaseConfig } from '../../interface/agent.interface'; -import { BaseChatMessageHistory } from 'langchain/schema'; +import { BaseChatMessageHistory, BaseMessage } from 'langchain/schema'; import RedisChatHistory from './redis-chat-history'; +import { BufferMemory } from 'langchain/memory'; +import MemoryChatHistory from './memory-chat-history'; + +interface IChatHistory { + addUserMessage(message: string): Promise; + addAIChatMessage(message: string): Promise; + getMessages(): Promise; + getFormatedMessages(): Promise; + clear(): Promise; + getChatHistory(): BaseChatMessageHistory; + getBufferMemory(): BufferMemory; +} const Services = { redis: RedisChatHistory, + memory: MemoryChatHistory, } as any; class ChatHistoryFactory { - public static async create(settings: IDatabaseConfig): Promise { - return await new Services[settings.type](settings).build(); + public static async create(settings: IDatabaseConfig): Promise { + const service = new Services[settings?.type](settings); + + if (!service) { + return await new MemoryChatHistory(settings).build(); + } + + return await service.build(); } } -export default ChatHistoryFactory; +export { + IChatHistory, + ChatHistoryFactory, +}; diff --git a/src/services/chat-history/memory-chat-history.ts b/src/services/chat-history/memory-chat-history.ts new file mode 100644 index 0000000..1c65329 --- /dev/null +++ b/src/services/chat-history/memory-chat-history.ts @@ -0,0 +1,65 @@ +import { BaseChatMessageHistory, BaseMessage } from 'langchain/schema'; +import { IDatabaseConfig } from '../../interface/agent.interface'; +import { IChatHistory } from '.'; +import { BufferMemory } from 'langchain/memory'; + +class MemoryChatHistory implements IChatHistory { + private _settings: IDatabaseConfig; + private _history: BaseChatMessageHistory; + private _bufferMemory: BufferMemory; + + constructor(settings: IDatabaseConfig) { + this._settings = settings; + } + + addUserMessage(message: string): Promise { + return this._history?.addUserMessage(message); + } + + addAIChatMessage(message: string): Promise { + return this._history?.addAIChatMessage(message); + } + + getMessages(): Promise { + return this._history?.getMessages(); + } + + async getFormatedMessages(): Promise { + const messages = await this._history?.getMessages(); + + const cut = messages + .slice(-(this._settings?.limit || 5)); + + const formated = cut.map((message) => `${message._getType().toUpperCase()}: ${message.content}`).join('\n'); + + return formated; + } + + getChatHistory(): BaseChatMessageHistory { + return this._history; + } + + getBufferMemory(): BufferMemory { + return this._bufferMemory; + } + + clear(): Promise { + return this._history?.clear(); + } + + async build(): Promise { + const { ChatMessageHistory } = (await import('langchain/stores/message/in_memory')); + + this._history = new ChatMessageHistory(); + + this._bufferMemory = new BufferMemory({ + returnMessages: true, + memoryKey: 'chat_history', + chatHistory: this._history, + }); + + return this; + } +} + +export default MemoryChatHistory; \ No newline at end of file diff --git a/src/services/chat-history/redis-chat-history.ts b/src/services/chat-history/redis-chat-history.ts index eb1b3a7..dd470a6 100644 --- a/src/services/chat-history/redis-chat-history.ts +++ b/src/services/chat-history/redis-chat-history.ts @@ -1,9 +1,13 @@ -import { BaseChatMessageHistory } from 'langchain/schema'; +import { BaseChatMessageHistory, BaseMessage } from 'langchain/schema'; import { IDatabaseConfig } from '../../interface/agent.interface'; +import { IChatHistory } from '.'; +import { BufferMemory } from 'langchain/memory'; -class RedisChatHistory { +class RedisChatHistory implements IChatHistory { private _settings: IDatabaseConfig; private _redisClientInstance: any; + private _history: BaseChatMessageHistory; + private _bufferMemory: BufferMemory; constructor(settings: IDatabaseConfig) { this._settings = settings; @@ -26,17 +30,60 @@ class RedisChatHistory { return this._redisClientInstance; } + + addUserMessage(message: string): Promise { + return this._history?.addUserMessage(message); + } + + addAIChatMessage(message: string): Promise { + return this._history?.addAIChatMessage(message); + } + + getMessages(): Promise { + return this._history?.getMessages(); + } + + async getFormatedMessages(): Promise { + const messages = await this._history?.getMessages(); + + const cut = messages + .slice(-(this._settings?.limit || 5)); + + const formated = cut.map((message) => `${message._getType().toUpperCase()}: ${message.content}`).join('\n'); + + return formated; + } + + getChatHistory(): BaseChatMessageHistory { + return this._history; + } + + getBufferMemory(): BufferMemory { + return this._bufferMemory; + } + + clear(): Promise { + return this._history?.clear(); + } - async build(): Promise { + async build(): Promise { const { RedisChatMessageHistory } = (await import('langchain/stores/message/ioredis')); const client = await this.createClient(); - return new RedisChatMessageHistory({ + this._history = new RedisChatMessageHistory({ sessionTTL: this._settings.sessionTTL, sessionId: this._settings.sessionId || new Date().toISOString(), client, }); + + this._bufferMemory = new BufferMemory({ + returnMessages: true, + memoryKey: 'chat_history', + chatHistory: this._history, + }); + + return this; } }