diff --git a/websocket.go b/websocket.go index 07c73a3b..97108ee5 100644 --- a/websocket.go +++ b/websocket.go @@ -4,6 +4,7 @@ import ( "bufio" "crypto/tls" "io" + "net" "net/http" "net/url" "strings" @@ -32,7 +33,14 @@ func (proxy *ProxyHttpServer) serveWebsocketTLS( tlsConfig *tls.Config, clientConn *tls.Conn, ) { - targetURL := url.URL{Scheme: "wss", Host: req.URL.Host, Path: req.URL.Path} + host := req.URL.Host + // Port is optional in req.URL.Host, in this case SplitHostPort returns + // an error, and we add the default port + _, port, err := net.SplitHostPort(req.URL.Host) + if err != nil || port == "" { + host = net.JoinHostPort(req.URL.Host, "443") + } + targetURL := url.URL{Scheme: "wss", Host: host, Path: req.URL.Path} // Connect to upstream targetConn, err := tls.Dial("tcp", targetURL.Host, tlsConfig) @@ -58,7 +66,14 @@ func (proxy *ProxyHttpServer) serveWebsocketHttpOverTLS( req *http.Request, clientConn *tls.Conn, ) { - targetURL := url.URL{Scheme: "ws", Host: req.URL.Host, Path: req.URL.Path} + host := req.URL.Host + // Port is optional in req.URL.Host, in this case SplitHostPort returns + // an error, and we add the default port + _, port, err := net.SplitHostPort(req.URL.Host) + if err != nil || port == "" { + host = net.JoinHostPort(req.URL.Host, "80") + } + targetURL := url.URL{Scheme: "ws", Host: host, Path: req.URL.Path} // Connect to upstream targetConn, err := proxy.connectDial(ctx, "tcp", targetURL.Host)