Skip to content

Commit

Permalink
Add integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
Chrisbattarbee committed Dec 11, 2024
1 parent 6d2a485 commit 8c15eae
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 31 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,7 @@ require (
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/tools v0.28.0/go.mod h1:dcIOrVd3mfQKTgrDVQHqCPMWy6lnhfhtX3hLXYVLfRw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
Expand Down
7 changes: 5 additions & 2 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ func main() {
}
`

var i = 1

func TestServerIntegration(t *testing.T) {
// Get the current module's root directory
currentDir, err := os.Getwd()
Expand Down Expand Up @@ -126,8 +128,9 @@ func TestServerIntegration(t *testing.T) {
Jsonrpc: "2.0",
Method: method,
Params: json.RawMessage(paramsBytes),
Id: 1,
Id: transport.RequestId(i),
}
i++

reqBytes, err := json.Marshal(req)
if err != nil {
Expand Down Expand Up @@ -182,7 +185,7 @@ func TestServerIntegration(t *testing.T) {
time.Sleep(100 * time.Millisecond)

// Test 2: List tools
resp, err = sendRequest("tools/list", nil)
resp, err = sendRequest("tools/list", map[string]interface{}{})
require.NoError(t, err)
tools, ok := resp["result"].(map[string]interface{})["tools"].([]interface{})
require.True(t, ok)
Expand Down
24 changes: 9 additions & 15 deletions internal/protocol/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ func (p *Protocol) Connect(tr transport.Transport) error {
p.handleRequest(message.JsonRpcRequest)
case m == transport.BaseMessageTypeJSONRPCNotificationType:
p.handleNotification(message.JsonRpcNotification)
case m == transport.BaseMessageTypeJSONRPCResponseType:
p.handleResponse(message.JsonRpcResponse, nil)
case m == transport.BaseMessageTypeJSONRPCErrorType:
p.handleResponse(nil, message.JsonRpcError)
}
})

Expand Down Expand Up @@ -267,12 +271,14 @@ func (p *Protocol) handleRequest(request *transport.BaseJSONRPCRequest) {

result, err := handler(request, RequestHandlerExtra{Context: ctx})
if err != nil {
//println("error:", err.Error())
p.sendErrorResponse(request.Id, err)
return
}

jsonResult, err := json.Marshal(result)
if err != nil {
//println("error:", err.Error())
p.sendErrorResponse(request.Id, fmt.Errorf("failed to marshal result: %w", err))
return
}
Expand Down Expand Up @@ -334,8 +340,8 @@ func (p *Protocol) handleCancelledNotification(notification *transport.BaseJSONR
return nil
}

func (p *Protocol) handleResponse(response interface{}, errResp *transport.BaseJSONRPCError) {
var id transport.RequestId
func (p *Protocol) handleResponse(response *transport.BaseJSONRPCResponse, errResp *transport.BaseJSONRPCError) {
var id = response.Id
var result interface{}
var err error

Expand All @@ -344,19 +350,7 @@ func (p *Protocol) handleResponse(response interface{}, errResp *transport.BaseJ
err = fmt.Errorf("RPC error %d: %s", errResp.Error.Code, errResp.Error.Message)
} else {
// Parse the response
resp := response.(map[string]interface{})
switch v := resp["id"].(type) {
case float64:
id = transport.RequestId(int64(v))
case int64:
id = transport.RequestId(v)
case int:
id = transport.RequestId(v)
default:
p.handleError(fmt.Errorf("unexpected id type: %T", resp["id"]))
return
}
result = resp["result"]
result = response.Result
}

p.mu.RLock()
Expand Down
19 changes: 10 additions & 9 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,15 +399,15 @@ func createWrappedToolHandler(userHandler any) func(baseCallToolRequestParams) *
}

func (s *Server) Serve() error {
protocol := protocol.NewProtocol(nil)
protocol.SetRequestHandler("initialize", s.handleInitialize)
protocol.SetRequestHandler("tools/list", s.handleListTools)
protocol.SetRequestHandler("tools/call", s.handleToolCalls)
protocol.SetRequestHandler("prompts/list", s.handleListPrompts)
protocol.SetRequestHandler("prompts/get", s.handlePromptCalls)
protocol.SetRequestHandler("resources/list", s.handleListResources)
protocol.SetRequestHandler("resources/read", s.handleResourceCalls)
return protocol.Connect(s.transport)
pr := protocol.NewProtocol(nil)
pr.SetRequestHandler("initialize", s.handleInitialize)
pr.SetRequestHandler("tools/list", s.handleListTools)
pr.SetRequestHandler("tools/call", s.handleToolCalls)
pr.SetRequestHandler("prompts/list", s.handleListPrompts)
pr.SetRequestHandler("prompts/get", s.handlePromptCalls)
pr.SetRequestHandler("resources/list", s.handleListResources)
pr.SetRequestHandler("resources/read", s.handleResourceCalls)
return pr.Connect(s.transport)
}

func (s *Server) handleInitialize(_ *transport.BaseJSONRPCRequest, _ protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) {
Expand All @@ -424,6 +424,7 @@ func (s *Server) handleInitialize(_ *transport.BaseJSONRPCRequest, _ protocol.Re
}

func (s *Server) handleListTools(_ *transport.BaseJSONRPCRequest, _ protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) {
//println("listing tools")
return tools.ToolsResponse{
Tools: func() []tools.ToolRetType {
var ts []tools.ToolRetType
Expand Down
3 changes: 0 additions & 3 deletions transport/stdio/internal/stdio/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ func (rb *ReadBuffer) ReadMessage() (*transport.BaseJsonRpcMessage, error) {
// Extract line
line := string(rb.buffer[:i])
rb.buffer = rb.buffer[i+1:]
//println("serialized message:", line)
return deserializeMessage(line)
}
}
Expand Down Expand Up @@ -143,8 +142,6 @@ func deserializeMessage(line string) (*transport.BaseJsonRpcMessage, error) {
return transport.NewBaseMessageError(&errorResponse), nil
}

// TODO: Add error handling and response deserialization

// Must be a response
return nil, errors.New("failed to unmarshal JSON-RPC message, unrecognized type")
}
6 changes: 4 additions & 2 deletions transport/stdio/stdio_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (t *StdioServerTransport) Send(message *transport.BaseJsonRpcMessage) error
}
data = append(data, '\n')

println("serialized message:", string(data))
//println("serialized message:", string(data))

t.mu.Lock()
defer t.mu.Unlock()
Expand Down Expand Up @@ -136,13 +136,15 @@ func (t *StdioServerTransport) processReadBuffer() {
for {
msg, err := t.readBuf.ReadMessage()
if err != nil {
//println("error reading message:", err.Error())
t.handleError(err)
return
}
if msg == nil {
//println("no message")
return
}

//println("received message:", spew.Sprint(msg))
t.handleMessage(msg)
}
}
Expand Down

0 comments on commit 8c15eae

Please sign in to comment.