From 351bf845591062158c01733ca4a23dc7e5d40dd0 Mon Sep 17 00:00:00 2001 From: yusing Date: Sun, 31 Mar 2024 11:26:39 +0000 Subject: [PATCH] tcp/udp fix --- Dockerfile | 26 ++++-- Makefile | 3 +- src/go-proxy/file_reader.go | 23 ----- src/go-proxy/io.go | 93 +++++++++++++------- src/go-proxy/route.go | 1 - src/go-proxy/stream_route.go | 1 + src/go-proxy/tcp_route.go | 19 ++-- src/go-proxy/udp_route.go | 163 ++++++++++++----------------------- 8 files changed, 151 insertions(+), 178 deletions(-) delete mode 100644 src/go-proxy/file_reader.go diff --git a/Dockerfile b/Dockerfile index 793807c2..c9f7251f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,20 +1,30 @@ -FROM golang:1.22.1 as builder +FROM alpine:latest AS codemirror +RUN apk add --no-cache unzip wget make +COPY Makefile . +RUN make setup-codemirror -COPY go.mod /app/go.mod -COPY src/ /app/src -COPY Makefile /app -WORKDIR /app -RUN make get -RUN make build +FROM golang:1.22.1-alpine as builder +COPY src/ /src +COPY go.mod go.sum /src/go-proxy +WORKDIR /src/go-proxy +RUN --mount=type=cache,target="/go/pkg/mod" \ + go mod download + +ENV GOCACHE=/root/.cache/go-build +RUN --mount=type=cache,target="/go/pkg/mod" \ + --mount=type=cache,target="/root/.cache/go-build" \ + CGO_ENABLED=0 GOOS=linux go build -pgo=auto -o go-proxy FROM alpine:latest LABEL maintainer="yusing@6uo.me" RUN apk add --no-cache tzdata -COPY --from=builder /app/bin/go-proxy /app/ +RUN mkdir -p /app/templates +COPY --from=codemirror templates/codemirror/ /app/templates/codemirror COPY templates/ /app/templates COPY schema/ /app/schema +COPY --from=builder /src/go-proxy /app/ RUN chmod +x /app/go-proxy ENV DOCKER_HOST unix:///var/run/docker.sock diff --git a/Makefile b/Makefile index 69036f56..46b6f02e 100755 --- a/Makefile +++ b/Makefile @@ -11,6 +11,7 @@ setup-codemirror: wget https://codemirror.net/5/codemirror.zip unzip codemirror.zip rm codemirror.zip + mkdir -p templates mv codemirror-* templates/codemirror build: @@ -35,6 +36,6 @@ udp-server: -p 9999:9999/udp \ --label proxy.test-udp.scheme=udp \ --label proxy.test-udp.port=20003:9999 \ - --network data_default \ + --network host \ --name test-udp \ $$(docker build -q -f udp-test-server.Dockerfile .) diff --git a/src/go-proxy/file_reader.go b/src/go-proxy/file_reader.go deleted file mode 100644 index 96dc7ca4..00000000 --- a/src/go-proxy/file_reader.go +++ /dev/null @@ -1,23 +0,0 @@ -package main - -import "os" - -type Reader interface { - Read() ([]byte, error) -} - -type FileReader struct { - Path string -} - -func (r *FileReader) Read() ([]byte, error) { - return os.ReadFile(r.Path) -} - -type ByteReader struct { - Data []byte -} - -func (r *ByteReader) Read() ([]byte, error) { - return r.Data, nil -} \ No newline at end of file diff --git a/src/go-proxy/io.go b/src/go-proxy/io.go index e26b8068..854e1e09 100644 --- a/src/go-proxy/io.go +++ b/src/go-proxy/io.go @@ -2,13 +2,37 @@ package main import ( "context" + "errors" + "fmt" "io" - "sync" + "os" + "sync/atomic" ) +type Reader interface { + Read() ([]byte, error) +} + +type FileReader struct { + Path string +} + +func (r *FileReader) Read() ([]byte, error) { + return os.ReadFile(r.Path) +} + +type ByteReader struct { + Data []byte +} + +func (r *ByteReader) Read() ([]byte, error) { + return r.Data, nil +} + type ReadCloser struct { - ctx context.Context - r io.ReadCloser + ctx context.Context + r io.ReadCloser + closed atomic.Bool } func (r *ReadCloser) Read(p []byte) (int, error) { @@ -21,13 +45,16 @@ func (r *ReadCloser) Read(p []byte) (int, error) { } func (r *ReadCloser) Close() error { + if r.closed.Load() { + return nil + } + r.closed.Store(true) return r.r.Close() } type Pipe struct { r ReadCloser w io.WriteCloser - wg sync.WaitGroup ctx context.Context cancel context.CancelFunc } @@ -35,32 +62,24 @@ type Pipe struct { func NewPipe(ctx context.Context, r io.ReadCloser, w io.WriteCloser) *Pipe { ctx, cancel := context.WithCancel(ctx) return &Pipe{ - r: ReadCloser{ctx, r}, + r: ReadCloser{ctx: ctx, r: r}, w: w, ctx: ctx, cancel: cancel, } } -func (p *Pipe) Start() { - p.wg.Add(1) - go func() { - Copy(p.ctx, p.w, &p.r) - p.wg.Done() - }() +func (p *Pipe) Start() error { + return Copy(p.ctx, p.w, &p.r) } -func (p *Pipe) Stop() { +func (p *Pipe) Stop() error { p.cancel() - p.wg.Wait() -} - -func (p *Pipe) Close() (error, error) { - return p.r.Close(), p.w.Close() + return errors.Join(fmt.Errorf("read: %w", p.r.Close()), fmt.Errorf("write: %w", p.w.Close())) } -func (p *Pipe) Wait() { - p.wg.Wait() +func (p *Pipe) Write(b []byte) (int, error) { + return p.w.Write(b) } type BidirectionalPipe struct { @@ -75,26 +94,34 @@ func NewBidirectionalPipe(ctx context.Context, rw1 io.ReadWriteCloser, rw2 io.Re } } -func (p *BidirectionalPipe) Start() { - p.pSrcDst.Start() - p.pDstSrc.Start() -} - -func (p *BidirectionalPipe) Stop() { - p.pSrcDst.Stop() - p.pDstSrc.Stop() +func NewBidirectionalPipeIntermediate(ctx context.Context, listener io.ReadCloser, client io.ReadWriteCloser, target io.ReadWriteCloser) *BidirectionalPipe { + return &BidirectionalPipe{ + pSrcDst: *NewPipe(ctx, listener, client), + pDstSrc: *NewPipe(ctx, client, target), + } } -func (p *BidirectionalPipe) Close() (error, error) { - return p.pSrcDst.Close() +func (p *BidirectionalPipe) Start() error { + errCh := make(chan error, 2) + go func() { + errCh <- p.pSrcDst.Start() + }() + go func() { + errCh <- p.pDstSrc.Start() + }() + for err := range errCh { + if err != nil { + return err + } + } + return nil } -func (p *BidirectionalPipe) Wait() { - p.pSrcDst.Wait() - p.pDstSrc.Wait() +func (p *BidirectionalPipe) Stop() error { + return errors.Join(p.pSrcDst.Stop(), p.pDstSrc.Stop()) } func Copy(ctx context.Context, dst io.WriteCloser, src io.ReadCloser) error { - _, err := io.Copy(dst, &ReadCloser{ctx, src}) + _, err := io.Copy(dst, &ReadCloser{ctx: ctx, r: src}) return err } diff --git a/src/go-proxy/route.go b/src/go-proxy/route.go index e1e5bc34..7ef3ac04 100755 --- a/src/go-proxy/route.go +++ b/src/go-proxy/route.go @@ -15,7 +15,6 @@ func NewRoute(cfg *ProxyConfig) (Route, error) { if err != nil { return nil, NewNestedErrorFrom(err).Subject(cfg.Alias) } - streamRoutes.Set(id, route) return route, nil } else { httpRoutes.Ensure(cfg.Alias) diff --git a/src/go-proxy/stream_route.go b/src/go-proxy/stream_route.go index fbe29904..4ffc8a96 100755 --- a/src/go-proxy/stream_route.go +++ b/src/go-proxy/stream_route.go @@ -143,6 +143,7 @@ func (route *StreamRouteBase) Start() { route.l.Errorf("failed to setup: %v", err) return } + streamRoutes.Set(route.id, route) route.started = true route.wg.Add(2) go route.grAcceptConnections() diff --git a/src/go-proxy/tcp_route.go b/src/go-proxy/tcp_route.go index a6d6762b..e70d67a2 100755 --- a/src/go-proxy/tcp_route.go +++ b/src/go-proxy/tcp_route.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "sync" "time" ) @@ -14,12 +15,15 @@ type Pipes []*BidirectionalPipe type TCPRoute struct { *StreamRouteBase listener net.Listener + pipe Pipes + mu sync.Mutex } func NewTCPRoute(base *StreamRouteBase) StreamImpl { return &TCPRoute{ StreamRouteBase: base, listener: nil, + pipe: make(Pipes, 0), } } @@ -40,7 +44,6 @@ func (route *TCPRoute) Handle(c interface{}) error { clientConn := c.(net.Conn) defer clientConn.Close() - defer route.wg.Done() ctx, cancel := context.WithTimeout(context.Background(), tcpDialTimeout) defer cancel() @@ -58,11 +61,12 @@ func (route *TCPRoute) Handle(c interface{}) error { <-route.stopCh pipeCancel() }() + + route.mu.Lock() pipe := NewBidirectionalPipe(pipeCtx, clientConn, serverConn) - pipe.Start() - pipe.Wait() - pipe.Close() - return nil + route.pipe = append(route.pipe, pipe) + route.mu.Unlock() + return pipe.Start() } func (route *TCPRoute) CloseListeners() { @@ -71,4 +75,9 @@ func (route *TCPRoute) CloseListeners() { } route.listener.Close() route.listener = nil + for _, pipe := range route.pipe { + if err := pipe.Stop(); err != nil { + route.l.Error(err) + } + } } diff --git a/src/go-proxy/udp_route.go b/src/go-proxy/udp_route.go index fc274ad7..acc5a6c5 100755 --- a/src/go-proxy/udp_route.go +++ b/src/go-proxy/udp_route.go @@ -1,52 +1,55 @@ package main import ( + "context" "fmt" "io" "net" "sync" - - "github.com/sirupsen/logrus" ) type UDPRoute struct { *StreamRouteBase - connMap map[net.Addr]net.Conn + connMap UDPConnMap connMapMutex sync.Mutex listeningConn *net.UDPConn - targetConn *net.UDPConn + targetAddr *net.UDPAddr } type UDPConn struct { - remoteAddr net.Addr - buffer []byte - bytesReceived []byte - nReceived int + src *net.UDPConn + dst *net.UDPConn + *BidirectionalPipe } +type UDPConnMap map[net.Addr]*UDPConn + func NewUDPRoute(base *StreamRouteBase) StreamImpl { return &UDPRoute{ StreamRouteBase: base, - connMap: make(map[net.Addr]net.Conn), + connMap: make(UDPConnMap), } } func (route *UDPRoute) Setup() error { - source, err := net.ListenPacket(route.ListeningScheme, fmt.Sprintf(":%v", route.ListeningPort)) + laddr, err := net.ResolveUDPAddr(route.ListeningScheme, fmt.Sprintf(":%v", route.ListeningPort)) if err != nil { return err } - - target, err := net.Dial(route.TargetScheme, fmt.Sprintf("%s:%v", route.TargetHost, route.TargetPort)) + source, err := net.ListenUDP(route.ListeningScheme, laddr) + if err != nil { + return err + } + raddr, err := net.ResolveUDPAddr(route.TargetScheme, fmt.Sprintf("%s:%v", route.TargetHost, route.TargetPort)) if err != nil { source.Close() return err } - route.listeningConn = source.(*net.UDPConn) - route.targetConn = target.(*net.UDPConn) + route.listeningConn = source + route.targetAddr = raddr return nil } @@ -64,71 +67,39 @@ func (route *UDPRoute) Accept() (interface{}, error) { return nil, io.ErrShortBuffer } - conn := &UDPConn{ - remoteAddr: srcAddr, - buffer: buffer, - bytesReceived: buffer[:nRead], - nReceived: nRead, - } - return conn, nil -} - -func (route *UDPRoute) Handle(c interface{}) error { - var err error + conn, ok := route.connMap[srcAddr] - conn := c.(*UDPConn) - srcConn, ok := route.connMap[conn.remoteAddr] if !ok { route.connMapMutex.Lock() - srcConn, err = net.DialUDP("udp", nil, conn.remoteAddr.(*net.UDPAddr)) + srcConn, err := net.DialUDP("udp", nil, srcAddr) if err != nil { - return err + return nil, err } - route.connMap[conn.remoteAddr] = srcConn + dstConn, err := net.DialUDP("udp", nil, route.targetAddr) + if err != nil { + srcConn.Close() + return nil, err + } + pipeCtx, pipeCancel := context.WithCancel(context.Background()) + go func() { + <-route.stopCh + pipeCancel() + }() + conn = &UDPConn{ + srcConn, + dstConn, + NewBidirectionalPipe(pipeCtx, sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}), + } + route.connMap[srcAddr] = conn route.connMapMutex.Unlock() } - var forwarder func(*UDPConn, net.Conn) error - - if logLevel == logrus.DebugLevel { - forwarder = route.forwardReceivedDebug - } else { - forwarder = route.forwardReceivedReal - } - - // initiate connection to target - err = forwarder(conn, route.targetConn) - if err != nil { - return err - } + _, err = conn.dst.Write(buffer[:nRead]) + return conn, err +} - for { - select { - case <-route.stopCh: - return nil - default: - // receive from target - conn, err = route.readFrom(route.targetConn, conn.buffer) - if err != nil { - return err - } - // forward to source - err = forwarder(conn, srcConn) - if err != nil { - return err - } - // read from source - conn, err = route.readFrom(srcConn, conn.buffer) - if err != nil { - continue - } - // forward to target - err = forwarder(conn, route.targetConn) - if err != nil { - return err - } - } - } +func (route *UDPRoute) Handle(c interface{}) error { + return c.(*UDPConn).Start() } func (route *UDPRoute) CloseListeners() { @@ -136,50 +107,28 @@ func (route *UDPRoute) CloseListeners() { route.listeningConn.Close() route.listeningConn = nil } - if route.targetConn != nil { - route.targetConn.Close() - route.targetConn = nil - } for _, conn := range route.connMap { - conn.(*net.UDPConn).Close() // TODO: change on non udp target + if err := conn.dst.Close(); err != nil { + route.l.Error(err) + } } - route.connMap = make(map[net.Addr]net.Conn) + route.connMap = make(UDPConnMap) } -func (route *UDPRoute) readFrom(src net.Conn, buffer []byte) (*UDPConn, error) { - nRead, err := src.Read(buffer) - - if err != nil { - return nil, err - } - - if nRead == 0 { - return nil, io.ErrShortBuffer - } - - return &UDPConn{ - remoteAddr: src.RemoteAddr(), - buffer: buffer, - bytesReceived: buffer[:nRead], - nReceived: nRead, - }, nil +type sourceRWCloser struct { + server *net.UDPConn + target *net.UDPConn } -func (route *UDPRoute) forwardReceivedReal(receivedConn *UDPConn, dest net.Conn) error { - nWritten, err := dest.Write(receivedConn.bytesReceived) - - if nWritten != receivedConn.nReceived { - err = io.ErrShortWrite - } +func (w sourceRWCloser) Read(p []byte) (int, error) { + n, _, err := w.target.ReadFrom(p) + return n, err +} - return err +func (w sourceRWCloser) Write(p []byte) (int, error) { + return w.server.WriteToUDP(p, w.target.RemoteAddr().(*net.UDPAddr)) // TODO: support non udp } -func (route *UDPRoute) forwardReceivedDebug(receivedConn *UDPConn, dest net.Conn) error { - route.l.WithField("size", receivedConn.nReceived).Debugf( - "forwarding from %s to %s", - receivedConn.remoteAddr.String(), - dest.RemoteAddr().String(), - ) - return route.forwardReceivedReal(receivedConn, dest) +func (w sourceRWCloser) Close() error { + return w.target.Close() }