Skip to content

Commit

Permalink
migrate from gorilla/websocket to gobwas/ws
Browse files Browse the repository at this point in the history
  • Loading branch information
wwqgtxx committed Oct 17, 2023
1 parent de2abe2 commit 8827a26
Show file tree
Hide file tree
Showing 14 changed files with 1,120 additions and 712 deletions.
140 changes: 72 additions & 68 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package client

import (
"crypto/tls"
"fmt"
"log"
"net"
"net/http"
Expand All @@ -15,12 +16,13 @@ import (
"github.com/wwqgtxx/wstunnel/config"
"github.com/wwqgtxx/wstunnel/fallback"
"github.com/wwqgtxx/wstunnel/listener"
"github.com/wwqgtxx/wstunnel/proxy"
"github.com/wwqgtxx/wstunnel/tunnel"
"github.com/wwqgtxx/wstunnel/utils"

"github.com/gorilla/websocket"
)

const DialTimeout = 8 * time.Second

type client struct {
common.ClientImpl
serverWSPath string
Expand Down Expand Up @@ -73,21 +75,22 @@ func (c *client) GetServerWSPath() string {
}

type wsClientImpl struct {
header http.Header
wsUrl string
wsDialer *websocket.Dialer
ed uint32
proxy string
header http.Header
wsUrl *url.URL
tlsConfig *tls.Config
netDial proxy.NetDialerFunc
ed uint32
proxy string
}

type tcpClientImpl struct {
targetAddress string
netDial NetDialerFunc
netDial proxy.NetDialerFunc
proxy string
}

func (c *wsClientImpl) Target() string {
return c.wsUrl
return c.wsUrl.String()
}

func (c *wsClientImpl) Proxy() string {
Expand All @@ -97,12 +100,12 @@ func (c *wsClientImpl) Proxy() string {
func (c *wsClientImpl) Handle(tcp net.Conn) {
defer tcp.Close()
log.Println("Incoming --> ", tcp.RemoteAddr(), " --> ", c.Target(), c.Proxy())
header, edBuf, err := utils.EncodeXray0rtt(tcp, c.ed)
edBuf, err := utils.PrepareXray0rtt(tcp, c.ed)
if err != nil {
log.Println(err)
return
}
conn, err := c.Dial(edBuf, header)
conn, err := c.Dial(edBuf, nil)
if err != nil {
log.Println(err)
return
Expand All @@ -112,7 +115,7 @@ func (c *wsClientImpl) Handle(tcp net.Conn) {
}

func (c *wsClientImpl) Dial(edBuf []byte, inHeader http.Header) (common.ClientConn, error) {
header := c.header
var header http.Header
if len(inHeader) > 0 {
// copy from inHeader
header = inHeader.Clone()
Expand Down Expand Up @@ -141,43 +144,48 @@ func (c *wsClientImpl) Dial(edBuf []byte, inHeader http.Header) (common.ClientCo
edBuf, _ = utils.DecodeEd(secProtocol)
}
}
} else {
// copy from c.header
header = c.header.Clone()
}
if c.ed > 0 && len(edBuf) > 0 {
header.Set("Sec-WebSocket-Protocol", utils.EncodeEd(edBuf))
edBuf = nil
}

wsConn, header, err := utils.ClientWebsocketDial(*c.wsUrl, header, c.netDial, c.tlsConfig, DialTimeout)
log.Println("Dial to", c.Target(), c.Proxy(), "with", header)
ws, resp, err := c.wsDialer.Dial(c.Target(), header)
if resp != nil {
log.Println("Dial", c.Target(), c.Proxy(), "get response", resp.Header)
if err != nil {
return nil, err
}

if len(edBuf) > 0 {
err = ws.WriteMessage(websocket.BinaryMessage, edBuf)
_, err = wsConn.Write(edBuf)
if err != nil {
return nil, err
}
}
return &wsClientConn{ws: ws}, err
return &wsClientConn{wsConn: wsConn}, err
}

type wsClientConn struct {
ws *websocket.Conn
close sync.Once
wsConn *utils.WebsocketConn
close sync.Once
}

func (c *wsClientConn) Close() {
c.close.Do(func() {
_ = c.ws.Close()
_ = c.wsConn.Close()
})
}

func (c *wsClientConn) TunnelTcp(tcp net.Conn) {
tunnel.TcpWs(tcp, c.ws)
tunnel.Tunnel(tcp, c.wsConn)
}

func (c *wsClientConn) TunnelWs(ws *websocket.Conn) {
func (c *wsClientConn) TunnelWs(wsConn *utils.WebsocketConn) {
// fastpath for direct tunnel underlying ws connection
tunnel.TcpTcp(ws.UnderlyingConn(), c.ws.UnderlyingConn())
tunnel.Tunnel(wsConn.Conn, c.wsConn.Conn)
}

func (c *tcpClientImpl) Target() string {
Expand Down Expand Up @@ -223,11 +231,11 @@ func (c *tcpClientConn) Close() {
}

func (c *tcpClientConn) TunnelTcp(tcp net.Conn) {
tunnel.TcpTcp(tcp, c.tcp)
tunnel.Tunnel(tcp, c.tcp)
}

func (c *tcpClientConn) TunnelWs(ws *websocket.Conn) {
tunnel.TcpWs(c.tcp, ws)
func (c *tcpClientConn) TunnelWs(wsConn *utils.WebsocketConn) {
tunnel.Tunnel(c.tcp, wsConn)
}

func BuildClient(clientConfig config.ClientConfig) {
Expand Down Expand Up @@ -267,35 +275,38 @@ func parseProxy(proxyString string) (proxyUrl *url.URL, proxyStr string) {
return
}

func NewClientImpl(clientConfig config.ClientConfig) common.ClientImpl {
if len(clientConfig.TargetAddress) > 0 {
return NewTcpClientImpl(clientConfig)
} else {
return NewWsClientImpl(clientConfig)
}
}

func NewTcpClientImpl(clientConfig config.ClientConfig) common.ClientImpl {
proxyUrl, proxyStr := parseProxy(clientConfig.Proxy)

var netDial NetDialerFunc
func getNetDial(proxyUrl *url.URL) (netDial proxy.NetDialerFunc) {
tcpDialer := &net.Dialer{
Timeout: 8 * time.Second,
Timeout: DialTimeout,
}
netDial = tcpDialer.Dial

proxyDialer := proxy_FromEnvironment()
proxyDialer := proxy.FromEnvironment()
if proxyUrl != nil {
dialer, err := proxy_FromURL(proxyUrl, netDial)
dialer, err := proxy.FromURL(proxyUrl, netDial)
if err != nil {
log.Println(err)
} else {
proxyDialer = dialer
}
}
if proxyDialer != proxy_Direct {
if proxyDialer != proxy.Direct {
netDial = proxyDialer.Dial
}
return
}

func NewClientImpl(clientConfig config.ClientConfig) common.ClientImpl {
if len(clientConfig.TargetAddress) > 0 {
return NewTcpClientImpl(clientConfig)
} else {
return NewWsClientImpl(clientConfig)
}
}

func NewTcpClientImpl(clientConfig config.ClientConfig) common.ClientImpl {
proxyUrl, proxyStr := parseProxy(clientConfig.Proxy)
netDial := getNetDial(proxyUrl)

return &tcpClientImpl{
targetAddress: clientConfig.TargetAddress,
Expand All @@ -306,45 +317,38 @@ func NewTcpClientImpl(clientConfig config.ClientConfig) common.ClientImpl {

func NewWsClientImpl(clientConfig config.ClientConfig) common.ClientImpl {
proxyUrl, proxyStr := parseProxy(clientConfig.Proxy)

proxy := http.ProxyFromEnvironment
if proxyUrl != nil {
proxy = http.ProxyURL(proxyUrl)
}
netDial := getNetDial(proxyUrl)

header := http.Header{}
if len(clientConfig.WSHeaders) != 0 {
for key, value := range clientConfig.WSHeaders {
header.Add(key, value)
}
}
wsDialer := &websocket.Dialer{
Proxy: proxy,
HandshakeTimeout: 8 * time.Second,
ReadBufferSize: tunnel.BufSize,
WriteBufferSize: tunnel.BufSize,
WriteBufferPool: tunnel.WriteBufferPool,
}
wsDialer.TLSClientConfig = &tls.Config{
tlsConfig := &tls.Config{
ServerName: clientConfig.ServerName,
InsecureSkipVerify: clientConfig.SkipCertVerify,
}
var ed uint32
if u, err := url.Parse(clientConfig.WSUrl); err == nil {
if q := u.Query(); q.Get("ed") != "" {
Ed, _ := strconv.Atoi(q.Get("ed"))
ed = uint32(Ed)
q.Del("ed")
u.RawQuery = q.Encode()
clientConfig.WSUrl = u.String()
}
u, err := url.Parse(clientConfig.WSUrl)
if err != nil {
panic(fmt.Errorf("parse url %s error: %w", clientConfig.WSUrl, err))
}
if q := u.Query(); q.Get("ed") != "" {
Ed, _ := strconv.Atoi(q.Get("ed"))
ed = uint32(Ed)
q.Del("ed")
u.RawQuery = q.Encode()
//clientConfig.WSUrl = u.String()
}

return &wsClientImpl{
header: header,
wsUrl: clientConfig.WSUrl,
wsDialer: wsDialer,
ed: ed,
proxy: proxyStr,
header: header,
wsUrl: u,
netDial: netDial,
tlsConfig: tlsConfig,
ed: ed,
proxy: proxyStr,
}
}

Expand Down
Loading

0 comments on commit 8827a26

Please sign in to comment.