Skip to content

Commit

Permalink
Merge pull request #48 from dev-jpnobrega/feature/addChatHistory
Browse files Browse the repository at this point in the history
Add Chat History - Sql Chain and OpenAPI Chain
  • Loading branch information
zanova authored Feb 8, 2024
2 parents df8841f + 408f1dd commit 2712fb6
Show file tree
Hide file tree
Showing 5 changed files with 284 additions and 165 deletions.
83 changes: 60 additions & 23 deletions src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,28 @@ import { BufferMemory } from 'langchain/memory';
import { VectorStore } from 'langchain/vectorstores/base';

import AgentBaseCommand from './agent.base';
import { IAgentConfig, IDatabaseConfig, IInputProps, IAgent, IVectorStoreConfig } from './interface/agent.interface';
import {
IAgent,
IAgentConfig,
IDatabaseConfig,
IInputProps,
IVectorStoreConfig,
} from './interface/agent.interface';

import VectorStoreFactory from './services/vector-store';
import ChatHistoryFactory from './services/chat-history';
import LLMFactory from './services/llm';
import { ChainService, IChainService } from './services/chain';
import { nanoid } from 'ai';
import { interpolate } from './helpers/string.helpers';
import { ChainService, IChainService } from './services/chain';
import ChatHistoryFactory from './services/chat-history';
import LLMFactory from './services/llm';
import VectorStoreFactory from './services/vector-store';

const EVENTS_NAME = {
onMessage: 'onMessage',
onToken: 'onToken',
onEnd: 'onEnd',
onError: 'onError',
onMessageSystem: 'onMessageSystem',
onMessageHuman: 'onMessageHuman'
onMessageHuman: 'onMessageHuman',
};

