Skip to content

Commit

Permalink
feat(internal/lsp/jsonrpc): 完成 conn 相关功能
Browse files Browse the repository at this point in the history
update #56
  • Loading branch information
caixw committed Apr 29, 2020
1 parent cbf2fe5 commit 706a377
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 22 deletions.
162 changes: 160 additions & 2 deletions internal/lsp/jsonrpc/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,175 @@

package jsonrpc

import "io"
import (
"context"
"encoding/json"
"io"
"log"
"reflect"
"strconv"
"sync"

"github.com/caixw/apidoc/v6/internal/locale"
)

// ServerFunc 每个服务的函数签名
type ServerFunc func(in, out interface{}) error

// Conn 连接对象,json-rpc 客户端和服务端是对等的,两者都使用 conn 初始化。
type Conn struct {
sequence int64
errlog *log.Logger
stream *stream
servers sync.Map
}

type handler struct {
f ServerFunc
in, out reflect.Type
}

// NewConn 声明新的 Conn 实例
func NewConn(in io.Reader, out io.Writer) *Conn {
func NewConn(errlog *log.Logger, in io.Reader, out io.Writer) *Conn {
return &Conn{
errlog: errlog,
stream: newStream(in, out),
}
}

// Register 注册一个新的服务
//
// 返回值表示是否添加成功,在已经存在相同值时,会添加失败。
func (conn *Conn) Register(method string, f ServerFunc) bool {
if _, found := conn.servers.Load(method); found {
return false
}

conn.servers.Store(method, newHandler(f))
return true
}

func newHandler(f ServerFunc) *handler {
t := reflect.TypeOf(f)
return &handler{
f: f,
in: t.Method(0).Type.Elem(),
out: t.Method(0).Type.Elem(),
}
}

// Notify 发送通知信息
func (conn *Conn) Notify(method string, in interface{}) error {
return conn.send(true, method, in, nil)
}

// Send 发送请求内容,并获取其返回的数据
func (conn *Conn) Send(method string, in, out interface{}) error {
return conn.send(false, method, in, out)
}

func (conn *Conn) send(notify bool, method string, in, out interface{}) error {
data, err := json.Marshal(in)
if err != nil {
return err
}

req := &Request{
Version: Version,
Method: method,
Params: (*json.RawMessage)(&data),
}
if !notify {
req.ID = strconv.FormatInt(conn.sequence, 10)
}

if _, err = conn.stream.write(req); err != nil {
return err
}

if notify {
return nil
}

resp := &Response{}
if err = conn.stream.readResponse(resp); err != nil {
return err
}

if resp.Error != nil {
return resp.Error
}

if req.ID != resp.ID {
return NewError(CodeInvalidParams, locale.Sprintf(locale.VersionInCompatible))
}

return json.Unmarshal([]byte(*resp.Result), out)
}

// Serve 作为服务端运行
func (conn *Conn) Serve(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return nil
default:
if err := conn.serve(); err != nil {
return err
}
}
}
}

// 作为服务端,根据参数查找和执行服务
func (conn *Conn) serve() error {
req := &Request{}
if err := conn.stream.readRequest(req); err != nil {
return conn.writeError("", CodeParseError, err, nil)
}

f, found := conn.servers.Load(req.Method)
if !found {
return conn.writeError("", CodeMethodNotFound, locale.Errorf(locale.ErrInvalidValue), nil)
}
h := f.(*handler)

in := reflect.New(h.in).Interface()
if err := json.Unmarshal([]byte(*req.Params), in); err != nil {
return conn.writeError("", CodeParseError, err, nil)
}

out := reflect.New(h.out).Interface()
if err := h.f(in, out); err != nil {
return conn.writeError("", CodeInternalError, err, nil)
}

data, err := json.Marshal(out)
if err != nil {
return err
}

resp := &Response{
Version: Version,
Result: (*json.RawMessage)(&data),
ID: req.ID,
}
_, err = conn.stream.write(resp)
return err
}

func (conn *Conn) writeError(id string, code int, err error, data interface{}) error {
resp := &Response{
ID: id,
Version: Version,
}

if err2, ok := err.(*Error); ok {
resp.Error = err2
} else {
resp.Error = NewError(code, err.Error())
}

_, err = conn.stream.write(resp)
return err
}
4 changes: 3 additions & 1 deletion internal/lsp/jsonrpc/jsonrpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
// https://wiki.geekdream.com/Specification/json-rpc_2.0.html
package jsonrpc

import "encoding/json"
import (
"encoding/json"
)

// Version json-rpc 的版本
const Version = "2.0"
Expand Down
44 changes: 31 additions & 13 deletions internal/lsp/jsonrpc/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,33 @@ func newStream(in io.Reader, out io.Writer) *stream {
}
}

