Skip to content

Commit 3c18d81

Browse files
committed
e2e test
1 parent 9f78626 commit 3c18d81

File tree

8 files changed

+83
-47
lines changed

8 files changed

+83
-47
lines changed

src/api/cloudflare.ts

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import type { AIResponse, Result } from "~types/cloudflare";
2+
3+
4+
const BASE_URL = 'https://api.cloudflare.com/client/v4'
5+
6+
export async function runAI(chat: string, { token, account, model }: { token: string, account: string, model: string }): Promise<Result<AIResponse>> {
7+
const res = await fetch(`${BASE_URL}/accounts/${account}/ai/run/${model}`, {
8+
method: 'POST',
9+
headers: {
10+
Authorization: `Bearer ${token}`
11+
},
12+
body: JSON.stringify({ prompt: chat })
13+
})
14+
return await res.json()
15+
}
16+
17+
export async function *runAIStream(chat: string, { token, account, model }: { token: string, account: string, model: string }): AsyncGenerator<string> {
18+
const res = await fetch(`${BASE_URL}/accounts/${account}/ai/run/${model}`, {
19+
method: 'POST',
20+
headers: {
21+
Authorization: `Bearer ${token}`
22+
},
23+
body: JSON.stringify({ prompt: chat, stream: true })
24+
})
25+
if (!res.body) throw new Error('Cloudflare AI response body is not readable')
26+
const reader = res.body.getReader()
27+
const decoder = new TextDecoder('utf-8', { ignoreBOM: true })
28+
while (true) {
29+
const { done, value } = await reader.read()
30+
if (done) break
31+
const { response } = JSON.parse(decoder.decode(value, { stream: true }))
32+
yield response
33+
}
34+
}
35+
36+
export async function validateAIToken(accountId: string, token: string): Promise<boolean> {
37+
const res = await fetch(`${BASE_URL}/accounts/${accountId}/ai/models/search?per_page=1`, {
38+
headers: {
39+
Authorization: `Bearer ${this.apiToken}`
40+
}
41+
})
42+
const data = await res.json() as Result<any>
43+
return data.success
44+
}

src/llms/cloudflare-ai.ts

+8-31
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,27 @@
1+
import { runAI, runAIStream, validateAIToken } from "~api/cloudflare";
12
import type { LLMProviders, Session } from "~llms";
23

