From 66deb6b01d75383c50e5f0a446df4d3ef5d5e99b Mon Sep 17 00:00:00 2001 From: BrandonStudio Date: Fri, 8 Nov 2024 15:53:12 +0800 Subject: [PATCH 1/2] Fix GitHub models Refactor OpenAICompatibleFactory --- src/config/modelProviders/github.ts | 29 +++++++---- src/libs/agent-runtime/github/index.test.ts | 35 ------------- src/libs/agent-runtime/github/index.ts | 52 ++++++++++++++++++- src/libs/agent-runtime/togetherai/index.ts | 3 +- .../utils/openaiCompatibleFactory/index.ts | 21 ++++---- 5 files changed, 83 insertions(+), 57 deletions(-) diff --git a/src/config/modelProviders/github.ts b/src/config/modelProviders/github.ts index 0394b2b64d88..e23b0a22b05c 100644 --- a/src/config/modelProviders/github.ts +++ b/src/config/modelProviders/github.ts @@ -15,7 +15,8 @@ const Github: ModelProviderCard = { vision: true, }, { - description: '专注于高级推理和解决复杂问题,包括数学和科学任务。非常适合需要深度上下文理解和自主工作流程的应用。', + description: + '专注于高级推理和解决复杂问题,包括数学和科学任务。非常适合需要深度上下文理解和自主工作流程的应用。', displayName: 'OpenAI o1-preview', enabled: true, functionCall: false, @@ -45,7 +46,8 @@ const Github: ModelProviderCard = { vision: true, }, { - description: '一个52B参数(12B活跃)的多语言模型,提供256K长上下文窗口、函数调用、结构化输出和基于事实的生成。', + description: + '一个52B参数(12B活跃)的多语言模型,提供256K长上下文窗口、函数调用、结构化输出和基于事实的生成。', displayName: 'AI21 Jamba 1.5 Mini', functionCall: true, id: 'ai21-jamba-1.5-mini', @@ -53,7 +55,8 @@ const Github: ModelProviderCard = { tokens: 262_144, }, { - description: '一个398B参数(94B活跃)的多语言模型,提供256K长上下文窗口、函数调用、结构化输出和基于事实的生成。', + description: + '一个398B参数(94B活跃)的多语言模型,提供256K长上下文窗口、函数调用、结构化输出和基于事实的生成。', displayName: 'AI21 Jamba 1.5 Large', functionCall: true, id: 'ai21-jamba-1.5-large', @@ -61,7 +64,8 @@ const Github: ModelProviderCard = { tokens: 262_144, }, { - description: 'Command R是一个可扩展的生成模型,旨在针对RAG和工具使用,使企业能够实现生产级AI。', + description: + 'Command R是一个可扩展的生成模型,旨在针对RAG和工具使用,使企业能够实现生产级AI。', displayName: 'Cohere Command R', id: 'cohere-command-r', maxOutput: 4096, @@ -75,7 +79,8 @@ const Github: ModelProviderCard = { tokens: 131_072, }, { - description: 'Mistral Nemo是一种尖端的语言模型(LLM),在其尺寸类别中拥有最先进的推理、世界知识和编码能力。', + description: + 'Mistral Nemo是一种尖端的语言模型(LLM),在其尺寸类别中拥有最先进的推理、世界知识和编码能力。', displayName: 'Mistral Nemo', id: 'mistral-nemo', maxOutput: 4096, @@ -89,7 +94,8 @@ const Github: ModelProviderCard = { tokens: 131_072, }, { - description: 'Mistral的旗舰模型,适合需要大规模推理能力或高度专业化的复杂任务(合成文本生成、代码生成、RAG或代理)。', + description: + 'Mistral的旗舰模型,适合需要大规模推理能力或高度专业化的复杂任务(合成文本生成、代码生成、RAG或代理)。', displayName: 'Mistral Large', id: 'mistral-large', maxOutput: 4096, @@ -112,21 +118,24 @@ const Github: ModelProviderCard = { vision: true, }, { - description: 'Llama 3.1指令调优的文本模型,针对多语言对话用例进行了优化,在许多可用的开源和封闭聊天模型中,在常见行业基准上表现优异。', + description: + 'Llama 3.1指令调优的文本模型,针对多语言对话用例进行了优化,在许多可用的开源和封闭聊天模型中,在常见行业基准上表现优异。', displayName: 'Meta Llama 3.1 8B', id: 'meta-llama-3.1-8b-instruct', maxOutput: 4096, tokens: 131_072, }, { - description: 'Llama 3.1指令调优的文本模型,针对多语言对话用例进行了优化,在许多可用的开源和封闭聊天模型中,在常见行业基准上表现优异。', + description: + 'Llama 3.1指令调优的文本模型,针对多语言对话用例进行了优化,在许多可用的开源和封闭聊天模型中,在常见行业基准上表现优异。', displayName: 'Meta Llama 3.1 70B', id: 'meta-llama-3.1-70b-instruct', maxOutput: 4096, tokens: 131_072, }, { - description: 'Llama 3.1指令调优的文本模型,针对多语言对话用例进行了优化,在许多可用的开源和封闭聊天模型中,在常见行业基准上表现优异。', + description: + 'Llama 3.1指令调优的文本模型,针对多语言对话用例进行了优化,在许多可用的开源和封闭聊天模型中,在常见行业基准上表现优异。', displayName: 'Meta Llama 3.1 405B', id: 'meta-llama-3.1-405b-instruct', maxOutput: 4096, @@ -209,7 +218,7 @@ const Github: ModelProviderCard = { description: '通过GitHub模型,开发人员可以成为AI工程师,并使用行业领先的AI模型进行构建。', enabled: true, id: 'github', - // modelList: { showModelFetcher: true }, + modelList: { showModelFetcher: true }, // I'm not sure if it is good to show the model fetcher, as remote list is not complete. name: 'GitHub', url: 'https://github.com/marketplace/models', }; diff --git a/src/libs/agent-runtime/github/index.test.ts b/src/libs/agent-runtime/github/index.test.ts index e466ac155389..377cc45ac01a 100644 --- a/src/libs/agent-runtime/github/index.test.ts +++ b/src/libs/agent-runtime/github/index.test.ts @@ -119,41 +119,6 @@ describe('LobeGithubAI', () => { } }); - it('should return GithubBizError with an cause response with desensitize Url', async () => { - // Arrange - const errorInfo = { - stack: 'abc', - cause: { message: 'api is undefined' }, - }; - const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {}); - - instance = new LobeGithubAI({ - apiKey: 'test', - baseURL: 'https://api.abc.com/v1', - }); - - vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError); - - // Act - try { - await instance.chat({ - messages: [{ content: 'Hello', role: 'user' }], - model: 'meta-llama-3-70b-instruct', - temperature: 0.7, - }); - } catch (e) { - expect(e).toEqual({ - endpoint: 'https://api.***.com/v1', - error: { - cause: { message: 'api is undefined' }, - stack: 'abc', - }, - errorType: bizErrorType, - provider, - }); - } - }); - it('should throw an InvalidGithubToken error type on 401 status code', async () => { // Mock the API call to simulate a 401 error const error = new Error('InvalidApiKey') as any; diff --git a/src/libs/agent-runtime/github/index.ts b/src/libs/agent-runtime/github/index.ts index fd7fa6f280bc..781adcd84f3c 100644 --- a/src/libs/agent-runtime/github/index.ts +++ b/src/libs/agent-runtime/github/index.ts @@ -1,7 +1,35 @@ +import { LOBE_DEFAULT_MODEL_LIST } from '@/config/modelProviders'; +import type { ChatModelCard } from '@/types/llm'; + import { AgentRuntimeErrorType } from '../error'; import { o1Models, pruneO1Payload } from '../openai'; import { ModelProvider } from '../types'; -import { LobeOpenAICompatibleFactory } from '../utils/openaiCompatibleFactory'; +import { + CHAT_MODELS_BLOCK_LIST, + LobeOpenAICompatibleFactory, +} from '../utils/openaiCompatibleFactory'; + +enum Task { + 'chat-completion', + 'embeddings', +} + +/* eslint-disable typescript-sort-keys/interface */ +type Model = { + id: string; + name: string; + friendly_name: string; + model_version: number; + publisher: string; + model_family: string; + model_registry: string; + license: string; + task: Task; + description: string; + summary: string; + tags: string[]; +}; +/* eslint-enable typescript-sort-keys/interface */ export const LobeGithubAI = LobeOpenAICompatibleFactory({ baseURL: 'https://models.inference.ai.azure.com', @@ -23,5 +51,27 @@ export const LobeGithubAI = LobeOpenAICompatibleFactory({ bizError: AgentRuntimeErrorType.ProviderBizError, invalidAPIKey: AgentRuntimeErrorType.InvalidGithubToken, }, + models: async ({ client }) => { + const modelsPage = (await client.models.list()) as any; + const modelList: Model[] = modelsPage.body; + return modelList + .filter((model) => { + return CHAT_MODELS_BLOCK_LIST.every( + (keyword) => !model.name.toLowerCase().includes(keyword), + ); + }) + .map((model) => { + const knownModel = LOBE_DEFAULT_MODEL_LIST.find((m) => m.id === model.name); + + if (knownModel) return knownModel; + + return { + description: model.description, + displayName: model.friendly_name, + id: model.name, + }; + }) + .filter(Boolean) as ChatModelCard[]; + }, provider: ModelProvider.Github, }); diff --git a/src/libs/agent-runtime/togetherai/index.ts b/src/libs/agent-runtime/togetherai/index.ts index b291aa5faad7..73b06cbcd314 100644 --- a/src/libs/agent-runtime/togetherai/index.ts +++ b/src/libs/agent-runtime/togetherai/index.ts @@ -16,7 +16,8 @@ export const LobeTogetherAI = LobeOpenAICompatibleFactory({ debug: { chatCompletion: () => process.env.DEBUG_TOGETHERAI_CHAT_COMPLETION === '1', }, - models: async ({ apiKey }) => { + models: async ({ client }) => { + const apiKey = client.apiKey; const data = await fetch(`${baseURL}/api/models`, { headers: { Authorization: `Bearer ${apiKey}`, diff --git a/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts b/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts index 74f76d70a8bf..5653e0f9a93e 100644 --- a/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts +++ b/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts @@ -2,17 +2,18 @@ import OpenAI, { ClientOptions } from 'openai'; import { Stream } from 'openai/streaming'; import { LOBE_DEFAULT_MODEL_LIST } from '@/config/modelProviders'; -import { ChatModelCard } from '@/types/llm'; +import type { ChatModelCard } from '@/types/llm'; import { LobeRuntimeAI } from '../../BaseAI'; import { AgentRuntimeErrorType, ILobeAgentRuntimeErrorType } from '../../error'; -import { +import type { ChatCompetitionOptions, ChatCompletionErrorPayload, ChatStreamPayload, Embeddings, EmbeddingsOptions, EmbeddingsPayload, + ModelProvider, TextToImagePayload, TextToSpeechOptions, TextToSpeechPayload, @@ -26,7 +27,7 @@ import { StreamingResponse } from '../response'; import { OpenAIStream, OpenAIStreamOptions } from '../streams'; // the model contains the following keywords is not a chat model, so we should filter them out -const CHAT_MODELS_BLOCK_LIST = [ +export const CHAT_MODELS_BLOCK_LIST = [ 'embedding', 'davinci', 'curie', @@ -77,7 +78,7 @@ interface OpenAICompatibleFactoryOptions = any> { invalidAPIKey: ILobeAgentRuntimeErrorType; }; models?: - | ((params: { apiKey: string }) => Promise) + | ((params: { client: OpenAI }) => Promise) | { transformModel?: (model: OpenAI.Model) => ChatModelCard; }; @@ -157,7 +158,7 @@ export const LobeOpenAICompatibleFactory = = any> client!: OpenAI; baseURL!: string; - private _options: ConstructorOptions; + protected _options: ConstructorOptions; constructor(options: ClientOptions & Record = {}) { const _options = { @@ -249,7 +250,7 @@ export const LobeOpenAICompatibleFactory = = any> } async models() { - if (typeof models === 'function') return models({ apiKey: this.client.apiKey }); + if (typeof models === 'function') return models({ client: this.client }); const list = await this.client.models.list(); @@ -312,7 +313,7 @@ export const LobeOpenAICompatibleFactory = = any> } } - private handleError(error: any): ChatCompletionErrorPayload { + protected handleError(error: any): ChatCompletionErrorPayload { let desensitizedEndpoint = this.baseURL; // refs: https://github.com/lobehub/lobe-chat/issues/842 @@ -326,7 +327,7 @@ export const LobeOpenAICompatibleFactory = = any> if (errorResult) return AgentRuntimeError.chat({ ...errorResult, - provider, + provider: provider, } as ChatCompletionErrorPayload); } @@ -337,7 +338,7 @@ export const LobeOpenAICompatibleFactory = = any> endpoint: desensitizedEndpoint, error: error as any, errorType: ErrorType.invalidAPIKey, - provider: provider as any, + provider: provider as ModelProvider, }); } @@ -353,7 +354,7 @@ export const LobeOpenAICompatibleFactory = = any> endpoint: desensitizedEndpoint, error: errorResult, errorType: RuntimeError || ErrorType.bizError, - provider: provider as any, + provider: provider as ModelProvider, }); } }; From 3d0149e2feeb367c47bbad9132ee06a70fd0aab5 Mon Sep 17 00:00:00 2001 From: BrandonStudio Date: Mon, 11 Nov 2024 09:39:43 +0800 Subject: [PATCH 2/2] Restore unnecessary change --- src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts b/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts index 5653e0f9a93e..1c8acb884db7 100644 --- a/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts +++ b/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts @@ -327,7 +327,7 @@ export const LobeOpenAICompatibleFactory = = any> if (errorResult) return AgentRuntimeError.chat({ ...errorResult, - provider: provider, + provider, } as ChatCompletionErrorPayload); }