Skip to content

Commit

Permalink
Handle list tools with no body (#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
Chrisbattarbee authored Jan 21, 2025
1 parent b6b1dca commit 2f1bb12
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 26 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ require (
github.com/buger/jsonparser v1.1.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ github.com/invopop/jsonschema v0.12.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uO
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
Expand Down
55 changes: 29 additions & 26 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/metoro-io/mcp-golang/internal/protocol"
"github.com/metoro-io/mcp-golang/internal/tools"
"github.com/metoro-io/mcp-golang/transport"
"github.com/pkg/errors"
"reflect"
"sort"
"strings"
Expand Down Expand Up @@ -319,19 +320,19 @@ func createWrappedPromptHandler(userHandler any) func(baseGetPromptRequestParams
// Unmarshal the JSON into the correct type
err := json.Unmarshal(arguments.Arguments, &unmarshaledArguments)
if err != nil {
return newPromptResponseSentError(fmt.Errorf("failed to unmarshal arguments: %w", err))
return newPromptResponseSentError(errors.Wrap(err, "failed to unmarshal arguments"))
}

// Need to dereference the unmarshaled arguments
of := reflect.ValueOf(unmarshaledArguments)
if of.Kind() != reflect.Ptr || !of.Elem().CanInterface() {
return newPromptResponseSentError(fmt.Errorf("arguments must be a struct"))
return newPromptResponseSentError(errors.Wrap(err, "arguments must be a struct"))
}
// Call the handler with the typed arguments
output := handlerValue.Call([]reflect.Value{of.Elem()})

if len(output) != 2 {
return newPromptResponseSentError(fmt.Errorf("handler must return exactly two values, got %d", len(output)))
return newPromptResponseSentError(errors.New(fmt.Sprintf("handler must return exactly two values, got %d", len(output))))
}

if !output[0].CanInterface() {
Expand Down Expand Up @@ -437,40 +438,40 @@ func createWrappedToolHandler(userHandler any) func(baseCallToolRequestParams) *
return func(arguments baseCallToolRequestParams) *toolResponseSent {
// Instantiate a struct of the type of the arguments
if !reflect.New(argumentType).CanInterface() {
return newToolResponseSentError(fmt.Errorf("arguments must be a struct"))
return newToolResponseSentError(errors.Wrap(fmt.Errorf("arguments must be a struct"), "failed to create argument struct"))
}
unmarshaledArguments := reflect.New(argumentType).Interface()

// Unmarshal the JSON into the correct type
err := json.Unmarshal(arguments.Arguments, &unmarshaledArguments)
if err != nil {
return newToolResponseSentError(fmt.Errorf("failed to unmarshal arguments: %w", err))
return newToolResponseSentError(errors.Wrap(err, "failed to unmarshal arguments"))
}

// Need to dereference the unmarshaled arguments
of := reflect.ValueOf(unmarshaledArguments)
if of.Kind() != reflect.Ptr || !of.Elem().CanInterface() {
return newToolResponseSentError(fmt.Errorf("arguments must be a struct"))
return newToolResponseSentError(errors.Wrap(fmt.Errorf("arguments must be a struct"), "failed to dereference arguments"))
}
// Call the handler with the typed arguments
output := handlerValue.Call([]reflect.Value{of.Elem()})

if len(output) != 2 {
return newToolResponseSentError(fmt.Errorf("handler must return exactly two values, got %d", len(output)))
return newToolResponseSentError(errors.Wrap(fmt.Errorf("handler must return exactly two values, got %d", len(output)), "invalid handler return"))
}

if !output[0].CanInterface() {
return newToolResponseSentError(fmt.Errorf("handler must return a struct, got %s", output[0].Type().Name()))
return newToolResponseSentError(errors.Wrap(fmt.Errorf("handler must return a struct, got %s", output[0].Type().Name()), "invalid handler return"))
}
tool := output[0].Interface()
if !output[1].CanInterface() {
return newToolResponseSentError(fmt.Errorf("handler must return an error, got %s", output[1].Type().Name()))
return newToolResponseSentError(errors.Wrap(fmt.Errorf("handler must return an error, got %s", output[1].Type().Name()), "invalid handler return"))
}
errorOut := output[1].Interface()
if errorOut == nil {
return newToolResponseSent(tool.(*ToolResponse))
}
return newToolResponseSentError(errorOut.(error))
return newToolResponseSentError(errors.Wrap(errorOut.(error), "handler returned an error"))
}
}

Expand Down Expand Up @@ -514,9 +515,13 @@ func (s *Server) handleListTools(request *transport.BaseJSONRPCRequest, _ protoc
Cursor *string `json:"cursor"`
}
var params toolRequestParams
err := json.Unmarshal(request.Params, &params)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal arguments: %w", err)
if request.Params == nil {
params = toolRequestParams{}
} else {
err := json.Unmarshal(request.Params, &params)
if err != nil {
return nil, errors.Wrap(err, "failed to unmarshal arguments")
}
}

// Order by name for pagination
Expand All @@ -534,7 +539,7 @@ func (s *Server) handleListTools(request *transport.BaseJSONRPCRequest, _ protoc
// Base64 decode the cursor
c, err := base64.StdEncoding.DecodeString(*params.Cursor)
if err != nil {
return nil, fmt.Errorf("failed to decode cursor: %w", err)
return nil, errors.Wrap(err, "failed to decode cursor")
}
cString := string(c)
// Iterate through the tools until we find an entry > the cursor
Expand Down Expand Up @@ -585,7 +590,7 @@ func (s *Server) handleToolCalls(req *transport.BaseJSONRPCRequest, _ protocol.R
// Instantiate a struct of the type of the arguments
err := json.Unmarshal(req.Params, &params)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal arguments: %w", err)
return nil, errors.Wrap(err, "failed to unmarshal arguments")
}

var toolToUse *tool
Expand All @@ -598,11 +603,10 @@ func (s *Server) handleToolCalls(req *transport.BaseJSONRPCRequest, _ protocol.R
})

if toolToUse == nil {
return nil, fmt.Errorf("unknown tool: %s", req.Method)
return nil, errors.Wrapf(err, "unknown tool: %s", req.Method)
}
return toolToUse.Handler(params), nil
}

func (s *Server) generateCapabilities() serverCapabilities {
t := false
return serverCapabilities{
Expand All @@ -623,15 +627,14 @@ func (s *Server) generateCapabilities() serverCapabilities {
}(),
}
}

func (s *Server) handleListPrompts(request *transport.BaseJSONRPCRequest, extra protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) {
type promptRequestParams struct {
Cursor *string `json:"cursor"`
}
var params promptRequestParams
err := json.Unmarshal(request.Params, &params)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal arguments: %w", err)
return nil, errors.Wrap(err, "failed to unmarshal arguments")
}

// Order by name for pagination
Expand All @@ -649,7 +652,7 @@ func (s *Server) handleListPrompts(request *transport.BaseJSONRPCRequest, extra
// Base64 decode the cursor
c, err := base64.StdEncoding.DecodeString(*params.Cursor)
if err != nil {
return nil, fmt.Errorf("failed to decode cursor: %w", err)
return nil, errors.Wrap(err, "failed to decode cursor")
}
cString := string(c)
// Iterate through the prompts until we find an entry > the cursor
Expand Down Expand Up @@ -694,7 +697,7 @@ func (s *Server) handleListResources(request *transport.BaseJSONRPCRequest, extr
var params resourceRequestParams
err := json.Unmarshal(request.Params, &params)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal arguments: %w", err)
return nil, errors.Wrap(err, "failed to unmarshal arguments")
}

// Order by URI for pagination
Expand All @@ -712,7 +715,7 @@ func (s *Server) handleListResources(request *transport.BaseJSONRPCRequest, extr
// Base64 decode the cursor
c, err := base64.StdEncoding.DecodeString(*params.Cursor)
if err != nil {
return nil, fmt.Errorf("failed to decode cursor: %w", err)
return nil, errors.Wrap(err, "failed to decode cursor")
}
cString := string(c)
// Iterate through the resources until we find an entry > the cursor
Expand Down Expand Up @@ -760,7 +763,7 @@ func (s *Server) handlePromptCalls(req *transport.BaseJSONRPCRequest, extra prot
// Instantiate a struct of the type of the arguments
err := json.Unmarshal(req.Params, &params)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal arguments: %w", err)
return nil, errors.Wrap(err, "failed to unmarshal arguments")
}

var promptToUse *prompt
Expand All @@ -773,7 +776,7 @@ func (s *Server) handlePromptCalls(req *transport.BaseJSONRPCRequest, extra prot
})

if promptToUse == nil {
return nil, fmt.Errorf("unknown prompt: %s", req.Method)
return nil, errors.Wrapf(err, "unknown prompt: %s", req.Method)
}
return promptToUse.Handler(params), nil
}
Expand All @@ -783,7 +786,7 @@ func (s *Server) handleResourceCalls(req *transport.BaseJSONRPCRequest, extra pr
// Instantiate a struct of the type of the arguments
err := json.Unmarshal(req.Params, &params)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal arguments: %w", err)
return nil, errors.Wrap(err, "failed to unmarshal arguments")
}

var resourceToUse *resource
Expand All @@ -796,7 +799,7 @@ func (s *Server) handleResourceCalls(req *transport.BaseJSONRPCRequest, extra pr
})

if resourceToUse == nil {
return nil, fmt.Errorf("unknown prompt: %s", req.Method)
return nil, errors.Wrapf(err, "unknown prompt: %s", req.Method)
}
return resourceToUse.Handler(), nil
}
Expand Down

0 comments on commit 2f1bb12

Please sign in to comment.