Skip to content

Add MethodType query parameter #43

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 131 additions & 17 deletions wsproxy/websocket_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,17 @@ import (
"golang.org/x/net/context"
)

// MethodOverrideParam defines the special URL parameter that is translated into the subsequent proxied streaming http request's method.
//
// Deprecated: it is preferable to use the Options parameters to WebSocketProxy to supply parameters.
var MethodOverrideParam = "method"
var (
// MethodOverrideParam defines the special URL parameter that is translated into the subsequent
// proxied streaming http request's method.
//
// Deprecated: it is preferable to use the Options parameters to WebSocketProxy to supply parameters.
MethodOverrideParam = "method"

// defaultMethodTypeParam defines the default URL parameter name for the special parameter that is
// translated into the subsequent proxied streaming http request's gRPC method type.
defaultMethodTypeParam = "methodType"
)

// TokenCookieName defines the cookie name that is translated to an 'Authorization: Bearer' header in the streaming http request's headers.
//
Expand All @@ -32,6 +39,7 @@ type Proxy struct {
logger Logger
maxRespBodyBufferBytes int
methodOverrideParam string
methodTypeParam string
tokenCookieName string
requestMutator RequestMutatorFunc
headerForwarder func(header string) bool
Expand All @@ -46,6 +54,28 @@ type Logger interface {
Debugln(...interface{})
}

// MethodType defines the type of gRPC method.
type MethodType string

const (
MethodTypeUnary MethodType = "Unary"
MethodTypeClientStreaming MethodType = "ClientStreaming"
MethodTypeServerStreaming MethodType = "ServerStreaming"
MethodTypeDuplexStreaming MethodType = "DuplexStreaming"
)

func (mt MethodType) String() string {
return string(mt)
}

func (mt MethodType) IsValid() bool {
switch mt {
case MethodTypeUnary, MethodTypeClientStreaming, MethodTypeServerStreaming, MethodTypeDuplexStreaming:
return true
}
return false
}

func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !websocket.IsWebSocketUpgrade(r) {
p.h.ServeHTTP(w, r)
Expand Down Expand Up @@ -73,6 +103,17 @@ func WithMethodParamOverride(param string) Option {
}
}

// WithMethodTypeParam allows specification of gRPC method type.
// Default name of the query parameter is "methodType".
//
// The query parameter expects one of the values:
// "Unary", "ClientStreaming", "ServerStreaming", "DuplexStreaming".
func WithMethodTypeParam(param string) Option {
return func(p *Proxy) {
p.methodTypeParam = param
}
}

// WithTokenCookieName allows specification of the cookie that is supplied as an upstream 'Authorization: Bearer' http header.
func WithTokenCookieName(param string) Option {
return func(p *Proxy) {
Expand Down Expand Up @@ -130,16 +171,20 @@ func defaultHeaderForwarder(header string) bool {
// The cookie name is specified by the TokenCookieName value.
//
// example:
// Sec-Websocket-Protocol: Bearer, foobar
//
// Sec-Websocket-Protocol: Bearer, foobar
//
// is converted to:
// Authorization: Bearer foobar
//
// Authorization: Bearer foobar
//
// Method can be overwritten with the MethodOverrideParam get parameter in the requested URL
func WebsocketProxy(h http.Handler, opts ...Option) http.Handler {
p := &Proxy{
h: h,
logger: logrus.New(),
methodOverrideParam: MethodOverrideParam,
methodTypeParam: defaultMethodTypeParam,
tokenCookieName: TokenCookieName,
headerForwarder: defaultHeaderForwarder,
}
Expand All @@ -166,6 +211,8 @@ func isClosedConnError(err error) bool {

func (p *Proxy) proxy(w http.ResponseWriter, r *http.Request) {
var responseHeader http.Header
var methodType MethodType

// If Sec-WebSocket-Protocol starts with "Bearer", respond in kind.
// TODO(tmc): consider customizability/extension point here.
if strings.HasPrefix(r.Header.Get("Sec-WebSocket-Protocol"), "Bearer") {
Expand Down Expand Up @@ -204,6 +251,15 @@ func (p *Proxy) proxy(w http.ResponseWriter, r *http.Request) {
if m := r.URL.Query().Get(p.methodOverrideParam); m != "" {
request.Method = m
}
if m := r.URL.Query().Get(p.methodTypeParam); m != "" {
methodType = MethodType(m)
if !methodType.IsValid() {
p.logger.Warnln("invalid", p.methodTypeParam, "parameter:", m,
"expected one of:", MethodTypeUnary, MethodTypeClientStreaming,
MethodTypeServerStreaming, MethodTypeDuplexStreaming)
return
}
}

if p.requestMutator != nil {
request = p.requestMutator(r, request)
Expand All @@ -225,15 +281,16 @@ func (p *Proxy) proxy(w http.ResponseWriter, r *http.Request) {
}()

// read loop -- take messages from websocket and write to http request
go func() {
if p.pingInterval > 0 && p.pingWait > 0 && p.pongWait > 0 {
conn.SetReadDeadline(time.Now().Add(p.pongWait))
conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(p.pongWait)); return nil })
}
defer func() {
cancelFn()
}()
for {
switch methodType {
case MethodTypeUnary, MethodTypeServerStreaming:
go func() {
if p.pingInterval > 0 && p.pingWait > 0 && p.pongWait > 0 {
conn.SetReadDeadline(time.Now().Add(p.pongWait))
conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(p.pongWait)); return nil })
}
defer func() {
cancelFn()
}()
select {
case <-ctx.Done():
p.logger.Debugln("read loop done")
Expand All @@ -259,8 +316,62 @@ func (p *Proxy) proxy(w http.ResponseWriter, r *http.Request) {
p.logger.Warnln("[read] error writing message to upstream http server:", err)
return
}
}
}()
// Close request body since server doesn't expect any more data
requestBodyW.Close()
messageType, _, err := conn.ReadMessage()
if err != nil {
if isClosedConnError(err) {
p.logger.Debugln("[read] websocket closed:", err)
return
}
p.logger.Warnln("error reading websocket message:", err)
return
}
if messageType == websocket.CloseMessage {
p.logger.Debugln("[read] websocket closed")
return
}
p.logger.Debugln("[read] unexpected message type:", messageType)
}()
case MethodTypeClientStreaming, MethodTypeDuplexStreaming:
go func() {
if p.pingInterval > 0 && p.pingWait > 0 && p.pongWait > 0 {
conn.SetReadDeadline(time.Now().Add(p.pongWait))
conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(p.pongWait)); return nil })
}
defer func() {
cancelFn()
}()
for {
select {
case <-ctx.Done():
p.logger.Debugln("read loop done")
return
default:
}
p.logger.Debugln("[read] reading from socket.")
_, payload, err := conn.ReadMessage()
if err != nil {
if isClosedConnError(err) {
p.logger.Debugln("[read] websocket closed:", err)
return
}
p.logger.Warnln("error reading websocket message:", err)
return
}
p.logger.Debugln("[read] read payload:", string(payload))
p.logger.Debugln("[read] writing to requestBody:")
n, err := requestBodyW.Write(payload)
requestBodyW.Write([]byte("\n"))
p.logger.Debugln("[read] wrote to requestBody", n)
if err != nil {
p.logger.Warnln("[read] error writing message to upstream http server:", err)
return
}
}
}()
}

// ping write loop
if p.pingInterval > 0 && p.pingWait > 0 && p.pongWait > 0 {
go func() {
Expand Down Expand Up @@ -338,12 +449,15 @@ func transformSubProtocolHeader(header string) string {
func (w *inMemoryResponseWriter) Write(b []byte) (int, error) {
return w.Writer.Write(b)
}

func (w *inMemoryResponseWriter) Header() http.Header {
return w.header
}

func (w *inMemoryResponseWriter) WriteHeader(code int) {
w.code = code
}

func (w *inMemoryResponseWriter) CloseNotify() <-chan bool {
return w.closed
}
Expand Down