diff --git a/p2p/test/transport/gating_test.go b/p2p/test/transport/gating_test.go index df53da6eeb..06b37cbbc2 100644 --- a/p2p/test/transport/gating_test.go +++ b/p2p/test/transport/gating_test.go @@ -101,7 +101,7 @@ func TestInterceptSecuredOutgoing(t *testing.T) { connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true), connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) { // remove the certhash component from WebTransport and WebRTC addresses - require.Equal(t, stripCertHash(h2.Addrs()[0]).String(), addrs.RemoteMultiaddr().String()) + require.Equal(t, h2.Addrs()[0].String(), addrs.RemoteMultiaddr().String()) }), ) err := h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()}) @@ -135,8 +135,7 @@ func TestInterceptUpgradedOutgoing(t *testing.T) { connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true), connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Return(true), connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) { - // remove the certhash component from WebTransport addresses - require.Equal(t, stripCertHash(h2.Addrs()[0]), c.RemoteMultiaddr()) + require.Equal(t, h2.Addrs()[0], c.RemoteMultiaddr()) require.Equal(t, h1.ID(), c.LocalPeer()) require.Equal(t, h2.ID(), c.RemotePeer()) })) @@ -170,12 +169,12 @@ func TestInterceptAccept(t *testing.T) { // In WebRTC, retransmissions of the STUN packet might cause us to create multiple connections, // if the first connection attempt is rejected. connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) { - // remove the certhash component from WebTransport addresses + // remove the certhash component from WebRTC and WebTransport addresses require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr()) }).AnyTimes() } else { connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) { - // remove the certhash component from WebTransport addresses + // remove the certhash component from WebRTC and WebTransport addresses require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr()) }) } @@ -213,8 +212,7 @@ func TestInterceptSecuredIncoming(t *testing.T) { gomock.InOrder( connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true), connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) { - // remove the certhash component from WebTransport addresses - require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr()) + require.Equal(t, h2.Addrs()[0], addrs.LocalMultiaddr()) }), ) h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour) @@ -248,7 +246,7 @@ func TestInterceptUpgradedIncoming(t *testing.T) { connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Return(true), connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) { // remove the certhash component from WebTransport addresses - require.Equal(t, stripCertHash(h2.Addrs()[0]), c.LocalMultiaddr()) + require.Equal(t, h2.Addrs()[0], c.LocalMultiaddr()) require.Equal(t, h1.ID(), c.RemotePeer()) require.Equal(t, h2.ID(), c.LocalPeer()) }), diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index 9994eaead2..050bed4cbb 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -867,3 +867,29 @@ func TestConnClosedWhenRemoteCloses(t *testing.T) { }) } } + +func TestConnMatchingAddress(t *testing.T) { + for _, tc := range transportsToTest { + t.Run(tc.Name, func(t *testing.T) { + server := tc.HostGenerator(t, TransportTestCaseOpts{}) + client1 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) + client2 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) + defer server.Close() + defer client1.Close() + defer client2.Close() + + client1.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err := client1.Connect(ctx, peer.AddrInfo{ID: server.ID(), Addrs: server.Addrs()}) + require.NoError(t, err) + + client1Conns := client1.Network().ConnsToPeer(server.ID()) + require.Equal(t, 1, len(client1Conns)) + remoteMA := client1Conns[0].RemoteMultiaddr() + + err = client2.Connect(ctx, peer.AddrInfo{ID: server.ID(), Addrs: []ma.Multiaddr{remoteMA}}) + require.NoError(t, err) + }) + } +} diff --git a/p2p/transport/webrtc/listener.go b/p2p/transport/webrtc/listener.go index d4ba3c0550..8c20b4824f 100644 --- a/p2p/transport/webrtc/listener.go +++ b/p2p/transport/webrtc/listener.go @@ -264,14 +264,13 @@ func (l *listener) setupConnection( return nil, err } - localMultiaddrWithoutCerthash, _ := ma.SplitFunc(l.localMultiaddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_CERTHASH }) conn, err := newConnection( network.DirInbound, w.PeerConnection, l.transport, scope, l.transport.localPeerId, - localMultiaddrWithoutCerthash, + l.localMultiaddr, remotePeer, remotePubKey, remoteMultiaddr, diff --git a/p2p/transport/webrtc/transport.go b/p2p/transport/webrtc/transport.go index c4c16fd402..647f76ddc7 100644 --- a/p2p/transport/webrtc/transport.go +++ b/p2p/transport/webrtc/transport.go @@ -387,7 +387,6 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement if err != nil { return nil, err } - remoteMultiaddrWithoutCerthash, _ := ma.SplitFunc(remoteMultiaddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_CERTHASH }) conn, err := newConnection( network.DirOutbound, @@ -398,7 +397,7 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement localAddr, p, remotePubKey, - remoteMultiaddrWithoutCerthash, + remoteMultiaddr, w.IncomingDataChannels, w.PeerConnectionClosedCh, ) diff --git a/p2p/transport/websocket/conn.go b/p2p/transport/websocket/conn.go index 30b70055d0..1fb8599cc9 100644 --- a/p2p/transport/websocket/conn.go +++ b/p2p/transport/websocket/conn.go @@ -1,6 +1,9 @@ package websocket import ( + "fmt" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" "io" "net" "sync" @@ -25,17 +28,72 @@ type Conn struct { closeOnce sync.Once readLock, writeLock sync.Mutex + + laddr, raddr *Addr + laddrma, raddrma ma.Multiaddr } var _ net.Conn = (*Conn)(nil) -// NewConn creates a Conn given a regular gorilla/websocket Conn. -func NewConn(raw *ws.Conn, secure bool) *Conn { +// NewOutboundConn creates an outbound Conn given a regular gorilla/websocket Conn. +func NewOutboundConn(raw *ws.Conn, secure bool, sni string) (*Conn, error) { + return newConn(raw, secure, sni, false) +} + +// NewInboundConn creates an inbound Conn given a regular gorilla/websocket Conn. +func NewInboundConn(raw *ws.Conn, secure bool, sni string) (*Conn, error) { + return newConn(raw, secure, sni, true) +} + +// newConn creates a Conn given a regular gorilla/websocket Conn. +func newConn(raw *ws.Conn, secure bool, sni string, inbound bool) (*Conn, error) { + laddr := NewAddrWithScheme(raw.LocalAddr().String(), secure) + raddr := NewAddrWithScheme(raw.RemoteAddr().String(), secure) + + laddrma, err := manet.FromNetAddr(laddr) + if err != nil { + return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err) + } + + raddrma, err := manet.FromNetAddr(raddr) + if err != nil { + return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err) + } + + if secure && sni != "" { + var wssMA ma.Multiaddr + if inbound { + wssMA = laddrma + } else { + wssMA = raddrma + } + + if withoutWSS := wssMA.Decapsulate(ma.StringCast("/wss")); withoutWSS.Equal(wssMA) { + return nil, fmt.Errorf("missing wss component from converted multiaddr") + } else { + tlsSniWsMa, err := ma.NewMultiaddr(fmt.Sprintf("/tls/sni/%s/ws", sni)) + if err != nil { + return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err) + } + wssMA = withoutWSS.Encapsulate(tlsSniWsMa) + } + + if inbound { + laddrma = wssMA + } else { + raddrma = wssMA + } + } + return &Conn{ Conn: raw, secure: secure, DefaultMessageType: ws.BinaryMessage, - } + laddr: laddr, + raddr: raddr, + laddrma: laddrma, + raddrma: raddrma, + }, nil } func (c *Conn) Read(b []byte) (int, error) { @@ -122,11 +180,19 @@ func (c *Conn) Close() error { } func (c *Conn) LocalAddr() net.Addr { - return NewAddrWithScheme(c.Conn.LocalAddr().String(), c.secure) + return c.laddr } func (c *Conn) RemoteAddr() net.Addr { - return NewAddrWithScheme(c.Conn.RemoteAddr().String(), c.secure) + return c.raddr +} + +func (c *Conn) LocalMultiaddr() ma.Multiaddr { + return c.laddrma +} + +func (c *Conn) RemoteMultiaddr() ma.Multiaddr { + return c.raddrma } func (c *Conn) SetDeadline(t time.Time) error { diff --git a/p2p/transport/websocket/listener.go b/p2p/transport/websocket/listener.go index 8071ddb814..6875ac1b2d 100644 --- a/p2p/transport/websocket/listener.go +++ b/p2p/transport/websocket/listener.go @@ -112,10 +112,20 @@ func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + var sni string + if r.TLS != nil { + sni = r.TLS.ServerName + } + mnc, err := NewInboundConn(c, l.isWss, sni) + if err != nil { + _ = c.Close() + return + } + select { - case l.incoming <- NewConn(c, l.isWss): + case l.incoming <- mnc: case <-l.closed: - c.Close() + mnc.Close() } // The connection has been hijacked, it's safe to return. } @@ -126,13 +136,7 @@ func (l *listener) Accept() (manet.Conn, error) { if !ok { return nil, transport.ErrListenerClosed } - - mnc, err := manet.WrapNetConn(c) - if err != nil { - c.Close() - return nil, err - } - return mnc, nil + return c, nil case <-l.closed: return nil, transport.ErrListenerClosed } diff --git a/p2p/transport/websocket/websocket.go b/p2p/transport/websocket/websocket.go index 36818decee..404a11e0eb 100644 --- a/p2p/transport/websocket/websocket.go +++ b/p2p/transport/websocket/websocket.go @@ -188,8 +188,8 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma } isWss := wsurl.Scheme == "wss" dialer := ws.Dialer{HandshakeTimeout: 30 * time.Second} + var sni string if isWss { - sni := "" sni, err = raddr.ValueForProtocol(ma.P_SNI) if err != nil { sni = "" @@ -220,7 +220,7 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma return nil, err } - mnc, err := manet.WrapNetConn(NewConn(wscon, isWss)) + mnc, err := NewOutboundConn(wscon, isWss, sni) if err != nil { wscon.Close() return nil, err diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index ff611fe927..84eb044f2c 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -234,10 +234,7 @@ func (l *listener) Accept() (tpt.CapableConn, error) { } func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (*connSecurityMultiaddrs, error) { - local, err := toWebtransportMultiaddr(sess.LocalAddr()) - if err != nil { - return nil, fmt.Errorf("error determiniting local addr: %w", err) - } + local := l.Multiaddr() remote, err := toWebtransportMultiaddr(sess.RemoteAddr()) if err != nil { return nil, fmt.Errorf("error determiniting remote addr: %w", err) diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index acb40f0b89..17efa1bcf3 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -172,7 +172,7 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee if err != nil { return nil, err } - sconn, err := t.upgrade(ctx, sess, p, certHashes) + sconn, err := t.upgrade(ctx, sess, p, certHashes, raddr) if err != nil { sess.CloseWithError(1, "") return nil, err @@ -230,15 +230,11 @@ func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string return sess, conn, err } -func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash) (*connSecurityMultiaddrs, error) { +func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash, remote ma.Multiaddr) (*connSecurityMultiaddrs, error) { local, err := toWebtransportMultiaddr(sess.LocalAddr()) if err != nil { return nil, fmt.Errorf("error determining local addr: %w", err) } - remote, err := toWebtransportMultiaddr(sess.RemoteAddr()) - if err != nil { - return nil, fmt.Errorf("error determining remote addr: %w", err) - } str, err := sess.OpenStreamSync(ctx) if err != nil {