From 86c131fc054761803b5142e358b825745e6105a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 6 Dec 2023 20:31:44 +0800 Subject: [PATCH] Implementation read waiter for pipe --- common/pipe/README.md | 3 + common/pipe/pipe.go | 243 +++++++++++++++++++++++++++++++++++++ common/pipe/pipe_wait.go | 53 ++++++++ protocol/http/handshake.go | 3 +- 4 files changed, 301 insertions(+), 1 deletion(-) create mode 100644 common/pipe/README.md create mode 100644 common/pipe/pipe.go create mode 100644 common/pipe/pipe_wait.go diff --git a/common/pipe/README.md b/common/pipe/README.md new file mode 100644 index 000000000..74035cdfc --- /dev/null +++ b/common/pipe/README.md @@ -0,0 +1,3 @@ +# pipe + +mod from go1.21.4 \ No newline at end of file diff --git a/common/pipe/pipe.go b/common/pipe/pipe.go new file mode 100644 index 000000000..2f87ae304 --- /dev/null +++ b/common/pipe/pipe.go @@ -0,0 +1,243 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pipe + +import ( + "io" + "net" + "os" + "sync" + "time" + + "github.com/sagernet/sing/common/buf" +) + +// pipeDeadline is an abstraction for handling timeouts. +type pipeDeadline struct { + mu sync.Mutex // Guards timer and cancel + timer *time.Timer + cancel chan struct{} // Must be non-nil +} + +func makePipeDeadline() pipeDeadline { + return pipeDeadline{cancel: make(chan struct{})} +} + +// set sets the point in time when the deadline will time out. +// A timeout event is signaled by closing the channel returned by waiter. +// Once a timeout has occurred, the deadline can be refreshed by specifying a +// t value in the future. +// +// A zero value for t prevents timeout. +func (d *pipeDeadline) set(t time.Time) { + d.mu.Lock() + defer d.mu.Unlock() + + if d.timer != nil && !d.timer.Stop() { + <-d.cancel // Wait for the timer callback to finish and close cancel + } + d.timer = nil + + // Time is zero, then there is no deadline. + closed := isClosedChan(d.cancel) + if t.IsZero() { + if closed { + d.cancel = make(chan struct{}) + } + return + } + + // Time in the future, setup a timer to cancel in the future. + if dur := time.Until(t); dur > 0 { + if closed { + d.cancel = make(chan struct{}) + } + d.timer = time.AfterFunc(dur, func() { + close(d.cancel) + }) + return + } + + // Time in the past, so close immediately. + if !closed { + close(d.cancel) + } +} + +// wait returns a channel that is closed when the deadline is exceeded. +func (d *pipeDeadline) wait() chan struct{} { + d.mu.Lock() + defer d.mu.Unlock() + return d.cancel +} + +func isClosedChan(c <-chan struct{}) bool { + select { + case <-c: + return true + default: + return false + } +} + +type pipeAddr struct{} + +func (pipeAddr) Network() string { return "pipe" } +func (pipeAddr) String() string { return "pipe" } + +type pipe struct { + wrMu sync.Mutex // Serialize Write operations + + // Used by local Read to interact with remote Write. + // Successful receive on rdRx is always followed by send on rdTx. + rdRx <-chan []byte + rdTx chan<- int + + // Used by local Write to interact with remote Read. + // Successful send on wrTx is always followed by receive on wrRx. + wrTx chan<- []byte + wrRx <-chan int + + once sync.Once // Protects closing localDone + localDone chan struct{} + remoteDone <-chan struct{} + + readDeadline pipeDeadline + writeDeadline pipeDeadline + + newBuffer func() *buf.Buffer +} + +// Pipe creates a synchronous, in-memory, full duplex +// network connection; both ends implement the Conn interface. +// Reads on one end are matched with writes on the other, +// copying data directly between the two; there is no internal +// buffering. +func Pipe() (net.Conn, net.Conn) { + cb1 := make(chan []byte) + cb2 := make(chan []byte) + cn1 := make(chan int) + cn2 := make(chan int) + done1 := make(chan struct{}) + done2 := make(chan struct{}) + + p1 := &pipe{ + rdRx: cb1, rdTx: cn1, + wrTx: cb2, wrRx: cn2, + localDone: done1, remoteDone: done2, + readDeadline: makePipeDeadline(), + writeDeadline: makePipeDeadline(), + } + p2 := &pipe{ + rdRx: cb2, rdTx: cn2, + wrTx: cb1, wrRx: cn1, + localDone: done2, remoteDone: done1, + readDeadline: makePipeDeadline(), + writeDeadline: makePipeDeadline(), + } + return p1, p2 +} + +func (*pipe) LocalAddr() net.Addr { return pipeAddr{} } +func (*pipe) RemoteAddr() net.Addr { return pipeAddr{} } + +func (p *pipe) Read(b []byte) (int, error) { + n, err := p.read(b) + if err != nil && err != io.EOF && err != io.ErrClosedPipe { + err = &net.OpError{Op: "read", Net: "pipe", Err: err} + } + return n, err +} + +func (p *pipe) read(b []byte) (n int, err error) { + switch { + case isClosedChan(p.localDone): + return 0, io.ErrClosedPipe + case isClosedChan(p.remoteDone): + return 0, io.EOF + case isClosedChan(p.readDeadline.wait()): + return 0, os.ErrDeadlineExceeded + } + + select { + case bw := <-p.rdRx: + nr := copy(b, bw) + p.rdTx <- nr + return nr, nil + case <-p.localDone: + return 0, io.ErrClosedPipe + case <-p.remoteDone: + return 0, io.EOF + case <-p.readDeadline.wait(): + return 0, os.ErrDeadlineExceeded + } +} + +func (p *pipe) Write(b []byte) (int, error) { + n, err := p.write(b) + if err != nil && err != io.ErrClosedPipe { + err = &net.OpError{Op: "write", Net: "pipe", Err: err} + } + return n, err +} + +func (p *pipe) write(b []byte) (n int, err error) { + switch { + case isClosedChan(p.localDone): + return 0, io.ErrClosedPipe + case isClosedChan(p.remoteDone): + return 0, io.ErrClosedPipe + case isClosedChan(p.writeDeadline.wait()): + return 0, os.ErrDeadlineExceeded + } + + p.wrMu.Lock() // Ensure entirety of b is written together + defer p.wrMu.Unlock() + for once := true; once || len(b) > 0; once = false { + select { + case p.wrTx <- b: + nw := <-p.wrRx + b = b[nw:] + n += nw + case <-p.localDone: + return n, io.ErrClosedPipe + case <-p.remoteDone: + return n, io.ErrClosedPipe + case <-p.writeDeadline.wait(): + return n, os.ErrDeadlineExceeded + } + } + return n, nil +} + +func (p *pipe) SetDeadline(t time.Time) error { + if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) { + return io.ErrClosedPipe + } + p.readDeadline.set(t) + p.writeDeadline.set(t) + return nil +} + +func (p *pipe) SetReadDeadline(t time.Time) error { + if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) { + return io.ErrClosedPipe + } + p.readDeadline.set(t) + return nil +} + +func (p *pipe) SetWriteDeadline(t time.Time) error { + if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) { + return io.ErrClosedPipe + } + p.writeDeadline.set(t) + return nil +} + +func (p *pipe) Close() error { + p.once.Do(func() { close(p.localDone) }) + return nil +} diff --git a/common/pipe/pipe_wait.go b/common/pipe/pipe_wait.go new file mode 100644 index 000000000..27ea27b02 --- /dev/null +++ b/common/pipe/pipe_wait.go @@ -0,0 +1,53 @@ +package pipe + +import ( + "io" + "net" + "os" + + "github.com/sagernet/sing/common/buf" + N "github.com/sagernet/sing/common/network" +) + +var _ N.ReadWaiter = (*pipe)(nil) + +func (p *pipe) InitializeReadWaiter(newBuffer func() *buf.Buffer) { + p.newBuffer = newBuffer +} + +func (p *pipe) WaitReadBuffer() (buffer *buf.Buffer, err error) { + buffer, err = p.waitReadBuffer() + if err != nil && err != io.EOF && err != io.ErrClosedPipe { + err = &net.OpError{Op: "read", Net: "pipe", Err: err} + } + return +} + +func (p *pipe) waitReadBuffer() (buffer *buf.Buffer, err error) { + switch { + case isClosedChan(p.localDone): + return nil, io.ErrClosedPipe + case isClosedChan(p.remoteDone): + return nil, io.EOF + case isClosedChan(p.readDeadline.wait()): + return nil, os.ErrDeadlineExceeded + } + select { + case bw := <-p.rdRx: + buffer = p.newBuffer() + var nr int + nr, err = buffer.Write(bw) + if err != nil { + buffer.Release() + return + } + p.rdTx <- nr + return + case <-p.localDone: + return nil, io.ErrClosedPipe + case <-p.remoteDone: + return nil, io.EOF + case <-p.readDeadline.wait(): + return nil, os.ErrDeadlineExceeded + } +} diff --git a/protocol/http/handshake.go b/protocol/http/handshake.go index 7d90e0537..d60656f37 100644 --- a/protocol/http/handshake.go +++ b/protocol/http/handshake.go @@ -16,6 +16,7 @@ import ( F "github.com/sagernet/sing/common/format" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/pipe" ) type Handler = N.TCPConnectionHandler @@ -102,7 +103,7 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read DialContext: func(context context.Context, network, address string) (net.Conn, error) { metadata.Destination = M.ParseSocksaddr(address) metadata.Protocol = "http" - input, output := net.Pipe() + input, output := pipe.Pipe() go func() { hErr := handler.NewConnection(ctx, output, metadata) if hErr != nil {