func (s *stream) read(req *Request) error {
func (s *stream) readRequest(req *Request) error {
data, err := s.read()
if err != nil {
return err
}

return json.Unmarshal(data, req)
}

func (s *stream) readResponse(resp *Response) error {
data, err := s.read()
if err != nil {
return err
}

return json.Unmarshal(data, resp)
}

// 读取内容,先验证报头,并返回实际 body 的内容
func (s *stream) read() ([]byte, error) {
buf := bufio.NewReader(s.in)
var l int

for {
line, err := buf.ReadString('\n')
if err != nil {
return NewError(CodeParseError, err.Error())
return nil, err
}
line = strings.TrimSpace(line)
if line == "" {
Expand All @@ -52,39 +71,38 @@ func (s *stream) read(req *Request) error {

index := strings.IndexByte(line, ':')
if index <= 0 {
return NewError(CodeParseError, locale.Sprintf(locale.ErrInvalidHeaderFormat))
return nil, locale.Errorf(locale.ErrInvalidHeaderFormat)
}

v := strings.TrimSpace(line[index+1:])
switch http.CanonicalHeaderKey(strings.TrimSpace(line[:index])) {
case contentLength:
l, err = strconv.Atoi(v)
if err != nil {
return NewError(CodeParseError, err.Error())
return nil, err
}
case contentType:
if err := validContentType(v); err != nil {
return err
return nil, err
}
default: // 忽略其它报头
}
}

if l <= 0 {
return NewError(CodeParseError, locale.Sprintf(locale.ErrInvalidContentLength))
return nil, locale.Errorf(locale.ErrInvalidContentLength)
}

data := make([]byte, l)
n, err := io.ReadFull(buf, data)
if err != nil {
return NewError(CodeParseError, err.Error())
return nil, err
}
if n == 0 {
return NewError(CodeParseError, locale.Sprintf(locale.ErrBodyIsEmpty))
return nil, locale.Errorf(locale.ErrBodyIsEmpty)
}

data = data[:n]
return json.Unmarshal(data, req)
return data[:n], nil
}

func validContentType(header string) error {
Expand All @@ -95,15 +113,15 @@ func validContentType(header string) error {
if index > 0 &&
strings.ToLower(strings.TrimSpace(pair[:index])) == "charset" &&
strings.ToLower(strings.TrimSpace(pair[index+1:])) != charset {
return NewError(CodeParseError, locale.Sprintf(locale.ErrInvalidContentTypeCharset))
return locale.Errorf(locale.ErrInvalidContentTypeCharset)
}
}

return nil
}

func (s *stream) write(resp *Response) (int, error) {
data, err := json.Marshal(resp)
func (s *stream) write(obj interface{}) (int, error) {
data, err := json.Marshal(obj)
if err != nil {
return 0, err
}
Expand Down
12 changes: 6 additions & 6 deletions internal/lsp/jsonrpc/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/issue9/assert"
)

func TestStream_read(t *testing.T) {
func TestStream_readRequest(t *testing.T) {
a := assert.New(t)

r := new(bytes.Buffer)
Expand All @@ -20,7 +20,7 @@ func TestStream_read(t *testing.T) {
{"jsonrpc":"2.0","id":"1"}`)
rr := &Request{}
a.NotError(s.read(rr))
a.NotError(s.readRequest(rr))
a.Equal(rr.Version, Version).Equal(rr.ID, "1")

// 无效的 content-length
Expand All @@ -32,7 +32,7 @@ func TestStream_read(t *testing.T) {
{"jsonrpc":"2.0","id":"1"}`)
rr = &Request{}
a.Error(s.read(rr))
a.Error(s.readRequest(rr))

// content-type 中未指定 charset
r = new(bytes.Buffer)
Expand All @@ -43,7 +43,7 @@ func TestStream_read(t *testing.T) {
{"jsonrpc":"2.0","id":"1"}`)
rr = &Request{}
a.NotError(s.read(rr))
a.NotError(s.readRequest(rr))

// content-length 格式无效
r = new(bytes.Buffer)
Expand All @@ -54,7 +54,7 @@ func TestStream_read(t *testing.T) {
{"jsonrpc":"2.0","id":"1"}`)
rr = &Request{}
a.Error(s.read(rr))
a.Error(s.readRequest(rr))

// content-type 是指定了非 utf-8 编码
r = new(bytes.Buffer)
Expand All @@ -65,7 +65,7 @@ func TestStream_read(t *testing.T) {
{"jsonrpc":"2.0","id":"1"}`)
rr = &Request{}
a.Error(s.read(rr))
a.Error(s.readRequest(rr))
}

func TestStream_write(t *testing.T) {
Expand Down

0 comments on commit 706a377

Please sign in to comment.