diff --git a/.changeset/giant-trainers-knock.md b/.changeset/giant-trainers-knock.md new file mode 100644 index 0000000..df5e647 --- /dev/null +++ b/.changeset/giant-trainers-knock.md @@ -0,0 +1,5 @@ +--- +"chrome-ai": minor +--- + +break: stream text now return text part diff --git a/.changeset/thick-apricots-march.md b/.changeset/thick-apricots-march.md new file mode 100644 index 0000000..7d62d4a --- /dev/null +++ b/.changeset/thick-apricots-march.md @@ -0,0 +1,5 @@ +--- +"chrome-ai": minor +--- + +feat: support stream object diff --git a/src/language-model.ts b/src/language-model.ts index 063e66f..07336e7 100644 --- a/src/language-model.ts +++ b/src/language-model.ts @@ -16,6 +16,7 @@ import { } from '@ai-sdk/provider'; import { ChromeAISession, ChromeAISessionOptions } from './global'; import createDebug from 'debug'; +import { StreamAI } from './stream-ai'; const debug = createDebug('chromeai'); @@ -213,7 +214,7 @@ export class ChromeAIChatLanguageModel implements LanguageModelV1 { }> => { debug('stream options:', options); - if (['regular'].indexOf(options.mode.type) < 0) { + if (['regular', 'object-json'].indexOf(options.mode.type) < 0) { throw new UnsupportedFunctionalityError({ functionality: `${options.mode.type} mode`, }); @@ -222,26 +223,7 @@ export class ChromeAIChatLanguageModel implements LanguageModelV1 { const session = await this.getSession(); const message = this.formatMessages(options); const promptStream = session.promptStreaming(message); - - let tempResult = ''; - const transformStream = new TransformStream< - string, - LanguageModelV1StreamPart - >({ - transform(textDelta, controller) { - controller.enqueue({ type: 'text-delta', textDelta }); - tempResult = textDelta; - }, - flush(controller) { - controller.enqueue({ - type: 'finish', - finishReason: 'stop', - usage: { completionTokens: 0, promptTokens: 0 }, - }); - debug('stream result:', tempResult); - tempResult = ''; - }, - }); + const transformStream = new StreamAI(); const stream = promptStream.pipeThrough(transformStream); return { diff --git a/src/stream-ai.test.ts b/src/stream-ai.test.ts new file mode 100644 index 0000000..3142e32 --- /dev/null +++ b/src/stream-ai.test.ts @@ -0,0 +1,26 @@ +import { describe, it, expect, vi, afterEach } from 'vitest'; +import { StreamAI } from './stream-ai'; + +describe('stream-ai', () => { + it('should correctly transform', async () => { + const transformStream = new StreamAI(); + + const writer = transformStream.writable.getWriter(); + writer.write('hello'); + writer.write('helloworld'); + writer.close(); + + const reader = transformStream.readable.getReader(); + expect(await reader.read()).toMatchObject({ + value: { type: 'text-delta', textDelta: 'hello' }, + done: false, + }); + expect(await reader.read()).toMatchObject({ + value: { type: 'text-delta', textDelta: 'world' }, + done: false, + }); + expect(await reader.read()).toMatchObject({ + value: { type: 'finish' }, + }); + }); +}); diff --git a/src/stream-ai.ts b/src/stream-ai.ts new file mode 100644 index 0000000..57f6a12 --- /dev/null +++ b/src/stream-ai.ts @@ -0,0 +1,32 @@ +import { LanguageModelV1StreamPart } from '@ai-sdk/provider'; +import createDebug from 'debug'; + +const debug = createDebug('chromeai'); + +export class StreamAI extends TransformStream< + string, + LanguageModelV1StreamPart +> { + public constructor() { + let textTemp = ''; + super({ + start: () => { + textTemp = ''; + }, + transform: (chunk, controller) => { + const textDelta = chunk.replace(textTemp, ''); + textTemp += textDelta; + controller.enqueue({ type: 'text-delta', textDelta }); + }, + flush: (controller) => { + controller.enqueue({ + type: 'finish', + finishReason: 'stop', + usage: { completionTokens: 0, promptTokens: 0 }, + }); + debug('stream result:', textTemp); + textTemp = ''; + }, + }); + } +}