From 688854971b81bc9166c4268a3ae2f4bc3c242eb2 Mon Sep 17 00:00:00 2001 From: caixw Date: Tue, 28 Jan 2020 00:22:11 +0800 Subject: [PATCH] =?UTF-8?q?feat(internal/lsp/jsonrpc):=20=E6=8A=BD?= =?UTF-8?q?=E8=B1=A1=20stream=20=E5=AF=B9=E8=B1=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加对 websocket 的支持 update #56 --- internal/lsp/jsonrpc/conn.go | 19 +++---- .../lsp/jsonrpc/{stream.go => httpstream.go} | 57 +++++++------------ .../{stream_test.go => httpstream_test.go} | 32 ++++++----- internal/lsp/jsonrpc/jsonrpc.go | 10 +++- internal/lsp/jsonrpc/wsstream.go | 24 ++++++++ internal/lsp/jsonrpc/wsstream_test.go | 5 ++ 6 files changed, 81 insertions(+), 66 deletions(-) rename internal/lsp/jsonrpc/{stream.go => httpstream.go} (62%) rename internal/lsp/jsonrpc/{stream_test.go => httpstream_test.go} (79%) create mode 100644 internal/lsp/jsonrpc/wsstream.go create mode 100644 internal/lsp/jsonrpc/wsstream_test.go diff --git a/internal/lsp/jsonrpc/conn.go b/internal/lsp/jsonrpc/conn.go index c40210b2..17165675 100644 --- a/internal/lsp/jsonrpc/conn.go +++ b/internal/lsp/jsonrpc/conn.go @@ -5,7 +5,6 @@ package jsonrpc import ( "context" "encoding/json" - "io" "reflect" "strconv" "sync" @@ -17,7 +16,7 @@ import ( // Conn 连接对象,json-rpc 客户端和服务端是对等的,两者都使用 conn 初始化。 type Conn struct { - stream *stream + stream Streamer servers sync.Map autoinc *autoinc.AutoInc } @@ -28,9 +27,9 @@ type handler struct { } // NewConn 声明新的 Conn 实例 -func NewConn(in io.Reader, out io.Writer) *Conn { +func NewConn(stream Streamer) *Conn { return &Conn{ - stream: newStream(in, out), + stream: stream, autoinc: autoinc.New(0, 1, 1000), } } @@ -109,7 +108,7 @@ func (conn *Conn) send(notify bool, method string, in, out interface{}) error { req.ID = strconv.FormatInt(conn.autoinc.MustID(), 10) } - if _, err = conn.stream.write(req); err != nil { + if err = conn.stream.Write(req); err != nil { return err } @@ -118,7 +117,7 @@ func (conn *Conn) send(notify bool, method string, in, out interface{}) error { } resp := &Response{} - if err = conn.stream.readResponse(resp); err != nil { + if err = conn.stream.Read(resp); err != nil { return err } @@ -150,7 +149,7 @@ func (conn *Conn) Serve(ctx context.Context) error { // 作为服务端,根据参数查找和执行服务 func (conn *Conn) serve() error { req := &Request{} - if err := conn.stream.readRequest(req); err != nil { + if err := conn.stream.Read(req); err != nil { return conn.writeError("", CodeParseError, err, nil) } @@ -181,8 +180,7 @@ func (conn *Conn) serve() error { Result: (*json.RawMessage)(&data), ID: req.ID, } - _, err = conn.stream.write(resp) - return err + return conn.stream.Write(resp) } func (conn *Conn) writeError(id string, code int, err error, data interface{}) error { @@ -197,6 +195,5 @@ func (conn *Conn) writeError(id string, code int, err error, data interface{}) e resp.Error = NewError(code, err.Error()) } - _, err = conn.stream.write(resp) - return err + return conn.stream.Write(resp) } diff --git a/internal/lsp/jsonrpc/stream.go b/internal/lsp/jsonrpc/httpstream.go similarity index 62% rename from internal/lsp/jsonrpc/stream.go rename to internal/lsp/jsonrpc/httpstream.go index 8e947adc..82fe65e7 100644 --- a/internal/lsp/jsonrpc/stream.go +++ b/internal/lsp/jsonrpc/httpstream.go @@ -23,46 +23,29 @@ var ( contentLength = http.CanonicalHeaderKey("content-length") ) -type stream struct { +type httpStream struct { in io.Reader out io.Writer outMux sync.Mutex } -func newStream(in io.Reader, out io.Writer) *stream { - return &stream{ +// NewHTTPStream 声明基于 HTTP 的 streamer 实例 +func NewHTTPStream(in io.Reader, out io.Writer) Streamer { + return &httpStream{ in: in, out: out, } } -func (s *stream) readRequest(req *Request) error { - data, err := s.read() - if err != nil { - return err - } - - return json.Unmarshal(data, req) -} - -func (s *stream) readResponse(resp *Response) error { - data, err := s.read() - if err != nil { - return err - } - - return json.Unmarshal(data, resp) -} - -// 读取内容,先验证报头,并返回实际 body 的内容 -func (s *stream) read() ([]byte, error) { +// Read 读取内容,先验证报头,并返回实际 body 的内容 +func (s *httpStream) Read(v interface{}) error { buf := bufio.NewReader(s.in) var l int for { line, err := buf.ReadString('\n') if err != nil { - return nil, err + return err } line = strings.TrimSpace(line) if line == "" { @@ -71,7 +54,7 @@ func (s *stream) read() ([]byte, error) { index := strings.IndexByte(line, ':') if index <= 0 { - return nil, locale.Errorf(locale.ErrInvalidHeaderFormat) + return locale.Errorf(locale.ErrInvalidHeaderFormat) } v := strings.TrimSpace(line[index+1:]) @@ -79,30 +62,30 @@ func (s *stream) read() ([]byte, error) { case contentLength: l, err = strconv.Atoi(v) if err != nil { - return nil, err + return err } case contentType: if err := validContentType(v); err != nil { - return nil, err + return err } default: // 忽略其它报头 } } if l <= 0 { - return nil, locale.Errorf(locale.ErrInvalidContentLength) + return locale.Errorf(locale.ErrInvalidContentLength) } data := make([]byte, l) n, err := io.ReadFull(buf, data) if err != nil { - return nil, err + return err } if n == 0 { - return nil, locale.Errorf(locale.ErrBodyIsEmpty) + return locale.Errorf(locale.ErrBodyIsEmpty) } - return data[:n], nil + return json.Unmarshal(data[:n], v) } func validContentType(header string) error { @@ -120,20 +103,20 @@ func validContentType(header string) error { return nil } -func (s *stream) write(obj interface{}) (int, error) { +func (s *httpStream) Write(obj interface{}) error { data, err := json.Marshal(obj) if err != nil { - return 0, err + return err } s.outMux.Lock() defer s.outMux.Unlock() - n, err := fmt.Fprintf(s.out, "%s: %s\r\n%s: %d\r\n\r\n", contentType, charset, contentLength, len(data)) + _, err = fmt.Fprintf(s.out, "%s: %s\r\n%s: %d\r\n\r\n", contentType, charset, contentLength, len(data)) if err != nil { - return 0, err + return err } - size, err := s.out.Write(data) - return n + size, err + _, err = s.out.Write(data) + return err } diff --git a/internal/lsp/jsonrpc/stream_test.go b/internal/lsp/jsonrpc/httpstream_test.go similarity index 79% rename from internal/lsp/jsonrpc/stream_test.go rename to internal/lsp/jsonrpc/httpstream_test.go index 1e125d74..8e43b11f 100644 --- a/internal/lsp/jsonrpc/stream_test.go +++ b/internal/lsp/jsonrpc/httpstream_test.go @@ -9,72 +9,74 @@ import ( "github.com/issue9/assert" ) -func TestStream_readRequest(t *testing.T) { +var _ Streamer = &httpStream{} + +func TestHTTPStream_Read(t *testing.T) { a := assert.New(t) r := new(bytes.Buffer) - s := newStream(r, nil) + s := NewHTTPStream(r, nil) a.NotNil(s) r.WriteString(`Content-Type: text/json;charset=utf-8 Content-Length:26 {"jsonrpc":"2.0","id":"1"}`) rr := &Request{} - a.NotError(s.readRequest(rr)) + a.NotError(s.Read(rr)) a.Equal(rr.Version, Version).Equal(rr.ID, "1") // 无效的 content-length r = new(bytes.Buffer) - s = newStream(r, nil) + s = NewHTTPStream(r, nil) a.NotNil(s) r.WriteString(`Content-Type: text/json;charset=utf-8 Content-Length:0 {"jsonrpc":"2.0","id":"1"}`) rr = &Request{} - a.Error(s.readRequest(rr)) + a.Error(s.Read(rr)) // content-type 中未指定 charset r = new(bytes.Buffer) - s = newStream(r, nil) + s = NewHTTPStream(r, nil) a.NotNil(s) r.WriteString(`Content-Type: text/json;charset-xx=utf-8 Content-Length:26 {"jsonrpc":"2.0","id":"1"}`) rr = &Request{} - a.NotError(s.readRequest(rr)) + a.NotError(s.Read(rr)) // content-length 格式无效 r = new(bytes.Buffer) - s = newStream(r, nil) + s = NewHTTPStream(r, nil) a.NotNil(s) r.WriteString(`Content-Type: text/json;charset-xx=utf-8 Content-Length:26xx {"jsonrpc":"2.0","id":"1"}`) rr = &Request{} - a.Error(s.readRequest(rr)) + a.Error(s.Read(rr)) // content-type 是指定了非 utf-8 编码 r = new(bytes.Buffer) - s = newStream(r, nil) + s = NewHTTPStream(r, nil) a.NotNil(s) r.WriteString(`Content-Type: text/json;charset-xx=utf-7 Content-Length:26xx {"jsonrpc":"2.0","id":"1"}`) rr = &Request{} - a.Error(s.readRequest(rr)) + a.Error(s.Read(rr)) } -func TestStream_write(t *testing.T) { +func TestHTTPStream_Write(t *testing.T) { a := assert.New(t) w := new(bytes.Buffer) - s := newStream(nil, w) + s := NewHTTPStream(nil, w) a.NotNil(s) - size, err := s.write(&Response{ + err := s.Write(&Response{ Version: "1.0.1", Error: &Error{ Code: CodeParseError, @@ -83,7 +85,7 @@ func TestStream_write(t *testing.T) { ID: "1", }) a.NotError(err) - a.NotEmpty(w.Bytes()).True(size > 0) + a.NotEmpty(w.Bytes()) } func TestValidContentType(t *testing.T) { diff --git a/internal/lsp/jsonrpc/jsonrpc.go b/internal/lsp/jsonrpc/jsonrpc.go index dc60bc5b..436e17b4 100644 --- a/internal/lsp/jsonrpc/jsonrpc.go +++ b/internal/lsp/jsonrpc/jsonrpc.go @@ -5,9 +5,7 @@ // https://wiki.geekdream.com/Specification/json-rpc_2.0.html package jsonrpc -import ( - "encoding/json" -) +import "encoding/json" // Version json-rpc 的版本 const Version = "2.0" @@ -30,6 +28,12 @@ const ( CodeContentModified = -32801 ) +// Streamer 用于操作 jsonrpc 的传输层接口 +type Streamer interface { + Read(interface{}) error + Write(interface{}) error +} + // Request 请求对象 type Request struct { // 指定 JSON-RPC 协议版本的字符串 diff --git a/internal/lsp/jsonrpc/wsstream.go b/internal/lsp/jsonrpc/wsstream.go new file mode 100644 index 00000000..c6e2e51b --- /dev/null +++ b/internal/lsp/jsonrpc/wsstream.go @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT + +package jsonrpc + +import ( + "github.com/gorilla/websocket" +) + +type websocketStream struct { + conn *websocket.Conn +} + +// NewWebsocketStream 声明基于 websocket 的 streamer 实例 +func NewWebsocketStream(conn *websocket.Conn) Streamer { + return &websocketStream{conn: conn} +} + +func (s *websocketStream) Read(v interface{}) error { + return s.conn.ReadJSON(v) +} + +func (s *websocketStream) Write(v interface{}) error { + return s.conn.WriteJSON(v) +} diff --git a/internal/lsp/jsonrpc/wsstream_test.go b/internal/lsp/jsonrpc/wsstream_test.go new file mode 100644 index 00000000..76397faa --- /dev/null +++ b/internal/lsp/jsonrpc/wsstream_test.go @@ -0,0 +1,5 @@ +// SPDX-License-Identifier: MIT + +package jsonrpc + +var _ Streamer = &websocketStream{}