diff --git a/examples/mws.json b/examples/mws.json new file mode 100644 index 000000000..62aff74a0 --- /dev/null +++ b/examples/mws.json @@ -0,0 +1,24 @@ +{ + "relay_configs": [ + { + "listen": "127.0.0.1:1235", + "listen_type": "raw", + "transport_type": "mws", + "tcp_remotes": ["ws://0.0.0.0:2443"], + "ws_config": { + "path": "pwd", + "remote_addr": "127.0.0.1:5201" + } + }, + { + "listen": "127.0.0.1:2443", + "listen_type": "mws", + "transport_type": "raw", + "tcp_remotes": ["0.0.0.0:5201"], + "ws_config": { + "path": "pwd", + "remote_addr": "127.0.0.1:5201" + } + } + ] +} diff --git a/internal/constant/constant.go b/internal/constant/constant.go index d8a8b22d0..bbaafcdde 100644 --- a/internal/constant/constant.go +++ b/internal/constant/constant.go @@ -29,9 +29,13 @@ const ( // relay type const ( + // tcp relay RelayTypeRaw = "raw" + RelayTypeMTCP = "mtcp" + + // ws relay RelayTypeWS = "ws" + RelayTypeMWS = "mws" RelayTypeWSS = "wss" RelayTypeMWSS = "mwss" - RelayTypeMTCP = "mtcp" ) diff --git a/internal/relay/conf/cfg.go b/internal/relay/conf/cfg.go index eea710d1b..c0082e210 100644 --- a/internal/relay/conf/cfg.go +++ b/internal/relay/conf/cfg.go @@ -24,16 +24,17 @@ type WSConfig struct { } type Config struct { + Label string `json:"label,omitempty"` Listen string `json:"listen"` ListenType string `json:"listen_type"` TransportType string `json:"transport_type"` TCPRemotes []string `json:"tcp_remotes"` UDPRemotes []string `json:"udp_remotes"` - Label string `json:"label,omitempty"` - MaxConnection int `json:"max_connection,omitempty"` - BlockedProtocols []string `json:"blocked_protocols,omitempty"` - WSConfig *WSConfig `json:"ws_config,omitempty"` + MaxConnection int `json:"max_connection,omitempty"` + BlockedProtocols []string `json:"blocked_protocols,omitempty"` + + WSConfig *WSConfig `json:"ws_config,omitempty"` } func (r *Config) GetWSHandShakePath() string { @@ -58,20 +59,9 @@ func (r *Config) Validate() error { if r.Adjust() != nil { return errors.New("adjust config failed") } - if r.ListenType != constant.RelayTypeRaw && - r.ListenType != constant.RelayTypeWS && - r.ListenType != constant.RelayTypeWSS && - r.ListenType != constant.RelayTypeMTCP && - r.ListenType != constant.RelayTypeMWSS { - return fmt.Errorf("invalid listen type:%s", r.ListenType) - } - if r.TransportType != constant.RelayTypeRaw && - r.TransportType != constant.RelayTypeWS && - r.TransportType != constant.RelayTypeWSS && - r.TransportType != constant.RelayTypeMTCP && - r.TransportType != constant.RelayTypeMWSS { - return fmt.Errorf("invalid transport type:%s", r.ListenType) + if err := r.validateType(); err != nil { + return err } if r.Listen == "" { @@ -176,3 +166,24 @@ func (r *Config) ToTCPRemotes() lb.RoundRobin { func (r *Config) GetLoggerName() string { return fmt.Sprintf("%s(%s<->%s)", r.Label, r.ListenType, r.TransportType) } + +func (r *Config) validateType() error { + if r.ListenType != constant.RelayTypeRaw && + r.ListenType != constant.RelayTypeWS && + r.ListenType != constant.RelayTypeMWS && + r.ListenType != constant.RelayTypeWSS && + r.ListenType != constant.RelayTypeMTCP && + r.ListenType != constant.RelayTypeMWSS { + return fmt.Errorf("invalid listen type:%s", r.ListenType) + } + + if r.TransportType != constant.RelayTypeRaw && + r.TransportType != constant.RelayTypeWS && + r.TransportType != constant.RelayTypeMWS && + r.TransportType != constant.RelayTypeWSS && + r.TransportType != constant.RelayTypeMTCP && + r.TransportType != constant.RelayTypeMWSS { + return fmt.Errorf("invalid transport type:%s", r.ListenType) + } + return nil +} diff --git a/internal/transporter/interface.go b/internal/transporter/interface.go index 360f89662..f36c48e25 100644 --- a/internal/transporter/interface.go +++ b/internal/transporter/interface.go @@ -22,6 +22,8 @@ func newRelayClient(base *baseTransporter) (RelayClient, error) { return newRawClient(base) case constant.RelayTypeWS: return newWsClient(base) + case constant.RelayTypeMWS: + return newMwsClient(base) case constant.RelayTypeWSS: return newWssClient(base) case constant.RelayTypeMWSS: @@ -45,6 +47,8 @@ func NewRelayServer(cfg *conf.Config, cmgr cmgr.Cmgr) (RelayServer, error) { return newRawServer(base) case constant.RelayTypeWS: return newWsServer(base) + case constant.RelayTypeMWS: + return newMwsServer(base) case constant.RelayTypeWSS: return newWssServer(base) case constant.RelayTypeMWSS: diff --git a/internal/transporter/ws_mux.go b/internal/transporter/ws_mux.go new file mode 100644 index 000000000..46a204455 --- /dev/null +++ b/internal/transporter/ws_mux.go @@ -0,0 +1,121 @@ +// NOTE CAN NOT use real ws frame to transport smux frame +// err: accept stream err: buffer size:8 too small to transport ws payload size:45 +// so this transport just use ws protocol to handshake and then use smux protocol to transport +package transporter + +import ( + "context" + "net" + "net/http" + "time" + + "github.com/gobwas/ws" + "github.com/labstack/echo/v4" + "github.com/xtaci/smux" + + "github.com/Ehco1996/ehco/internal/metrics" + "github.com/Ehco1996/ehco/pkg/lb" +) + +var ( + _ RelayClient = &MwsClient{} + _ RelayServer = &MwsServer{} + _ muxServer = &MwsServer{} +) + +type MwsClient struct { + *WssClient + + muxTP *smuxTransporter +} + +func newMwsClient(base *baseTransporter) (*MwsClient, error) { + wc, err := newWssClient(base) + if err != nil { + return nil, err + } + c := &MwsClient{WssClient: wc} + c.muxTP = NewSmuxTransporter(c.l.Named("mwss"), c.initNewSession) + return c, nil +} + +func (c *MwsClient) initNewSession(ctx context.Context, addr string) (*smux.Session, error) { + rc, _, _, err := c.dialer.Dial(ctx, addr) + if err != nil { + return nil, err + } + // stream multiplex + cfg := smux.DefaultConfig() + cfg.KeepAliveDisabled = true + session, err := smux.Client(rc, cfg) + if err != nil { + return nil, err + } + c.l.Infof("init new session to: %s", rc.RemoteAddr()) + return session, nil +} + +func (s *MwsClient) TCPHandShake(remote *lb.Node) (net.Conn, error) { + t1 := time.Now() + addr, err := s.cfg.GetWSRemoteAddr(remote.Address) + if err != nil { + return nil, err + } + mwssc, err := s.muxTP.Dial(context.TODO(), addr) + if err != nil { + return nil, err + } + latency := time.Since(t1) + metrics.HandShakeDuration.WithLabelValues(remote.Label).Observe(float64(latency.Milliseconds())) + remote.HandShakeDuration = latency + return mwssc, nil +} + +type MwsServer struct { + *WsServer + *muxServerImpl +} + +func newMwsServer(base *baseTransporter) (*MwsServer, error) { + wsServer, err := newWsServer(base) + if err != nil { + return nil, err + } + s := &MwsServer{ + WsServer: wsServer, + muxServerImpl: newMuxServer(base.cfg.Listen, base.l.Named("mwss")), + } + s.e.GET(base.cfg.GetWSHandShakePath(), echo.WrapHandler(http.HandlerFunc(s.HandleRequest))) + return s, nil +} + +func (s *MwsServer) ListenAndServe() error { + go func() { + s.errChan <- s.e.StartServer(s.httpServer) + }() + + for { + conn, e := s.Accept() + if e != nil { + return e + } + go func(c net.Conn) { + if err := s.RelayTCPConn(c, s.relayer.TCPHandShake); err != nil { + s.l.Errorf("RelayTCPConn error: %s", err.Error()) + } + }(conn) + } +} + +func (s *MwsServer) HandleRequest(w http.ResponseWriter, r *http.Request) { + c, _, _, err := ws.UpgradeHTTP(r, w) + if err != nil { + s.l.Error(err) + return + } + s.mux(c) +} + +func (s *MwsServer) Close() error { + return s.e.Close() +} diff --git a/test/relay_test.go b/test/relay_test.go index dc0521c8e..43ba035eb 100644 --- a/test/relay_test.go +++ b/test/relay_test.go @@ -41,6 +41,10 @@ const ( MTCP_LISTEN = "0.0.0.0:1238" MTCP_REMOTE = "0.0.0.0:2003" MTCP_SERVER = "0.0.0.0:2003" + + MWS_LISTEN = "0.0.0.0:1239" + MWS_REMOTE = "ws://0.0.0.0:2004" + MSS_SERVER = "0.0.0.0:2004" ) func init() { @@ -127,6 +131,20 @@ func init() { TCPRemotes: []string{ECHO_SERVER}, TransportType: constant.RelayTypeRaw, }, + + // mws + { + Listen: MWS_LISTEN, + ListenType: constant.RelayTypeRaw, + TCPRemotes: []string{MWS_REMOTE}, + TransportType: constant.RelayTypeMWS, + }, + { + Listen: MSS_SERVER, + ListenType: constant.RelayTypeMWS, + TCPRemotes: []string{ECHO_SERVER}, + TransportType: constant.RelayTypeRaw, + }, }, } logger := zap.S() @@ -265,6 +283,16 @@ func TestRelayOverMTCP(t *testing.T) { t.Log("test tcp over mtcp done!") } +func TestRelayOverMWS(t *testing.T) { + msg := []byte("hello") + // test tcp + res := echo.SendTcpMsg(msg, MWS_LISTEN) + if string(res) != string(msg) { + t.Fatal(res) + } + t.Log("test tcp over mws done!") +} + func BenchmarkTcpRelay(b *testing.B) { msg := []byte("hello") for i := 0; i <= b.N; i++ {