diff --git a/client.go b/client.go index 35459a5..5fc126a 100644 --- a/client.go +++ b/client.go @@ -86,7 +86,8 @@ func (c *Client) DialContext(ctx context.Context, network string, destination M. if err != nil { return nil, err } - return bufio.NewUnbindPacketConn(&clientPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: destination}), nil + extendedConn := bufio.NewExtendedConn(stream) + return &clientPacketConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil default: return nil, E.Extend(N.ErrUnknownNetwork, network) } @@ -97,7 +98,8 @@ func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net if err != nil { return nil, err } - return &clientPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: destination}, nil + extendedConn := bufio.NewExtendedConn(stream) + return &clientPacketAddrConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil } func (c *Client) openStream(ctx context.Context) (net.Conn, error) { diff --git a/client_conn.go b/client_conn.go index 694cb62..c0291f6 100644 --- a/client_conn.go +++ b/client_conn.go @@ -93,12 +93,16 @@ func (c *clientConn) Upstream() any { return c.Conn } +var _ N.NetPacketConn = (*clientPacketConn)(nil) + type clientPacketConn struct { - N.ExtendedConn - access sync.Mutex - destination M.Socksaddr - requestWritten bool - responseRead bool + N.AbstractConn + conn N.ExtendedConn + access sync.Mutex + destination M.Socksaddr + requestWritten bool + responseRead bool + readWaitOptions N.ReadWaitOptions } func (c *clientPacketConn) NeedHandshake() bool { @@ -106,7 +110,7 @@ func (c *clientPacketConn) NeedHandshake() bool { } func (c *clientPacketConn) readResponse() error { - response, err := ReadStreamResponse(c.ExtendedConn) + response, err := ReadStreamResponse(c.conn) if err != nil { return err } @@ -125,14 +129,14 @@ func (c *clientPacketConn) Read(b []byte) (n int, err error) { c.responseRead = true } var length uint16 - err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) + err = binary.Read(c.conn, binary.BigEndian, &length) if err != nil { return } if cap(b) < int(length) { return 0, io.ErrShortBuffer } - return io.ReadFull(c.ExtendedConn, b[:length]) + return io.ReadFull(c.conn, b[:length]) } func (c *clientPacketConn) writeRequest(payload []byte) (n int, err error) { @@ -156,7 +160,7 @@ func (c *clientPacketConn) writeRequest(payload []byte) (n int, err error) { common.Error(buffer.Write(payload)), ) } - _, err = c.ExtendedConn.Write(buffer.Bytes()) + _, err = c.conn.Write(buffer.Bytes()) if err != nil { return } @@ -174,11 +178,11 @@ func (c *clientPacketConn) Write(b []byte) (n int, err error) { return c.writeRequest(b) } } - err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(b))) + err = binary.Write(c.conn, binary.BigEndian, uint16(len(b))) if err != nil { return } - return c.ExtendedConn.Write(b) + return c.conn.Write(b) } func (c *clientPacketConn) ReadBuffer(buffer *buf.Buffer) (err error) { @@ -190,11 +194,11 @@ func (c *clientPacketConn) ReadBuffer(buffer *buf.Buffer) (err error) { c.responseRead = true } var length uint16 - err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) + err = binary.Read(c.conn, binary.BigEndian, &length) if err != nil { return } - _, err = buffer.ReadFullFrom(c.ExtendedConn, int(length)) + _, err = buffer.ReadFullFrom(c.conn, int(length)) return } @@ -211,7 +215,7 @@ func (c *clientPacketConn) WriteBuffer(buffer *buf.Buffer) error { } bLen := buffer.Len() binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(bLen)) - return c.ExtendedConn.WriteBuffer(buffer) + return c.conn.WriteBuffer(buffer) } func (c *clientPacketConn) FrontHeadroom() int { @@ -227,14 +231,14 @@ func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) c.responseRead = true } var length uint16 - err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) + err = binary.Read(c.conn, binary.BigEndian, &length) if err != nil { return } if cap(p) < int(length) { return 0, nil, io.ErrShortBuffer } - n, err = io.ReadFull(c.ExtendedConn, p[:length]) + n, err = io.ReadFull(c.conn, p[:length]) return } @@ -248,11 +252,11 @@ func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { return c.writeRequest(p) } } - err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p))) + err = binary.Write(c.conn, binary.BigEndian, uint16(len(p))) if err != nil { return } - return c.ExtendedConn.Write(p) + return c.conn.Write(p) } func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { @@ -265,7 +269,7 @@ func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksad } func (c *clientPacketConn) LocalAddr() net.Addr { - return c.ExtendedConn.LocalAddr() + return c.conn.LocalAddr() } func (c *clientPacketConn) RemoteAddr() net.Addr { @@ -277,17 +281,19 @@ func (c *clientPacketConn) NeedAdditionalReadDeadline() bool { } func (c *clientPacketConn) Upstream() any { - return c.ExtendedConn + return c.conn } var _ N.NetPacketConn = (*clientPacketAddrConn)(nil) type clientPacketAddrConn struct { - N.ExtendedConn - access sync.Mutex - destination M.Socksaddr - requestWritten bool - responseRead bool + N.AbstractConn + conn N.ExtendedConn + access sync.Mutex + destination M.Socksaddr + requestWritten bool + responseRead bool + readWaitOptions N.ReadWaitOptions } func (c *clientPacketAddrConn) NeedHandshake() bool { @@ -295,7 +301,7 @@ func (c *clientPacketAddrConn) NeedHandshake() bool { } func (c *clientPacketAddrConn) readResponse() error { - response, err := ReadStreamResponse(c.ExtendedConn) + response, err := ReadStreamResponse(c.conn) if err != nil { return err } @@ -313,7 +319,7 @@ func (c *clientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err err } c.responseRead = true } - destination, err := M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn) + destination, err := M.SocksaddrSerializer.ReadAddrPort(c.conn) if err != nil { return } @@ -323,14 +329,14 @@ func (c *clientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err err addr = destination.UDPAddr() } var length uint16 - err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) + err = binary.Read(c.conn, binary.BigEndian, &length) if err != nil { return } if cap(p) < int(length) { return 0, nil, io.ErrShortBuffer } - n, err = io.ReadFull(c.ExtendedConn, p[:length]) + n, err = io.ReadFull(c.conn, p[:length]) return } @@ -360,7 +366,7 @@ func (c *clientPacketAddrConn) writeRequest(payload []byte, destination M.Socksa common.Error(buffer.Write(payload)), ) } - _, err = c.ExtendedConn.Write(buffer.Bytes()) + _, err = c.conn.Write(buffer.Bytes()) if err != nil { return } @@ -378,15 +384,15 @@ func (c *clientPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err erro return c.writeRequest(p, M.SocksaddrFromNet(addr)) } } - err = M.SocksaddrSerializer.WriteAddrPort(c.ExtendedConn, M.SocksaddrFromNet(addr)) + err = M.SocksaddrSerializer.WriteAddrPort(c.conn, M.SocksaddrFromNet(addr)) if err != nil { return } - err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p))) + err = binary.Write(c.conn, binary.BigEndian, uint16(len(p))) if err != nil { return } - return c.ExtendedConn.Write(p) + return c.conn.Write(p) } func (c *clientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { @@ -397,16 +403,16 @@ func (c *clientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Soc } c.responseRead = true } - destination, err = M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn) + destination, err = M.SocksaddrSerializer.ReadAddrPort(c.conn) if err != nil { return } var length uint16 - err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) + err = binary.Read(c.conn, binary.BigEndian, &length) if err != nil { return } - _, err = buffer.ReadFullFrom(c.ExtendedConn, int(length)) + _, err = buffer.ReadFullFrom(c.conn, int(length)) return } @@ -428,11 +434,11 @@ func (c *clientPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Soc return err } common.Must(binary.Write(header, binary.BigEndian, uint16(bLen))) - return c.ExtendedConn.WriteBuffer(buffer) + return c.conn.WriteBuffer(buffer) } func (c *clientPacketAddrConn) LocalAddr() net.Addr { - return c.ExtendedConn.LocalAddr() + return c.conn.LocalAddr() } func (c *clientPacketAddrConn) FrontHeadroom() int { @@ -444,5 +450,5 @@ func (c *clientPacketAddrConn) NeedAdditionalReadDeadline() bool { } func (c *clientPacketAddrConn) Upstream() any { - return c.ExtendedConn + return c.conn } diff --git a/client_conn_wait.go b/client_conn_wait.go new file mode 100644 index 0000000..fd8580d --- /dev/null +++ b/client_conn_wait.go @@ -0,0 +1,73 @@ +package mux + +import ( + "encoding/binary" + + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +var _ N.PacketReadWaiter = (*clientPacketConn)(nil) + +func (c *clientPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + c.readWaitOptions = options + return false +} + +func (c *clientPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + if !c.responseRead { + err = c.readResponse() + if err != nil { + return + } + c.responseRead = true + } + var length uint16 + err = binary.Read(c.conn, binary.BigEndian, &length) + if err != nil { + return + } + buffer = c.readWaitOptions.NewPacketBuffer() + _, err = buffer.ReadFullFrom(c.conn, int(length)) + if err != nil { + buffer.Release() + return nil, M.Socksaddr{}, err + } + c.readWaitOptions.PostReturn(buffer) + return +} + +var _ N.PacketReadWaiter = (*clientPacketAddrConn)(nil) + +func (c *clientPacketAddrConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + c.readWaitOptions = options + return false +} + +func (c *clientPacketAddrConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + if !c.responseRead { + err = c.readResponse() + if err != nil { + return + } + c.responseRead = true + } + destination, err = M.SocksaddrSerializer.ReadAddrPort(c.conn) + if err != nil { + return + } + var length uint16 + err = binary.Read(c.conn, binary.BigEndian, &length) + if err != nil { + return + } + buffer = c.readWaitOptions.NewPacketBuffer() + _, err = buffer.ReadFullFrom(c.conn, int(length)) + if err != nil { + buffer.Release() + return nil, M.Socksaddr{}, err + } + c.readWaitOptions.PostReturn(buffer) + return +} diff --git a/go.mod b/go.mod index 6aef2b8..34a2b22 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.18 require ( github.com/hashicorp/yamux v0.1.1 - github.com/sagernet/sing v0.2.18 + github.com/sagernet/sing v0.2.19-0.20231207034108-445cd4f41e3f github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 golang.org/x/net v0.19.0 golang.org/x/sys v0.15.0 diff --git a/go.sum b/go.sum index a53462e..909eb34 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,8 @@ github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE= github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ= github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk= -github.com/sagernet/sing v0.2.18 h1:2Ce4dl0pkWft+4914NGXPb8OiQpgA8UHQ9xFOmgvKuY= -github.com/sagernet/sing v0.2.18/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo= +github.com/sagernet/sing v0.2.19-0.20231207034108-445cd4f41e3f h1:hYkBnmJjVphGc4b02b4jN46ojh05vACYZI3ciD/V3pA= +github.com/sagernet/sing v0.2.19-0.20231207034108-445cd4f41e3f/go.mod h1:Ce5LNojQOgOiWhiD8pPD6E9H7e2KgtOe3Zxx4Ou5u80= github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 h1:HuE6xSwco/Xed8ajZ+coeYLmioq0Qp1/Z2zczFaV8as= github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37/go.mod h1:3skNSftZDJWTGVtVaM2jfbce8qHnmH/AGDRe62iNOg0= golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=