Skip to content

feat(go): added dynamic tools support #3025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,29 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption)
modelName = genOpts.ModelName
}

var dynamicTools []Tool
tools := make([]string, len(genOpts.Tools))
for i, tool := range genOpts.Tools {
tools[i] = tool.Name()
toolNames := make(map[string]bool)
for i, toolRef := range genOpts.Tools {
name := toolRef.Name()
// Redundant duplicate tool check with GenerateWithRequest otherwise we will panic when we register the dynamic tools.
if toolNames[name] {
return nil, core.NewError(core.INVALID_ARGUMENT, "ai.Generate: duplicate tool %q", name)
}
toolNames[name] = true
tools[i] = name
// Dynamic tools wouldn't have been registered by this point.
if LookupTool(r, name) == nil {
if tool, ok := toolRef.(Tool); ok {
dynamicTools = append(dynamicTools, tool)
}
}
}
if len(dynamicTools) > 0 {
r = r.NewChild()
for _, tool := range dynamicTools {
tool.Register(r)
}
}

messages := []*Message{}
Expand Down Expand Up @@ -527,7 +547,7 @@ func handleToolRequests(ctx context.Context, r *registry.Registry, req *ModelReq

output, err := tool.RunRaw(ctx, toolReq.Input)
if err != nil {
var interruptErr *ToolInterruptError
var interruptErr *toolInterruptError
if errors.As(err, &interruptErr) {
logger.FromContext(ctx).Debug("tool %q triggered an interrupt: %v", toolReq.Name, interruptErr.Metadata)
revisedMessage.Content[idx] = &Part{
Expand Down Expand Up @@ -559,7 +579,7 @@ func handleToolRequests(ctx context.Context, r *registry.Registry, req *ModelReq
for range toolCount {
result := <-resultChan
if result.err != nil {
var interruptErr *ToolInterruptError
var interruptErr *toolInterruptError
if errors.As(result.err, &interruptErr) {
hasInterrupts = true
continue
Expand Down
118 changes: 118 additions & 0 deletions go/ai/generator_test.go → go/ai/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,124 @@ func TestGenerate(t *testing.T) {
t.Errorf("got text %q, want %q", res.Text(), expectedText)
}
})

t.Run("registers dynamic tools", func(t *testing.T) {
// Create a tool that is NOT registered in the global registry
dynamicTool := NewTool("dynamicTestTool", "a tool that is dynamically registered",
func(ctx *ToolContext, input struct {
Message string
}) (string, error) {
return "Dynamic: " + input.Message, nil
},
)

// Verify the tool is not in the global registry
if LookupTool(r, "dynamicTestTool") != nil {
t.Fatal("dynamicTestTool should not be registered in global registry")
}

// Create a model that will call the dynamic tool then provide a final response
roundCount := 0
info := &ModelInfo{
Supports: &ModelSupports{
Multiturn: true,
Tools: true,
},
}
toolCallModel := DefineModel(r, "test", "toolcall", info,
func(ctx context.Context, gr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) {
roundCount++
if roundCount == 1 {
// First response: call the dynamic tool
return &ModelResponse{
Request: gr,
Message: &Message{
Role: RoleModel,
Content: []*Part{
NewToolRequestPart(&ToolRequest{
Name: "dynamicTestTool",
Input: map[string]any{"Message": "Hello from dynamic tool"},
}),
},
},
}, nil
}
// Second response: provide final answer based on tool response
var toolResult string
for _, msg := range gr.Messages {
if msg.Role == RoleTool {
for _, part := range msg.Content {
if part.ToolResponse != nil {
toolResult = part.ToolResponse.Output.(string)
}
}
}
}
return &ModelResponse{
Request: gr,
Message: &Message{
Role: RoleModel,
Content: []*Part{
NewTextPart(toolResult),
},
},
}, nil
})

// Use Generate with the dynamic tool - this should trigger the dynamic registration
res, err := Generate(context.Background(), r,
WithModel(toolCallModel),
WithPrompt("call the dynamic tool"),
WithTools(dynamicTool),
)
if err != nil {
t.Fatal(err)
}

// The tool should have been called and returned a response
expectedText := "Dynamic: Hello from dynamic tool"
if res.Text() != expectedText {
t.Errorf("expected text %q, got %q", expectedText, res.Text())
}

// Verify two rounds were executed: tool call + final response
if roundCount != 2 {
t.Errorf("expected 2 rounds, got %d", roundCount)
}

// Verify the tool is still not in the global registry (it was registered in a child)
if LookupTool(r, "dynamicTestTool") != nil {
t.Error("dynamicTestTool should not be registered in global registry after generation")
}
})

t.Run("handles duplicate dynamic tools", func(t *testing.T) {
// Create two tools with the same name
dynamicTool1 := NewTool("duplicateTool", "first tool",
func(ctx *ToolContext, input any) (string, error) {
return "tool1", nil
},
)
dynamicTool2 := NewTool("duplicateTool", "second tool",
func(ctx *ToolContext, input any) (string, error) {
return "tool2", nil
},
)

// Using both tools should result in an error
_, err := Generate(context.Background(), r,
WithModel(echoModel),
WithPrompt("test duplicate tools"),
WithTools(dynamicTool1, dynamicTool2),
)

if err == nil {
t.Fatal("expected error for duplicate tool names")
}
if !strings.Contains(err.Error(), "duplicate tool \"duplicateTool\"") {
t.Errorf("unexpected error message: %v", err)
}
})
}

func TestModelVersion(t *testing.T) {
Expand Down
53 changes: 36 additions & 17 deletions go/ai/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,16 @@ type Tool interface {
Definition() *ToolDefinition
// RunRaw runs this tool using the provided raw input.
RunRaw(ctx context.Context, input any) (any, error)
// Register sets the tracing state on the action and registers it with the registry.
Register(r *registry.Registry)
}

// ToolInterruptError represents an intentional interruption of tool execution.
type ToolInterruptError struct {
// toolInterruptError represents an intentional interruption of tool execution.
type toolInterruptError struct {
Metadata map[string]any
}

func (e *ToolInterruptError) Error() string {
func (e *toolInterruptError) Error() string {
return "tool execution interrupted"
}

Expand All @@ -80,32 +82,43 @@ type ToolContext struct {
Interrupt func(opts *InterruptOptions) error
}

// DefineTool defines a tool function with interrupt capability
func DefineTool[In, Out any](
r *registry.Registry,
name, description string,
fn func(ctx *ToolContext, input In) (Out, error),
) Tool {
metadata := make(map[string]any)
metadata["type"] = "tool"
metadata["name"] = name
metadata["description"] = description
// DefineTool defines a tool.
func DefineTool[In, Out any](r *registry.Registry, name, description string,
fn func(ctx *ToolContext, input In) (Out, error)) Tool {
metadata, wrappedFn := implementTool(name, description, fn)
toolAction := core.DefineAction(r, "", name, core.ActionTypeTool, metadata, wrappedFn)
return &tool{Action: toolAction}
}

// NewTool creates a tool but does not register it in the registry. It can be passed directly to [Generate].
func NewTool[In, Out any](name, description string,
fn func(ctx *ToolContext, input In) (Out, error)) Tool {
metadata, wrappedFn := implementTool(name, description, fn)
metadata["dynamic"] = true
toolAction := core.NewAction("", name, core.ActionTypeTool, metadata, wrappedFn)
return &tool{Action: toolAction}
}

// implementTool creates the metadata and wrapped function common to both DefineTool and NewTool.
func implementTool[In, Out any](name, description string, fn func(ctx *ToolContext, input In) (Out, error)) (map[string]any, func(context.Context, In) (Out, error)) {
metadata := map[string]any{
"type": core.ActionTypeTool,
"name": name,
"description": description,
}
wrappedFn := func(ctx context.Context, input In) (Out, error) {
toolCtx := &ToolContext{
Context: ctx,
Interrupt: func(opts *InterruptOptions) error {
return &ToolInterruptError{
return &toolInterruptError{
Metadata: opts.Metadata,
}
},
}
return fn(toolCtx, input)
}

toolAction := core.DefineAction(r, "", name, core.ActionTypeTool, metadata, wrappedFn)

return &tool{Action: toolAction}
return metadata, wrappedFn
}

// Name returns the name of the tool.
Expand Down Expand Up @@ -135,6 +148,12 @@ func (t *tool) RunRaw(ctx context.Context, input any) (any, error) {
return runAction(ctx, t.Definition(), t.Action, input)
}

// Register sets the tracing state on the action and registers it with the registry.
func (t *tool) Register(r *registry.Registry) {
t.Action.SetTracingState(r.TracingState())
r.RegisterAction(fmt.Sprintf("/%s/%s", core.ActionTypeTool, t.Action.Name()), t.Action)
}

// runAction runs the given action with the provided raw input and returns the output in raw format.
func runAction(ctx context.Context, def *ToolDefinition, action core.Action, input any) (any, error) {
mi, err := json.Marshal(input)
Expand Down
40 changes: 37 additions & 3 deletions go/core/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ type Action interface {
RunJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error)
// Desc returns a descriptor of the action.
Desc() ActionDesc
// SetTracingState sets the tracing state on the action.
SetTracingState(tstate *tracing.State)
}

// An ActionType is the kind of an action.
Expand Down Expand Up @@ -106,6 +108,23 @@ func DefineAction[In, Out any](
})
}

// NewAction creates a new non-streaming Action without registering it.
func NewAction[In, Out any](
provider, name string,
atype ActionType,
metadata map[string]any,
fn Func[In, Out],
) *ActionDef[In, Out, struct{}] {
fullName := name
if provider != "" {
fullName = provider + "/" + name
}
return newAction(nil, fullName, atype, metadata, nil,
func(ctx context.Context, in In, cb noStream) (Out, error) {
return fn(ctx, in)
})
}

// DefineStreamingAction creates a new streaming action and registers it.
func DefineStreamingAction[In, Out, Stream any](
r *registry.Registry,
Expand Down Expand Up @@ -155,6 +174,7 @@ func defineAction[In, Out, Stream any](
}

// newAction creates a new Action with the given name and arguments.
// If registry is nil, tracing state is left nil to be set later.
// If inputSchema is nil, it is inferred from In.
func newAction[In, Out, Stream any](
r *registry.Registry,
Expand All @@ -164,23 +184,31 @@ func newAction[In, Out, Stream any](
inputSchema *jsonschema.Schema,
fn StreamingFunc[In, Out, Stream],
) *ActionDef[In, Out, Stream] {
var i In
var o Out
if inputSchema == nil {
var i In
if reflect.ValueOf(i).Kind() != reflect.Invalid {
inputSchema = base.InferJSONSchema(i)
}
}

var o Out
var outputSchema *jsonschema.Schema
if reflect.ValueOf(o).Kind() != reflect.Invalid {
outputSchema = base.InferJSONSchema(o)
}

var description string
if desc, ok := metadata["description"].(string); ok {
description = desc
}

var tstate *tracing.State
if r != nil {
tstate = r.TracingState()
}

return &ActionDef[In, Out, Stream]{
tstate: r.TracingState(),
tstate: tstate,
fn: func(ctx context.Context, input In, cb StreamCallback[Stream]) (Out, error) {
tracing.SetCustomMetadataAttr(ctx, "subtype", string(atype))
return fn(ctx, input, cb)
Expand All @@ -200,6 +228,12 @@ func newAction[In, Out, Stream any](
// Name returns the Action's Name.
func (a *ActionDef[In, Out, Stream]) Name() string { return a.desc.Name }

// SetTracingState sets the tracing state on the action. This is used when an action
// created without a registry needs to have its tracing state set later.
func (a *ActionDef[In, Out, Stream]) SetTracingState(tstate *tracing.State) {
a.tstate = tstate
}

// Run executes the Action's function in a new trace span.
func (a *ActionDef[In, Out, Stream]) Run(ctx context.Context, input In, cb StreamCallback[Stream]) (output Out, err error) {
logger.FromContext(ctx).Debug("Action.Run",
Expand Down
5 changes: 5 additions & 0 deletions go/core/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ func (f *Flow[In, Out, Stream]) Run(ctx context.Context, input In) (Out, error)
return (*ActionDef[In, Out, Stream])(f).Run(ctx, input, nil)
}

// SetTracingState sets the tracing state on the flow.
func (f *Flow[In, Out, Stream]) SetTracingState(tstate *tracing.State) {
(*ActionDef[In, Out, Stream])(f).SetTracingState(tstate)
}

// Stream runs the flow in the context of another flow and streams the output.
// It returns a function whose argument function (the "yield function") will be repeatedly
// called with the results.
Expand Down
Loading
Loading