diff --git a/intra/udp.go b/intra/udp.go index fa8dfc7f..faeff816 100644 --- a/intra/udp.go +++ b/intra/udp.go @@ -70,6 +70,7 @@ var ( errIcmpFirewalled = errors.New("icmp: firewalled") errUdpFirewalled = errors.New("udp: firewalled") errUdpSetupConn = errors.New("udp: could not create conn") + errProxyMismatch = errors.New("udp: proxy mismatch") errUdpUnconnected = errors.New("udp: cannot connect") errUdpEnd = errors.New("udp: stopped") errIcmpEnd = errors.New("icmp: stopped") @@ -323,7 +324,7 @@ func (h *udpHandler) Connect(gconn *netstack.GUDPConn, src, target netip.AddrPor for i, dstipp := range makeIPPorts(realips, target, 0) { selectedTarget = dstipp if mux { // mux is not supported by all proxies (few like Exit, Base, WG support it) - pc, err = h.mux.associate(cid, src, selectedTarget, px.Dialer().Announce, dmx) + pc, err = h.mux.associate(cid, pid, src, selectedTarget, px.Dialer().Announce, dmx) } else { pc, err = px.Dialer().Dial("udp", selectedTarget.String()) } diff --git a/intra/udpmux.go b/intra/udpmux.go index 00918d0a..d332730d 100644 --- a/intra/udpmux.go +++ b/intra/udpmux.go @@ -48,9 +48,10 @@ func (s *stats) String() string { // muxer muxes multiple connections grouped by remote addr over net.PacketConn type muxer struct { - // id, mxconn, stats are immutable (never reassigned) - cid string + // cid, pid, mxconn, stats are immutable (never reassigned) mxconn net.PacketConn + cid string // connection id of mxconn + pid string // proxy id mxconn is listening on stats *stats until time.Time // deadline extension @@ -96,9 +97,10 @@ var _ sender = (*muxer)(nil) var _ core.UDPConn = (*demuxconn)(nil) // newMuxer creates a muxer/demuxer for a connectionless conn. -func newMuxerLocked(id string, conn net.PacketConn, vnd netstack.DemuxerFn, f func()) *muxer { +func newMuxer(cid, pid string, conn net.PacketConn, vnd netstack.DemuxerFn, f func()) *muxer { x := &muxer{ - cid: id, + cid: cid, + pid: pid, mxconn: conn, stats: &stats{start: time.Now()}, routes: make(map[netip.AddrPort]*demuxconn), @@ -456,9 +458,8 @@ func newMuxTable() *muxTable { return &muxTable{t: make(map[netip.AddrPort]*muxer)} } -func (e *muxTable) associate(id string, src, dst netip.AddrPort, mk assocFn, v netstack.DemuxerFn) (c net.Conn, err error) { - e.Lock() - defer e.Unlock() +func (e *muxTable) associate(cid, pid string, src, dst netip.AddrPort, mk assocFn, v netstack.DemuxerFn) (c net.Conn, err error) { + e.Lock() // lock var mxr *muxer // dst may be of a different family than src (4to6, 6to4 etc) @@ -470,19 +471,27 @@ func (e *muxTable) associate(id string, src, dst netip.AddrPort, mk assocFn, v n pc, err = mk(proto, anyaddr) if err != nil { core.Close(pc) - return nil, err + e.Unlock() // unlock + return nil, err // return } - mxr = newMuxerLocked(id, pc, v, func() { - e.dissociate(id, src) + mxr = newMuxer(cid, pid, pc, v, func() { + e.dissociate(cid, pid, src) }) e.t[src] = mxr - log.I("udp: mux: %s new assoc for %s", id, src) + log.I("udp: mux: %s new assoc for %s", cid, src) + } else if mxr.pid != pid { + // client rules prevent from associating w/ a different proxy + log.E("udp: mux: %s assoc proxy mismatch: %s != %s", cid, mxr.pid, pid) + e.Unlock() // unlock + return nil, errProxyMismatch // return } - return mxr.vend(dst) + + e.Unlock() // unlock + return mxr.vend(dst) // do not hold e.lock on calls into mxr } -func (e *muxTable) dissociate(id string, src netip.AddrPort) { - log.I("udp: mux: %s dissoc for %s", id, src) +func (e *muxTable) dissociate(cid, pid string, src netip.AddrPort) { + log.I("udp: mux: %s (%s) dissoc for %s", cid, pid, src) e.Lock() defer e.Unlock()