diff --git a/go.mod b/go.mod index 6d8d7f3..8238b6e 100644 --- a/go.mod +++ b/go.mod @@ -20,5 +20,7 @@ require ( github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + go.uber.org/multierr v1.11.0 // indirect + go.uber.org/zap v1.27.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index abfb010..0b9bac1 100644 --- a/go.sum +++ b/go.sum @@ -27,6 +27,10 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= +go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/tools v0.28.0/go.mod h1:dcIOrVd3mfQKTgrDVQHqCPMWy6lnhfhtX3hLXYVLfRw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/integration_test.go b/integration_test.go new file mode 100644 index 0000000..f5762d8 --- /dev/null +++ b/integration_test.go @@ -0,0 +1,208 @@ +package mcp_golang + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "testing" + "time" + + "github.com/metoro-io/mcp-golang/transport" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const testServerCode = `package main + +import ( + mcp "github.com/metoro-io/mcp-golang" + "github.com/metoro-io/mcp-golang/transport/stdio" +) + +type EchoArgs struct { + Message string ` + "`json:\"message\" jsonschema:\"required,description=Message to echo back\"`" + ` +} + +func main() { + server := mcp.NewServer(stdio.NewStdioServerTransport()) + err := server.RegisterTool("echo", "Echo back the input message", func(args EchoArgs) (*mcp.ToolResponse, error) { + return mcp.NewToolReponse(mcp.NewTextContent(args.Message)), nil + }) + if err != nil { + panic(err) + } + + err = server.Serve() + if err != nil { + panic(err) + } + + select {} +} +` + +var i = 1 + +func TestServerIntegration(t *testing.T) { + // Get the current module's root directory + currentDir, err := os.Getwd() + require.NoError(t, err) + + // Create a temporary directory for our test server + tmpDir, err := os.MkdirTemp("", "mcp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + // Initialize a new module + cmd := exec.Command("go", "mod", "init", "testserver") + cmd.Dir = tmpDir + output, err := cmd.CombinedOutput() + require.NoError(t, err, "Failed to initialize module: %s", string(output)) + + // Replace the dependency with the local version + cmd = exec.Command("go", "mod", "edit", "-replace", "github.com/metoro-io/mcp-golang="+currentDir) + cmd.Dir = tmpDir + output, err = cmd.CombinedOutput() + require.NoError(t, err, "Failed to replace dependency: %s", string(output)) + + // Write the test server code + serverPath := filepath.Join(tmpDir, "test_server.go") + err = os.WriteFile(serverPath, []byte(testServerCode), 0644) + require.NoError(t, err) + + // Run go mod tidy + cmd = exec.Command("go", "mod", "tidy") + cmd.Dir = tmpDir + output, err = cmd.CombinedOutput() + require.NoError(t, err, "Failed to tidy modules: %s", string(output)) + + // Build the test server + binPath := filepath.Join(tmpDir, "test_server") + cmd = exec.Command("go", "build", "-o", binPath, serverPath) + cmd.Dir = tmpDir + output, err = cmd.CombinedOutput() + require.NoError(t, err, "Failed to build test server: %s\nServer code:\n%s", string(output), testServerCode) + + // Start the server process + serverProc := exec.Command(binPath) + stdin, err := serverProc.StdinPipe() + require.NoError(t, err) + stdout, err := serverProc.StdoutPipe() + require.NoError(t, err) + stderr, err := serverProc.StderrPipe() + require.NoError(t, err) + + err = serverProc.Start() + require.NoError(t, err) + defer serverProc.Process.Kill() + + // Start a goroutine to read stderr + go func() { + buf := make([]byte, 1024) + for { + n, err := stderr.Read(buf) + if err != nil { + if err != io.EOF { + t.Logf("Error reading stderr: %v", err) + } + return + } + if n > 0 { + t.Logf("Server stderr: %s", string(buf[:n])) + } + } + }() + + // Helper function to send a request and read response + sendRequest := func(method string, params interface{}) (map[string]interface{}, error) { + paramsBytes, err := json.Marshal(params) + if err != nil { + return nil, err + } + + req := transport.BaseJSONRPCRequest{ + Jsonrpc: "2.0", + Method: method, + Params: json.RawMessage(paramsBytes), + Id: transport.RequestId(i), + } + i++ + + reqBytes, err := json.Marshal(req) + if err != nil { + return nil, err + } + reqBytes = append(reqBytes, '\n') + + t.Logf("Sending request: %s", string(reqBytes)) + _, err = stdin.Write(reqBytes) + if err != nil { + return nil, err + } + + // Read response with timeout + respChan := make(chan map[string]interface{}, 1) + errChan := make(chan error, 1) + + go func() { + var buf bytes.Buffer + reader := io.TeeReader(stdout, &buf) + decoder := json.NewDecoder(reader) + + t.Log("Waiting for response...") + var response map[string]interface{} + err := decoder.Decode(&response) + if err != nil { + errChan <- fmt.Errorf("failed to decode response: %v\nraw response: %s", err, buf.String()) + return + } + t.Logf("Got response: %+v", response) + respChan <- response + }() + + select { + case resp := <-respChan: + return resp, nil + case err := <-errChan: + return nil, err + case <-time.After(5 * time.Second): // Increased timeout to 5 seconds + return nil, fmt.Errorf("timeout waiting for response") + } + } + + // Test 1: Initialize + resp, err := sendRequest("initialize", map[string]interface{}{ + "capabilities": map[string]interface{}{}, + }) + require.NoError(t, err) + assert.Equal(t, float64(1), resp["id"]) + assert.NotNil(t, resp["result"]) + + time.Sleep(100 * time.Millisecond) + + // Test 2: List tools + resp, err = sendRequest("tools/list", map[string]interface{}{}) + require.NoError(t, err) + tools, ok := resp["result"].(map[string]interface{})["tools"].([]interface{}) + require.True(t, ok) + assert.Len(t, tools, 1) + tool := tools[0].(map[string]interface{}) + assert.Equal(t, "echo", tool["name"]) + + // Test 3: Call echo tool + callParams := map[string]interface{}{ + "name": "echo", + "arguments": map[string]interface{}{ + "message": "Hello, World!", + }, + } + resp, err = sendRequest("tools/call", callParams) + require.NoError(t, err) + result := resp["result"].(map[string]interface{}) + content := result["content"].([]interface{})[0].(map[string]interface{}) + assert.Equal(t, "Hello, World!", content["text"]) +} diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index b972b96..c7eb9ea 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -180,6 +180,10 @@ func (p *Protocol) Connect(tr transport.Transport) error { p.handleRequest(message.JsonRpcRequest) case m == transport.BaseMessageTypeJSONRPCNotificationType: p.handleNotification(message.JsonRpcNotification) + case m == transport.BaseMessageTypeJSONRPCResponseType: + p.handleResponse(message.JsonRpcResponse, nil) + case m == transport.BaseMessageTypeJSONRPCErrorType: + p.handleResponse(nil, message.JsonRpcError) } }) @@ -267,12 +271,14 @@ func (p *Protocol) handleRequest(request *transport.BaseJSONRPCRequest) { result, err := handler(request, RequestHandlerExtra{Context: ctx}) if err != nil { + //println("error:", err.Error()) p.sendErrorResponse(request.Id, err) return } jsonResult, err := json.Marshal(result) if err != nil { + //println("error:", err.Error()) p.sendErrorResponse(request.Id, fmt.Errorf("failed to marshal result: %w", err)) return } @@ -334,8 +340,8 @@ func (p *Protocol) handleCancelledNotification(notification *transport.BaseJSONR return nil } -func (p *Protocol) handleResponse(response interface{}, errResp *transport.BaseJSONRPCError) { - var id transport.RequestId +func (p *Protocol) handleResponse(response *transport.BaseJSONRPCResponse, errResp *transport.BaseJSONRPCError) { + var id = response.Id var result interface{} var err error @@ -344,19 +350,7 @@ func (p *Protocol) handleResponse(response interface{}, errResp *transport.BaseJ err = fmt.Errorf("RPC error %d: %s", errResp.Error.Code, errResp.Error.Message) } else { // Parse the response - resp := response.(map[string]interface{}) - switch v := resp["id"].(type) { - case float64: - id = transport.RequestId(int64(v)) - case int64: - id = transport.RequestId(v) - case int: - id = transport.RequestId(v) - default: - p.handleError(fmt.Errorf("unexpected id type: %T", resp["id"])) - return - } - result = resp["result"] + result = response.Result } p.mu.RLock() diff --git a/server.go b/server.go index 2d2f5b3..e32f0f4 100644 --- a/server.go +++ b/server.go @@ -3,12 +3,13 @@ package mcp_golang import ( "encoding/json" "fmt" + "reflect" + "strings" + "github.com/invopop/jsonschema" "github.com/metoro-io/mcp-golang/internal/protocol" "github.com/metoro-io/mcp-golang/internal/tools" "github.com/metoro-io/mcp-golang/transport" - "reflect" - "strings" ) // Here we define the actual MCP server that users will create and run @@ -398,15 +399,15 @@ func createWrappedToolHandler(userHandler any) func(baseCallToolRequestParams) * } func (s *Server) Serve() error { - protocol := protocol.NewProtocol(nil) - protocol.SetRequestHandler("initialize", s.handleInitialize) - protocol.SetRequestHandler("tools/list", s.handleListTools) - protocol.SetRequestHandler("tools/call", s.handleToolCalls) - protocol.SetRequestHandler("prompts/list", s.handleListPrompts) - protocol.SetRequestHandler("prompts/get", s.handlePromptCalls) - protocol.SetRequestHandler("resources/list", s.handleListResources) - protocol.SetRequestHandler("resources/read", s.handleResourceCalls) - return protocol.Connect(s.transport) + pr := protocol.NewProtocol(nil) + pr.SetRequestHandler("initialize", s.handleInitialize) + pr.SetRequestHandler("tools/list", s.handleListTools) + pr.SetRequestHandler("tools/call", s.handleToolCalls) + pr.SetRequestHandler("prompts/list", s.handleListPrompts) + pr.SetRequestHandler("prompts/get", s.handlePromptCalls) + pr.SetRequestHandler("resources/list", s.handleListResources) + pr.SetRequestHandler("resources/read", s.handleResourceCalls) + return pr.Connect(s.transport) } func (s *Server) handleInitialize(_ *transport.BaseJSONRPCRequest, _ protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) { @@ -423,6 +424,7 @@ func (s *Server) handleInitialize(_ *transport.BaseJSONRPCRequest, _ protocol.Re } func (s *Server) handleListTools(_ *transport.BaseJSONRPCRequest, _ protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) { + //println("listing tools") return tools.ToolsResponse{ Tools: func() []tools.ToolRetType { var ts []tools.ToolRetType diff --git a/transport/stdio/internal/stdio/stdio.go b/transport/stdio/internal/stdio/stdio.go index 5c3fa24..04dfcc2 100644 --- a/transport/stdio/internal/stdio/stdio.go +++ b/transport/stdio/internal/stdio/stdio.go @@ -103,7 +103,6 @@ func (rb *ReadBuffer) ReadMessage() (*transport.BaseJsonRpcMessage, error) { // Extract line line := string(rb.buffer[:i]) rb.buffer = rb.buffer[i+1:] - //println("serialized message:", line) return deserializeMessage(line) } } @@ -143,8 +142,6 @@ func deserializeMessage(line string) (*transport.BaseJsonRpcMessage, error) { return transport.NewBaseMessageError(&errorResponse), nil } - // TODO: Add error handling and response deserialization - // Must be a response return nil, errors.New("failed to unmarshal JSON-RPC message, unrecognized type") } diff --git a/transport/stdio/stdio_server.go b/transport/stdio/stdio_server.go index e82e33a..fa2df0a 100644 --- a/transport/stdio/stdio_server.go +++ b/transport/stdio/stdio_server.go @@ -73,7 +73,7 @@ func (t *StdioServerTransport) Send(message *transport.BaseJsonRpcMessage) error } data = append(data, '\n') - println("serialized message:", string(data)) + //println("serialized message:", string(data)) t.mu.Lock() defer t.mu.Unlock() @@ -136,13 +136,15 @@ func (t *StdioServerTransport) processReadBuffer() { for { msg, err := t.readBuf.ReadMessage() if err != nil { + //println("error reading message:", err.Error()) t.handleError(err) return } if msg == nil { + //println("no message") return } - + //println("received message:", spew.Sprint(msg)) t.handleMessage(msg) } }