class Agent extends AgentBaseCommand implements IAgent {
Expand Down Expand Up @@ -46,15 +52,23 @@ class Agent extends AgentBaseCommand implements IAgent {
this._chainService = new ChainService(settings);

if (settings?.vectorStoreConfig)
this._vectorService = VectorStoreFactory.create(settings.vectorStoreConfig, settings.llmConfig);
this._vectorService = VectorStoreFactory.create(
settings.vectorStoreConfig,
settings.llmConfig
);
}

private async buildHistory(userSessionId: string, settings: IDatabaseConfig): Promise<BufferMemory> {
if (this._bufferMemory && !settings)
return this._bufferMemory;
private async buildHistory(
userSessionId: string,
settings: IDatabaseConfig
): Promise<BufferMemory> {
if (this._bufferMemory && !settings) return this._bufferMemory;

if (!this._bufferMemory && !settings) {
this._bufferMemory = new BufferMemory({ returnMessages: true, memoryKey: 'chat_history' });
this._bufferMemory = new BufferMemory({
returnMessages: true,
memoryKey: 'chat_history',
});

return this._bufferMemory;
}
Expand All @@ -71,17 +85,28 @@ class Agent extends AgentBaseCommand implements IAgent {
return this._bufferMemory;
}

private async buildRelevantDocs(args: IInputProps, settings: IVectorStoreConfig): Promise<any> {
private async buildRelevantDocs(
args: IInputProps,
settings: IVectorStoreConfig
): Promise<any> {
if (!settings) return { relevantDocs: [], referenciesDocs: [] };

const { customFilters = null } = settings;

const relevantDocs = await this._vectorService.similaritySearch(args.question, 10, {
vectorFields: settings.vectorFieldName,
filter: customFilters ? interpolate<IInputProps>(customFilters, args) : '',
});

const referenciesDocs = relevantDocs.map((doc: { metadata: unknown; }) => doc.metadata).join(', ');
const relevantDocs = await this._vectorService.similaritySearch(
args.question,
10,
{
vectorFields: settings.vectorFieldName,
filter: customFilters
? interpolate<IInputProps>(customFilters, args)
: '',
}
);

const referenciesDocs = relevantDocs
.map((doc: { metadata: unknown }) => doc.metadata)
.join(', ');

return { relevantDocs, referenciesDocs };
}
Expand All @@ -90,20 +115,33 @@ class Agent extends AgentBaseCommand implements IAgent {
const { question, chatThreadID } = args;

try {
const memoryChat = await this.buildHistory(chatThreadID, this._settings.dbHistoryConfig);
const memoryChat = await this.buildHistory(
chatThreadID,
this._settings.dbHistoryConfig
);

memoryChat.chatHistory?.addUserMessage(question);

const { relevantDocs, referenciesDocs } = await this.buildRelevantDocs(args, this._settings.vectorStoreConfig);
const { relevantDocs, referenciesDocs } = await this.buildRelevantDocs(
args,
this._settings.vectorStoreConfig
);

const chain = await this._chainService.build(this._llm, question);
const chain = await this._chainService.build(
this._llm,
question,
memoryChat
);
const chat_history = await memoryChat.chatHistory?.getMessages();

const result = await chain.call({
referencies: referenciesDocs,
input_documents: relevantDocs,
query: question,
question: question,
chat_history: await memoryChat.chatHistory?.getMessages(),
chat_history: chat_history?.slice(
-(this._settings?.dbHistoryConfig?.limit || 5)
),
});

await memoryChat.chatHistory?.addAIChatMessage(result?.text);
Expand All @@ -124,5 +162,4 @@ class Agent extends AgentBaseCommand implements IAgent {
}
}


export default Agent;
101 changes: 51 additions & 50 deletions src/interface/agent.interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,85 +18,86 @@ export const SYSTEM_MESSAGE_DEFAULT = `
`;

export interface IDatabaseConfig {
type: DATABASE_TYPE,
host: string,
port: number,
ssl?: boolean,
sessionId?: string,
sessionTTL?: number,
username?: string,
password?: string,
database?: string | number,
container?: string,
synchronize?: boolean,
type: DATABASE_TYPE;
host: string;
port: number;
ssl?: boolean;
sessionId?: string;
sessionTTL?: number;
username?: string;
password?: string;
database?: string | number;
container?: string;
synchronize?: boolean;
limit?: number;
}

export interface IDataSourceConfig {
dataSource: DataSource,
includesTables?: string[],
ignoreTables?: string[],
customizeSystemMessage?: string,
ssl?: boolean,
dataSource: DataSource;
includesTables?: string[];
ignoreTables?: string[];
customizeSystemMessage?: string;
ssl?: boolean;
}

export interface IOpenAPIConfig {
data: string,
customizeSystemMessage?: string,
xApiKey?: string,
authorization?: string,
data: string;
customizeSystemMessage?: string;
xApiKey?: string;
authorization?: string;
}

export interface IChatConfig {
temperature: number,
topP?: number,
frequencyPenalty?: number,
presencePenalty?: number,
maxTokens?: number,
temperature: number;
topP?: number;
frequencyPenalty?: number;
presencePenalty?: number;
maxTokens?: number;
}

export interface ILLMConfig {
type: LLM_TYPE,
model: string,
instance: string,
apiKey: string,
apiVersion: string,
type: LLM_TYPE;
model: string;
instance: string;
apiKey: string;
apiVersion: string;
}

export interface IVectorStoreConfig {
name: string,
type: LLM_TYPE,
apiKey: string,
apiVersion: string,
indexes: string[] | string,
vectorFieldName: string,
model?: string,
customFilters?: string,
name: string;
type: LLM_TYPE;
apiKey: string;
apiVersion: string;
indexes: string[] | string;
vectorFieldName: string;
model?: string;
customFilters?: string;
}

export interface IAgentConfig {
name?: string,
debug?: boolean,
systemMesssage?: string | typeof SYSTEM_MESSAGE_DEFAULT,
name?: string;
debug?: boolean;
systemMesssage?: string | typeof SYSTEM_MESSAGE_DEFAULT;
llmConfig: ILLMConfig;
chatConfig: IChatConfig,
dbHistoryConfig?: IDatabaseConfig,
chatConfig: IChatConfig;
dbHistoryConfig?: IDatabaseConfig;
vectorStoreConfig?: IVectorStoreConfig;
dataSourceConfig?: IDataSourceConfig;
openAPIConfig?: IOpenAPIConfig;
};
}

export interface IInputProps {
question?: string,
userSessionId?: string,
chatThreadID?: string,
question?: string;
userSessionId?: string;
chatThreadID?: string;
}

export interface TModel extends Record<string, unknown> { }
export interface TModel extends Record<string, unknown> {}

export interface IAgent {
call(input: IInputProps): Promise<void>;

emit(event: string, ...args: any[]): void;

on(eventName: string | symbol, listener: (...args: any[]) => void): this
}
on(eventName: string | symbol, listener: (...args: any[]) => void): this;
}
50 changes: 35 additions & 15 deletions src/services/chain/index.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,31 @@

import {
BaseChain,
SequentialChain,
loadQAMapReduceChain,
} from 'langchain/chains';
import { BaseChatModel } from 'langchain/chat_models/base';
import { BaseChain, SequentialChain, loadQAMapReduceChain } from 'langchain/chains';

import { BasePromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate } from 'langchain/prompts';

import { IAgentConfig, SYSTEM_MESSAGE_DEFAULT } from '../../interface/agent.interface';
import SqlChain from './sql-chain';
import {
BasePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
} from 'langchain/prompts';

import {
IAgentConfig,
SYSTEM_MESSAGE_DEFAULT,
} from '../../interface/agent.interface';
import OpenAPIChain from './openapi-chain';
import SqlChain from './sql-chain';

interface IChain {
create(llm: BaseChatModel, ...args: any): Promise<BaseChain>
create(llm: BaseChatModel, ...args: any): Promise<BaseChain>;
}

interface IChainService {
build(llm: BaseChatModel, ...args: any): Promise<BaseChain>
build(llm: BaseChatModel, ...args: any): Promise<BaseChain>;
}

class ChainService {
Expand Down Expand Up @@ -60,12 +72,12 @@ class ChainService {
--------------------------------------
`;
}

if (this._isOpenAPIChainEnabled) {
builtMessage += `
--------------------------------------
API Result: {openAPIResult}\n
--------------------------------------
--------------------------------------
`;
}

Expand All @@ -87,12 +99,15 @@ class ChainService {
return CHAT_COMBINE_PROMPT;
}

private async buildChains(llm: BaseChatModel, ...args: any): Promise<BaseChain[]> {
private async buildChains(
llm: BaseChatModel,
...args: any
): Promise<BaseChain[]> {
const enabledChains = this.checkEnabledChains(this._settings);

const chainQA = loadQAMapReduceChain(llm, {
combinePrompt: this.buildPromptTemplate(
this._settings.systemMesssage || SYSTEM_MESSAGE_DEFAULT,
this._settings.systemMesssage || SYSTEM_MESSAGE_DEFAULT
),
});

Expand All @@ -106,19 +121,24 @@ class ChainService {
}

public async build(llm: BaseChatModel, ...args: any): Promise<BaseChain> {
const { memoryChat } = args;
const chains = await this.buildChains(llm, args);

const enhancementChain = new SequentialChain({
chains,
inputVariables: [
'query', 'referencies', 'input_documents', 'question', 'chat_history',
'query',
'referencies',
'input_documents',
'question',
'chat_history',
],
verbose: this._settings.debug || false,
memory: memoryChat,
});


return enhancementChain;
}
}

export { ChainService, IChainService, IChain };
export { ChainService, IChain, IChainService };
Loading

0 comments on commit 2712fb6

Please sign in to comment.