Skip to content

Commit

Permalink
feat: add google chat model
Browse files Browse the repository at this point in the history
  • Loading branch information
TUCEGA authored and TUCEGA committed Feb 23, 2024
1 parent 71401c1 commit 8d2b345
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 11 deletions.
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"node": ">=16"
},
"dependencies": {
"@langchain/google-genai": "^0.0.10",
"ai": "^2.2.10",
"ioredis": "^5.3.2",
"langchain": "^0.0.178",
Expand Down
6 changes: 3 additions & 3 deletions src/interface/agent.interface.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { DataSource } from 'typeorm';

export type LLM_TYPE = 'azure' | 'gpt' | 'aws';
export type LLM_TYPE = 'azure' | 'gpt' | 'aws' | 'google';
export type DATABASE_TYPE = 'cosmos' | 'redis' | 'postgres';

export const SYSTEM_MESSAGE_DEFAULT = `
Expand Down Expand Up @@ -58,9 +58,9 @@ export interface IChatConfig {
export interface ILLMConfig {
type: LLM_TYPE;
model: string;
instance: string;
instance?: string;
apiKey: string;
apiVersion: string;
apiVersion?: string;
}

export interface IVectorStoreConfig {
Expand Down
4 changes: 3 additions & 1 deletion src/services/chain/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {
SystemMessagePromptTemplate,
} from 'langchain/prompts';

import { AIMessage } from 'langchain/schema';
import {
IAgentConfig,
SYSTEM_MESSAGE_DEFAULT,
Expand Down Expand Up @@ -90,11 +91,12 @@ class ChainService {
this.buildSystemMessages(systemMessages)
),
new MessagesPlaceholder('chat_history'),
new AIMessage('Olá! Em que posso ajudar?'),
HumanMessagePromptTemplate.fromTemplate('{question}'),
];

const CHAT_COMBINE_PROMPT =
ChatPromptTemplate.fromPromptMessages(combine_messages);
ChatPromptTemplate.fromMessages(combine_messages);

return CHAT_COMBINE_PROMPT;
}
Expand Down
4 changes: 3 additions & 1 deletion src/services/chain/sql-database-chain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import {
MessagesPlaceholder,
SystemMessagePromptTemplate,
} from 'langchain/prompts';
import { ChainValues } from 'langchain/schema';
import { AIMessage, ChainValues } from 'langchain/schema';
import { StringOutputParser } from 'langchain/schema/output_parser';
import { RunnableSequence } from 'langchain/schema/runnable';
import { SqlDatabase } from 'langchain/sql_db';
Expand Down Expand Up @@ -88,6 +88,7 @@ export default class SqlDatabaseChain extends BaseChain {
getSQLPrompt(): string {
return `
Based on the SQL table schema provided below, write an SQL query that answers the user's question.\n
Remeber to put double quotes around database table names\n
Your response must only be a valid SQL query, based on the schema provided.\n
-------------------------------------------\n
Here are some important observations for generating the query:\n
Expand Down Expand Up @@ -152,6 +153,7 @@ export default class SqlDatabaseChain extends BaseChain {
const combine_messages = [
SystemMessagePromptTemplate.fromTemplate(systemMessages),
new MessagesPlaceholder('chat_history'),
new AIMessage('Aguarde! Estamos pesquisando em nossa base de dados.'),
HumanMessagePromptTemplate.fromTemplate('{question}'),
];

Expand Down
24 changes: 24 additions & 0 deletions src/services/llm/google-llm-service.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import { ChatGoogleGenerativeAI } from '@langchain/google-genai';
import { IChatConfig, ILLMConfig } from '../../interface/agent.interface';

class GoogleLLMService {
private _chatSettings: IChatConfig;
private _llmSettings: ILLMConfig;

constructor(chatSettings: IChatConfig, llmSettings: ILLMConfig) {
this._chatSettings = chatSettings;
this._llmSettings = llmSettings;
}

public build(): ChatGoogleGenerativeAI {
return new ChatGoogleGenerativeAI({
temperature: this._chatSettings.temperature,
modelName: this._llmSettings.model,
apiKey: this._llmSettings.apiKey,
maxOutputTokens: 2048,
streaming: true,
});
}
}

export default GoogleLLMService;
14 changes: 8 additions & 6 deletions src/services/llm/index.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import { IChatConfig, ILLMConfig } from '../../interface/agent.interface';
import { BaseChatModel } from 'langchain/chat_models/base';
import { IChatConfig, ILLMConfig } from '../../interface/agent.interface';

import AzureLLMService from './azure-llm-service';
import GoogleLLMService from './google-llm-service';

const ServiceLLM = {
azure: AzureLLMService,
google: GoogleLLMService,
} as any;

class LLMFactory {
public static create(chatSettings: IChatConfig, llmSettings: ILLMConfig): BaseChatModel {
return new ServiceLLM[llmSettings.type](
chatSettings,
llmSettings,
).build();
public static create(
chatSettings: IChatConfig,
llmSettings: ILLMConfig
): BaseChatModel {
return new ServiceLLM[llmSettings.type](chatSettings, llmSettings).build();
}
}

Expand Down

0 comments on commit 8d2b345

Please sign in to comment.