Skip to content

Commit 4d5e611

Browse files
committed
reshaped ai schema and model selectable
1 parent 3f2b241 commit 4d5e611

File tree

11 files changed

+204
-149
lines changed

11 files changed

+204
-149
lines changed

src/api/cloudflare.ts

+11-5
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ export async function runAI(data: any, { token, account, model }: { token: strin
1616
return json
1717
}
1818

19-
export async function *runAIStream(data: any, { token, account, model }: { token: string, account: string, model: string }): AsyncGenerator<string> {
19+
export async function* runAIStream(data: any, { token, account, model }: { token: string, account: string, model: string }): AsyncGenerator<string> {
2020
const res = await fetch(`${BASE_URL}/accounts/${account}/ai/run/${model}`, {
2121
method: 'POST',
2222
headers: {
@@ -35,12 +35,18 @@ export async function *runAIStream(data: any, { token, account, model }: { token
3535
}
3636
}
3737

38-
export async function validateAIToken(accountId: string, token: string): Promise<boolean> {
39-
const res = await fetch(`${BASE_URL}/accounts/${accountId}/ai/models/search?per_page=1`, {
38+
export async function validateAIToken(accountId: string, token: string, model: string): Promise<string | boolean> {
39+
const res = await fetch(`${BASE_URL}/accounts/${accountId}/ai/models/search?search=${model}&per_page=1`, {
4040
headers: {
41-
Authorization: `Bearer ${this.apiToken}`
41+
Authorization: `Bearer ${token}`
4242
}
4343
})
4444
const data = await res.json() as Result<any>
45-
return data.success
45+
if (!data.success) {
46+
return false
47+
} else if (data.result.length === 0) {
48+
return '找不到指定 AI 模型'
49+
} else {
50+
return true
51+
}
4652
}

src/features/jimaku/components/ButtonArea.tsx

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ function ButtonArea({ clearJimaku, jimakus }: ButtonAreaProps): JSX.Element {
7373
弹出同传视窗
7474
</JimakuButton>
7575
}
76-
{aiZone.enabled && (
76+
{aiZone.summarizeEnabled && (
7777
<JimakuButton onClick={summerize}>
7878
同传字幕AI总结
7979
</JimakuButton>

src/llms/cf-qwen.ts src/llms/cloudflare-ai.ts

+22-10
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,38 @@
11
import { runAI, runAIStream, validateAIToken } from "~api/cloudflare";
22
import type { LLMProviders, Session } from "~llms";
3+
import type { SettingSchema } from "~options/fragments/llm";
34

4-
export default class CloudFlareQwen implements LLMProviders {
5+
export default class CloudFlareAI implements LLMProviders {
56

6-
private static readonly MODEL: string = '@cf/qwen/qwen1.5-14b-chat-awq'
7+
private static readonly DEFAULT_MODEL: string = '@cf/qwen/qwen1.5-14b-chat-awq'
78

8-
constructor(
9-
private readonly accountId: string,
10-
private readonly apiToken: string,
11-
) { }
9+
private readonly accountId: string
10+
private readonly apiToken: string
11+
12+
private readonly model: string
13+
14+
constructor(settings: SettingSchema) {
15+
this.accountId = settings.accountId
16+
this.apiToken = settings.apiToken
17+
18+
// only text generation model for now
19+
this.model = settings.model || CloudFlareAI.DEFAULT_MODEL
20+
}
1221

1322
async validate(): Promise<void> {
14-
const success = await validateAIToken(this.accountId, this.apiToken)
15-
if (!success) throw new Error('Cloudflare API 验证失败')
23+
const success = await validateAIToken(this.accountId, this.apiToken, this.model)
24+
if (typeof success === 'boolean' && !success) throw new Error('Cloudflare API 验证失败')
25+
if (typeof success === 'string') throw new Error(success)
1626
}
1727

1828
async prompt(chat: string): Promise<string> {
19-
const res = await runAI(this.wrap(chat), { token: this.apiToken, account: this.accountId, model: CloudFlareQwen.MODEL })
29+
const res = await runAI(this.wrap(chat), { token: this.apiToken, account: this.accountId, model: this.model })
2030
if (!res.result) throw new Error(res.errors.join(', '))
2131
return res.result.response
2232
}
2333

2434
async *promptStream(chat: string): AsyncGenerator<string> {
25-
return runAIStream(this.wrap(chat), { token: this.apiToken, account: this.accountId, model: CloudFlareQwen.MODEL })
35+
return runAIStream(this.wrap(chat), { token: this.apiToken, account: this.accountId, model: this.model })
2636
}
2737

2838
async asSession(): Promise<Session<LLMProviders>> {
@@ -33,6 +43,8 @@ export default class CloudFlareQwen implements LLMProviders {
3343
}
3444
}
3545

46+
// text generation model input schema
47+
// so only text generation model for now
3648
private wrap(chat: string): any {
3749
return {
3850
max_tokens: 512,

src/llms/index.ts

+8-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
import qwen from './cf-qwen'
1+
import type { SettingSchema as LLMSchema } from '~options/fragments/llm'
2+
3+
import cloudflare from './cloudflare-ai'
24
import nano from './gemini-nano'
35
import worker from './remote-worker'
46

@@ -12,7 +14,7 @@ export interface LLMProviders {
1214
export type Session<T> = Disposable & Omit<T, 'asSession' | 'validate'>
1315

1416
const llms = {
15-
qwen,
17+
cloudflare,
1618
nano,
1719
worker
1820
}
@@ -21,9 +23,10 @@ export type LLMs = typeof llms
2123

2224
export type LLMTypes = keyof LLMs
2325

24-
function createLLMProvider<K extends LLMTypes, M extends LLMs[K]>(type: K, ...args: ConstructorParameters<M>): LLMProviders {
25-
const LLM = llms[type].bind(this, ...args)
26-
return new LLM()
26+
function createLLMProvider(settings: LLMSchema): LLMProviders {
27+
const type = settings.provider
28+
const LLM = llms[type]
29+
return new LLM(settings)
2730
}
2831

2932
export default createLLMProvider

src/llms/remote-worker.ts

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
import type { LLMProviders, Session } from "~llms";
2+
import type { SettingSchema } from "~options/fragments/llm";
23
import { parseSSEResponses } from "~utils/binary";
34

4-
55
// for my worker, so limited usage
66
export default class RemoteWorker implements LLMProviders {
77

8+
private readonly model?: string
9+
10+
constructor(settings: SettingSchema) {
11+
this.model = settings.model || undefined
12+
}
13+
814
async validate(): Promise<void> {
915
const res = await fetch('https://llm.ericlamm.xyz/status')
1016
const json = await res.json()
@@ -19,7 +25,7 @@ export default class RemoteWorker implements LLMProviders {
1925
headers: {
2026
'Content-Type': 'application/json'
2127
},
22-
body: JSON.stringify({ prompt: chat })
28+
body: JSON.stringify({ prompt: chat, model: this.model })
2329
})
2430
if (!res.ok) throw new Error(await res.text())
2531
const json = await res.json()
@@ -32,7 +38,7 @@ export default class RemoteWorker implements LLMProviders {
3238
headers: {
3339
'Content-Type': 'application/json'
3440
},
35-
body: JSON.stringify({ prompt: chat, stream: true })
41+
body: JSON.stringify({ prompt: chat, stream: true, model: this.model })
3642
})
3743
if (!res.ok) throw new Error(await res.text())
3844
if (!res.body) throw new Error('Remote worker response body is not readable')

src/options/features/jimaku/components/AIFragment.tsx

+7-106
Original file line numberDiff line numberDiff line change
@@ -1,134 +1,35 @@
1-
import { Button, Input, List, Tooltip, Typography } from "@material-tailwind/react"
2-
import { type ChangeEvent, Fragment, useState } from "react"
3-
import { toast } from "sonner/dist"
1+
import { List } from "@material-tailwind/react"
2+
import { type ChangeEvent, Fragment } from "react"
43
import type { StateProxy } from "~hooks/binding"
5-
import type { LLMProviders, LLMTypes } from "~llms"
6-
import createLLMProvider from "~llms"
74
import ExperienmentFeatureIcon from "~options/components/ExperientmentFeatureIcon"
8-
import Selector from "~options/components/Selector"
95
import SwitchListItem from "~options/components/SwitchListItem"
106

11-
12-
137
export type AISchema = {
14-
enabled: boolean
15-
provider: LLMTypes
16-
17-
// cloudflare settings
18-
accountId?: string
19-
apiToken?: string
8+
summarizeEnabled: boolean
209
}
2110

2211

2312
export const aiDefaultSettings: Readonly<AISchema> = {
24-
enabled: false,
25-
provider: 'worker'
13+
summarizeEnabled: false
2614
}
2715

2816

2917
function AIFragment({ state, useHandler }: StateProxy<AISchema>): JSX.Element {
3018

31-
const [validating, setValidating] = useState(false)
32-
33-
const handler = useHandler<ChangeEvent<HTMLInputElement>, string>((e) => e.target.value)
3419
const checker = useHandler<ChangeEvent<HTMLInputElement>, boolean>((e) => e.target.checked)
3520

36-
const onValidate = async () => {
37-
setValidating(true)
38-
try {
39-
let provider: LLMProviders;
40-
if (state.provider === 'qwen') {
41-
provider = createLLMProvider(state.provider, state.accountId, state.apiToken)
42-
} else {
43-
provider = createLLMProvider(state.provider)
44-
}
45-
await provider.validate()
46-
toast.success('配置可用!')
47-
} catch (e) {
48-
toast.error('配置不可用: ' + e.message)
49-
} finally {
50-
setValidating(false)
51-
}
52-
}
53-
5421
return (
5522
<Fragment>
5623
<List className="col-span-2 border border-[#808080] rounded-md">
5724
<SwitchListItem
5825
data-testid="ai-enabled"
5926
label="启用同传字幕AI总结"
60-
hint="此功能将采用通义大模型对同传字幕进行总结"
61-
value={state.enabled}
62-
onChange={checker('enabled')}
27+
hint="此功能将采用大语言模型对同传字幕进行总结"
28+
value={state.summarizeEnabled}
29+
onChange={checker('summarizeEnabled')}
6330
marker={<ExperienmentFeatureIcon />}
6431
/>
6532
</List>
66-
{state.enabled && (
67-
<Fragment>
68-
<Selector<typeof state.provider>
69-
className="col-span-2"
70-
data-testid="ai-provider"
71-
label="技术来源"
72-
value={state.provider}
73-
onChange={e => state.provider = e}
74-
options={[
75-
{ label: 'Cloudflare AI', value: 'qwen' },
76-
{ label: '有限度服务器', value: 'worker' },
77-
{ label: 'Chrome 浏览器内置 AI', value: 'nano' }
78-
]}
79-
/>
80-
{state.provider === 'qwen' && (
81-
<Fragment>
82-
<Typography
83-
className="flex items-center gap-1 font-normal dark:text-gray-200 col-span-2"
84-
>
85-
<svg
86-
xmlns="http://www.w3.org/2000/svg"
87-
viewBox="0 0 24 24"
88-
fill="currentColor"
89-
className="-mt-px h-6 w-6"
90-
>
91-
<path
92-
fillRule="evenodd"
93-
d="M2.25 12c0-5.385 4.365-9.75 9.75-9.75s9.75 4.365 9.75 9.75-4.365 9.75-9.75 9.75S2.25 17.385 2.25 12zm8.706-1.442c1.146-.573 2.437.463 2.126 1.706l-.709 2.836.042-.02a.75.75 0 01.67 1.34l-.04.022c-1.147.573-2.438-.463-2.127-1.706l.71-2.836-.042.02a.75.75 0 11-.671-1.34l.041-.022zM12 9a.75.75 0 100-1.5.75.75 0 000 1.5z"
94-
clipRule="evenodd"
95-
/>
96-
</svg>
97-
<Typography className="underline" as="a" href="https://linux.do/t/topic/34037" target="_blank">点击此处</Typography>
98-
查看如何获得 Cloudflare API Token 和 Account ID
99-
</Typography>
100-
<Input
101-
data-testid="cf-account-id"
102-
crossOrigin="anonymous"
103-
variant="static"
104-
required
105-
label="Cloudflare Account ID"
106-
value={state.accountId}
107-
onChange={handler('accountId')}
108-
/>
109-
<Input
110-
data-testid="cf-api-token"
111-
crossOrigin="anonymous"
112-
variant="static"
113-
required
114-
label="Cloudflare API Token"
115-
value={state.apiToken}
116-
onChange={handler('apiToken')}
117-
/>
118-
</Fragment>
119-
)}
120-
</Fragment>
121-
)}
122-
<div className="col-span-2">
123-
<Button disabled={validating} onClick={onValidate} color="blue" size="lg" className="group flex items-center justify-center gap-3 text-[1rem] hover:shadow-lg">
124-
验证是否可用
125-
<Tooltip content="检查你目前的配置是否可用。若不可用,则无法启用AI总结功能。" placement="top-end">
126-
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" strokeWidth={1.5} stroke="currentColor" className="size-6">
127-
<path strokeLinecap="round" strokeLinejoin="round" d="m11.25 11.25.041-.02a.75.75 0 0 1 1.063.852l-.708 2.836a.75.75 0 0 0 1.063.853l.041-.021M21 12a9 9 0 1 1-18 0 9 9 0 0 1 18 0Zm-9-3.75h.008v.008H12V8.25Z" />
128-
</svg>
129-
</Tooltip>
130-
</Button>
131-
</div>
13233
</Fragment>
13334
)
13435
}

src/options/fragments.ts

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import * as display from './fragments/display'
55
import * as features from './fragments/features'
66
import * as listings from './fragments/listings'
77
import * as version from './fragments/version'
8+
import * as llm from './fragments/llm'
89

910

1011
interface SettingFragment<T extends object> {
@@ -28,6 +29,7 @@ const fragments = {
2829
'settings.listings': listings,
2930
'settings.capture': capture,
3031
'settings.display': display,
32+
'settings.llm': llm,
3133
'settings.developer': developer,
3234
'settings.version': version
3335
}

0 commit comments

Comments
 (0)