Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add integration tests #29

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,7 @@ require (
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/tools v0.28.0/go.mod h1:dcIOrVd3mfQKTgrDVQHqCPMWy6lnhfhtX3hLXYVLfRw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
Expand Down
208 changes: 208 additions & 0 deletions integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
package mcp_golang

import (
"bytes"
"encoding/json"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"testing"
"time"

"github.com/metoro-io/mcp-golang/transport"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

const testServerCode = `package main

import (
mcp "github.com/metoro-io/mcp-golang"
"github.com/metoro-io/mcp-golang/transport/stdio"
)

type EchoArgs struct {
Message string ` + "`json:\"message\" jsonschema:\"required,description=Message to echo back\"`" + `
}

func main() {
server := mcp.NewServer(stdio.NewStdioServerTransport())
err := server.RegisterTool("echo", "Echo back the input message", func(args EchoArgs) (*mcp.ToolResponse, error) {
return mcp.NewToolReponse(mcp.NewTextContent(args.Message)), nil
})
if err != nil {
panic(err)
}

err = server.Serve()
if err != nil {
panic(err)
}

select {}
}
`

var i = 1

func TestServerIntegration(t *testing.T) {
// Get the current module's root directory
currentDir, err := os.Getwd()
require.NoError(t, err)

// Create a temporary directory for our test server
tmpDir, err := os.MkdirTemp("", "mcp-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)

// Initialize a new module
cmd := exec.Command("go", "mod", "init", "testserver")
cmd.Dir = tmpDir
output, err := cmd.CombinedOutput()
require.NoError(t, err, "Failed to initialize module: %s", string(output))

// Replace the dependency with the local version
cmd = exec.Command("go", "mod", "edit", "-replace", "github.com/metoro-io/mcp-golang="+currentDir)
cmd.Dir = tmpDir
output, err = cmd.CombinedOutput()
require.NoError(t, err, "Failed to replace dependency: %s", string(output))

// Write the test server code
serverPath := filepath.Join(tmpDir, "test_server.go")
err = os.WriteFile(serverPath, []byte(testServerCode), 0644)
require.NoError(t, err)

// Run go mod tidy
cmd = exec.Command("go", "mod", "tidy")
cmd.Dir = tmpDir
output, err = cmd.CombinedOutput()
require.NoError(t, err, "Failed to tidy modules: %s", string(output))

// Build the test server
binPath := filepath.Join(tmpDir, "test_server")
cmd = exec.Command("go", "build", "-o", binPath, serverPath)
cmd.Dir = tmpDir
output, err = cmd.CombinedOutput()
require.NoError(t, err, "Failed to build test server: %s\nServer code:\n%s", string(output), testServerCode)

// Start the server process
serverProc := exec.Command(binPath)
stdin, err := serverProc.StdinPipe()
require.NoError(t, err)
stdout, err := serverProc.StdoutPipe()
require.NoError(t, err)
stderr, err := serverProc.StderrPipe()
require.NoError(t, err)

err = serverProc.Start()
require.NoError(t, err)
defer serverProc.Process.Kill()

// Start a goroutine to read stderr
go func() {
buf := make([]byte, 1024)
for {
n, err := stderr.Read(buf)
if err != nil {
if err != io.EOF {
t.Logf("Error reading stderr: %v", err)
}
return
}
if n > 0 {
t.Logf("Server stderr: %s", string(buf[:n]))
}
}
}()

// Helper function to send a request and read response
sendRequest := func(method string, params interface{}) (map[string]interface{}, error) {
paramsBytes, err := json.Marshal(params)
if err != nil {
return nil, err
}

req := transport.BaseJSONRPCRequest{
Jsonrpc: "2.0",
Method: method,
Params: json.RawMessage(paramsBytes),
Id: transport.RequestId(i),
}
i++

reqBytes, err := json.Marshal(req)
if err != nil {
return nil, err
}
reqBytes = append(reqBytes, '\n')

t.Logf("Sending request: %s", string(reqBytes))
_, err = stdin.Write(reqBytes)
if err != nil {
return nil, err
}

// Read response with timeout
respChan := make(chan map[string]interface{}, 1)
errChan := make(chan error, 1)

go func() {
var buf bytes.Buffer
reader := io.TeeReader(stdout, &buf)
decoder := json.NewDecoder(reader)

t.Log("Waiting for response...")
var response map[string]interface{}
err := decoder.Decode(&response)
if err != nil {
errChan <- fmt.Errorf("failed to decode response: %v\nraw response: %s", err, buf.String())
return
}
t.Logf("Got response: %+v", response)
respChan <- response
}()

select {
case resp := <-respChan:
return resp, nil
case err := <-errChan:
return nil, err
case <-time.After(5 * time.Second): // Increased timeout to 5 seconds
return nil, fmt.Errorf("timeout waiting for response")
}
}

// Test 1: Initialize
resp, err := sendRequest("initialize", map[string]interface{}{
"capabilities": map[string]interface{}{},
})
require.NoError(t, err)
assert.Equal(t, float64(1), resp["id"])
assert.NotNil(t, resp["result"])

time.Sleep(100 * time.Millisecond)

// Test 2: List tools
resp, err = sendRequest("tools/list", map[string]interface{}{})
require.NoError(t, err)
tools, ok := resp["result"].(map[string]interface{})["tools"].([]interface{})
require.True(t, ok)
assert.Len(t, tools, 1)
tool := tools[0].(map[string]interface{})
assert.Equal(t, "echo", tool["name"])

// Test 3: Call echo tool
callParams := map[string]interface{}{
"name": "echo",
"arguments": map[string]interface{}{
"message": "Hello, World!",
},
}
resp, err = sendRequest("tools/call", callParams)
require.NoError(t, err)
result := resp["result"].(map[string]interface{})
content := result["content"].([]interface{})[0].(map[string]interface{})
assert.Equal(t, "Hello, World!", content["text"])
}
24 changes: 9 additions & 15 deletions internal/protocol/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ func (p *Protocol) Connect(tr transport.Transport) error {
p.handleRequest(message.JsonRpcRequest)
case m == transport.BaseMessageTypeJSONRPCNotificationType:
p.handleNotification(message.JsonRpcNotification)
case m == transport.BaseMessageTypeJSONRPCResponseType:
p.handleResponse(message.JsonRpcResponse, nil)
case m == transport.BaseMessageTypeJSONRPCErrorType:
p.handleResponse(nil, message.JsonRpcError)
}
})

Expand Down Expand Up @@ -267,12 +271,14 @@ func (p *Protocol) handleRequest(request *transport.BaseJSONRPCRequest) {

result, err := handler(request, RequestHandlerExtra{Context: ctx})
if err != nil {
//println("error:", err.Error())
p.sendErrorResponse(request.Id, err)
return
}

jsonResult, err := json.Marshal(result)
if err != nil {
//println("error:", err.Error())
p.sendErrorResponse(request.Id, fmt.Errorf("failed to marshal result: %w", err))
return
}
Expand Down Expand Up @@ -334,8 +340,8 @@ func (p *Protocol) handleCancelledNotification(notification *transport.BaseJSONR
return nil
}

func (p *Protocol) handleResponse(response interface{}, errResp *transport.BaseJSONRPCError) {
var id transport.RequestId
func (p *Protocol) handleResponse(response *transport.BaseJSONRPCResponse, errResp *transport.BaseJSONRPCError) {
var id = response.Id
var result interface{}
var err error

Expand All @@ -344,19 +350,7 @@ func (p *Protocol) handleResponse(response interface{}, errResp *transport.BaseJ
err = fmt.Errorf("RPC error %d: %s", errResp.Error.Code, errResp.Error.Message)
} else {
// Parse the response
resp := response.(map[string]interface{})
switch v := resp["id"].(type) {
case float64:
id = transport.RequestId(int64(v))
case int64:
id = transport.RequestId(v)
case int:
id = transport.RequestId(v)
default:
p.handleError(fmt.Errorf("unexpected id type: %T", resp["id"]))
return
}
result = resp["result"]
result = response.Result
}

p.mu.RLock()
Expand Down
24 changes: 13 additions & 11 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ package mcp_golang
import (
"encoding/json"
"fmt"
"reflect"
"strings"

"github.com/invopop/jsonschema"
"github.com/metoro-io/mcp-golang/internal/protocol"
"github.com/metoro-io/mcp-golang/internal/tools"
"github.com/metoro-io/mcp-golang/transport"
"reflect"
"strings"
)

// Here we define the actual MCP server that users will create and run
Expand Down Expand Up @@ -398,15 +399,15 @@ func createWrappedToolHandler(userHandler any) func(baseCallToolRequestParams) *
}

func (s *Server) Serve() error {
protocol := protocol.NewProtocol(nil)
protocol.SetRequestHandler("initialize", s.handleInitialize)
protocol.SetRequestHandler("tools/list", s.handleListTools)
protocol.SetRequestHandler("tools/call", s.handleToolCalls)
protocol.SetRequestHandler("prompts/list", s.handleListPrompts)
protocol.SetRequestHandler("prompts/get", s.handlePromptCalls)
protocol.SetRequestHandler("resources/list", s.handleListResources)
protocol.SetRequestHandler("resources/read", s.handleResourceCalls)
return protocol.Connect(s.transport)
pr := protocol.NewProtocol(nil)
pr.SetRequestHandler("initialize", s.handleInitialize)
pr.SetRequestHandler("tools/list", s.handleListTools)
pr.SetRequestHandler("tools/call", s.handleToolCalls)
pr.SetRequestHandler("prompts/list", s.handleListPrompts)
pr.SetRequestHandler("prompts/get", s.handlePromptCalls)
pr.SetRequestHandler("resources/list", s.handleListResources)
pr.SetRequestHandler("resources/read", s.handleResourceCalls)
return pr.Connect(s.transport)
}

func (s *Server) handleInitialize(_ *transport.BaseJSONRPCRequest, _ protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) {
Expand All @@ -423,6 +424,7 @@ func (s *Server) handleInitialize(_ *transport.BaseJSONRPCRequest, _ protocol.Re
}

func (s *Server) handleListTools(_ *transport.BaseJSONRPCRequest, _ protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) {
//println("listing tools")
return tools.ToolsResponse{
Tools: func() []tools.ToolRetType {
var ts []tools.ToolRetType
Expand Down
3 changes: 0 additions & 3 deletions transport/stdio/internal/stdio/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ func (rb *ReadBuffer) ReadMessage() (*transport.BaseJsonRpcMessage, error) {
// Extract line
line := string(rb.buffer[:i])
rb.buffer = rb.buffer[i+1:]
//println("serialized message:", line)
return deserializeMessage(line)
}
}
Expand Down Expand Up @@ -143,8 +142,6 @@ func deserializeMessage(line string) (*transport.BaseJsonRpcMessage, error) {
return transport.NewBaseMessageError(&errorResponse), nil
}

// TODO: Add error handling and response deserialization

// Must be a response
return nil, errors.New("failed to unmarshal JSON-RPC message, unrecognized type")
}
6 changes: 4 additions & 2 deletions transport/stdio/stdio_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (t *StdioServerTransport) Send(message *transport.BaseJsonRpcMessage) error
}
data = append(data, '\n')

println("serialized message:", string(data))
//println("serialized message:", string(data))

t.mu.Lock()
defer t.mu.Unlock()
Expand Down Expand Up @@ -136,13 +136,15 @@ func (t *StdioServerTransport) processReadBuffer() {
for {
msg, err := t.readBuf.ReadMessage()
if err != nil {
//println("error reading message:", err.Error())
t.handleError(err)
return
}
if msg == nil {
//println("no message")
return
}

//println("received message:", spew.Sprint(msg))
t.handleMessage(msg)
}
}
Expand Down
Loading