From 13bd145831cc15e314a56e009a1c57032c9975f9 Mon Sep 17 00:00:00 2001 From: Jeason Date: Thu, 8 Aug 2024 20:10:53 +0800 Subject: [PATCH 1/2] feat: improve structured ouput --- src/language-model.ts | 13 +++++++++--- src/stream-ai.test.ts | 46 +++++++++++++++++++++++++++++++++++++++--- src/stream-ai.ts | 47 ++++++++++++++++++++++++++++++++----------- 3 files changed, 88 insertions(+), 18 deletions(-) diff --git a/src/language-model.ts b/src/language-model.ts index fe93ddd..e787b6f 100644 --- a/src/language-model.ts +++ b/src/language-model.ts @@ -16,7 +16,7 @@ import { } from '@ai-sdk/provider'; import { ChromeAISession, ChromeAISessionOptions } from './global'; import createDebug from 'debug'; -import { StreamAI } from './stream-ai'; +import { objectStartSequence, objectStopSequence, StreamAI } from './stream-ai'; const debug = createDebug('chromeai'); @@ -164,8 +164,15 @@ export class ChromeAIChatLanguageModel implements LanguageModelV1 { const session = await this.getSession(); const message = this.formatMessages(options); - const text = await session.prompt(message); + let text = await session.prompt(message); + + if (options.mode.type === 'object-json') { + text = text.replace(new RegExp('^' + objectStartSequence, 'ig'), ''); + text = text.replace(new RegExp(objectStopSequence + '$', 'ig'), ''); + } + debug('generate result:', text); + return { text, finishReason: 'stop', @@ -193,7 +200,7 @@ export class ChromeAIChatLanguageModel implements LanguageModelV1 { const session = await this.getSession(); const message = this.formatMessages(options); const promptStream = session.promptStreaming(message); - const transformStream = new StreamAI(options.abortSignal); + const transformStream = new StreamAI(options); const stream = promptStream.pipeThrough(transformStream); return { diff --git a/src/stream-ai.test.ts b/src/stream-ai.test.ts index 37755fe..bfa3c3e 100644 --- a/src/stream-ai.test.ts +++ b/src/stream-ai.test.ts @@ -1,9 +1,15 @@ import { describe, it, expect } from 'vitest'; import { StreamAI } from './stream-ai'; +import type { LanguageModelV1CallOptions } from '@ai-sdk/provider'; -describe('stream-ai', () => { +describe('stream-ai', async () => { + const defaultOptions: LanguageModelV1CallOptions = { + prompt: [], + mode: { type: 'regular' }, + inputFormat: 'messages', + }; it('should correctly transform', async () => { - const transformStream = new StreamAI(); + const transformStream = new StreamAI(defaultOptions); const writer = transformStream.writable.getWriter(); writer.write('hello'); @@ -26,7 +32,10 @@ describe('stream-ai', () => { it('should abort when signal', async () => { const controller = new AbortController(); - const transformStream = new StreamAI(controller.signal); + const transformStream = new StreamAI({ + ...defaultOptions, + abortSignal: controller.signal, + }); const writer = transformStream.writable.getWriter(); const reader = transformStream.readable.getReader(); @@ -41,4 +50,35 @@ describe('stream-ai', () => { controller.abort(); expect(await reader.read()).toMatchObject({ done: true }); }); + + it('should transform when object-json', async () => { + const transformStream = new StreamAI({ + ...defaultOptions, + mode: { type: 'object-json', schema: {} }, + }); + + const writer = transformStream.writable.getWriter(); + const reader = transformStream.readable.getReader(); + + for (const chunk of [ + ' ```', + ' ```json\n', + ' ```json\n{}', + ' ```json\n{}\n```', + ]) { + writer.write(chunk); + } + writer.close(); + + let output = ''; + while (true) { + const item = await reader.read(); + if (item.done || item.value?.type === 'finish') break; + if (item.value?.type === 'text-delta') { + output += item.value.textDelta; + } + } + + expect(output).toBe('{}'); + }); }); diff --git a/src/stream-ai.ts b/src/stream-ai.ts index d440e80..ffe2d15 100644 --- a/src/stream-ai.ts +++ b/src/stream-ai.ts @@ -1,28 +1,52 @@ -import { LanguageModelV1StreamPart } from '@ai-sdk/provider'; +import { + LanguageModelV1CallOptions, + LanguageModelV1StreamPart, +} from '@ai-sdk/provider'; import createDebug from 'debug'; const debug = createDebug('chromeai'); +export const objectStartSequence = ' ```json\n'; +export const objectStopSequence = '\n```'; + export class StreamAI extends TransformStream< string, LanguageModelV1StreamPart > { - public constructor(abortSignal?: AbortSignal) { - let textTemp = ''; + public constructor(options: LanguageModelV1CallOptions) { + let buffer = ''; + let transforming = false; + + const reset = () => { + buffer = ''; + transforming = false; + }; + super({ start: (controller) => { - textTemp = ''; - if (!abortSignal) return; - abortSignal.addEventListener('abort', () => { + reset(); + if (!options.abortSignal) return; + options.abortSignal.addEventListener('abort', () => { debug('streamText terminate by abortSignal'); controller.terminate(); - textTemp = ''; }); }, transform: (chunk, controller) => { - const textDelta = chunk.replace(textTemp, ''); - textTemp += textDelta; - controller.enqueue({ type: 'text-delta', textDelta }); + if (options.mode.type === 'object-json') { + transforming = + chunk.startsWith(objectStartSequence) && + !chunk.endsWith(objectStopSequence); + chunk = chunk.replace( + new RegExp('^' + objectStartSequence, 'ig'), + '' + ); + chunk = chunk.replace(new RegExp(objectStopSequence + '$', 'ig'), ''); + } else { + transforming = true; + } + const textDelta = chunk.replace(buffer, ''); // See: https://github.com/jeasonstudio/chrome-ai/issues/11 + if (transforming) controller.enqueue({ type: 'text-delta', textDelta }); + buffer = chunk; }, flush: (controller) => { controller.enqueue({ @@ -30,8 +54,7 @@ export class StreamAI extends TransformStream< finishReason: 'stop', usage: { completionTokens: 0, promptTokens: 0 }, }); - debug('stream result:', textTemp); - textTemp = ''; + debug('stream result:', buffer); }, }); } From d1f472772b6f4349720573abfa607233e58a1918 Mon Sep 17 00:00:00 2001 From: Jeason Date: Thu, 8 Aug 2024 20:11:39 +0800 Subject: [PATCH 2/2] chore: add changeset --- .changeset/silver-chefs-add.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/silver-chefs-add.md diff --git a/.changeset/silver-chefs-add.md b/.changeset/silver-chefs-add.md new file mode 100644 index 0000000..b006506 --- /dev/null +++ b/.changeset/silver-chefs-add.md @@ -0,0 +1,5 @@ +--- +"chrome-ai": patch +--- + +feat: improve structured output