Skip to content

Commit

Permalink
- adjust README
Browse files Browse the repository at this point in the history
- add new prompt scope
- add prompt user
- add CHAT Hisotry in prompt
  • Loading branch information
Joao Paulo Nobrega committed Feb 20, 2024
1 parent 408f1dd commit aac1293
Show file tree
Hide file tree
Showing 8 changed files with 238 additions and 59 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ The documents found are used for the context of the Agent.
systemMesssage: '<a message that will specialize your agent>',
chatConfig: {
temperature: 0,
}
},
llmConfig: {
type: '<cloud-provider-llm-service>', // Check availability at <link>
model: '<llm-model>',
Expand Down
60 changes: 30 additions & 30 deletions src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ import {
import { nanoid } from 'ai';
import { interpolate } from './helpers/string.helpers';
import { ChainService, IChainService } from './services/chain';
import ChatHistoryFactory from './services/chat-history';
import { ChatHistoryFactory, IChatHistory } from './services/chat-history';
import LLMFactory from './services/llm';
import VectorStoreFactory from './services/vector-store';
import { BaseMessage } from 'langchain/schema';

const EVENTS_NAME = {
onMessage: 'onMessage',
Expand All @@ -33,6 +34,9 @@ class Agent extends AgentBaseCommand implements IAgent {
private _vectorService: VectorStore;

private _chainService: IChainService;


private _chatHistory: IChatHistory;
private _bufferMemory: BufferMemory;
private _logger: Console;
private _settings: IAgentConfig;
Expand Down Expand Up @@ -61,28 +65,15 @@ class Agent extends AgentBaseCommand implements IAgent {
private async buildHistory(
userSessionId: string,
settings: IDatabaseConfig
): Promise<BufferMemory> {
if (this._bufferMemory && !settings) return this._bufferMemory;

if (!this._bufferMemory && !settings) {
this._bufferMemory = new BufferMemory({
returnMessages: true,
memoryKey: 'chat_history',
});

return this._bufferMemory;
}
): Promise<IChatHistory> {
if (this._chatHistory) return this._chatHistory;

this._bufferMemory = new BufferMemory({
returnMessages: true,
memoryKey: 'chat_history',
chatHistory: await ChatHistoryFactory.create({
...settings,
sessionId: userSessionId || nanoid(), // TODO
}),
});
this._chatHistory = await ChatHistoryFactory.create({
...settings,
sessionId: userSessionId || nanoid(), // TODO
})

return this._bufferMemory;
return this._chatHistory;
}

private async buildRelevantDocs(
Expand Down Expand Up @@ -115,13 +106,11 @@ class Agent extends AgentBaseCommand implements IAgent {
const { question, chatThreadID } = args;

try {
const memoryChat = await this.buildHistory(
const chatHistory = await this.buildHistory(
chatThreadID,
this._settings.dbHistoryConfig
);

memoryChat.chatHistory?.addUserMessage(question);

const { relevantDocs, referenciesDocs } = await this.buildRelevantDocs(
args,
this._settings.vectorStoreConfig
Expand All @@ -130,21 +119,23 @@ class Agent extends AgentBaseCommand implements IAgent {
const chain = await this._chainService.build(
this._llm,
question,
memoryChat
chatHistory.getBufferMemory(),
);
const chat_history = await memoryChat.chatHistory?.getMessages();

const chatMessages = await chatHistory.getMessages();

const result = await chain.call({
referencies: referenciesDocs,
input_documents: relevantDocs,
query: question,
question: question,
chat_history: chat_history?.slice(
-(this._settings?.dbHistoryConfig?.limit || 5)
),
chat_history: chatMessages,
format_chat_messages: await chatHistory.getFormatedMessages(),
user_prompt: this._settings.systemMesssage,
});

await memoryChat.chatHistory?.addAIChatMessage(result?.text);
await chatHistory.addUserMessage(question);
await chatHistory.addAIChatMessage(result?.text);

this.emit(EVENTS_NAME.onMessage, result?.text);

Expand All @@ -157,6 +148,15 @@ class Agent extends AgentBaseCommand implements IAgent {
}
}

getMessageFormat(messages: BaseMessage[]): string {
const cut = messages
.slice(-(this._settings?.dbHistoryConfig?.limit || 5));

const formated = cut.map((message) => `${message._getType().toUpperCase()}: ${message.content}`).join('\n');

return formated;
}

execute(args: any): Promise<void> {
throw new Error(args);
}
Expand Down
15 changes: 15 additions & 0 deletions src/services/chain/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,19 @@ class ChainService {

builtMessage += '\n';
builtMessage += `
Given the user prompt and conversation log, the document context, the API output, and the following database output, formulate a response from a knowledge base.\n
You must follow the following rules and priorities when generating and responding:\n
- Always prioritize user prompt over conversation record.\n
- Ignore any conversation logs that are not directly related to the user prompt.\n
- Only try to answer if a question is asked.\n
- The question must be a single sentence.\n
- You must remove any punctuation from the question.\n
- You must remove any words that are not relevant to the question.\n
- If you are unable to formulate a question, respond in a friendly manner so the user can rephrase the question.\n\n
USER PROMPT: {user_prompt}\n
--------------------------------------
CHAT HISTORY: {format_chat_messages}\n
--------------------------------------
Context found in documents: {summaries}\n
--------------------------------------
Expand Down Expand Up @@ -132,6 +145,8 @@ class ChainService {
'input_documents',
'question',
'chat_history',
'format_chat_messages',
'user_prompt'
],
verbose: this._settings.debug || false,
memory: memoryChat,
Expand Down
57 changes: 43 additions & 14 deletions src/services/chain/openapi-base-chain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,16 @@ export class OpenApiBaseChain extends BaseChain {

private getOpenApiPrompt(): string {
return `You are an AI with expertise in OpenAPI and Swagger.\n
Always answer the question in the language in which the question was asked.\n
- Always respond with the URL;\n
- Never put information or explanations in the answer;\n
${this._input.customizeSystemMessage || ''}
You should follow the following rules when generating and answer:\n
- Only execute the request on the service if the question is not in CHAT HISTORY, if the question has already been answered, use the same answer and do not make a request on the service.
- Only attempt to answer if a question was posed.\n
- Always answer the question in the language in which the question was asked.\n\n
-------------------------------------------
USER PROMPT: ${this._input.customizeSystemMessage || ''}
-------------------------------------------\n
SCHEMA: {schema}\n
-------------------------------------------\n
CHAT HISTORY: {chat_history}\n
CHAT HISTORY: {format_chat_messages}\n
-------------------------------------------\n
QUESTION: {question}\n
------------------------------------------\n
Expand All @@ -66,6 +68,22 @@ export class OpenApiBaseChain extends BaseChain {
return CHAT_COMBINE_PROMPT;
}

private tryParseText(text: string): string {
if (text.includes('No function_call in message')) {
try {
const txtSplitJson = text.split('No function_call in message ')[1];
const txtJson = JSON.parse(txtSplitJson);

return txtJson[0]?.text;
} catch (error) {
return `Sorry, I could not find the answer to your question.`;
}
}

return text;
}


async _call(
values: ChainValues,
runManager?: CallbackManagerForChainRun
Expand All @@ -83,15 +101,26 @@ export class OpenApiBaseChain extends BaseChain {
verbose: true,
});

const answer = await chain.invoke({
question,
schema,
chat_history: values?.chat_history,
});

console.log('OPENAPI Resposta: ', answer);

return { [this.outputKey]: answer?.response };
let answer:string = '';

try {
const rs = await chain.invoke({
question,
schema,
chat_history: values?.chat_history,
format_chat_messages: values?.format_chat_messages,
});

console.log('OPENAPI Resposta: ', answer);

answer = rs?.response;
} catch (error) {
console.error('OPENAPI Error: ', error);

answer = this.tryParseText(error?.message);
} finally {
return { [this.outputKey]: answer };
}
}

_chainType(): string {
Expand Down
13 changes: 7 additions & 6 deletions src/services/chain/sql-database-chain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,12 @@ export default class SqlDatabaseChain extends BaseChain {
Your response must only be a valid SQL query, based on the schema provided.\n
-------------------------------------------\n
Here are some important observations for generating the query:\n
- Only execute the request on the service if the question is not in CHAT HISTORY, if the question has already been answered, use the same answer and do not make a query on the database.
${this.customMessage}\n
-------------------------------------------\n
SCHEMA: {schema}\n
-------------------------------------------\n
CHAT HISTORY: {chat_history}\n
CHAT HISTORY: {format_chat_messages}\n
-------------------------------------------\n
QUESTION: {question}\n
------------------------------------------\n
Expand Down Expand Up @@ -123,8 +124,7 @@ export default class SqlDatabaseChain extends BaseChain {

return sqlBlock;
}

throw new Error(MESSAGES_ERRORS.dataEmpty);
return;
}

// TODO: check implementation for big data
Expand Down Expand Up @@ -173,6 +173,7 @@ export default class SqlDatabaseChain extends BaseChain {
schema: () => table_schema,
question: (input: { question: string }) => input.question,
chat_history: () => values?.chat_history,
format_chat_messages: () => values?.format_chat_messages,
},
this.buildPromptTemplate(this.getSQLPrompt()),
this.llm.bind({ stop: ['\nSQLResult:'] }),
Expand All @@ -190,12 +191,12 @@ export default class SqlDatabaseChain extends BaseChain {
question: (input) => input.question,
query: (input) => input.query,
response: async (input) => {
const sql = input.query.content;
const text = input.query.content;

try {
const sqlParserd = this.parserSQL(sql);
const sqlParserd = this.parserSQL(text);

if (!sqlParserd) return null;
if (!sqlParserd) return text;

console.log(`SQL`, sqlParserd);

Expand Down
30 changes: 26 additions & 4 deletions src/services/chat-history/index.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,39 @@

import { IDatabaseConfig } from '../../interface/agent.interface';
import { BaseChatMessageHistory } from 'langchain/schema';
import { BaseChatMessageHistory, BaseMessage } from 'langchain/schema';

import RedisChatHistory from './redis-chat-history';
import { BufferMemory } from 'langchain/memory';
import MemoryChatHistory from './memory-chat-history';

interface IChatHistory {
addUserMessage(message: string): Promise<void>;
addAIChatMessage(message: string): Promise<void>;
getMessages(): Promise<BaseMessage[]>;
getFormatedMessages(): Promise<string>;
clear(): Promise<void>;
getChatHistory(): BaseChatMessageHistory;
getBufferMemory(): BufferMemory;
}

const Services = {
redis: RedisChatHistory,
memory: MemoryChatHistory,
} as any;

class ChatHistoryFactory {
public static async create(settings: IDatabaseConfig): Promise<BaseChatMessageHistory> {
return await new Services[settings.type](settings).build();
public static async create(settings: IDatabaseConfig): Promise<IChatHistory> {
const service = new Services[settings?.type](settings);

if (!service) {
return await new MemoryChatHistory(settings).build();
}

return await service.build();
}
}

export default ChatHistoryFactory;
export {
IChatHistory,
ChatHistoryFactory,
};
65 changes: 65 additions & 0 deletions src/services/chat-history/memory-chat-history.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import { BaseChatMessageHistory, BaseMessage } from 'langchain/schema';
import { IDatabaseConfig } from '../../interface/agent.interface';
import { IChatHistory } from '.';
import { BufferMemory } from 'langchain/memory';

class MemoryChatHistory implements IChatHistory {
private _settings: IDatabaseConfig;
private _history: BaseChatMessageHistory;
private _bufferMemory: BufferMemory;

constructor(settings: IDatabaseConfig) {
this._settings = settings;
}

addUserMessage(message: string): Promise<void> {
return this._history?.addUserMessage(message);
}

addAIChatMessage(message: string): Promise<void> {
return this._history?.addAIChatMessage(message);
}

getMessages(): Promise<BaseMessage[]> {
return this._history?.getMessages();
}

async getFormatedMessages(): Promise<string> {
const messages = await this._history?.getMessages();

const cut = messages
.slice(-(this._settings?.limit || 5));

const formated = cut.map((message) => `${message._getType().toUpperCase()}: ${message.content}`).join('\n');

return formated;
}

getChatHistory(): BaseChatMessageHistory {
return this._history;
}

getBufferMemory(): BufferMemory {
return this._bufferMemory;
}

clear(): Promise<void> {
return this._history?.clear();
}

async build(): Promise<IChatHistory> {
const { ChatMessageHistory } = (await import('langchain/stores/message/in_memory'));

this._history = new ChatMessageHistory();

this._bufferMemory = new BufferMemory({
returnMessages: true,
memoryKey: 'chat_history',
chatHistory: this._history,
});

return this;
}
}

export default MemoryChatHistory;
Loading

0 comments on commit aac1293

Please sign in to comment.