From db8ebe3e40dba1e06395f87c1539f3bc4308258d Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Tue, 3 Jun 2025 22:00:22 -0400 Subject: [PATCH 1/9] feat(js): added dynamicTool factory function that does not depend on genkit instance --- js/ai/src/generate.ts | 12 +- js/ai/src/tool.ts | 112 ++++++++---- js/core/src/action.ts | 294 ++++++++++++++++++------------- js/genkit/src/common.ts | 1 + js/genkit/src/genkit.ts | 2 +- js/genkit/tests/generate_test.ts | 4 +- 6 files changed, 259 insertions(+), 166 deletions(-) diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 869e157a3d..b4c5f52b38 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -53,7 +53,12 @@ import { type ToolResponsePart, } from './model.js'; import type { ExecutablePrompt } from './prompt.js'; -import { resolveTools, toToolDefinition, type ToolArgument } from './tool.js'; +import { + DynamicToolAction, + resolveTools, + toToolDefinition, + type ToolArgument, +} from './tool.js'; export { GenerateResponse, GenerateResponseChunk }; /** Specifies how tools should be called by the model. */ @@ -271,8 +276,6 @@ async function toolsToActionRefs( } else if (typeof (t as ExecutablePrompt).asTool === 'function') { const promptToolAction = await (t as ExecutablePrompt).asTool(); tools.push(`/prompt/${promptToolAction.__action.name}`); - } else if (t.name) { - tools.push(await resolveFullToolName(registry, t.name)); } else { throw new Error(`Unable to determine type of tool: ${JSON.stringify(t)}`); } @@ -375,6 +378,9 @@ function maybeRegisterDynamicTools< (t as Action).__action.metadata?.type === 'tool' && (t as Action).__action.metadata?.dynamic ) { + if (typeof (t as DynamicToolAction).register === 'function') { + t = (t as DynamicToolAction).register(registry); + } if (!hasDynamicTools) { hasDynamicTools = true; // Create a temporary registry with dynamic tools for the duration of this diff --git a/js/ai/src/tool.ts b/js/ai/src/tool.ts index 6ca1616ee1..ea788a8b31 100644 --- a/js/ai/src/tool.ts +++ b/js/ai/src/tool.ts @@ -15,17 +15,18 @@ */ import { - action, assertUnstable, defineAction, stripUndefinedProps, + unregisteredAction, z, type Action, type ActionContext, type ActionRunOptions, type JSONSchema7, + type UnregisteredAction, } from '@genkit-ai/core'; -import type { HasRegistry, Registry } from '@genkit-ai/core/registry'; +import type { Registry } from '@genkit-ai/core/registry'; import { parseSchema, toJsonSchema } from '@genkit-ai/core/schema'; import { setCustomMetadataAttributes } from '@genkit-ai/core/tracing'; import type { @@ -36,18 +37,10 @@ import type { } from './model.js'; import type { ExecutablePrompt } from './prompt.js'; -/** - * An action with a `tool` type. - */ -export type ToolAction< +export interface Resumable< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, -> = Action & { - __action: { - metadata: { - type: 'tool'; - }; - }; +> { /** * respond constructs a tool response corresponding to the provided interrupt tool request * using the provided reply data, validating it against the output schema of the tool if @@ -90,7 +83,37 @@ export type ToolAction< replaceInput?: z.infer; } ): ToolRequestPart; -}; +} + +/** + * An action with a `tool` type. + */ +export type ToolAction< + I extends z.ZodTypeAny = z.ZodTypeAny, + O extends z.ZodTypeAny = z.ZodTypeAny, +> = Action & + Resumable & { + __action: { + metadata: { + type: 'tool'; + }; + }; + }; + +/** + * An action with a `tool` type. + */ +export type DynamicToolAction< + I extends z.ZodTypeAny = z.ZodTypeAny, + O extends z.ZodTypeAny = z.ZodTypeAny, +> = UnregisteredAction & + Resumable & { + __action: { + metadata: { + type: 'tool'; + }; + }; + }; export interface ToolRunOptions extends ActionRunOptions { /** @@ -128,7 +151,12 @@ export interface ToolConfig { export type ToolArgument< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, -> = string | ToolAction | Action | ExecutablePrompt; +> = + | string + | ToolAction + | DynamicToolAction + | Action + | ExecutablePrompt; /** * Converts an action to a tool action by setting the appropriate metadata. @@ -174,10 +202,11 @@ export async function resolveTools< return asTool(registry, ref as Action); } else if (typeof (ref as ExecutablePrompt).asTool === 'function') { return await (ref as ExecutablePrompt).asTool(); - } else if (ref.name) { + } else if ((ref as ToolDefinition).name) { return await lookupToolByName( registry, - (ref as ToolDefinition).metadata?.originalName || ref.name + (ref as ToolDefinition).metadata?.originalName || + (ref as ToolDefinition).name ); } throw new Error('Tools must be strings, tool definitions, or actions.'); @@ -278,14 +307,16 @@ export function defineTool( function implementTool( a: ToolAction, config: ToolConfig, - registry: Registry + registry?: Registry ) { (a as ToolAction).respond = (interrupt, responseData, options) => { - assertUnstable( - registry, - 'beta', - "The 'tool.reply' method is part of the 'interrupts' beta feature." - ); + if (registry) { + assertUnstable( + registry, + 'beta', + "The 'tool.reply' method is part of the 'interrupts' beta feature." + ); + } parseSchema(responseData, { jsonSchema: config.outputJsonSchema, schema: config.outputSchema, @@ -303,11 +334,13 @@ function implementTool( }; (a as ToolAction).restart = (interrupt, resumedMetadata, options) => { - assertUnstable( - registry, - 'beta', - "The 'tool.restart' method is part of the 'interrupts' beta feature." - ); + if (registry) { + assertUnstable( + registry, + 'beta', + "The 'tool.restart' method is part of the 'interrupts' beta feature." + ); + } let replaceInput = options?.replaceInput; if (replaceInput) { replaceInput = parseSchema(replaceInput, { @@ -396,19 +429,17 @@ function interruptTool(registry: Registry) { * Genkit registry and can be defined dynamically at runtime. */ export function dynamicTool( - ai: HasRegistry, config: ToolConfig, fn?: ToolFn -): ToolAction { - const a = action( - ai.registry, +): DynamicToolAction { + const a = unregisteredAction( { ...config, actionType: 'tool', metadata: { ...(config.metadata || {}), type: 'tool', dynamic: true }, }, (i, runOptions) => { - const interrupt = interruptTool(ai.registry); + const interrupt = interruptTool(runOptions.registry); if (fn) { return fn(i, { ...runOptions, @@ -419,6 +450,19 @@ export function dynamicTool( return interrupt(); } ); - implementTool(a as ToolAction, config, ai.registry); - return a as ToolAction; + implementTool(a as any, config); + return { + __action: { + ...a.__action, + metadata: { + ...a.__action.metadata, + type: 'tool', + }, + }, + register(registry) { + const bound = a.register(registry); + implementTool(bound as ToolAction, config); + return bound; + }, + } as DynamicToolAction; } diff --git a/js/core/src/action.ts b/js/core/src/action.ts index 7852e92f17..c3936e4228 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -137,6 +137,19 @@ export type Action< ): StreamingResponse; }; +/** + * Self-describing, validating, observable, locally and remotely callable function. + */ +export type UnregisteredAction< + I extends z.ZodTypeAny = z.ZodTypeAny, + O extends z.ZodTypeAny = z.ZodTypeAny, + S extends z.ZodTypeAny = z.ZodTypeAny, + RunOptions extends ActionRunOptions = ActionRunOptions, +> = { + __action: ActionMetadata; + register(registry: Registry): Action; +}; + /** * Action factory params. */ @@ -263,18 +276,28 @@ export function action< options: ActionFnArg> ) => Promise> ): Action> { + return unregisteredAction(config, fn).register(registry); +} + +/** + * Creates an action with the provided config. + */ +export function unregisteredAction< + I extends z.ZodTypeAny, + O extends z.ZodTypeAny, + S extends z.ZodTypeAny = z.ZodTypeAny, +>( + config: ActionParams, + fn: ( + input: z.infer, + options: ActionFnArg> & {registry: Registry} + ) => Promise> +): UnregisteredAction> { const actionName = typeof config.name === 'string' ? config.name : `${config.name.pluginId}/${config.name.actionId}`; - const actionFn = async ( - input?: I, - options?: ActionRunOptions> - ) => { - return (await actionFn.run(input, options)).result; - }; - actionFn.__registry = registry; - actionFn.__action = { + const actionMetadata = { name: actionName, description: config.description, inputSchema: config.inputSchema, @@ -285,129 +308,148 @@ export function action< metadata: config.metadata, actionType: config.actionType, } as ActionMetadata; - actionFn.run = async ( - input: z.infer, - options?: ActionRunOptions> - ): Promise>> => { - input = parseSchema(input, { - schema: config.inputSchema, - jsonSchema: config.inputJsonSchema, - }); - let traceId; - let spanId; - let output = await newTrace( - registry, - { - name: actionName, - labels: { - [SPAN_TYPE_ATTR]: 'action', - 'genkit:metadata:subtype': config.actionType, - ...options?.telemetryLabels, - }, - }, - async (metadata, span) => { - setCustomMetadataAttributes(registry, { subtype: config.actionType }); - if (options?.context) { - setCustomMetadataAttributes(registry, { - context: JSON.stringify(options.context), - }); - } - traceId = span.spanContext().traceId; - spanId = span.spanContext().spanId; - metadata.name = actionName; - metadata.input = input; - - try { - const actionFn = () => - fn(input, { - ...options, - // Context can either be explicitly set, or inherited from the parent action. - context: options?.context ?? getContext(registry), - sendChunk: options?.onChunk ?? sentinelNoopStreamingCallback, - trace: { - traceId, - spanId, - }, + return { + __action: actionMetadata, + register(registry: Registry): Action> { + const actionFn = async ( + input?: I, + options?: ActionRunOptions> + ) => { + return (await actionFn.run(input, options)).result; + }; + actionFn.__registry = registry; + actionFn.__action = actionMetadata; + actionFn.run = async ( + input: z.infer, + options?: ActionRunOptions> + ): Promise>> => { + input = parseSchema(input, { + schema: config.inputSchema, + jsonSchema: config.inputJsonSchema, + }); + let traceId; + let spanId; + let output = await newTrace( + registry, + { + name: actionName, + labels: { + [SPAN_TYPE_ATTR]: 'action', + 'genkit:metadata:subtype': config.actionType, + ...options?.telemetryLabels, + }, + }, + async (metadata, span) => { + setCustomMetadataAttributes(registry, { + subtype: config.actionType, }); - // if context is explicitly passed in, we run action with the provided context, - // otherwise we let upstream context carry through. - const output = await runWithContext( - registry, - options?.context, - actionFn - ); - - metadata.output = JSON.stringify(output); - return output; - } catch (err) { - if (typeof err === 'object') { - (err as any).traceId = traceId; + if (options?.context) { + setCustomMetadataAttributes(registry, { + context: JSON.stringify(options.context), + }); + } + + traceId = span.spanContext().traceId; + spanId = span.spanContext().spanId; + metadata.name = actionName; + metadata.input = input; + + try { + const actionFn = () => + fn(input, { + ...options, + // Context can either be explicitly set, or inherited from the parent action. + context: options?.context ?? getContext(registry), + sendChunk: options?.onChunk ?? sentinelNoopStreamingCallback, + trace: { + traceId, + spanId, + }, + registry, + }); + // if context is explicitly passed in, we run action with the provided context, + // otherwise we let upstream context carry through. + const output = await runWithContext( + registry, + options?.context, + actionFn + ); + + metadata.output = JSON.stringify(output); + return output; + } catch (err) { + if (typeof err === 'object') { + (err as any).traceId = traceId; + } + throw err; + } } - throw err; - } - } - ); - output = parseSchema(output, { - schema: config.outputSchema, - jsonSchema: config.outputJsonSchema, - }); - return { - result: output, - telemetry: { - traceId, - spanId, - }, - }; - }; + ); + output = parseSchema(output, { + schema: config.outputSchema, + jsonSchema: config.outputJsonSchema, + }); + return { + result: output, + telemetry: { + traceId, + spanId, + }, + }; + }; - actionFn.stream = ( - input?: z.infer, - opts?: ActionRunOptions> - ): StreamingResponse => { - let chunkStreamController: ReadableStreamController>; - const chunkStream = new ReadableStream>({ - start(controller) { - chunkStreamController = controller; - }, - pull() {}, - cancel() {}, - }); - - const invocationPromise = actionFn - .run(config.inputSchema ? config.inputSchema.parse(input) : input, { - onChunk: ((chunk: z.infer) => { - chunkStreamController.enqueue(chunk); - }) as S extends z.ZodVoid ? undefined : StreamingCallback>, - context: opts?.context, - }) - .then((s) => s.result) - .finally(() => { - chunkStreamController.close(); - }); - - return { - output: invocationPromise, - stream: (async function* () { - const reader = chunkStream.getReader(); - while (true) { - const chunk = await reader.read(); - if (chunk.value) { - yield chunk.value; - } - if (chunk.done) { - break; - } - } - return await invocationPromise; - })(), - }; - }; + actionFn.stream = ( + input?: z.infer, + opts?: ActionRunOptions> + ): StreamingResponse => { + let chunkStreamController: ReadableStreamController>; + const chunkStream = new ReadableStream>({ + start(controller) { + chunkStreamController = controller; + }, + pull() {}, + cancel() {}, + }); + + const invocationPromise = actionFn + .run(config.inputSchema ? config.inputSchema.parse(input) : input, { + onChunk: ((chunk: z.infer) => { + chunkStreamController.enqueue(chunk); + }) as S extends z.ZodVoid + ? undefined + : StreamingCallback>, + context: opts?.context, + }) + .then((s) => s.result) + .finally(() => { + chunkStreamController.close(); + }); - if (config.use) { - return actionWithMiddleware(actionFn, config.use); - } - return actionFn; + return { + output: invocationPromise, + stream: (async function* () { + const reader = chunkStream.getReader(); + while (true) { + const chunk = await reader.read(); + if (chunk.value) { + yield chunk.value; + } + if (chunk.done) { + break; + } + } + return await invocationPromise; + })(), + }; + }; + + if (config.use) { + return actionWithMiddleware(actionFn, config.use); + } + return actionFn; + }, + }; } /** diff --git a/js/genkit/src/common.ts b/js/genkit/src/common.ts index a44571561a..c9486a2f82 100644 --- a/js/genkit/src/common.ts +++ b/js/genkit/src/common.ts @@ -107,6 +107,7 @@ export { type ToolResponsePart, } from '@genkit-ai/ai'; export { Chat } from '@genkit-ai/ai/chat'; +export { dynamicTool } from '@genkit-ai/ai/tool'; export { Session, type SessionData, diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index efb62a1c73..5c47d72607 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -207,7 +207,7 @@ export class Genkit implements HasRegistry { config: ToolConfig, fn?: ToolFn ): ToolAction { - return dynamicTool(this, config, fn); + return dynamicTool(config, fn).register(this.registry) as ToolAction; } /** diff --git a/js/genkit/tests/generate_test.ts b/js/genkit/tests/generate_test.ts index ede1a6ea32..7f81db1a8d 100644 --- a/js/genkit/tests/generate_test.ts +++ b/js/genkit/tests/generate_test.ts @@ -19,7 +19,7 @@ import { z, type JSONSchema7 } from '@genkit-ai/core'; import * as assert from 'assert'; import { beforeEach, describe, it } from 'node:test'; import { modelRef } from '../../ai/src/model'; -import { genkit, type GenkitBeta } from '../src/beta'; +import { genkit, type GenkitBeta, dynamicTool } from '../src/beta'; import { defineEchoModel, defineProgrammableModel, @@ -405,7 +405,7 @@ describe('generate', () => { foo: { type: 'string' }, }, } as JSONSchema7; - const dynamicTestTool1 = ai.dynamicTool( + const dynamicTestTool1 = dynamicTool( { name: 'dynamicTestTool1', inputJsonSchema: schema, From 2918aeb68a16884aad44fdf96e7a319924e21f86 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Tue, 3 Jun 2025 22:05:03 -0400 Subject: [PATCH 2/9] fmt --- js/core/src/action.ts | 2 +- js/genkit/src/common.ts | 2 +- js/genkit/tests/generate_test.ts | 2 +- js/testapps/flow-simple-ai/src/index.ts | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/js/core/src/action.ts b/js/core/src/action.ts index c3936e4228..aa3d20f6e8 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -290,7 +290,7 @@ export function unregisteredAction< config: ActionParams, fn: ( input: z.infer, - options: ActionFnArg> & {registry: Registry} + options: ActionFnArg> & { registry: Registry } ) => Promise> ): UnregisteredAction> { const actionName = diff --git a/js/genkit/src/common.ts b/js/genkit/src/common.ts index c9486a2f82..5c5f40b6c9 100644 --- a/js/genkit/src/common.ts +++ b/js/genkit/src/common.ts @@ -107,12 +107,12 @@ export { type ToolResponsePart, } from '@genkit-ai/ai'; export { Chat } from '@genkit-ai/ai/chat'; -export { dynamicTool } from '@genkit-ai/ai/tool'; export { Session, type SessionData, type SessionStore, } from '@genkit-ai/ai/session'; +export { dynamicTool } from '@genkit-ai/ai/tool'; export { GENKIT_CLIENT_HEADER, GENKIT_VERSION, diff --git a/js/genkit/tests/generate_test.ts b/js/genkit/tests/generate_test.ts index 7f81db1a8d..079a65872a 100644 --- a/js/genkit/tests/generate_test.ts +++ b/js/genkit/tests/generate_test.ts @@ -19,7 +19,7 @@ import { z, type JSONSchema7 } from '@genkit-ai/core'; import * as assert from 'assert'; import { beforeEach, describe, it } from 'node:test'; import { modelRef } from '../../ai/src/model'; -import { genkit, type GenkitBeta, dynamicTool } from '../src/beta'; +import { dynamicTool, genkit, type GenkitBeta } from '../src/beta'; import { defineEchoModel, defineProgrammableModel, diff --git a/js/testapps/flow-simple-ai/src/index.ts b/js/testapps/flow-simple-ai/src/index.ts index 73d118c2de..565a0f6918 100644 --- a/js/testapps/flow-simple-ai/src/index.ts +++ b/js/testapps/flow-simple-ai/src/index.ts @@ -33,7 +33,7 @@ import { GoogleAIFileManager } from '@google/generative-ai/server'; import { AlwaysOnSampler } from '@opentelemetry/sdk-trace-base'; import { initializeApp } from 'firebase-admin/app'; import { getFirestore } from 'firebase-admin/firestore'; -import { MessageSchema, genkit, z, type GenerateResponseData } from 'genkit'; +import { MessageSchema, dynamicTool, genkit, z, type GenerateResponseData } from 'genkit'; import { logger } from 'genkit/logging'; import { simulateConstrainedGeneration, @@ -489,7 +489,7 @@ export const dynamicToolCaller = ai.defineFlow( streamSchema: z.any(), }, async (input, { sendChunk }) => { - const dynamicGablorkenTool = ai.dynamicTool( + const dynamicGablorkenTool = dynamicTool( { name: 'dynamicGablorkenTool', inputSchema: z.object({ From f9e86eedf377c0c73f65265ce2cbfe22c54ab9ce Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Tue, 3 Jun 2025 22:05:50 -0400 Subject: [PATCH 3/9] fmt --- js/testapps/flow-simple-ai/src/index.ts | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/js/testapps/flow-simple-ai/src/index.ts b/js/testapps/flow-simple-ai/src/index.ts index 565a0f6918..6186d45356 100644 --- a/js/testapps/flow-simple-ai/src/index.ts +++ b/js/testapps/flow-simple-ai/src/index.ts @@ -33,7 +33,13 @@ import { GoogleAIFileManager } from '@google/generative-ai/server'; import { AlwaysOnSampler } from '@opentelemetry/sdk-trace-base'; import { initializeApp } from 'firebase-admin/app'; import { getFirestore } from 'firebase-admin/firestore'; -import { MessageSchema, dynamicTool, genkit, z, type GenerateResponseData } from 'genkit'; +import { + MessageSchema, + dynamicTool, + genkit, + z, + type GenerateResponseData, +} from 'genkit'; import { logger } from 'genkit/logging'; import { simulateConstrainedGeneration, From 6e22b1195037ec1bc0f56c2123c3b78e1635d239 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Wed, 4 Jun 2025 15:59:16 -0400 Subject: [PATCH 4/9] rename --- js/ai/src/generate.ts | 4 ++-- js/ai/src/tool.ts | 12 ++++++------ js/core/src/action.ts | 12 ++++++------ js/genkit/src/genkit.ts | 2 +- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index b4c5f52b38..5e384d878a 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -378,8 +378,8 @@ function maybeRegisterDynamicTools< (t as Action).__action.metadata?.type === 'tool' && (t as Action).__action.metadata?.dynamic ) { - if (typeof (t as DynamicToolAction).register === 'function') { - t = (t as DynamicToolAction).register(registry); + if (typeof (t as DynamicToolAction).attach === 'function') { + t = (t as DynamicToolAction).attach(registry); } if (!hasDynamicTools) { hasDynamicTools = true; diff --git a/js/ai/src/tool.ts b/js/ai/src/tool.ts index ea788a8b31..aa9237bc9f 100644 --- a/js/ai/src/tool.ts +++ b/js/ai/src/tool.ts @@ -17,14 +17,14 @@ import { assertUnstable, defineAction, + detachedAction, stripUndefinedProps, - unregisteredAction, z, type Action, type ActionContext, type ActionRunOptions, + type DetachedAction, type JSONSchema7, - type UnregisteredAction, } from '@genkit-ai/core'; import type { Registry } from '@genkit-ai/core/registry'; import { parseSchema, toJsonSchema } from '@genkit-ai/core/schema'; @@ -106,7 +106,7 @@ export type ToolAction< export type DynamicToolAction< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, -> = UnregisteredAction & +> = DetachedAction & Resumable & { __action: { metadata: { @@ -432,7 +432,7 @@ export function dynamicTool( config: ToolConfig, fn?: ToolFn ): DynamicToolAction { - const a = unregisteredAction( + const a = detachedAction( { ...config, actionType: 'tool', @@ -459,8 +459,8 @@ export function dynamicTool( type: 'tool', }, }, - register(registry) { - const bound = a.register(registry); + attach(registry) { + const bound = a.attach(registry); implementTool(bound as ToolAction, config); return bound; }, diff --git a/js/core/src/action.ts b/js/core/src/action.ts index aa3d20f6e8..9360395b8a 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -140,14 +140,14 @@ export type Action< /** * Self-describing, validating, observable, locally and remotely callable function. */ -export type UnregisteredAction< +export type DetachedAction< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, RunOptions extends ActionRunOptions = ActionRunOptions, > = { __action: ActionMetadata; - register(registry: Registry): Action; + attach(registry: Registry): Action; }; /** @@ -276,13 +276,13 @@ export function action< options: ActionFnArg> ) => Promise> ): Action> { - return unregisteredAction(config, fn).register(registry); + return detachedAction(config, fn).attach(registry); } /** * Creates an action with the provided config. */ -export function unregisteredAction< +export function detachedAction< I extends z.ZodTypeAny, O extends z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, @@ -292,7 +292,7 @@ export function unregisteredAction< input: z.infer, options: ActionFnArg> & { registry: Registry } ) => Promise> -): UnregisteredAction> { +): DetachedAction> { const actionName = typeof config.name === 'string' ? config.name @@ -311,7 +311,7 @@ export function unregisteredAction< return { __action: actionMetadata, - register(registry: Registry): Action> { + attach(registry: Registry): Action> { const actionFn = async ( input?: I, options?: ActionRunOptions> diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index 5c47d72607..6d375c1860 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -207,7 +207,7 @@ export class Genkit implements HasRegistry { config: ToolConfig, fn?: ToolFn ): ToolAction { - return dynamicTool(config, fn).register(this.registry) as ToolAction; + return dynamicTool(config, fn).attach(this.registry) as ToolAction; } /** From 8ce7fc655ebb89c8bb4bea5fecb4e270000c1004 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Wed, 4 Jun 2025 16:30:07 -0400 Subject: [PATCH 5/9] added global context --- js/core/src/action.ts | 10 +++++-- js/core/src/registry.ts | 3 ++ js/core/tests/action_test.ts | 42 ++++++++++++++++++++++++++++ js/genkit/src/genkit.ts | 5 ++++ js/genkit/tests/flow_test.ts | 15 +++++++++- js/genkit/tests/generate_test.ts | 47 ++++++++++++++++++++++++++++++++ 6 files changed, 119 insertions(+), 3 deletions(-) diff --git a/js/core/src/action.ts b/js/core/src/action.ts index 9360395b8a..cf5561c946 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -360,7 +360,10 @@ export function detachedAction< fn(input, { ...options, // Context can either be explicitly set, or inherited from the parent action. - context: options?.context ?? getContext(registry), + context: { + ...registry.context, + ...(options?.context ?? getContext(registry)), + }, sendChunk: options?.onChunk ?? sentinelNoopStreamingCallback, trace: { traceId, @@ -419,7 +422,10 @@ export function detachedAction< }) as S extends z.ZodVoid ? undefined : StreamingCallback>, - context: opts?.context, + context: { + ...registry.context, + ...(opts?.context ?? getContext(registry)), + }, }) .then((s) => s.result) .finally(() => { diff --git a/js/core/src/registry.ts b/js/core/src/registry.ts index fdc95f5891..45a5a769fb 100644 --- a/js/core/src/registry.ts +++ b/js/core/src/registry.ts @@ -22,6 +22,7 @@ import { type Action, type ActionMetadata, } from './action.js'; +import { ActionContext } from './context.js'; import { GenkitError } from './error.js'; import { logger } from './logging.js'; import type { PluginProvider } from './plugin.js'; @@ -118,6 +119,8 @@ export class Registry { readonly asyncStore: AsyncStore; readonly dotprompt: Dotprompt; readonly parent?: Registry; + /** Additional runtime context data for flows and tools. */ + context?: ActionContext; constructor(parent?: Registry) { if (parent) { diff --git a/js/core/tests/action_test.ts b/js/core/tests/action_test.ts index 0d34be7694..37ba9d7566 100644 --- a/js/core/tests/action_test.ts +++ b/js/core/tests/action_test.ts @@ -118,6 +118,48 @@ describe('action', () => { assert.deepStrictEqual(chunks, [1, 2, 3]); }); + it('run the action with context plus registry global context', async () => { + let passedContext; + const act = action( + registry, + { + name: 'foo', + inputSchema: z.string(), + outputSchema: z.number(), + actionType: 'util', + }, + async (input, { sendChunk, context }) => { + passedContext = context; + sendChunk(1); + sendChunk(2); + sendChunk(3); + return input.length; + } + ); + + registry.context = { bar: 'baz' }; + + await act.run('1234', { + context: { foo: 'bar' }, + }); + + assert.deepStrictEqual(passedContext, { + foo: 'bar', + bar: 'baz', // these come from glboal registry context + }); + + registry.context = { bar2: 'baz2' }; + const { output } = act.stream('1234', { + context: { foo2: 'bar2' }, + }); + await output; + + assert.deepStrictEqual(passedContext, { + foo2: 'bar2', + bar2: 'baz2', // these come from glboal registry context + }); + }); + it('should stream the response', async () => { const action = defineAction( registry, diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index 6d375c1860..b12367765f 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -134,6 +134,8 @@ export interface GenkitOptions { promptDir?: string; /** Default model to use if no model is specified. */ model?: ModelArgument; + /** Additional runtime context data for flows and tools. */ + context?: ActionContext; } /** @@ -162,6 +164,9 @@ export class Genkit implements HasRegistry { constructor(options?: GenkitOptions) { this.options = options || {}; this.registry = new Registry(); + if (this.options.context) { + this.registry.context = this.options.context; + } this.configure(); if (isDevEnv() && !disableReflectionApi) { this.reflectionServer = new ReflectionServer(this.registry, { diff --git a/js/genkit/tests/flow_test.ts b/js/genkit/tests/flow_test.ts index 60ed100bb9..0cb7854bed 100644 --- a/js/genkit/tests/flow_test.ts +++ b/js/genkit/tests/flow_test.ts @@ -23,7 +23,9 @@ describe('flow', () => { let ai: Genkit; beforeEach(() => { - ai = genkit({}); + ai = genkit({ + context: { something: 'extra' }, + }); }); it('calls simple flow', async () => { @@ -78,4 +80,15 @@ describe('flow', () => { // a "streaming" flow can be invoked in non-streaming mode. assert.strictEqual(await streamingBananaFlow('banana2'), 'banana2'); }); + + it('pass thought the context', async () => { + const bananaFlow = ai.defineFlow('banana', (_, { context }) => + JSON.stringify(context) + ); + + assert.strictEqual( + await bananaFlow(undefined, { context: { foo: 'bar' } }), + '{"something":"extra","foo":"bar"}' + ); + }); }); diff --git a/js/genkit/tests/generate_test.ts b/js/genkit/tests/generate_test.ts index 079a65872a..41afe0a46b 100644 --- a/js/genkit/tests/generate_test.ts +++ b/js/genkit/tests/generate_test.ts @@ -309,6 +309,7 @@ describe('generate', () => { beforeEach(() => { ai = genkit({ model: 'programmableModel', + context: { something: 'extra' }, }); pm = defineProgrammableModel(ai); defineEchoModel(ai); @@ -399,6 +400,52 @@ describe('generate', () => { ); }); + it('call the tool with context', async () => { + ai.defineTool( + { name: 'testTool', description: 'description' }, + async (_, { context }) => JSON.stringify(context) + ); + + // first response be tools call, the subsequent just text response from agent b. + let reqCounter = 0; + pm.handleResponse = async (req, sc) => { + return { + message: { + role: 'model', + content: [ + reqCounter++ === 0 + ? { + toolRequest: { + name: 'testTool', + input: {}, + ref: 'ref123', + }, + } + : { text: 'done' }, + ], + }, + }; + }; + + const { messages } = await ai.generate({ + prompt: 'call the tool', + tools: ['testTool'], + }); + + assert.deepStrictEqual(messages[2], { + role: 'tool', + content: [ + { + toolResponse: { + name: 'testTool', + output: '{"something":"extra"}', + ref: 'ref123', + }, + }, + ], + }); + }); + it('calls the dynamic tool', async () => { const schema = { properties: { From b06388da4dd1ef289409051aa4b5122e453dd7ca Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Wed, 4 Jun 2025 17:37:14 -0400 Subject: [PATCH 6/9] Test: Add non-serializable data to generate context Modify the generate test to include a non-serializable object in the context. This ensures that the serialization logic handles such cases correctly. --- js/genkit/tests/generate_test.ts | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/js/genkit/tests/generate_test.ts b/js/genkit/tests/generate_test.ts index 41afe0a46b..dd2db0677e 100644 --- a/js/genkit/tests/generate_test.ts +++ b/js/genkit/tests/generate_test.ts @@ -307,9 +307,15 @@ describe('generate', () => { let pm: ProgrammableModel; beforeEach(() => { + class Extra { + toJSON() { + return 'extra'; + } + } ai = genkit({ model: 'programmableModel', - context: { something: 'extra' }, + // testing with a non-serializable data in the context + context: { something: new Extra() }, }); pm = defineProgrammableModel(ai); defineEchoModel(ai); From 50db0b662cf4b2511a9d532255af95f0cf11b8f8 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Tue, 10 Jun 2025 11:33:54 -0400 Subject: [PATCH 7/9] cleanup --- js/ai/src/generate.ts | 26 +++++++++++--------------- js/ai/src/prompt.ts | 2 +- js/ai/src/tool.ts | 10 ++++++++++ js/core/src/action.ts | 20 +++++++++++++++++++- js/genkit/tests/generate_test.ts | 2 +- 5 files changed, 42 insertions(+), 18 deletions(-) diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 5e384d878a..fb4c00ac24 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -16,6 +16,8 @@ import { GenkitError, + isAction, + isDetachedAction, runWithContext, runWithStreamingCallback, sentinelNoopStreamingCallback, @@ -52,9 +54,9 @@ import { type ToolRequestPart, type ToolResponsePart, } from './model.js'; -import type { ExecutablePrompt } from './prompt.js'; +import { isExecutablePrompt } from './prompt.js'; import { - DynamicToolAction, + isDynamicTool, resolveTools, toToolDefinition, type ToolArgument, @@ -269,12 +271,10 @@ async function toolsToActionRefs( for (const t of toolOpt) { if (typeof t === 'string') { tools.push(await resolveFullToolName(registry, t)); - } else if ((t as Action).__action) { - tools.push( - `/${(t as Action).__action.metadata?.type}/${(t as Action).__action.name}` - ); - } else if (typeof (t as ExecutablePrompt).asTool === 'function') { - const promptToolAction = await (t as ExecutablePrompt).asTool(); + } else if (isAction(t) || isDynamicTool(t)) { + tools.push(`/${t.__action.metadata?.type}/${t.__action.name}`); + } else if (isExecutablePrompt(t)) { + const promptToolAction = await t.asTool(); tools.push(`/prompt/${promptToolAction.__action.name}`); } else { throw new Error(`Unable to determine type of tool: ${JSON.stringify(t)}`); @@ -373,13 +373,9 @@ function maybeRegisterDynamicTools< >(registry: Registry, options: GenerateOptions): Registry { let hasDynamicTools = false; options?.tools?.forEach((t) => { - if ( - (t as Action).__action && - (t as Action).__action.metadata?.type === 'tool' && - (t as Action).__action.metadata?.dynamic - ) { - if (typeof (t as DynamicToolAction).attach === 'function') { - t = (t as DynamicToolAction).attach(registry); + if (isDynamicTool(t)) { + if (isDetachedAction(t)) { + t = t.attach(registry); } if (!hasDynamicTools) { hasDynamicTools = true; diff --git a/js/ai/src/prompt.ts b/js/ai/src/prompt.ts index fda11063eb..e235e8657e 100644 --- a/js/ai/src/prompt.ts +++ b/js/ai/src/prompt.ts @@ -694,7 +694,7 @@ async function renderDotpromptToParts< /** * Checks whether the provided object is an executable prompt. */ -export function isExecutablePrompt(obj: any): boolean { +export function isExecutablePrompt(obj: any): obj is ExecutablePrompt { return ( !!(obj as ExecutablePrompt)?.render && !!(obj as ExecutablePrompt)?.asTool && diff --git a/js/ai/src/tool.ts b/js/ai/src/tool.ts index aa9237bc9f..6adab47b4c 100644 --- a/js/ai/src/tool.ts +++ b/js/ai/src/tool.ts @@ -18,6 +18,8 @@ import { assertUnstable, defineAction, detachedAction, + isAction, + isDetachedAction, stripUndefinedProps, z, type Action, @@ -385,6 +387,14 @@ export function isToolResponse(part: Part): part is ToolResponsePart { return !!part.toolResponse; } +export function isDynamicTool(t: unknown): t is DynamicToolAction { + return ( + (isDetachedAction(t) || isAction(t)) && + t.__action.metadata?.type === 'tool' && + t.__action.metadata?.dynamic + ); +} + export function defineInterrupt( registry: Registry, config: InterruptConfig diff --git a/js/core/src/action.ts b/js/core/src/action.ts index cf5561c946..ceaf48b9f1 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -46,6 +46,7 @@ export interface ActionMetadata< outputJsonSchema?: JSONSchema7; streamSchema?: S; metadata?: Record; + detached?: boolean; } /** @@ -307,6 +308,7 @@ export function detachedAction< streamSchema: config.streamSchema, metadata: config.metadata, actionType: config.actionType, + detached: true, } as ActionMetadata; return { @@ -319,7 +321,9 @@ export function detachedAction< return (await actionFn.run(input, options)).result; }; actionFn.__registry = registry; - actionFn.__action = actionMetadata; + actionFn.__action = { ...actionMetadata }; + delete actionFn.__action['detached']; + actionFn.run = async ( input: z.infer, options?: ActionRunOptions> @@ -458,6 +462,20 @@ export function detachedAction< }; } +export function isAction(a: unknown): a is Action { + return ( + typeof a === 'function' && '__action' in a && !(a as any).__action.detached + ); +} + +export function isDetachedAction(a: unknown): a is DetachedAction { + return ( + !!(a as DetachedAction).__action && + !!(a as DetachedAction).__action.detached && + typeof (a as DetachedAction).attach === 'function' + ); +} + /** * Defines an action with the given config and registers it in the registry. */ diff --git a/js/genkit/tests/generate_test.ts b/js/genkit/tests/generate_test.ts index dd2db0677e..e50d5f57d2 100644 --- a/js/genkit/tests/generate_test.ts +++ b/js/genkit/tests/generate_test.ts @@ -582,7 +582,7 @@ describe('generate', () => { ); }); - it.only('interrupts the dynamic tool with no impl', async () => { + it('interrupts the dynamic tool with no impl', async () => { const schema = { properties: { foo: { type: 'string' }, From 8b5d1ed47f6d64c06c4f6df00c300d058f08a374 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Tue, 10 Jun 2025 11:38:17 -0400 Subject: [PATCH 8/9] cleanup --- js/ai/src/tool.ts | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/js/ai/src/tool.ts b/js/ai/src/tool.ts index 6adab47b4c..5443af7da6 100644 --- a/js/ai/src/tool.ts +++ b/js/ai/src/tool.ts @@ -37,7 +37,7 @@ import type { ToolRequestPart, ToolResponsePart, } from './model.js'; -import type { ExecutablePrompt } from './prompt.js'; +import { isExecutablePrompt, type ExecutablePrompt } from './prompt.js'; export interface Resumable< I extends z.ZodTypeAny = z.ZodTypeAny, @@ -103,7 +103,7 @@ export type ToolAction< }; /** - * An action with a `tool` type. + * A dynamic action with a `tool` type. Dynamic tools are detached actions -- not associated with any registry. */ export type DynamicToolAction< I extends z.ZodTypeAny = z.ZodTypeAny, @@ -200,10 +200,10 @@ export async function resolveTools< tools.map(async (ref): Promise => { if (typeof ref === 'string') { return await lookupToolByName(registry, ref); - } else if ((ref as Action).__action) { - return asTool(registry, ref as Action); - } else if (typeof (ref as ExecutablePrompt).asTool === 'function') { - return await (ref as ExecutablePrompt).asTool(); + } else if (isAction(ref)) { + return asTool(registry, ref); + } else if (isExecutablePrompt(ref)) { + return await ref.asTool(); } else if ((ref as ToolDefinition).name) { return await lookupToolByName( registry, From 6f99a2ab00cd97f5fc9eea7f87888bbf264e3083 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Tue, 10 Jun 2025 11:43:22 -0400 Subject: [PATCH 9/9] cleanup --- js/core/tests/action_test.ts | 2 +- js/genkit/tests/chat_test.ts | 4 ++-- js/genkit/tests/flow_test.ts | 2 +- js/genkit/tests/generate_test.ts | 16 ++++++++-------- js/genkit/tests/prompts_test.ts | 4 ++-- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/js/core/tests/action_test.ts b/js/core/tests/action_test.ts index 37ba9d7566..77ed2af526 100644 --- a/js/core/tests/action_test.ts +++ b/js/core/tests/action_test.ts @@ -118,7 +118,7 @@ describe('action', () => { assert.deepStrictEqual(chunks, [1, 2, 3]); }); - it('run the action with context plus registry global context', async () => { + it('runs the action with context plus registry global context', async () => { let passedContext; const act = action( registry, diff --git a/js/genkit/tests/chat_test.ts b/js/genkit/tests/chat_test.ts index 65986cc746..aaf435c628 100644 --- a/js/genkit/tests/chat_test.ts +++ b/js/genkit/tests/chat_test.ts @@ -278,7 +278,7 @@ describe('preamble', () => { // transfer to agent B... - // first response be tools call, the subsequent just text response from agent b. + // first response is a tool call, the subsequent responses are just text response from agent b. let reqCounter = 0; pm.handleResponse = async (req, sc) => { return { @@ -367,7 +367,7 @@ describe('preamble', () => { // transfer back to to agent A... - // first response be tools call, the subsequent just text response from agent a. + // first response is a tool call, the subsequent responses are just text response from agent a. reqCounter = 0; pm.handleResponse = async (req, sc) => { return { diff --git a/js/genkit/tests/flow_test.ts b/js/genkit/tests/flow_test.ts index 0cb7854bed..3ab1208315 100644 --- a/js/genkit/tests/flow_test.ts +++ b/js/genkit/tests/flow_test.ts @@ -81,7 +81,7 @@ describe('flow', () => { assert.strictEqual(await streamingBananaFlow('banana2'), 'banana2'); }); - it('pass thought the context', async () => { + it('passes through the context', async () => { const bananaFlow = ai.defineFlow('banana', (_, { context }) => JSON.stringify(context) ); diff --git a/js/genkit/tests/generate_test.ts b/js/genkit/tests/generate_test.ts index e50d5f57d2..4b68a055e4 100644 --- a/js/genkit/tests/generate_test.ts +++ b/js/genkit/tests/generate_test.ts @@ -327,7 +327,7 @@ describe('generate', () => { async () => 'tool called' ); - // first response be tools call, the subsequent just text response from agent b. + // first response is a tool call, the subsequent responses are just text response from agent b. let reqCounter = 0; pm.handleResponse = async (req, sc) => { return { @@ -412,7 +412,7 @@ describe('generate', () => { async (_, { context }) => JSON.stringify(context) ); - // first response be tools call, the subsequent just text response from agent b. + // first response is a tool call, the subsequent responses are just text response from agent b. let reqCounter = 0; pm.handleResponse = async (req, sc) => { return { @@ -475,7 +475,7 @@ describe('generate', () => { async () => 'tool called 2' ); - // first response be tools call, the subsequent just text response from agent b. + // first response is a tool call, the subsequent responses are just text response from agent b. let reqCounter = 0; pm.handleResponse = async (req, sc) => { return { @@ -594,7 +594,7 @@ describe('generate', () => { description: 'description', }); - // first response be tools call, the subsequent just text response from agent b. + // first response is a tool call, the subsequent responses are just text response from agent b. let reqCounter = 0; pm.handleResponse = async (req, sc) => { return { @@ -655,7 +655,7 @@ describe('generate', () => { } ); - // first response be tools call, the subsequent just text response from agent b. + // first response is a tool call, the subsequent responses are just text response from agent b. let reqCounter = 0; pm.handleResponse = async (req, sc) => { return { @@ -707,7 +707,7 @@ describe('generate', () => { } ); - // first response be tools call, the subsequent just text response from agent b. + // first response is a tool call, the subsequent responses are just text response from agent b. let reqCounter = 0; pm.handleResponse = async (req, sc) => { return { @@ -753,7 +753,7 @@ describe('generate', () => { async () => 'tool called' ); - // first response be tools call, the subsequent just text response from agent b. + // first response is a tool call, the subsequent responses are just text response from agent b. let reqCounter = 0; pm.handleResponse = async (req, sc) => { if (sc) { @@ -891,7 +891,7 @@ describe('generate', () => { } ); - // first response be tools call, the subsequent just text response from agent b. + // first response is a tool call, the subsequent responses are just text response from agent b. let reqCounter = 0; pm.handleResponse = async (req, sc) => { return { diff --git a/js/genkit/tests/prompts_test.ts b/js/genkit/tests/prompts_test.ts index 397409eeed..a6fed70dc1 100644 --- a/js/genkit/tests/prompts_test.ts +++ b/js/genkit/tests/prompts_test.ts @@ -1447,7 +1447,7 @@ describe('asTool', () => { // transfer to toolPrompt... - // first response be tools call, the subsequent just text response from agent b. + // first response is a tool call, the subsequent responses are just text response from agent b. let reqCounter = 0; pm.handleResponse = async (req, sc) => { return { @@ -1536,7 +1536,7 @@ describe('asTool', () => { // transfer back to to agent A... - // first response be tools call, the subsequent just text response from agent a. + // first response is a tool call, the subsequent responses are just text response from agent a. reqCounter = 0; pm.handleResponse = async (req, sc) => { return {