From 8c15eaeee63688d6850405a7efe034e09081501f Mon Sep 17 00:00:00 2001 From: Chris Battarbee Date: Wed, 11 Dec 2024 18:02:58 +0000 Subject: [PATCH] Add integration test --- go.mod | 2 ++ go.sum | 4 ++++ integration_test.go | 7 +++++-- internal/protocol/protocol.go | 24 +++++++++--------------- server.go | 19 ++++++++++--------- transport/stdio/internal/stdio/stdio.go | 3 --- transport/stdio/stdio_server.go | 6 ++++-- 7 files changed, 34 insertions(+), 31 deletions(-) diff --git a/go.mod b/go.mod index 6d8d7f3..8238b6e 100644 --- a/go.mod +++ b/go.mod @@ -20,5 +20,7 @@ require ( github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + go.uber.org/multierr v1.11.0 // indirect + go.uber.org/zap v1.27.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index abfb010..0b9bac1 100644 --- a/go.sum +++ b/go.sum @@ -27,6 +27,10 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= +go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/tools v0.28.0/go.mod h1:dcIOrVd3mfQKTgrDVQHqCPMWy6lnhfhtX3hLXYVLfRw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/integration_test.go b/integration_test.go index 0150020..f5762d8 100644 --- a/integration_test.go +++ b/integration_test.go @@ -45,6 +45,8 @@ func main() { } ` +var i = 1 + func TestServerIntegration(t *testing.T) { // Get the current module's root directory currentDir, err := os.Getwd() @@ -126,8 +128,9 @@ func TestServerIntegration(t *testing.T) { Jsonrpc: "2.0", Method: method, Params: json.RawMessage(paramsBytes), - Id: 1, + Id: transport.RequestId(i), } + i++ reqBytes, err := json.Marshal(req) if err != nil { @@ -182,7 +185,7 @@ func TestServerIntegration(t *testing.T) { time.Sleep(100 * time.Millisecond) // Test 2: List tools - resp, err = sendRequest("tools/list", nil) + resp, err = sendRequest("tools/list", map[string]interface{}{}) require.NoError(t, err) tools, ok := resp["result"].(map[string]interface{})["tools"].([]interface{}) require.True(t, ok) diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index b972b96..c7eb9ea 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -180,6 +180,10 @@ func (p *Protocol) Connect(tr transport.Transport) error { p.handleRequest(message.JsonRpcRequest) case m == transport.BaseMessageTypeJSONRPCNotificationType: p.handleNotification(message.JsonRpcNotification) + case m == transport.BaseMessageTypeJSONRPCResponseType: + p.handleResponse(message.JsonRpcResponse, nil) + case m == transport.BaseMessageTypeJSONRPCErrorType: + p.handleResponse(nil, message.JsonRpcError) } }) @@ -267,12 +271,14 @@ func (p *Protocol) handleRequest(request *transport.BaseJSONRPCRequest) { result, err := handler(request, RequestHandlerExtra{Context: ctx}) if err != nil { + //println("error:", err.Error()) p.sendErrorResponse(request.Id, err) return } jsonResult, err := json.Marshal(result) if err != nil { + //println("error:", err.Error()) p.sendErrorResponse(request.Id, fmt.Errorf("failed to marshal result: %w", err)) return } @@ -334,8 +340,8 @@ func (p *Protocol) handleCancelledNotification(notification *transport.BaseJSONR return nil } -func (p *Protocol) handleResponse(response interface{}, errResp *transport.BaseJSONRPCError) { - var id transport.RequestId +func (p *Protocol) handleResponse(response *transport.BaseJSONRPCResponse, errResp *transport.BaseJSONRPCError) { + var id = response.Id var result interface{} var err error @@ -344,19 +350,7 @@ func (p *Protocol) handleResponse(response interface{}, errResp *transport.BaseJ err = fmt.Errorf("RPC error %d: %s", errResp.Error.Code, errResp.Error.Message) } else { // Parse the response - resp := response.(map[string]interface{}) - switch v := resp["id"].(type) { - case float64: - id = transport.RequestId(int64(v)) - case int64: - id = transport.RequestId(v) - case int: - id = transport.RequestId(v) - default: - p.handleError(fmt.Errorf("unexpected id type: %T", resp["id"])) - return - } - result = resp["result"] + result = response.Result } p.mu.RLock() diff --git a/server.go b/server.go index 6d8de56..e32f0f4 100644 --- a/server.go +++ b/server.go @@ -399,15 +399,15 @@ func createWrappedToolHandler(userHandler any) func(baseCallToolRequestParams) * } func (s *Server) Serve() error { - protocol := protocol.NewProtocol(nil) - protocol.SetRequestHandler("initialize", s.handleInitialize) - protocol.SetRequestHandler("tools/list", s.handleListTools) - protocol.SetRequestHandler("tools/call", s.handleToolCalls) - protocol.SetRequestHandler("prompts/list", s.handleListPrompts) - protocol.SetRequestHandler("prompts/get", s.handlePromptCalls) - protocol.SetRequestHandler("resources/list", s.handleListResources) - protocol.SetRequestHandler("resources/read", s.handleResourceCalls) - return protocol.Connect(s.transport) + pr := protocol.NewProtocol(nil) + pr.SetRequestHandler("initialize", s.handleInitialize) + pr.SetRequestHandler("tools/list", s.handleListTools) + pr.SetRequestHandler("tools/call", s.handleToolCalls) + pr.SetRequestHandler("prompts/list", s.handleListPrompts) + pr.SetRequestHandler("prompts/get", s.handlePromptCalls) + pr.SetRequestHandler("resources/list", s.handleListResources) + pr.SetRequestHandler("resources/read", s.handleResourceCalls) + return pr.Connect(s.transport) } func (s *Server) handleInitialize(_ *transport.BaseJSONRPCRequest, _ protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) { @@ -424,6 +424,7 @@ func (s *Server) handleInitialize(_ *transport.BaseJSONRPCRequest, _ protocol.Re } func (s *Server) handleListTools(_ *transport.BaseJSONRPCRequest, _ protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) { + //println("listing tools") return tools.ToolsResponse{ Tools: func() []tools.ToolRetType { var ts []tools.ToolRetType diff --git a/transport/stdio/internal/stdio/stdio.go b/transport/stdio/internal/stdio/stdio.go index 5c3fa24..04dfcc2 100644 --- a/transport/stdio/internal/stdio/stdio.go +++ b/transport/stdio/internal/stdio/stdio.go @@ -103,7 +103,6 @@ func (rb *ReadBuffer) ReadMessage() (*transport.BaseJsonRpcMessage, error) { // Extract line line := string(rb.buffer[:i]) rb.buffer = rb.buffer[i+1:] - //println("serialized message:", line) return deserializeMessage(line) } } @@ -143,8 +142,6 @@ func deserializeMessage(line string) (*transport.BaseJsonRpcMessage, error) { return transport.NewBaseMessageError(&errorResponse), nil } - // TODO: Add error handling and response deserialization - // Must be a response return nil, errors.New("failed to unmarshal JSON-RPC message, unrecognized type") } diff --git a/transport/stdio/stdio_server.go b/transport/stdio/stdio_server.go index e82e33a..fa2df0a 100644 --- a/transport/stdio/stdio_server.go +++ b/transport/stdio/stdio_server.go @@ -73,7 +73,7 @@ func (t *StdioServerTransport) Send(message *transport.BaseJsonRpcMessage) error } data = append(data, '\n') - println("serialized message:", string(data)) + //println("serialized message:", string(data)) t.mu.Lock() defer t.mu.Unlock() @@ -136,13 +136,15 @@ func (t *StdioServerTransport) processReadBuffer() { for { msg, err := t.readBuf.ReadMessage() if err != nil { + //println("error reading message:", err.Error()) t.handleError(err) return } if msg == nil { + //println("no message") return } - + //println("received message:", spew.Sprint(msg)) t.handleMessage(msg) } }