Skip to content

Commit

Permalink
Merge branch 'lobehub:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
sxjeru authored Apr 14, 2024
2 parents 77af1f3 + 3078c2b commit 2ff4035
Show file tree
Hide file tree
Showing 10 changed files with 251 additions and 149 deletions.
17 changes: 17 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,23 @@

# Changelog

### [Version 0.147.10](https://github.com/lobehub/lobe-chat/compare/v0.147.9...v0.147.10)

<sup>Released on **2024-04-13**</sup>

<br/>

<details>
<summary><kbd>Improvements and Fixes</kbd></summary>

</details>

<div align="right">

[![](https://img.shields.io/badge/-BACK_TO_TOP-151515?style=flat-square)](#readme-top)

</div>

### [Version 0.147.9](https://github.com/lobehub/lobe-chat/compare/v0.147.8...v0.147.9)

<sup>Released on **2024-04-12**</sup>
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@lobehub/chat",
"version": "0.147.9",
"version": "0.147.10",
"description": "Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.",
"keywords": [
"framework",
Expand Down
18 changes: 18 additions & 0 deletions src/config/modelProviders/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { ChatModelCard, ModelProviderCard } from '@/types/llm';

import AnthropicProvider from './anthropic';
import AzureProvider from './azure';
import BedrockProvider from './bedrock';
import GoogleProvider from './google';
import GroqProvider from './groq';
Expand Down Expand Up @@ -30,6 +31,23 @@ export const LOBE_DEFAULT_MODEL_LIST: ChatModelCard[] = [
ZeroOneProvider.chatModels,
].flat();

export const DEFAULT_MODEL_PROVIDER_LIST = [
OpenAIProvider,
{ ...AzureProvider, chatModels: [] },
OllamaProvider,
AnthropicProvider,
GoogleProvider,
OpenRouterProvider,
TogetherAIProvider,
BedrockProvider,
PerplexityProvider,
MistralProvider,
GroqProvider,
MoonshotProvider,
ZeroOneProvider,
ZhiPuProvider,
];

export const filterEnabledModels = (provider: ModelProviderCard) => {
return provider.chatModels.filter((v) => v.enabled).map((m) => m.id);
};
Expand Down
1 change: 1 addition & 0 deletions src/features/AgentSetting/AgentConfig/ModelSelect.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ const ModelSelect = memo(() => {
modelProviderSelectors.modelProviderListForModelSelect,
isEqual,
);

const { styles } = useStyles();

const options = useMemo<SelectProps['options']>(() => {
Expand Down
9 changes: 9 additions & 0 deletions src/store/global/slices/common/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ export const createCommonSlice: StateCreator<

refreshUserConfig: async () => {
await mutate([USER_CONFIG_FETCH_KEY, true]);

// when get the user config ,refresh the model provider list to the latest
get().refreshModelProviderList();
},

switchBackToChat: (sessionId) => {
Expand Down Expand Up @@ -159,7 +162,10 @@ export const createCommonSlice: StateCreator<
};

const defaultSettings = merge(get().defaultSettings, serverSettings);

set({ defaultSettings, serverConfig: data }, false, n('initGlobalConfig'));

get().refreshDefaultModelProviderList();
}
},
revalidateOnFocus: false,
Expand All @@ -181,6 +187,9 @@ export const createCommonSlice: StateCreator<
n('fetchUserConfig', data),
);

// when get the user config ,refresh the model provider list to the latest
get().refreshModelProviderList();

const { language } = settingsSelectors.currentSettings(get());
if (language === 'auto') {
switchLang('auto');
Expand Down
96 changes: 91 additions & 5 deletions src/store/global/slices/settings/actions/llm.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,18 @@ import { act, renderHook } from '@testing-library/react';
import { describe, expect, it, vi } from 'vitest';

import { userService } from '@/services/user';
import { useGlobalStore } from '@/store/global';
import { modelConfigSelectors, settingsSelectors } from '@/store/global/slices/settings/selectors';
import { GlobalStore, useGlobalStore } from '@/store/global';
import {
GlobalSettingsState,
initialSettingsState,
} from '@/store/global/slices/settings/initialState';
import {
modelConfigSelectors,
modelProviderSelectors,
settingsSelectors,
} from '@/store/global/slices/settings/selectors';
import { GeneralModelProviderConfig } from '@/types/settings';
import { merge } from '@/utils/merge';

import { CustomModelCardDispatch, customModelCardsReducer } from '../reducers/customModelCard';

Expand All @@ -15,9 +24,6 @@ vi.mock('@/services/user', () => ({
resetUserSettings: vi.fn(),
},
}));
vi.mock('../reducers/customModelCard', () => ({
customModelCardsReducer: vi.fn().mockReturnValue([]),
}));

describe('LLMSettingsSliceAction', () => {
describe('setModelProviderConfig', () => {
Expand Down Expand Up @@ -57,4 +63,84 @@ describe('LLMSettingsSliceAction', () => {
expect(result.current.setModelProviderConfig).not.toHaveBeenCalled();
});
});

describe('refreshDefaultModelProviderList', () => {
it('default', async () => {
const { result } = renderHook(() => useGlobalStore());

act(() => {
useGlobalStore.setState({
serverConfig: {
languageModel: {
azure: { serverModelCards: [{ id: 'abc', deploymentName: 'abc' }] },
},
telemetry: {},
},
});
});

act(() => {
result.current.refreshDefaultModelProviderList();
});

// Assert that setModelProviderConfig was not called
const azure = result.current.defaultModelProviderList.find((m) => m.id === 'azure');
expect(azure?.chatModels).toEqual([{ id: 'abc', deploymentName: 'abc' }]);
});
});

describe('refreshModelProviderList', () => {
it('visible', async () => {
const { result } = renderHook(() => useGlobalStore());
act(() => {
useGlobalStore.setState({
settings: {
languageModel: {
ollama: { enabledModels: ['llava'] },
},
},
});
});

act(() => {
result.current.refreshModelProviderList();
});

const ollamaList = result.current.modelProviderList.find((r) => r.id === 'ollama');
// Assert that setModelProviderConfig was not called
expect(ollamaList?.chatModels.find((c) => c.id === 'llava')).toEqual({
displayName: 'LLaVA 7B',
functionCall: false,
enabled: true,
id: 'llava',
tokens: 4000,
vision: true,
});
});

it('modelProviderListForModelSelect should return only enabled providers', () => {
const { result } = renderHook(() => useGlobalStore());

act(() => {
useGlobalStore.setState({
settings: {
languageModel: {
perplexity: { enabled: true },
azure: { enabled: false },
},
},
});
});

act(() => {
result.current.refreshModelProviderList();
});

const enabledProviders = modelProviderSelectors.modelProviderListForModelSelect(
result.current,
);
expect(enabledProviders).toHaveLength(2);
expect(enabledProviders[1].id).toBe('perplexity');
});
});
});
96 changes: 96 additions & 0 deletions src/store/global/slices/settings/actions/llm.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,28 @@
import useSWR, { SWRResponse } from 'swr';
import type { StateCreator } from 'zustand/vanilla';

import {
AnthropicProviderCard,
AzureProviderCard,
BedrockProviderCard,
GoogleProviderCard,
GroqProviderCard,
MistralProviderCard,
MoonshotProviderCard,
OllamaProviderCard,
OpenAIProviderCard,
OpenRouterProviderCard,
PerplexityProviderCard,
TogetherAIProviderCard,
ZeroOneProviderCard,
ZhiPuProviderCard,
} from '@/config/modelProviders';
import { GlobalStore } from '@/store/global';
import { ChatModelCard } from '@/types/llm';
import { GlobalLLMConfig, GlobalLLMProviderKey } from '@/types/settings';

import { CustomModelCardDispatch, customModelCardsReducer } from '../reducers/customModelCard';
import { modelProviderSelectors } from '../selectors/modelProvider';
import { settingsSelectors } from '../selectors/settings';

/**
Expand All @@ -16,12 +33,18 @@ export interface LLMSettingsAction {
provider: GlobalLLMProviderKey,
payload: CustomModelCardDispatch,
) => Promise<void>;
/**
* make sure the default model provider list is sync to latest state
*/
refreshDefaultModelProviderList: () => void;
refreshModelProviderList: () => void;
removeEnabledModels: (provider: GlobalLLMProviderKey, model: string) => Promise<void>;
setModelProviderConfig: <T extends GlobalLLMProviderKey>(
provider: T,
config: Partial<GlobalLLMConfig[T]>,
) => Promise<void>;
toggleEditingCustomModelCard: (params?: { id: string; provider: GlobalLLMProviderKey }) => void;

toggleProviderEnabled: (provider: GlobalLLMProviderKey, enabled: boolean) => Promise<void>;

useFetchProviderModelList: (
Expand All @@ -46,6 +69,76 @@ export const llmSettingsSlice: StateCreator<
await get().setModelProviderConfig(provider, { customModelCards: nextState });
},

refreshDefaultModelProviderList: () => {
/**
* Because we have several model cards sources, we need to merge the model cards
* the priority is below:
* 1 - server side model cards
* 2 - remote model cards
* 3 - default model cards
*/

// eslint-disable-next-line unicorn/consistent-function-scoping
const mergeModels = (provider: GlobalLLMProviderKey, defaultChatModels: ChatModelCard[]) => {
// if the chat model is config in the server side, use the server side model cards
const serverChatModels = modelProviderSelectors.serverProviderModelCards(provider)(get());
const remoteChatModels = modelProviderSelectors.remoteProviderModelCards(provider)(get());

return serverChatModels ?? remoteChatModels ?? defaultChatModels;
};

const defaultModelProviderList = [
{
...OpenAIProviderCard,
chatModels: mergeModels('openai', OpenAIProviderCard.chatModels),
},
{ ...AzureProviderCard, chatModels: mergeModels('azure', []) },
{ ...OllamaProviderCard, chatModels: mergeModels('ollama', OllamaProviderCard.chatModels) },
AnthropicProviderCard,
GoogleProviderCard,
{
...OpenRouterProviderCard,
chatModels: mergeModels('openrouter', OpenRouterProviderCard.chatModels),
},
{
...TogetherAIProviderCard,
chatModels: mergeModels('togetherai', TogetherAIProviderCard.chatModels),
},
BedrockProviderCard,
PerplexityProviderCard,
MistralProviderCard,
GroqProviderCard,
MoonshotProviderCard,
ZeroOneProviderCard,
ZhiPuProviderCard,
];

set({ defaultModelProviderList }, false, 'refreshDefaultModelProviderList');

get().refreshModelProviderList();
},

refreshModelProviderList: () => {
const modelProviderList = get().defaultModelProviderList.map((list) => ({
...list,
chatModels: modelProviderSelectors
.getModelCardsById(list.id)(get())
?.map((model) => {
const models = modelProviderSelectors.getEnableModelsById(list.id)(get());

if (!models) return model;

return {
...model,
enabled: models?.some((m) => m === model.id),
};
}),
enabled: modelProviderSelectors.isProviderEnabled(list.id as any)(get()),
}));

set({ modelProviderList }, false, 'refreshModelProviderList');
},

removeEnabledModels: async (provider, model) => {
const config = settingsSelectors.providerConfig(provider)(get());

Expand All @@ -60,6 +153,7 @@ export const llmSettingsSlice: StateCreator<
toggleEditingCustomModelCard: (params) => {
set({ editingCustomCardModel: params }, false, 'toggleEditingCustomModelCard');
},

toggleProviderEnabled: async (provider, enabled) => {
await get().setSettings({ languageModel: { [provider]: { enabled } } });
},
Expand All @@ -79,6 +173,8 @@ export const llmSettingsSlice: StateCreator<
latestFetchTime: Date.now(),
remoteModelCards: data,
});

get().refreshDefaultModelProviderList();
}
},
revalidateOnFocus: false,
Expand Down
6 changes: 6 additions & 0 deletions src/store/global/slices/settings/initialState.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
import { DeepPartial } from 'utility-types';

import { DEFAULT_MODEL_PROVIDER_LIST } from '@/config/modelProviders';
import { DEFAULT_SETTINGS } from '@/const/settings';
import { ModelProviderCard } from '@/types/llm';
import { GlobalServerConfig } from '@/types/serverConfig';
import { GlobalSettings } from '@/types/settings';

export interface GlobalSettingsState {
avatar?: string;
defaultModelProviderList: ModelProviderCard[];
defaultSettings: GlobalSettings;
editingCustomCardModel?: { id: string; provider: string } | undefined;
modelProviderList: ModelProviderCard[];
serverConfig: GlobalServerConfig;
settings: DeepPartial<GlobalSettings>;
userId?: string;
}

export const initialSettingsState: GlobalSettingsState = {
defaultModelProviderList: DEFAULT_MODEL_PROVIDER_LIST,
defaultSettings: DEFAULT_SETTINGS,
modelProviderList: DEFAULT_MODEL_PROVIDER_LIST,
serverConfig: {
telemetry: {},
},
Expand Down
Loading

0 comments on commit 2ff4035

Please sign in to comment.