Skip to content

Commit

Permalink
feat: Add support for relay types MWS (#339)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ehco1996 authored Jun 8, 2024
1 parent 90d5a37 commit 68d2dc1
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 18 deletions.
24 changes: 24 additions & 0 deletions examples/mws.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"relay_configs": [
{
"listen": "127.0.0.1:1235",
"listen_type": "raw",
"transport_type": "mws",
"tcp_remotes": ["ws://0.0.0.0:2443"],
"ws_config": {
"path": "pwd",
"remote_addr": "127.0.0.1:5201"
}
},
{
"listen": "127.0.0.1:2443",
"listen_type": "mws",
"transport_type": "raw",
"tcp_remotes": ["0.0.0.0:5201"],
"ws_config": {
"path": "pwd",
"remote_addr": "127.0.0.1:5201"
}
}
]
}
6 changes: 5 additions & 1 deletion internal/constant/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,13 @@ const (

// relay type
const (
// tcp relay
RelayTypeRaw = "raw"
RelayTypeMTCP = "mtcp"

// ws relay
RelayTypeWS = "ws"
RelayTypeMWS = "mws"
RelayTypeWSS = "wss"
RelayTypeMWSS = "mwss"
RelayTypeMTCP = "mtcp"
)
45 changes: 28 additions & 17 deletions internal/relay/conf/cfg.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@ type WSConfig struct {
}

type Config struct {
Label string `json:"label,omitempty"`
Listen string `json:"listen"`
ListenType string `json:"listen_type"`
TransportType string `json:"transport_type"`
TCPRemotes []string `json:"tcp_remotes"`
UDPRemotes []string `json:"udp_remotes"`

Label string `json:"label,omitempty"`
MaxConnection int `json:"max_connection,omitempty"`
BlockedProtocols []string `json:"blocked_protocols,omitempty"`
WSConfig *WSConfig `json:"ws_config,omitempty"`
MaxConnection int `json:"max_connection,omitempty"`
BlockedProtocols []string `json:"blocked_protocols,omitempty"`

WSConfig *WSConfig `json:"ws_config,omitempty"`
}

func (r *Config) GetWSHandShakePath() string {
Expand All @@ -58,20 +59,9 @@ func (r *Config) Validate() error {
if r.Adjust() != nil {
return errors.New("adjust config failed")
}
if r.ListenType != constant.RelayTypeRaw &&
r.ListenType != constant.RelayTypeWS &&
r.ListenType != constant.RelayTypeWSS &&
r.ListenType != constant.RelayTypeMTCP &&
r.ListenType != constant.RelayTypeMWSS {
return fmt.Errorf("invalid listen type:%s", r.ListenType)
}

if r.TransportType != constant.RelayTypeRaw &&
r.TransportType != constant.RelayTypeWS &&
r.TransportType != constant.RelayTypeWSS &&
r.TransportType != constant.RelayTypeMTCP &&
r.TransportType != constant.RelayTypeMWSS {
return fmt.Errorf("invalid transport type:%s", r.ListenType)
if err := r.validateType(); err != nil {
return err
}

if r.Listen == "" {
Expand Down Expand Up @@ -176,3 +166,24 @@ func (r *Config) ToTCPRemotes() lb.RoundRobin {
func (r *Config) GetLoggerName() string {
return fmt.Sprintf("%s(%s<->%s)", r.Label, r.ListenType, r.TransportType)
}

func (r *Config) validateType() error {
if r.ListenType != constant.RelayTypeRaw &&
r.ListenType != constant.RelayTypeWS &&
r.ListenType != constant.RelayTypeMWS &&
r.ListenType != constant.RelayTypeWSS &&
r.ListenType != constant.RelayTypeMTCP &&
r.ListenType != constant.RelayTypeMWSS {
return fmt.Errorf("invalid listen type:%s", r.ListenType)
}

if r.TransportType != constant.RelayTypeRaw &&
r.TransportType != constant.RelayTypeWS &&
r.TransportType != constant.RelayTypeMWS &&
r.TransportType != constant.RelayTypeWSS &&
r.TransportType != constant.RelayTypeMTCP &&
r.TransportType != constant.RelayTypeMWSS {
return fmt.Errorf("invalid transport type:%s", r.ListenType)
}
return nil
}
4 changes: 4 additions & 0 deletions internal/transporter/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ func newRelayClient(base *baseTransporter) (RelayClient, error) {
return newRawClient(base)
case constant.RelayTypeWS:
return newWsClient(base)
case constant.RelayTypeMWS:
return newMwsClient(base)
case constant.RelayTypeWSS:
return newWssClient(base)
case constant.RelayTypeMWSS:
Expand All @@ -45,6 +47,8 @@ func NewRelayServer(cfg *conf.Config, cmgr cmgr.Cmgr) (RelayServer, error) {
return newRawServer(base)
case constant.RelayTypeWS:
return newWsServer(base)
case constant.RelayTypeMWS:
return newMwsServer(base)
case constant.RelayTypeWSS:
return newWssServer(base)
case constant.RelayTypeMWSS:
Expand Down
121 changes: 121 additions & 0 deletions internal/transporter/ws_mux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// NOTE CAN NOT use real ws frame to transport smux frame
// err: accept stream err: buffer size:8 too small to transport ws payload size:45
// so this transport just use ws protocol to handshake and then use smux protocol to transport
package transporter

import (
"context"
"net"
"net/http"
"time"

"github.com/gobwas/ws"
"github.com/labstack/echo/v4"
"github.com/xtaci/smux"

"github.com/Ehco1996/ehco/internal/metrics"
"github.com/Ehco1996/ehco/pkg/lb"
)

var (
_ RelayClient = &MwsClient{}
_ RelayServer = &MwsServer{}
_ muxServer = &MwsServer{}
)

type MwsClient struct {
*WssClient

muxTP *smuxTransporter
}

func newMwsClient(base *baseTransporter) (*MwsClient, error) {
wc, err := newWssClient(base)
if err != nil {
return nil, err
}
c := &MwsClient{WssClient: wc}
c.muxTP = NewSmuxTransporter(c.l.Named("mwss"), c.initNewSession)
return c, nil
}

func (c *MwsClient) initNewSession(ctx context.Context, addr string) (*smux.Session, error) {
rc, _, _, err := c.dialer.Dial(ctx, addr)
if err != nil {
return nil, err
}
// stream multiplex
cfg := smux.DefaultConfig()
cfg.KeepAliveDisabled = true
session, err := smux.Client(rc, cfg)
if err != nil {
return nil, err
}
c.l.Infof("init new session to: %s", rc.RemoteAddr())
return session, nil
}

func (s *MwsClient) TCPHandShake(remote *lb.Node) (net.Conn, error) {
t1 := time.Now()
addr, err := s.cfg.GetWSRemoteAddr(remote.Address)
if err != nil {
return nil, err
}
mwssc, err := s.muxTP.Dial(context.TODO(), addr)
if err != nil {
return nil, err
}
latency := time.Since(t1)
metrics.HandShakeDuration.WithLabelValues(remote.Label).Observe(float64(latency.Milliseconds()))
remote.HandShakeDuration = latency
return mwssc, nil
}

type MwsServer struct {
*WsServer
*muxServerImpl
}

func newMwsServer(base *baseTransporter) (*MwsServer, error) {
wsServer, err := newWsServer(base)
if err != nil {
return nil, err
}
s := &MwsServer{
WsServer: wsServer,
muxServerImpl: newMuxServer(base.cfg.Listen, base.l.Named("mwss")),
}
s.e.GET(base.cfg.GetWSHandShakePath(), echo.WrapHandler(http.HandlerFunc(s.HandleRequest)))
return s, nil
}

func (s *MwsServer) ListenAndServe() error {
go func() {
s.errChan <- s.e.StartServer(s.httpServer)
}()

for {
conn, e := s.Accept()
if e != nil {
return e
}
go func(c net.Conn) {
if err := s.RelayTCPConn(c, s.relayer.TCPHandShake); err != nil {
s.l.Errorf("RelayTCPConn error: %s", err.Error())
}
}(conn)
}
}

func (s *MwsServer) HandleRequest(w http.ResponseWriter, r *http.Request) {
c, _, _, err := ws.UpgradeHTTP(r, w)
if err != nil {
s.l.Error(err)
return
}
s.mux(c)
}

func (s *MwsServer) Close() error {
return s.e.Close()
}
28 changes: 28 additions & 0 deletions test/relay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ const (
MTCP_LISTEN = "0.0.0.0:1238"
MTCP_REMOTE = "0.0.0.0:2003"
MTCP_SERVER = "0.0.0.0:2003"

MWS_LISTEN = "0.0.0.0:1239"
MWS_REMOTE = "ws://0.0.0.0:2004"
MSS_SERVER = "0.0.0.0:2004"
)

func init() {
Expand Down Expand Up @@ -127,6 +131,20 @@ func init() {
TCPRemotes: []string{ECHO_SERVER},
TransportType: constant.RelayTypeRaw,
},

// mws
{
Listen: MWS_LISTEN,
ListenType: constant.RelayTypeRaw,
TCPRemotes: []string{MWS_REMOTE},
TransportType: constant.RelayTypeMWS,
},
{
Listen: MSS_SERVER,
ListenType: constant.RelayTypeMWS,
TCPRemotes: []string{ECHO_SERVER},
TransportType: constant.RelayTypeRaw,
},
},
}
logger := zap.S()
Expand Down Expand Up @@ -265,6 +283,16 @@ func TestRelayOverMTCP(t *testing.T) {
t.Log("test tcp over mtcp done!")
}

func TestRelayOverMWS(t *testing.T) {
msg := []byte("hello")
// test tcp
res := echo.SendTcpMsg(msg, MWS_LISTEN)
if string(res) != string(msg) {
t.Fatal(res)
}
t.Log("test tcp over mws done!")
}

func BenchmarkTcpRelay(b *testing.B) {
msg := []byte("hello")
for i := 0; i <= b.N; i++ {
Expand Down

0 comments on commit 68d2dc1

Please sign in to comment.