Skip to content
This repository has been archived by the owner on Jan 28, 2025. It is now read-only.

Commit

Permalink
Merge branch 'XTLS:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
MHSanaei authored Jul 21, 2024
2 parents 220a587 + c27d652 commit 27231ed
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 28 deletions.
2 changes: 1 addition & 1 deletion core/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
var (
Version_x byte = 1
Version_y byte = 8
Version_z byte = 20
Version_z byte = 21
)

var (
Expand Down
4 changes: 4 additions & 0 deletions transport/internet/splithttp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string)
gotDownResponse.Close()
}()

if c.isH3 {
gotConn.Close()
}

// we want to block Dial until we know the remote address of the server,
// for logging purposes
<-gotConn.Wait()
Expand Down
19 changes: 7 additions & 12 deletions transport/internet/splithttp/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,24 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
return &BrowserDialerClient{}
}

tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
isH2 := tlsConfig != nil && !(len(tlsConfig.NextProtocol) == 1 && tlsConfig.NextProtocol[0] == "http/1.1")
isH3 := tlsConfig != nil && (len(tlsConfig.NextProtocol) == 1 && tlsConfig.NextProtocol[0] == "h3")

globalDialerAccess.Lock()
defer globalDialerAccess.Unlock()

if globalDialerMap == nil {
globalDialerMap = make(map[dialerConf]DialerClient)
}

if isH3 {
dest.Network = net.Network_UDP
}
if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found {
return client
}

tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
isH2 := tlsConfig != nil && !(len(tlsConfig.NextProtocol) == 1 && tlsConfig.NextProtocol[0] == "http/1.1")
isH3 := tlsConfig != nil && (len(tlsConfig.NextProtocol) == 1 && tlsConfig.NextProtocol[0] == "h3")

var gotlsConfig *gotls.Config

if tlsConfig != nil {
Expand Down Expand Up @@ -86,16 +89,8 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
var uploadTransport http.RoundTripper

if isH3 {
dest.Network = net.Network_UDP
quicConfig := &quic.Config{
HandshakeIdleTimeout: 10 * time.Second,
MaxIdleTimeout: 90 * time.Second,
KeepAlivePeriod: 3 * time.Second,
Allow0RTT: true,
}
roundTripper := &http3.RoundTripper{
TLSClientConfig: gotlsConfig,
QUICConfig: quicConfig,
Dial: func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
if err != nil {
Expand Down
24 changes: 14 additions & 10 deletions transport/internet/splithttp/hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,6 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
tlsConfig := getTLSConfig(streamSettings)
l.isH3 = len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3"


if port == net.Port(0) { // unix
listener, err = internet.ListenSystem(ctx, &net.UnixAddr{
Name: address.Domain(),
Expand All @@ -285,9 +284,9 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
Port: int(port),
}, streamSettings.SocketSettings)
if err != nil {
return nil, errors.New("failed to listen UDP(for SH3) on ", address, ":", port).Base(err)
return nil, errors.New("failed to listen UDP(for SH3) on ", address, ":", port).Base(err)
}
h3listener, err := quic.ListenEarly(Conn,tlsConfig, nil)
h3listener, err := quic.ListenEarly(Conn, tlsConfig, nil)
if err != nil {
return nil, errors.New("failed to listen QUIC(for SH3) on ", address, ":", port).Base(err)
}
Expand All @@ -314,7 +313,6 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
if err != nil {
return nil, errors.New("failed to listen TCP(for SH) on ", address, ":", port).Base(err)
}
l.listener = listener
errors.LogInfo(ctx, "listening TCP(for SH) on ", address, ":", port)

// h2cHandler can handle both plaintext HTTP/1.1 and h2c
Expand All @@ -324,18 +322,24 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
ReadHeaderTimeout: time.Second * 4,
MaxHeaderBytes: 8192,
}
}

// tcp/unix (h1/h2)
if listener != nil {
if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil {
if tlsConfig := config.GetTLSConfig(); tlsConfig != nil {
listener = tls.NewListener(listener, tlsConfig)
}
}

l.listener = listener

go func() {
if err := l.server.Serve(l.listener); err != nil {
errors.LogWarningInner(ctx, err, "failed to serve http for splithttp")
}
}()
}
l.listener = listener
if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil {
if tlsConfig := config.GetTLSConfig(); tlsConfig != nil {
listener = tls.NewListener(listener, tlsConfig)
}
}

return l, err
}
Expand Down
64 changes: 59 additions & 5 deletions transport/internet/splithttp/splithttp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package splithttp_test

import (
"context"
"crypto/rand"
gotls "crypto/tls"
"fmt"
gonet "net"
Expand All @@ -10,7 +11,9 @@ import (
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol/tls/cert"
"github.com/xtls/xray-core/testing/servers/tcp"
Expand Down Expand Up @@ -143,15 +146,32 @@ func Test_listenSHAndDial_TLS(t *testing.T) {
}
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
go func() {
_ = conn.Close()
defer conn.Close()

var b [1024]byte
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err := conn.Read(b[:])
if err != nil {
return
}

common.Must2(conn.Write([]byte("Response")))
}()
})
common.Must(err)
defer listen.Close()

conn, err := Dial(context.Background(), net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
common.Must(err)
_ = conn.Close()

_, err = conn.Write([]byte("Test connection 1"))
common.Must(err)

var b [1024]byte
n, _ := conn.Read(b[:])
if string(b[:n]) != "Response" {
t.Error("response: ", string(b[:n]))
}

end := time.Now()
if !end.Before(start.Add(time.Second * 5)) {
Expand Down Expand Up @@ -229,18 +249,52 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
}
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
go func() {
_ = conn.Close()
defer conn.Close()

b := buf.New()
defer b.Release()

for {
b.Clear()
if _, err := b.ReadFrom(conn); err != nil {
return
}
common.Must2(conn.Write(b.Bytes()))
}
}()
})
common.Must(err)
defer listen.Close()

time.Sleep(time.Second)

conn, err := Dial(context.Background(), net.UDPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
common.Must(err)
_ = conn.Close()
defer conn.Close()

const N = 1024
b1 := make([]byte, N)
common.Must2(rand.Read(b1))
b2 := buf.New()

common.Must2(conn.Write(b1))

b2.Clear()
common.Must2(b2.ReadFullFrom(conn, N))
if r := cmp.Diff(b2.Bytes(), b1); r != "" {
t.Error(r)
}

common.Must2(conn.Write(b1))

b2.Clear()
common.Must2(b2.ReadFullFrom(conn, N))
if r := cmp.Diff(b2.Bytes(), b1); r != "" {
t.Error(r)
}

end := time.Now()
if !end.Before(start.Add(time.Second * 5)) {
t.Error("end: ", end, " start: ", start)
}
}
}

0 comments on commit 27231ed

Please sign in to comment.