Skip to content

Commit

Permalink
Improve Local models List functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
enricoros committed Jul 8, 2023
1 parent 8b3201b commit 844dbd8
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 18 deletions.
4 changes: 3 additions & 1 deletion src/modules/llms/localai/LocalAISourceSetup.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ export function LocalAISourceSetup(props: { sourceId: DModelSourceId }) {
const hasModels = !!sourceLLMs.length;

// fetch models
const { isFetching, refetch, isError } = apiQuery.openai.listModels.useQuery({ oaiKey: '', oaiHost: hostUrl, oaiOrg: '', heliKey: '', moderationCheck: false }, {
const { isFetching, refetch, isError } = apiQuery.openai.listModels.useQuery({
access: { oaiKey: '', oaiHost: hostUrl, oaiOrg: '', heliKey: '', moderationCheck: false },
}, {
enabled: false, //!sourceLLMs.length && shallFetchSucceed,
onSuccess: models => {
const llms = source ? models.map(model => localAIToDLLM(model, source)) : [];
Expand Down
9 changes: 7 additions & 2 deletions src/modules/llms/openai/OpenAISourceSetup.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import * as React from 'react';

import { Box, Button, FormControl, FormHelperText, FormLabel, Input, Switch } from '@mui/joy';
import { Alert, Box, Button, FormControl, FormHelperText, FormLabel, Input, Switch, Typography } from '@mui/joy';
import SyncIcon from '@mui/icons-material/Sync';

import { apiQuery } from '~/modules/trpc/trpc.client';
Expand Down Expand Up @@ -35,7 +35,10 @@ export function OpenAISourceSetup(props: { sourceId: DModelSourceId }) {
const shallFetchSucceed = oaiKey ? keyValid : !needsUserKey;

// fetch models
const { isFetching, refetch, isError } = apiQuery.openai.listModels.useQuery({ oaiKey, oaiHost, oaiOrg, heliKey, moderationCheck }, {
const { isFetching, refetch, isError, error } = apiQuery.openai.listModels.useQuery({
access: { oaiKey, oaiHost, oaiOrg, heliKey, moderationCheck },
filterGpt: true,
}, {
enabled: !hasModels && shallFetchSucceed,
onSuccess: models => {
const llms = source ? models.map(model => openAIModelToDLLM(model, source)) : [];
Expand Down Expand Up @@ -145,6 +148,8 @@ export function OpenAISourceSetup(props: { sourceId: DModelSourceId }) {

</Box>

{isError && <Alert variant='soft' color='warning' sx={{ mt: 1 }}><Typography>Issue: {error?.message || error?.toString() || 'unknown'}</Typography></Alert>}

</Box>;
}

Expand Down
10 changes: 6 additions & 4 deletions src/modules/llms/openai/openai.client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ export const hasServerKeyOpenAI = !!process.env.HAS_SERVER_KEY_OPENAI;
export const isValidOpenAIApiKey = (apiKey?: string) => !!apiKey && apiKey.startsWith('sk-') && apiKey.length > 40;


export const callChat = async (llm: DLLM, messages: VChatMessageIn[], maxTokens?: number) =>
callChatOverloaded<VChatMessageOut>(llm, messages, null, maxTokens);
export async function callChat(llm: DLLM, messages: VChatMessageIn[], maxTokens?: number) {
return callChatOverloaded<VChatMessageOut>(llm, messages, null, maxTokens);
}

export const callChatWithFunctions = async (llm: DLLM, messages: VChatMessageIn[], functions: VChatFunctionIn[], maxTokens?: number) =>
callChatOverloaded<VChatMessageOrFunctionCallOut>(llm, messages, functions, maxTokens);
export async function callChatWithFunctions(llm: DLLM, messages: VChatMessageIn[], functions: VChatFunctionIn[], maxTokens?: number) {
return callChatOverloaded<VChatMessageOrFunctionCallOut>(llm, messages, functions, maxTokens);
}


/**
Expand Down
24 changes: 15 additions & 9 deletions src/modules/llms/openai/openai.router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ const functionsSchema = z.array(z.object({
export const chatGenerateSchema = z.object({ access: accessSchema, model: modelSchema, history: historySchema, functions: functionsSchema.optional() });
export type ChatGenerateSchema = z.infer<typeof chatGenerateSchema>;

const listModelsSchema = z.object({ access: accessSchema, filterGpt: z.boolean().optional() });

const chatModerationSchema = z.object({ access: accessSchema, text: z.string() });


Expand All @@ -66,8 +68,6 @@ const chatGenerateWithFunctionsOutputSchema = z.union([
]);




export const openAIRouter = createTRPCRouter({

/**
Expand Down Expand Up @@ -103,8 +103,8 @@ export const openAIRouter = createTRPCRouter({
*/
moderation: publicProcedure
.input(chatModerationSchema)
.mutation(async ({ input }): Promise<OpenAI.Wire.Moderation.Response> => {
const { access, text, } = input;
.mutation(async ({ input }): Promise<OpenAI.Wire.Moderation.Response> => {
const { access, text } = input;
try {

return await openaiPOST<OpenAI.Wire.Moderation.Request, OpenAI.Wire.Moderation.Response>(access, {
Expand All @@ -125,14 +125,20 @@ export const openAIRouter = createTRPCRouter({
* List the Models available
*/
listModels: publicProcedure
.input(accessSchema)
.input(listModelsSchema)
.query(async ({ input }): Promise<OpenAI.Wire.Models.ModelDescription[]> => {

let wireModels: OpenAI.Wire.Models.Response;
wireModels = await openaiGET<OpenAI.Wire.Models.Response>(input, '/v1/models');
const wireModels: OpenAI.Wire.Models.Response = await openaiGET<OpenAI.Wire.Models.Response>(input.access, '/v1/models');

// filter out the non-gpt models, if requested
let llms = (wireModels.data || [])
.filter(model => !input.filterGpt || model.id.includes('gpt'));

// filter out the non-gpt models
const llms = wireModels.data?.filter(model => model.id.includes('gpt')) ?? [];
// remove models with duplicate ids (can happen for local servers)
const preFilterCount = llms.length;
llms = llms.filter((model, index) => llms.findIndex(m => m.id === model.id) === index);
if (preFilterCount !== llms.length)
console.warn(`openai.router.listModels: Duplicate model ids found, removed ${preFilterCount - llms.length} models`);

// sort by which model has the least number of '-' in the name, and then by id, decreasing
llms.sort((a, b) => {
Expand Down
4 changes: 2 additions & 2 deletions src/modules/llms/openai/openai.types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ export namespace OpenAI {
type: 'server_error' | string;
param: string | null;
code: string | null;
}
};
}
}

Expand Down Expand Up @@ -145,7 +145,7 @@ export namespace OpenAI {
category_scores: { [key in ModerationCategory]: number };
flagged: boolean;
}
]
];
}
}

Expand Down

1 comment on commit 844dbd8

@vercel
Copy link

@vercel vercel bot commented on 844dbd8 Jul 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

big-agi – ./

big-agi-enricoros.vercel.app
big-agi-git-main-enricoros.vercel.app
get.big-agi.com

Please sign in to comment.