Skip to content

Commit

Permalink
feat(internal/lsp/jsonrpc): 抽象 stream 对象
Browse files Browse the repository at this point in the history
添加对 websocket 的支持

update #56
  • Loading branch information
caixw committed Apr 29, 2020
1 parent 4d84b61 commit 6888549
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 66 deletions.
19 changes: 8 additions & 11 deletions internal/lsp/jsonrpc/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package jsonrpc
import (
"context"
"encoding/json"
"io"
"reflect"
"strconv"
"sync"
Expand All @@ -17,7 +16,7 @@ import (

// Conn 连接对象,json-rpc 客户端和服务端是对等的,两者都使用 conn 初始化。
type Conn struct {
stream *stream
stream Streamer
servers sync.Map
autoinc *autoinc.AutoInc
}
Expand All @@ -28,9 +27,9 @@ type handler struct {
}

// NewConn 声明新的 Conn 实例
func NewConn(in io.Reader, out io.Writer) *Conn {
func NewConn(stream Streamer) *Conn {
return &Conn{
stream: newStream(in, out),
stream: stream,
autoinc: autoinc.New(0, 1, 1000),
}
}
Expand Down Expand Up @@ -109,7 +108,7 @@ func (conn *Conn) send(notify bool, method string, in, out interface{}) error {
req.ID = strconv.FormatInt(conn.autoinc.MustID(), 10)
}

if _, err = conn.stream.write(req); err != nil {
if err = conn.stream.Write(req); err != nil {
return err
}

Expand All @@ -118,7 +117,7 @@ func (conn *Conn) send(notify bool, method string, in, out interface{}) error {
}

resp := &Response{}
if err = conn.stream.readResponse(resp); err != nil {
if err = conn.stream.Read(resp); err != nil {
return err
}

Expand Down Expand Up @@ -150,7 +149,7 @@ func (conn *Conn) Serve(ctx context.Context) error {
// 作为服务端,根据参数查找和执行服务
func (conn *Conn) serve() error {
req := &Request{}
if err := conn.stream.readRequest(req); err != nil {
if err := conn.stream.Read(req); err != nil {
return conn.writeError("", CodeParseError, err, nil)
}

Expand Down Expand Up @@ -181,8 +180,7 @@ func (conn *Conn) serve() error {
Result: (*json.RawMessage)(&data),
ID: req.ID,
}
_, err = conn.stream.write(resp)
return err
return conn.stream.Write(resp)
}

func (conn *Conn) writeError(id string, code int, err error, data interface{}) error {
Expand All @@ -197,6 +195,5 @@ func (conn *Conn) writeError(id string, code int, err error, data interface{}) e
resp.Error = NewError(code, err.Error())
}

_, err = conn.stream.write(resp)
return err
return conn.stream.Write(resp)
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,46 +23,29 @@ var (
contentLength = http.CanonicalHeaderKey("content-length")
)

type stream struct {
type httpStream struct {
in io.Reader
out io.Writer
outMux sync.Mutex
}

func newStream(in io.Reader, out io.Writer) *stream {
return &stream{
// NewHTTPStream 声明基于 HTTP 的 streamer 实例
func NewHTTPStream(in io.Reader, out io.Writer) Streamer {
return &httpStream{
in: in,
out: out,
}
}

func (s *stream) readRequest(req *Request) error {
data, err := s.read()
if err != nil {
return err
}

return json.Unmarshal(data, req)
}

func (s *stream) readResponse(resp *Response) error {
data, err := s.read()
if err != nil {
return err
}

return json.Unmarshal(data, resp)
}

// 读取内容,先验证报头,并返回实际 body 的内容
func (s *stream) read() ([]byte, error) {
// Read 读取内容,先验证报头,并返回实际 body 的内容
func (s *httpStream) Read(v interface{}) error {
buf := bufio.NewReader(s.in)
var l int

for {
line, err := buf.ReadString('\n')
if err != nil {
return nil, err
return err
}
line = strings.TrimSpace(line)
if line == "" {
Expand All @@ -71,38 +54,38 @@ func (s *stream) read() ([]byte, error) {

index := strings.IndexByte(line, ':')
if index <= 0 {
return nil, locale.Errorf(locale.ErrInvalidHeaderFormat)
return locale.Errorf(locale.ErrInvalidHeaderFormat)
}

v := strings.TrimSpace(line[index+1:])
switch http.CanonicalHeaderKey(strings.TrimSpace(line[:index])) {
case contentLength:
l, err = strconv.Atoi(v)
if err != nil {
return nil, err
return err
}
case contentType:
if err := validContentType(v); err != nil {
return nil, err
return err
}
default: // 忽略其它报头
}
}

if l <= 0 {
return nil, locale.Errorf(locale.ErrInvalidContentLength)
return locale.Errorf(locale.ErrInvalidContentLength)
}

data := make([]byte, l)
n, err := io.ReadFull(buf, data)
if err != nil {
return nil, err
return err
}
if n == 0 {
return nil, locale.Errorf(locale.ErrBodyIsEmpty)
return locale.Errorf(locale.ErrBodyIsEmpty)
}

return data[:n], nil
return json.Unmarshal(data[:n], v)
}

func validContentType(header string) error {
Expand All @@ -120,20 +103,20 @@ func validContentType(header string) error {
return nil
}

func (s *stream) write(obj interface{}) (int, error) {
func (s *httpStream) Write(obj interface{}) error {
data, err := json.Marshal(obj)
if err != nil {
return 0, err
return err
}

s.outMux.Lock()
defer s.outMux.Unlock()

n, err := fmt.Fprintf(s.out, "%s: %s\r\n%s: %d\r\n\r\n", contentType, charset, contentLength, len(data))
_, err = fmt.Fprintf(s.out, "%s: %s\r\n%s: %d\r\n\r\n", contentType, charset, contentLength, len(data))
if err != nil {
return 0, err
return err
}

size, err := s.out.Write(data)
return n + size, err
_, err = s.out.Write(data)
return err
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,72 +9,74 @@ import (
"github.com/issue9/assert"
)

func TestStream_readRequest(t *testing.T) {
var _ Streamer = &httpStream{}

func TestHTTPStream_Read(t *testing.T) {
a := assert.New(t)

r := new(bytes.Buffer)
s := newStream(r, nil)
s := NewHTTPStream(r, nil)
a.NotNil(s)
r.WriteString(`Content-Type: text/json;charset=utf-8
Content-Length:26
{"jsonrpc":"2.0","id":"1"}`)
rr := &Request{}
a.NotError(s.readRequest(rr))
a.NotError(s.Read(rr))
a.Equal(rr.Version, Version).Equal(rr.ID, "1")

// 无效的 content-length
r = new(bytes.Buffer)
s = newStream(r, nil)
s = NewHTTPStream(r, nil)
a.NotNil(s)
r.WriteString(`Content-Type: text/json;charset=utf-8
Content-Length:0
{"jsonrpc":"2.0","id":"1"}`)
rr = &Request{}
a.Error(s.readRequest(rr))
a.Error(s.Read(rr))

// content-type 中未指定 charset
r = new(bytes.Buffer)
s = newStream(r, nil)
s = NewHTTPStream(r, nil)
a.NotNil(s)
r.WriteString(`Content-Type: text/json;charset-xx=utf-8
Content-Length:26
{"jsonrpc":"2.0","id":"1"}`)
rr = &Request{}
a.NotError(s.readRequest(rr))
a.NotError(s.Read(rr))

// content-length 格式无效
r = new(bytes.Buffer)
s = newStream(r, nil)
s = NewHTTPStream(r, nil)
a.NotNil(s)
r.WriteString(`Content-Type: text/json;charset-xx=utf-8
Content-Length:26xx
{"jsonrpc":"2.0","id":"1"}`)
rr = &Request{}
a.Error(s.readRequest(rr))
a.Error(s.Read(rr))

// content-type 是指定了非 utf-8 编码
r = new(bytes.Buffer)
s = newStream(r, nil)
s = NewHTTPStream(r, nil)
a.NotNil(s)
r.WriteString(`Content-Type: text/json;charset-xx=utf-7
Content-Length:26xx
{"jsonrpc":"2.0","id":"1"}`)
rr = &Request{}
a.Error(s.readRequest(rr))
a.Error(s.Read(rr))
}

func TestStream_write(t *testing.T) {
func TestHTTPStream_Write(t *testing.T) {
a := assert.New(t)
w := new(bytes.Buffer)
s := newStream(nil, w)
s := NewHTTPStream(nil, w)
a.NotNil(s)

size, err := s.write(&Response{
err := s.Write(&Response{
Version: "1.0.1",
Error: &Error{
Code: CodeParseError,
Expand All @@ -83,7 +85,7 @@ func TestStream_write(t *testing.T) {
ID: "1",
})
a.NotError(err)
a.NotEmpty(w.Bytes()).True(size > 0)
a.NotEmpty(w.Bytes())
}

func TestValidContentType(t *testing.T) {
Expand Down
10 changes: 7 additions & 3 deletions internal/lsp/jsonrpc/jsonrpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
// https://wiki.geekdream.com/Specification/json-rpc_2.0.html
package jsonrpc

import (
"encoding/json"
)
import "encoding/json"

// Version json-rpc 的版本
const Version = "2.0"
Expand All @@ -30,6 +28,12 @@ const (
CodeContentModified = -32801
)

// Streamer 用于操作 jsonrpc 的传输层接口
type Streamer interface {
Read(interface{}) error
Write(interface{}) error
}

// Request 请求对象
type Request struct {
// 指定 JSON-RPC 协议版本的字符串
Expand Down
24 changes: 24 additions & 0 deletions internal/lsp/jsonrpc/wsstream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// SPDX-License-Identifier: MIT

package jsonrpc

import (
"github.com/gorilla/websocket"
)

type websocketStream struct {
conn *websocket.Conn
}

// NewWebsocketStream 声明基于 websocket 的 streamer 实例
func NewWebsocketStream(conn *websocket.Conn) Streamer {
return &websocketStream{conn: conn}
}

func (s *websocketStream) Read(v interface{}) error {
return s.conn.ReadJSON(v)
}

func (s *websocketStream) Write(v interface{}) error {
return s.conn.WriteJSON(v)
}
5 changes: 5 additions & 0 deletions internal/lsp/jsonrpc/wsstream_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// SPDX-License-Identifier: MIT

package jsonrpc

var _ Streamer = &websocketStream{}

0 comments on commit 6888549

Please sign in to comment.