Skip to content

Commit

Permalink
Update dialer implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Nov 15, 2024
1 parent 54badfa commit 9a245c7
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 361 deletions.
6 changes: 4 additions & 2 deletions brutal.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ const (
BrutalMinSpeedBPS = 65536
)

func WriteBrutalRequest(writer io.Writer, receiveBPS uint64) error {
return binary.Write(writer, binary.BigEndian, receiveBPS)
func EncodeBrutalRequest(receiveBPS uint64) *buf.Buffer {
buffer := buf.NewSize(8)
common.Must(binary.Write(buffer, binary.BigEndian, receiveBPS))
return buffer
}

func ReadBrutalRequest(reader io.Reader) (uint64, error) {
Expand Down
97 changes: 86 additions & 11 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package mux

import (
"context"
"encoding/binary"
"net"
"sync"

"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
Expand All @@ -14,6 +16,11 @@ import (
"github.com/sagernet/sing/common/x/list"
)

var (
_ N.Dialer = (*Client)(nil)
_ N.PayloadDialer = (*Client)(nil)
)

type Client struct {
dialer N.Dialer
logger logger.Logger
Expand Down Expand Up @@ -74,18 +81,71 @@ func NewClient(options Options) (*Client, error) {
}

func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
return c.DialPayloadContext(ctx, network, destination, nil)
}

func (c *Client) DialPayloadContext(ctx context.Context, network string, destination M.Socksaddr, payloads []*buf.Buffer) (net.Conn, error) {
switch N.NetworkName(network) {
case N.NetworkTCP:
stream, err := c.openStream(ctx)
if err != nil {
buf.ReleaseMulti(payloads)
return nil, err
}
return &clientConn{Conn: stream, destination: destination}, nil
request := StreamRequest{
Network: N.NetworkTCP,
Destination: destination,
}
buffer := buf.NewSize(streamRequestLen(request) + buf.LenMulti(payloads))
defer buffer.Release()
EncodeStreamRequest(request, buffer)
for _, payload := range payloads {
buffer.Write(payload.Bytes())
payload.Release()
}
_, err = stream.Write(buffer.Bytes())
if err != nil {
stream.Close()
return nil, E.Cause(err, "write multiplex handshake request")
}
response, err := ReadStreamResponse(stream)
if err != nil {
return nil, E.Cause(err, "read multiplex handshake response")
}
if response.Status == statusError {
return nil, E.New("remote error: " + response.Message)
}
return stream, nil
case N.NetworkUDP:
stream, err := c.openStream(ctx)
if err != nil {
buf.ReleaseMulti(payloads)
return nil, err
}
request := StreamRequest{
Network: N.NetworkUDP,
Destination: destination,
}
buffer := buf.NewSize(streamRequestLen(request) + 2*len(payloads) + buf.LenMulti(payloads))
defer buffer.Release()
EncodeStreamRequest(request, buffer)
for _, packetPayload := range payloads {
binary.Write(buffer, binary.BigEndian, uint16(packetPayload.Len()))
buffer.Write(packetPayload.Bytes())
packetPayload.Release()
}
_, err = stream.Write(buffer.Bytes())
if err != nil {
stream.Close()
return nil, E.Cause(err, "write multiplex handshake request")
}
response, err := ReadStreamResponse(stream)
if err != nil {
return nil, E.Cause(err, "read multiplex handshake response")
}
if response.Status == statusError {
return nil, E.New("remote error: " + response.Message)
}
extendedConn := bufio.NewExtendedConn(stream)
return &clientPacketConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil
default:
Expand All @@ -98,6 +158,26 @@ func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net
if err != nil {
return nil, err
}
request := StreamRequest{
Network: N.NetworkUDP,
Destination: destination,
PacketAddr: true,
}
buffer := buf.NewSize(streamRequestLen(request))
defer buffer.Release()
EncodeStreamRequest(request, buffer)
_, err = stream.Write(buffer.Bytes())
if err != nil {
stream.Close()
return nil, E.Cause(err, "write multiplex handshake request")
}
response, err := ReadStreamResponse(stream)
if err != nil {
return nil, E.Cause(err, "read multiplex handshake response")
}
if response.Status == statusError {
return nil, E.New("remote error: " + response.Message)
}
extendedConn := bufio.NewExtendedConn(stream)
return &clientPacketAddrConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil
}
Expand Down Expand Up @@ -194,7 +274,7 @@ func (c *Client) offerNew(ctx context.Context) (abstractSession, error) {
return nil, err
}
if c.brutal.Enabled {
err = c.brutalExchange(ctx, conn, session)
err = c.brutalExchange(ctx, conn)
if err != nil {
conn.Close()
session.Close()
Expand All @@ -205,21 +285,16 @@ func (c *Client) offerNew(ctx context.Context) (abstractSession, error) {
return session, nil
}

func (c *Client) brutalExchange(ctx context.Context, sessionConn net.Conn, session abstractSession) error {
stream, err := session.Open()
if err != nil {
return err
}
conn := &clientConn{Conn: &wrapStream{stream}, destination: M.Socksaddr{Fqdn: BrutalExchangeDomain}}
err = WriteBrutalRequest(conn, c.brutal.ReceiveBPS)
func (c *Client) brutalExchange(ctx context.Context, sessionConn net.Conn) error {
stream, err := c.DialPayloadContext(ctx, N.NetworkTCP, M.Socksaddr{Fqdn: BrutalExchangeDomain}, []*buf.Buffer{EncodeBrutalRequest(c.brutal.SendBPS)})
if err != nil {
return err
}
serverReceiveBPS, err := ReadBrutalResponse(conn)
serverReceiveBPS, err := ReadBrutalResponse(stream)
if err != nil {
return err
}
conn.Close()
stream.Close()
sendBPS := c.brutal.SendBPS
if serverReceiveBPS < sendBPS {
sendBPS = serverReceiveBPS
Expand Down
Loading

0 comments on commit 9a245c7

Please sign in to comment.