|
| 1 | +import { runAI, runAIStream, validateAIToken } from "~api/cloudflare"; |
1 | 2 | import type { LLMProviders, Session } from "~llms";
|
2 | 3 |
|
3 | 4 | export default class CloudflareAI implements LLMProviders {
|
4 | 5 |
|
5 | 6 | constructor(
|
6 | 7 | private readonly accountId: string,
|
7 |
| - private readonly apiToken: String, |
| 8 | + private readonly apiToken: string, |
8 | 9 | private readonly model = '@cf/facebook/bart-large-cnn' // text summarization model
|
9 | 10 | ) { }
|
10 | 11 |
|
11 | 12 | async validate(): Promise<void> {
|
12 |
| - const res = await fetch(`https://api.cloudflare.com/client/v4/accounts/${this.accountId}/ai/models/search?per_page=1`, { |
13 |
| - headers: { |
14 |
| - Authorization: `Bearer ${this.apiToken}` |
15 |
| - } |
16 |
| - }) |
17 |
| - const json = await res.json() |
18 |
| - if (!json.success) throw new Error('Cloudflare API 验证失败') |
| 13 | + const success = await validateAIToken(this.accountId, this.apiToken) |
| 14 | + if (!success) throw new Error('Cloudflare API 验证失败') |
19 | 15 | }
|
20 | 16 |
|
21 | 17 | async prompt(chat: string): Promise<string> {
|
22 |
| - const res = await fetch(`https://api.cloudflare.com/client/v4/accounts/${this.accountId}/ai/run/${this.model}`, { |
23 |
| - headers: { |
24 |
| - Authorization: `Bearer ${this.apiToken}` |
25 |
| - }, |
26 |
| - body: JSON.stringify({ prompt: chat }) |
27 |
| - }) |
28 |
| - const json = await res.json() |
29 |
| - return json.response |
| 18 | + const res = await runAI(chat, { token: this.apiToken, account: this.accountId, model: this.model }) |
| 19 | + if (!res.result) throw new Error(res.errors.join(', ')) |
| 20 | + return res.result.response |
30 | 21 | }
|
31 | 22 |
|
32 | 23 | async *promptStream(chat: string): AsyncGenerator<string> {
|
33 |
| - const res = await fetch(`https://api.cloudflare.com/client/v4/accounts/${this.accountId}/ai/run/${this.model}`, { |
34 |
| - headers: { |
35 |
| - Authorization: `Bearer ${this.apiToken}` |
36 |
| - }, |
37 |
| - body: JSON.stringify({ prompt: chat, stream: true }) |
38 |
| - }) |
39 |
| - if (!res.body) throw new Error('Cloudflare AI response body is not readable') |
40 |
| - const reader = res.body.getReader() |
41 |
| - const decoder = new TextDecoder('utf-8', { ignoreBOM: true }) |
42 |
| - while (true) { |
43 |
| - const { done, value } = await reader.read() |
44 |
| - if (done) break |
45 |
| - const { response } = JSON.parse(decoder.decode(value, { stream: true })) |
46 |
| - yield response |
47 |
| - } |
| 24 | + return runAIStream(chat, { token: this.apiToken, account: this.accountId, model: this.model }) |
48 | 25 | }
|
49 | 26 |
|
50 | 27 | async asSession(): Promise<Session<LLMProviders>> {
|
|
0 commit comments