From 844dbd8c2f794a2a23631fe44d49c22252a52bfe Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Fri, 7 Jul 2023 23:20:15 -0700 Subject: [PATCH] Improve Local models List functionality --- .../llms/localai/LocalAISourceSetup.tsx | 4 +++- src/modules/llms/openai/OpenAISourceSetup.tsx | 9 +++++-- src/modules/llms/openai/openai.client.ts | 10 ++++---- src/modules/llms/openai/openai.router.ts | 24 ++++++++++++------- src/modules/llms/openai/openai.types.ts | 4 ++-- 5 files changed, 33 insertions(+), 18 deletions(-) diff --git a/src/modules/llms/localai/LocalAISourceSetup.tsx b/src/modules/llms/localai/LocalAISourceSetup.tsx index a5a0b249a..974f6806a 100644 --- a/src/modules/llms/localai/LocalAISourceSetup.tsx +++ b/src/modules/llms/localai/LocalAISourceSetup.tsx @@ -31,7 +31,9 @@ export function LocalAISourceSetup(props: { sourceId: DModelSourceId }) { const hasModels = !!sourceLLMs.length; // fetch models - const { isFetching, refetch, isError } = apiQuery.openai.listModels.useQuery({ oaiKey: '', oaiHost: hostUrl, oaiOrg: '', heliKey: '', moderationCheck: false }, { + const { isFetching, refetch, isError } = apiQuery.openai.listModels.useQuery({ + access: { oaiKey: '', oaiHost: hostUrl, oaiOrg: '', heliKey: '', moderationCheck: false }, + }, { enabled: false, //!sourceLLMs.length && shallFetchSucceed, onSuccess: models => { const llms = source ? models.map(model => localAIToDLLM(model, source)) : []; diff --git a/src/modules/llms/openai/OpenAISourceSetup.tsx b/src/modules/llms/openai/OpenAISourceSetup.tsx index b770cac03..7a3fc8de3 100644 --- a/src/modules/llms/openai/OpenAISourceSetup.tsx +++ b/src/modules/llms/openai/OpenAISourceSetup.tsx @@ -1,6 +1,6 @@ import * as React from 'react'; -import { Box, Button, FormControl, FormHelperText, FormLabel, Input, Switch } from '@mui/joy'; +import { Alert, Box, Button, FormControl, FormHelperText, FormLabel, Input, Switch, Typography } from '@mui/joy'; import SyncIcon from '@mui/icons-material/Sync'; import { apiQuery } from '~/modules/trpc/trpc.client'; @@ -35,7 +35,10 @@ export function OpenAISourceSetup(props: { sourceId: DModelSourceId }) { const shallFetchSucceed = oaiKey ? keyValid : !needsUserKey; // fetch models - const { isFetching, refetch, isError } = apiQuery.openai.listModels.useQuery({ oaiKey, oaiHost, oaiOrg, heliKey, moderationCheck }, { + const { isFetching, refetch, isError, error } = apiQuery.openai.listModels.useQuery({ + access: { oaiKey, oaiHost, oaiOrg, heliKey, moderationCheck }, + filterGpt: true, + }, { enabled: !hasModels && shallFetchSucceed, onSuccess: models => { const llms = source ? models.map(model => openAIModelToDLLM(model, source)) : []; @@ -145,6 +148,8 @@ export function OpenAISourceSetup(props: { sourceId: DModelSourceId }) { + {isError && Issue: {error?.message || error?.toString() || 'unknown'}} + ; } diff --git a/src/modules/llms/openai/openai.client.ts b/src/modules/llms/openai/openai.client.ts index 0d94b1673..a16219feb 100644 --- a/src/modules/llms/openai/openai.client.ts +++ b/src/modules/llms/openai/openai.client.ts @@ -10,11 +10,13 @@ export const hasServerKeyOpenAI = !!process.env.HAS_SERVER_KEY_OPENAI; export const isValidOpenAIApiKey = (apiKey?: string) => !!apiKey && apiKey.startsWith('sk-') && apiKey.length > 40; -export const callChat = async (llm: DLLM, messages: VChatMessageIn[], maxTokens?: number) => - callChatOverloaded(llm, messages, null, maxTokens); +export async function callChat(llm: DLLM, messages: VChatMessageIn[], maxTokens?: number) { + return callChatOverloaded(llm, messages, null, maxTokens); +} -export const callChatWithFunctions = async (llm: DLLM, messages: VChatMessageIn[], functions: VChatFunctionIn[], maxTokens?: number) => - callChatOverloaded(llm, messages, functions, maxTokens); +export async function callChatWithFunctions(llm: DLLM, messages: VChatMessageIn[], functions: VChatFunctionIn[], maxTokens?: number) { + return callChatOverloaded(llm, messages, functions, maxTokens); +} /** diff --git a/src/modules/llms/openai/openai.router.ts b/src/modules/llms/openai/openai.router.ts index fe576358c..d88ff57c3 100644 --- a/src/modules/llms/openai/openai.router.ts +++ b/src/modules/llms/openai/openai.router.ts @@ -48,6 +48,8 @@ const functionsSchema = z.array(z.object({ export const chatGenerateSchema = z.object({ access: accessSchema, model: modelSchema, history: historySchema, functions: functionsSchema.optional() }); export type ChatGenerateSchema = z.infer; +const listModelsSchema = z.object({ access: accessSchema, filterGpt: z.boolean().optional() }); + const chatModerationSchema = z.object({ access: accessSchema, text: z.string() }); @@ -66,8 +68,6 @@ const chatGenerateWithFunctionsOutputSchema = z.union([ ]); - - export const openAIRouter = createTRPCRouter({ /** @@ -103,8 +103,8 @@ export const openAIRouter = createTRPCRouter({ */ moderation: publicProcedure .input(chatModerationSchema) - .mutation(async ({ input }): Promise => { - const { access, text, } = input; + .mutation(async ({ input }): Promise => { + const { access, text } = input; try { return await openaiPOST(access, { @@ -125,14 +125,20 @@ export const openAIRouter = createTRPCRouter({ * List the Models available */ listModels: publicProcedure - .input(accessSchema) + .input(listModelsSchema) .query(async ({ input }): Promise => { - let wireModels: OpenAI.Wire.Models.Response; - wireModels = await openaiGET(input, '/v1/models'); + const wireModels: OpenAI.Wire.Models.Response = await openaiGET(input.access, '/v1/models'); + + // filter out the non-gpt models, if requested + let llms = (wireModels.data || []) + .filter(model => !input.filterGpt || model.id.includes('gpt')); - // filter out the non-gpt models - const llms = wireModels.data?.filter(model => model.id.includes('gpt')) ?? []; + // remove models with duplicate ids (can happen for local servers) + const preFilterCount = llms.length; + llms = llms.filter((model, index) => llms.findIndex(m => m.id === model.id) === index); + if (preFilterCount !== llms.length) + console.warn(`openai.router.listModels: Duplicate model ids found, removed ${preFilterCount - llms.length} models`); // sort by which model has the least number of '-' in the name, and then by id, decreasing llms.sort((a, b) => { diff --git a/src/modules/llms/openai/openai.types.ts b/src/modules/llms/openai/openai.types.ts index 14405c1a7..305047f90 100644 --- a/src/modules/llms/openai/openai.types.ts +++ b/src/modules/llms/openai/openai.types.ts @@ -115,7 +115,7 @@ export namespace OpenAI { type: 'server_error' | string; param: string | null; code: string | null; - } + }; } } @@ -145,7 +145,7 @@ export namespace OpenAI { category_scores: { [key in ModerationCategory]: number }; flagged: boolean; } - ] + ]; } }