34
export default class CloudflareAI implements LLMProviders {
45

56
constructor(
67
private readonly accountId: string,
7-
private readonly apiToken: String,
8+
private readonly apiToken: string,
89
private readonly model = '@cf/facebook/bart-large-cnn' // text summarization model
910
) { }
1011

1112
async validate(): Promise<void> {
12-
const res = await fetch(`https://api.cloudflare.com/client/v4/accounts/${this.accountId}/ai/models/search?per_page=1`, {
13-
headers: {
14-
Authorization: `Bearer ${this.apiToken}`
15-
}
16-
})
17-
const json = await res.json()
18-
if (!json.success) throw new Error('Cloudflare API 验证失败')
13+
const success = await validateAIToken(this.accountId, this.apiToken)
14+
if (!success) throw new Error('Cloudflare API 验证失败')
1915
}
2016

2117
async prompt(chat: string): Promise<string> {
22-
const res = await fetch(`https://api.cloudflare.com/client/v4/accounts/${this.accountId}/ai/run/${this.model}`, {
23-
headers: {
24-
Authorization: `Bearer ${this.apiToken}`
25-
},
26-
body: JSON.stringify({ prompt: chat })
27-
})
28-
const json = await res.json()
29-
return json.response
18+
const res = await runAI(chat, { token: this.apiToken, account: this.accountId, model: this.model })
19+
if (!res.result) throw new Error(res.errors.join(', '))
20+
return res.result.response
3021
}
3122

3223
async *promptStream(chat: string): AsyncGenerator<string> {
33-
const res = await fetch(`https://api.cloudflare.com/client/v4/accounts/${this.accountId}/ai/run/${this.model}`, {
34-
headers: {
35-
Authorization: `Bearer ${this.apiToken}`
36-
},
37-
body: JSON.stringify({ prompt: chat, stream: true })
38-
})
39-
if (!res.body) throw new Error('Cloudflare AI response body is not readable')
40-
const reader = res.body.getReader()
41-
const decoder = new TextDecoder('utf-8', { ignoreBOM: true })
42-
while (true) {
43-
const { done, value } = await reader.read()
44-
if (done) break
45-
const { response } = JSON.parse(decoder.decode(value, { stream: true }))
46-
yield response
47-
}
24+
return runAIStream(chat, { token: this.apiToken, account: this.accountId, model: this.model })
4825
}
4926

5027
async asSession(): Promise<Session<LLMProviders>> {

src/llms/gemini-nano.ts

+12-4
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,21 @@ export default class GeminiNano implements LLMProviders {
1111
}
1212

1313
async prompt(chat: string): Promise<string> {
14-
using session = await this.asSession()
15-
return session.prompt(chat)
14+
const session = await this.asSession()
15+
try {
16+
return session.prompt(chat)
17+
} finally {
18+
session[Symbol.dispose]()
19+
}
1620
}
1721

1822
async *promptStream(chat: string): AsyncGenerator<string> {
19-
using session = await this.asSession()
20-
return session.promptStream(chat)
23+
const session = await this.asSession()
24+
try {
25+
return session.promptStream(chat)
26+
} finally {
27+
session[Symbol.dispose]()
28+
}
2129
}
2230

2331
async asSession(): Promise<Session<LLMProviders>> {

src/llms/index.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ const llms = {
1919

2020
export type LLMTypes = keyof typeof llms
2121

22-
export async function createLLMProvider(type: LLMTypes, ...args: any[]): Promise<LLMProviders> {
22+
async function createLLMProvider(type: LLMTypes, ...args: any[]): Promise<LLMProviders> {
2323
const LLM = llms[type].bind(this, ...args)
2424
return new LLM()
2525
}

src/types/cloudflare/index.ts

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
export * from './workers-ai'
2+
3+
export type Result<T> = {
4+
success: boolean
5+
result: T
6+
errors: string[]
7+
messages: string[]
8+
}

src/types/cloudflare/workers-ai.ts

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
2+
3+
export type AIResponse = {
4+
response: string
5+
}

tests/modules/llm.js

-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
11
import createLLMProvider from '~llms'
22

3-
4-
console.log('llm.js loaded!')
5-
console.log(createLLMProvider)
6-
73
window.llms = { createLLMProvider }

tests/units/llm.spec.ts

+5-7
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,12 @@ test('嘗試使用 Cloudflare AI 對話', async ({ page, modules }) => {
99
await modules['llm'].loadToPage()
1010
await modules['utils'].loadToPage()
1111

12-
const ret = await page.evaluate(async () => {
12+
const ret = await page.evaluate(async ({ accountId, apiToken }) => {
1313
const { llms } = window as any
1414
console.log('llms: ', llms)
15-
const llm = await llms.createLLMProvider('cloudflare',
16-
process.env.CF_ACCOUNT_ID,
17-
process.env.CF_API_TOKEN
18-
)
15+
const llm = await llms.createLLMProvider('cloudflare', accountId, apiToken)
1916
return await llm.prompt('你好')
20-
})
17+
}, { accountId: process.env.CF_ACCOUNT_ID, apiToken: process.env.CF_API_TOKEN })
2118

2219
logger.info('response: ', ret)
2320
await expect(ret).not.toBeEmpty()
@@ -29,6 +26,7 @@ test('嘗試使用 Gemini Nano 對話', async ({ page, modules }) => {
2926
return !!window.ai;
3027
})
3128

29+
logger.debug('Gemini Nano supported: ', supported)
3230
test.skip(!supported, 'Gemini Nano 不支援此瀏覽器')
3331

3432
await modules['llm'].loadToPage()
@@ -59,5 +57,5 @@ test('嘗試使用 Remote Worker 對話', async ({ page, modules }) => {
5957

6058
logger.info('response: ', ret)
6159
await expect(ret).not.toBeEmpty()
62-
60+
6361
})

0 commit comments

Comments
 (0)