Skip to content

Commit

Permalink
✨ feat: support model select
Browse files Browse the repository at this point in the history
  • Loading branch information
rdmclin2 committed Jun 8, 2024
1 parent d899c16 commit 65fcaaa
Show file tree
Hide file tree
Showing 14 changed files with 377 additions and 123 deletions.
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
"@dnd-kit/utilities": "^3.2.2",
"@gltf-transform/core": "^3.10.1",
"@icons-pack/react-simple-icons": "^9.5.0",
"@lobehub/icons": "^1.22.1",
"@lobehub/tts": "^1.24.1",
"@lobehub/ui": "^1.141.3",
"@pixiv/three-vrm": "2.1.1",
Expand Down
4 changes: 1 addition & 3 deletions src/app/chat/ChatMode/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ import classNames from 'classnames';
import React, { memo } from 'react';
import { Flexbox } from 'react-layout-kit';

import Alert from '@/features/Alert';
import MessageInput from '@/features/ChatInput/MessageInput';
import ChatList from '@/features/ChatList';
import MessageInput from '@/features/MessageInput';

import { useStyles } from './style';

Expand All @@ -25,7 +24,6 @@ const Chat = () => {
<Flexbox align={'center'} className={styles.docker} ref={ref}>
<div className={classNames(styles.input)}>
<MessageInput />
<Alert style={{ marginTop: 8 }} />
</div>
</Flexbox>
</Flexbox>
Expand Down
4 changes: 1 addition & 3 deletions src/app/chat/ViewerMode/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ import { Flexbox } from 'react-layout-kit';

import { HEADER_HEIGHT } from '@/constants/token';
import AgentViewer from '@/features/AgentViewer';
import Alert from '@/features/Alert';
import MessageInput from '@/features/ChatInput/MessageInput';
import MessageInput from '@/features/MessageInput';
import { sessionSelectors, useSessionStore } from '@/store/session';

import ChatDialog from './ChatDialog';
Expand All @@ -29,7 +28,6 @@ export default memo(() => {
<Flexbox align={'center'} className={styles.docker}>
<div className={classNames(styles.input, styles.content)}>
<MessageInput />
<Alert style={{ marginTop: 8 }} />
</div>
</Flexbox>
</Flexbox>
Expand Down
97 changes: 97 additions & 0 deletions src/components/ModelIcon/index.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import {
Adobe,
Ai21,
Aws,
Azure,
Baichuan,
ByteDance,
ChatGLM,
Claude,
Cohere,
Dbrx,
DeepSeek,
FishAudio,
Gemini,
Gemma,
Hunyuan,
LLaVA,
Meta,
Minimax,
Mistral,
Moonshot,
OpenAI,
OpenChat,
OpenRouter,
Perplexity,
Rwkv,
Spark,
Stability,
Tongyi,
Wenxin,
Yi,
} from '@lobehub/icons';
import { memo } from 'react';

interface ModelProviderIconProps {
model?: string;
size?: number;
}

const ModelIcon = memo<ModelProviderIconProps>(({ model: originModel, size = 12 }) => {
if (!originModel) return;

// lower case the origin model so to better match more model id case
const model = originModel.toLowerCase();

// currently supported models, maybe not in its own provider
if (model.includes('gpt-3')) return <OpenAI.Avatar size={size} type={'gpt3'} />;
if (model.includes('gpt-4')) return <OpenAI.Avatar size={size} type={'gpt4'} />;
if (model.startsWith('glm') || model.includes('chatglm')) return <ChatGLM.Avatar size={size} />;
if (model.includes('deepseek')) return <DeepSeek.Avatar size={size} />;
if (model.includes('claude')) return <Claude.Avatar size={size} />;
if (model.includes('titan')) return <Aws.Avatar size={size} />;
if (model.includes('llama')) return <Meta.Avatar size={size} />;
if (model.includes('llava')) return <LLaVA.Avatar size={size} />;
if (model.includes('gemini')) return <Gemini.Avatar size={size} />;
if (model.includes('gemma')) return <Gemma.Avatar size={size} />;
if (model.includes('moonshot')) return <Moonshot.Avatar size={size} />;
if (model.includes('qwen')) return <Tongyi.Avatar background={Tongyi.colorPrimary} size={size} />;
if (model.includes('minmax') || model.includes('abab')) return <Minimax.Avatar size={size} />;
if (model.includes('mistral') || model.includes('mixtral')) return <Mistral.Avatar size={size} />;
if (model.includes('pplx') || model.includes('sonar')) return <Perplexity.Avatar size={size} />;
if (model.includes('yi-')) return <Yi.Avatar size={size} />;
if (model.startsWith('openrouter')) return <OpenRouter.Avatar size={size} />; // only for Cinematika and Auto
if (model.startsWith('openchat')) return <OpenChat.Avatar size={size} />;
if (model.includes('command')) return <Cohere.Avatar size={size} />;
if (model.includes('dbrx')) return <Dbrx.Avatar size={size} />;

// below: To be supported in providers, move up if supported
if (model.includes('baichuan'))
return <Baichuan.Avatar background={Baichuan.colorPrimary} size={size} />;
if (model.includes('rwkv')) return <Rwkv.Avatar size={size} />;
if (model.includes('ernie')) return <Wenxin.Avatar size={size} />;
if (model.includes('spark')) return <Spark.Avatar size={size} />;
if (model.includes('hunyuan')) return <Hunyuan.Avatar size={size} />;
// ref https://github.com/fishaudio/Bert-VITS2/blob/master/train_ms.py#L702
if (model.startsWith('d_') || model.startsWith('g_') || model.startsWith('wd_'))
return <FishAudio.Avatar size={size} />;
if (model.includes('skylark')) return <ByteDance.Avatar size={size} />;

if (
model.includes('stable-diffusion') ||
model.includes('stable-video') ||
model.includes('stable-cascade') ||
model.includes('sdxl') ||
model.includes('stablelm') ||
model.startsWith('stable-') ||
model.startsWith('sd3')
)
return <Stability.Avatar size={size} />;

if (model.includes('wizardlm')) return <Azure.Avatar size={size} />;
if (model.includes('phi3')) return <Azure.Avatar size={size} />;
if (model.includes('firefly')) return <Adobe.Avatar size={size} />;
if (model.includes('jamba') || model.includes('j2-')) return <Ai21.Avatar size={size} />;
});

export default ModelIcon;
93 changes: 93 additions & 0 deletions src/components/ModelTag/ModelIcon.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import {
AdobeFirefly,
Ai21,
Aws,
Azure,
Baichuan,
ByteDance,
ChatGLM,
Claude,
Cohere,
Dbrx,
DeepSeek,
FishAudio,
Gemini,
Gemma,
Hunyuan,
LLaVA,
Meta,
Minimax,
Mistral,
Moonshot,
OpenAI,
OpenChat,
OpenRouter,
Perplexity,
Rwkv,
Spark,
Stability,
Tongyi,
Wenxin,
Yi,
} from '@lobehub/icons';
import { memo } from 'react';

interface ModelIconProps {
model?: string;
size?: number;
}

const ModelIcon = memo<ModelIconProps>(({ model, size = 12 }) => {
if (!model) return;

// currently supported models, maybe not in its own provider
if (model.startsWith('gpt')) return <OpenAI size={size} />;
if (model.startsWith('glm') || model.includes('chatglm')) return <ChatGLM size={size} />;
if (model.includes('claude')) return <Claude size={size} />;
if (model.includes('deepseek')) return <DeepSeek size={size} />;
if (model.includes('titan')) return <Aws size={size} />;
if (model.includes('llama')) return <Meta size={size} />;
if (model.includes('llava')) return <LLaVA size={size} />;
if (model.includes('gemini')) return <Gemini size={size} />;
if (model.includes('gemma')) return <Gemma.Simple size={size} />;
if (model.includes('moonshot')) return <Moonshot size={size} />;
if (model.includes('qwen')) return <Tongyi size={size} />;
if (model.includes('minmax')) return <Minimax size={size} />;
if (model.includes('abab')) return <Minimax size={size} />;
if (model.includes('mistral') || model.includes('mixtral')) return <Mistral size={size} />;
if (model.includes('pplx') || model.includes('sonar')) return <Perplexity size={size} />;
if (model.includes('yi-')) return <Yi size={size} />;
if (model.startsWith('openrouter')) return <OpenRouter size={size} />; // only for Cinematika and Auto
if (model.startsWith('openchat')) return <OpenChat size={size} />;
if (model.includes('command')) return <Cohere size={size} />;
if (model.includes('dbrx')) return <Dbrx size={size} />;

// below: To be supported in providers, move up if supported
if (model.includes('baichuan')) return <Baichuan size={size} />;
if (model.includes('rwkv')) return <Rwkv size={size} />;
if (model.includes('ernie')) return <Wenxin size={size} />;
if (model.includes('spark')) return <Spark size={size} />;
if (model.includes('hunyuan')) return <Hunyuan size={size} />;
// ref https://github.com/fishaudio/Bert-VITS2/blob/master/train_ms.py#L702
if (model.startsWith('d_') || model.startsWith('g_') || model.startsWith('wd_'))
return <FishAudio size={size} />;
if (model.includes('skylark')) return <ByteDance size={size} />;

if (
model.includes('stable-diffusion') ||
model.includes('stable-video') ||
model.includes('stable-cascade') ||
model.includes('sdxl') ||
model.includes('stablelm') ||
model.startsWith('stable-') ||
model.startsWith('sd3')
)
return <Stability size={size} />;

if (model.includes('wizardlm')) return <Azure size={size} />;
if (model.includes('phi3')) return <Azure size={size} />;
if (model.includes('firefly')) return <AdobeFirefly size={size} />;
if (model.includes('jamba') || model.includes('j2-')) return <Ai21 size={size} />;
});

export default ModelIcon;
13 changes: 13 additions & 0 deletions src/components/ModelTag/index.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import { Tag } from '@lobehub/ui';
import { memo } from 'react';

import ModelIcon from './ModelIcon';

interface ModelTagProps {
model?: string;
}
const ModelTag = memo<ModelTagProps>(({ model }) => (
<Tag icon={<ModelIcon model={model} />}>{model ? model : '请选择模型'}</Tag>
));

export default ModelTag;
9 changes: 7 additions & 2 deletions src/components/agent/AgentMeta/index.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { Avatar } from '@lobehub/ui';
import { Typography } from 'antd';
import { Space, Typography } from 'antd';
import React from 'react';

import ModelSelect from '@/features/Actions/ModelSelect';
import { AgentMeta } from '@/types/agent';

import { useStyles } from './style';
Expand All @@ -21,7 +22,11 @@ export default (props: AgentMetaProps) => {
<div className={cx(styles.container, className)} style={style}>
<Avatar avatar={avatar} size={36} />
<div className={styles.content}>
<div className={styles.title}>{name}</div>
<div className={styles.title}>
<Space size={4} align={'center'}>
{name} <ModelSelect />
</Space>
</div>
<Typography.Text className={styles.desc} ellipsis>
{description}
</Typography.Text>
Expand Down
2 changes: 2 additions & 0 deletions src/components/agent/AgentMeta/style.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ export const useStyles = createStyles(({ css, token }) => ({
`,
desc: css`
max-width: 480px;
margin-top: ${token.marginXXS}px;
font-size: ${token.fontSizeSM}px;
line-height: 18px;
color: ${token.colorTextDescription};
Expand Down
69 changes: 69 additions & 0 deletions src/features/Actions/ModelSelect.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import { Dropdown } from 'antd';
import { createStyles } from 'antd-style';
import { memo } from 'react';

import ModelIcon from '@/components/ModelIcon';
import ModelTag from '@/components/ModelTag';
import { OPENAI_MODEL_LIST } from '@/constants/openai';
import { configSelectors, useSettingStore } from '@/store/setting';

const useStyles = createStyles(({ css, prefixCls }) => ({
menu: css`
.${prefixCls}-dropdown-menu-item {
display: flex;
gap: 8px;
margin: 4px 0 !important;
}
.${prefixCls}-dropdown-menu {
&-item-group-title {
padding-inline: 8px;
}
&-item-group-list {
margin: 0 !important;
}
}
`,
tag: css`
cursor: pointer;
`,
}));

const ModelSelect = memo(() => {
const { styles } = useStyles();
const [model, setOpenAIConfig] = useSettingStore((s) => [
configSelectors.currentOpenAIConfig(s)?.model,
s.setOpenAIConfig,
]);

const items = OPENAI_MODEL_LIST.map((item) => {
return {
icon: <ModelIcon model={item.id} size={18} />,
key: item.id,
label: item.displayName,
onClick: () => setOpenAIConfig({ model: item.id }),
};
});

return (
<Dropdown
menu={{
items,
className: styles.menu,
activeKey: model,
style: {
maxHeight: 500,
overflowY: 'scroll',
},
}}
placement={'topLeft'}
trigger={['click']}
>
<div className={styles.tag}>
<ModelTag model={model} />
</div>
</Dropdown>
);
});

export default ModelSelect;
33 changes: 0 additions & 33 deletions src/features/Alert/index.tsx

This file was deleted.

Loading

0 comments on commit 65fcaaa

Please sign in to comment.