diff --git a/wsproxy/websocket_proxy.go b/wsproxy/websocket_proxy.go index 7092162..5c65918 100644 --- a/wsproxy/websocket_proxy.go +++ b/wsproxy/websocket_proxy.go @@ -13,10 +13,17 @@ import ( "golang.org/x/net/context" ) -// MethodOverrideParam defines the special URL parameter that is translated into the subsequent proxied streaming http request's method. -// -// Deprecated: it is preferable to use the Options parameters to WebSocketProxy to supply parameters. -var MethodOverrideParam = "method" +var ( + // MethodOverrideParam defines the special URL parameter that is translated into the subsequent + // proxied streaming http request's method. + // + // Deprecated: it is preferable to use the Options parameters to WebSocketProxy to supply parameters. + MethodOverrideParam = "method" + + // defaultMethodTypeParam defines the default URL parameter name for the special parameter that is + // translated into the subsequent proxied streaming http request's gRPC method type. + defaultMethodTypeParam = "methodType" +) // TokenCookieName defines the cookie name that is translated to an 'Authorization: Bearer' header in the streaming http request's headers. // @@ -32,6 +39,7 @@ type Proxy struct { logger Logger maxRespBodyBufferBytes int methodOverrideParam string + methodTypeParam string tokenCookieName string requestMutator RequestMutatorFunc headerForwarder func(header string) bool @@ -46,6 +54,28 @@ type Logger interface { Debugln(...interface{}) } +// MethodType defines the type of gRPC method. +type MethodType string + +const ( + MethodTypeUnary MethodType = "Unary" + MethodTypeClientStreaming MethodType = "ClientStreaming" + MethodTypeServerStreaming MethodType = "ServerStreaming" + MethodTypeDuplexStreaming MethodType = "DuplexStreaming" +) + +func (mt MethodType) String() string { + return string(mt) +} + +func (mt MethodType) IsValid() bool { + switch mt { + case MethodTypeUnary, MethodTypeClientStreaming, MethodTypeServerStreaming, MethodTypeDuplexStreaming: + return true + } + return false +} + func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { if !websocket.IsWebSocketUpgrade(r) { p.h.ServeHTTP(w, r) @@ -73,6 +103,17 @@ func WithMethodParamOverride(param string) Option { } } +// WithMethodTypeParam allows specification of gRPC method type. +// Default name of the query parameter is "methodType". +// +// The query parameter expects one of the values: +// "Unary", "ClientStreaming", "ServerStreaming", "DuplexStreaming". +func WithMethodTypeParam(param string) Option { + return func(p *Proxy) { + p.methodTypeParam = param + } +} + // WithTokenCookieName allows specification of the cookie that is supplied as an upstream 'Authorization: Bearer' http header. func WithTokenCookieName(param string) Option { return func(p *Proxy) { @@ -130,9 +171,12 @@ func defaultHeaderForwarder(header string) bool { // The cookie name is specified by the TokenCookieName value. // // example: -// Sec-Websocket-Protocol: Bearer, foobar +// +// Sec-Websocket-Protocol: Bearer, foobar +// // is converted to: -// Authorization: Bearer foobar +// +// Authorization: Bearer foobar // // Method can be overwritten with the MethodOverrideParam get parameter in the requested URL func WebsocketProxy(h http.Handler, opts ...Option) http.Handler { @@ -140,6 +184,7 @@ func WebsocketProxy(h http.Handler, opts ...Option) http.Handler { h: h, logger: logrus.New(), methodOverrideParam: MethodOverrideParam, + methodTypeParam: defaultMethodTypeParam, tokenCookieName: TokenCookieName, headerForwarder: defaultHeaderForwarder, } @@ -166,6 +211,8 @@ func isClosedConnError(err error) bool { func (p *Proxy) proxy(w http.ResponseWriter, r *http.Request) { var responseHeader http.Header + var methodType MethodType + // If Sec-WebSocket-Protocol starts with "Bearer", respond in kind. // TODO(tmc): consider customizability/extension point here. if strings.HasPrefix(r.Header.Get("Sec-WebSocket-Protocol"), "Bearer") { @@ -204,6 +251,15 @@ func (p *Proxy) proxy(w http.ResponseWriter, r *http.Request) { if m := r.URL.Query().Get(p.methodOverrideParam); m != "" { request.Method = m } + if m := r.URL.Query().Get(p.methodTypeParam); m != "" { + methodType = MethodType(m) + if !methodType.IsValid() { + p.logger.Warnln("invalid", p.methodTypeParam, "parameter:", m, + "expected one of:", MethodTypeUnary, MethodTypeClientStreaming, + MethodTypeServerStreaming, MethodTypeDuplexStreaming) + return + } + } if p.requestMutator != nil { request = p.requestMutator(r, request) @@ -225,15 +281,16 @@ func (p *Proxy) proxy(w http.ResponseWriter, r *http.Request) { }() // read loop -- take messages from websocket and write to http request - go func() { - if p.pingInterval > 0 && p.pingWait > 0 && p.pongWait > 0 { - conn.SetReadDeadline(time.Now().Add(p.pongWait)) - conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(p.pongWait)); return nil }) - } - defer func() { - cancelFn() - }() - for { + switch methodType { + case MethodTypeUnary, MethodTypeServerStreaming: + go func() { + if p.pingInterval > 0 && p.pingWait > 0 && p.pongWait > 0 { + conn.SetReadDeadline(time.Now().Add(p.pongWait)) + conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(p.pongWait)); return nil }) + } + defer func() { + cancelFn() + }() select { case <-ctx.Done(): p.logger.Debugln("read loop done") @@ -259,8 +316,62 @@ func (p *Proxy) proxy(w http.ResponseWriter, r *http.Request) { p.logger.Warnln("[read] error writing message to upstream http server:", err) return } - } - }() + // Close request body since server doesn't expect any more data + requestBodyW.Close() + messageType, _, err := conn.ReadMessage() + if err != nil { + if isClosedConnError(err) { + p.logger.Debugln("[read] websocket closed:", err) + return + } + p.logger.Warnln("error reading websocket message:", err) + return + } + if messageType == websocket.CloseMessage { + p.logger.Debugln("[read] websocket closed") + return + } + p.logger.Debugln("[read] unexpected message type:", messageType) + }() + case MethodTypeClientStreaming, MethodTypeDuplexStreaming: + go func() { + if p.pingInterval > 0 && p.pingWait > 0 && p.pongWait > 0 { + conn.SetReadDeadline(time.Now().Add(p.pongWait)) + conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(p.pongWait)); return nil }) + } + defer func() { + cancelFn() + }() + for { + select { + case <-ctx.Done(): + p.logger.Debugln("read loop done") + return + default: + } + p.logger.Debugln("[read] reading from socket.") + _, payload, err := conn.ReadMessage() + if err != nil { + if isClosedConnError(err) { + p.logger.Debugln("[read] websocket closed:", err) + return + } + p.logger.Warnln("error reading websocket message:", err) + return + } + p.logger.Debugln("[read] read payload:", string(payload)) + p.logger.Debugln("[read] writing to requestBody:") + n, err := requestBodyW.Write(payload) + requestBodyW.Write([]byte("\n")) + p.logger.Debugln("[read] wrote to requestBody", n) + if err != nil { + p.logger.Warnln("[read] error writing message to upstream http server:", err) + return + } + } + }() + } + // ping write loop if p.pingInterval > 0 && p.pingWait > 0 && p.pongWait > 0 { go func() { @@ -338,12 +449,15 @@ func transformSubProtocolHeader(header string) string { func (w *inMemoryResponseWriter) Write(b []byte) (int, error) { return w.Writer.Write(b) } + func (w *inMemoryResponseWriter) Header() http.Header { return w.header } + func (w *inMemoryResponseWriter) WriteHeader(code int) { w.code = code } + func (w *inMemoryResponseWriter) CloseNotify() <-chan bool { return w.closed }