diff --git a/core/core.go b/core/core.go index 262cf5f3160f..077b26380c45 100644 --- a/core/core.go +++ b/core/core.go @@ -21,7 +21,7 @@ import ( var ( Version_x byte = 1 Version_y byte = 8 - Version_z byte = 20 + Version_z byte = 21 ) var ( diff --git a/transport/internet/splithttp/client.go b/transport/internet/splithttp/client.go index 2a467d7dab93..8330d5d749fb 100644 --- a/transport/internet/splithttp/client.go +++ b/transport/internet/splithttp/client.go @@ -94,6 +94,10 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string) gotDownResponse.Close() }() + if c.isH3 { + gotConn.Close() + } + // we want to block Dial until we know the remote address of the server, // for logging purposes <-gotConn.Wait() diff --git a/transport/internet/splithttp/dialer.go b/transport/internet/splithttp/dialer.go index 5c22f8453005..99558975aef5 100644 --- a/transport/internet/splithttp/dialer.go +++ b/transport/internet/splithttp/dialer.go @@ -41,6 +41,10 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in return &BrowserDialerClient{} } + tlsConfig := tls.ConfigFromStreamSettings(streamSettings) + isH2 := tlsConfig != nil && !(len(tlsConfig.NextProtocol) == 1 && tlsConfig.NextProtocol[0] == "http/1.1") + isH3 := tlsConfig != nil && (len(tlsConfig.NextProtocol) == 1 && tlsConfig.NextProtocol[0] == "h3") + globalDialerAccess.Lock() defer globalDialerAccess.Unlock() @@ -48,14 +52,13 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in globalDialerMap = make(map[dialerConf]DialerClient) } + if isH3 { + dest.Network = net.Network_UDP + } if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found { return client } - tlsConfig := tls.ConfigFromStreamSettings(streamSettings) - isH2 := tlsConfig != nil && !(len(tlsConfig.NextProtocol) == 1 && tlsConfig.NextProtocol[0] == "http/1.1") - isH3 := tlsConfig != nil && (len(tlsConfig.NextProtocol) == 1 && tlsConfig.NextProtocol[0] == "h3") - var gotlsConfig *gotls.Config if tlsConfig != nil { @@ -86,16 +89,8 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in var uploadTransport http.RoundTripper if isH3 { - dest.Network = net.Network_UDP - quicConfig := &quic.Config{ - HandshakeIdleTimeout: 10 * time.Second, - MaxIdleTimeout: 90 * time.Second, - KeepAlivePeriod: 3 * time.Second, - Allow0RTT: true, - } roundTripper := &http3.RoundTripper{ TLSClientConfig: gotlsConfig, - QUICConfig: quicConfig, Dial: func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings) if err != nil { diff --git a/transport/internet/splithttp/hub.go b/transport/internet/splithttp/hub.go index 1ce8da6b0966..d4579bc72d12 100644 --- a/transport/internet/splithttp/hub.go +++ b/transport/internet/splithttp/hub.go @@ -269,7 +269,6 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet tlsConfig := getTLSConfig(streamSettings) l.isH3 = len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3" - if port == net.Port(0) { // unix listener, err = internet.ListenSystem(ctx, &net.UnixAddr{ Name: address.Domain(), @@ -285,9 +284,9 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet Port: int(port), }, streamSettings.SocketSettings) if err != nil { - return nil, errors.New("failed to listen UDP(for SH3) on ", address, ":", port).Base(err) + return nil, errors.New("failed to listen UDP(for SH3) on ", address, ":", port).Base(err) } - h3listener, err := quic.ListenEarly(Conn,tlsConfig, nil) + h3listener, err := quic.ListenEarly(Conn, tlsConfig, nil) if err != nil { return nil, errors.New("failed to listen QUIC(for SH3) on ", address, ":", port).Base(err) } @@ -314,7 +313,6 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet if err != nil { return nil, errors.New("failed to listen TCP(for SH) on ", address, ":", port).Base(err) } - l.listener = listener errors.LogInfo(ctx, "listening TCP(for SH) on ", address, ":", port) // h2cHandler can handle both plaintext HTTP/1.1 and h2c @@ -324,18 +322,24 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet ReadHeaderTimeout: time.Second * 4, MaxHeaderBytes: 8192, } + } + + // tcp/unix (h1/h2) + if listener != nil { + if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil { + if tlsConfig := config.GetTLSConfig(); tlsConfig != nil { + listener = tls.NewListener(listener, tlsConfig) + } + } + + l.listener = listener + go func() { if err := l.server.Serve(l.listener); err != nil { errors.LogWarningInner(ctx, err, "failed to serve http for splithttp") } }() } - l.listener = listener - if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil { - if tlsConfig := config.GetTLSConfig(); tlsConfig != nil { - listener = tls.NewListener(listener, tlsConfig) - } - } return l, err } diff --git a/transport/internet/splithttp/splithttp_test.go b/transport/internet/splithttp/splithttp_test.go index 5f59a738caa2..a3b609ab48ac 100644 --- a/transport/internet/splithttp/splithttp_test.go +++ b/transport/internet/splithttp/splithttp_test.go @@ -2,6 +2,7 @@ package splithttp_test import ( "context" + "crypto/rand" gotls "crypto/tls" "fmt" gonet "net" @@ -10,7 +11,9 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol/tls/cert" "github.com/xtls/xray-core/testing/servers/tcp" @@ -143,7 +146,16 @@ func Test_listenSHAndDial_TLS(t *testing.T) { } listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) { go func() { - _ = conn.Close() + defer conn.Close() + + var b [1024]byte + conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, err := conn.Read(b[:]) + if err != nil { + return + } + + common.Must2(conn.Write([]byte("Response"))) }() }) common.Must(err) @@ -151,7 +163,15 @@ func Test_listenSHAndDial_TLS(t *testing.T) { conn, err := Dial(context.Background(), net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings) common.Must(err) - _ = conn.Close() + + _, err = conn.Write([]byte("Test connection 1")) + common.Must(err) + + var b [1024]byte + n, _ := conn.Read(b[:]) + if string(b[:n]) != "Response" { + t.Error("response: ", string(b[:n])) + } end := time.Now() if !end.Before(start.Add(time.Second * 5)) { @@ -229,18 +249,52 @@ func Test_listenSHAndDial_QUIC(t *testing.T) { } listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) { go func() { - _ = conn.Close() + defer conn.Close() + + b := buf.New() + defer b.Release() + + for { + b.Clear() + if _, err := b.ReadFrom(conn); err != nil { + return + } + common.Must2(conn.Write(b.Bytes())) + } }() }) common.Must(err) defer listen.Close() + time.Sleep(time.Second) + conn, err := Dial(context.Background(), net.UDPDestination(net.DomainAddress("localhost"), listenPort), streamSettings) common.Must(err) - _ = conn.Close() + defer conn.Close() + + const N = 1024 + b1 := make([]byte, N) + common.Must2(rand.Read(b1)) + b2 := buf.New() + + common.Must2(conn.Write(b1)) + + b2.Clear() + common.Must2(b2.ReadFullFrom(conn, N)) + if r := cmp.Diff(b2.Bytes(), b1); r != "" { + t.Error(r) + } + + common.Must2(conn.Write(b1)) + + b2.Clear() + common.Must2(b2.ReadFullFrom(conn, N)) + if r := cmp.Diff(b2.Bytes(), b1); r != "" { + t.Error(r) + } end := time.Now() if !end.Before(start.Add(time.Second * 5)) { t.Error("end: ", end, " start: ", start) } -} \ No newline at end of file +}