diff --git a/core/network/conn.go b/core/network/conn.go index 3be8cb0d69..aa6b96f718 100644 --- a/core/network/conn.go +++ b/core/network/conn.go @@ -2,15 +2,60 @@ package network import ( "context" + "fmt" "io" ic "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/protocol" ma "github.com/multiformats/go-multiaddr" ) +type ConnErrorCode uint32 + +type ConnError struct { + Remote bool + ErrorCode ConnErrorCode + TransportError error +} + +func (c *ConnError) Error() string { + side := "local" + if c.Remote { + side = "remote" + } + if c.TransportError != nil { + return fmt.Sprintf("connection closed (%s): code: 0x%x: transport error: %s", side, c.ErrorCode, c.TransportError) + } + return fmt.Sprintf("connection closed (%s): code: 0x%x", side, c.ErrorCode) +} + +func (c *ConnError) Is(target error) bool { + if tce, ok := target.(*ConnError); ok { + return tce.ErrorCode == c.ErrorCode && tce.Remote == c.Remote + } + return false +} + +func (c *ConnError) Unwrap() []error { + return []error{ErrReset, c.TransportError} +} + +const ( + ConnNoError ConnErrorCode = 0 + ConnProtocolNegotiationFailed ConnErrorCode = 0x1000 + ConnResourceLimitExceeded ConnErrorCode = 0x1001 + ConnRateLimited ConnErrorCode = 0x1002 + ConnProtocolViolation ConnErrorCode = 0x1003 + ConnSupplanted ConnErrorCode = 0x1004 + ConnGarbageCollected ConnErrorCode = 0x1005 + ConnShutdown ConnErrorCode = 0x1006 + ConnGated ConnErrorCode = 0x1007 + ConnCodeOutOfRange ConnErrorCode = 0x1008 +) + // Conn is a connection to a remote peer. It multiplexes streams. // Usually there is no need to use a Conn directly, but it may // be useful to get information about the peer on the other side: @@ -24,6 +69,11 @@ type Conn interface { ConnStat ConnScoper + // CloseWithError closes the connection with errCode. The errCode is sent to the + // peer on a best effort basis. For transports that do not support sending error + // codes on connection close, the behavior is identical to calling Close. + CloseWithError(errCode ConnErrorCode) error + // ID returns an identifier that uniquely identifies this Conn within this // host, during this run. Connection IDs may repeat across restarts. ID() string diff --git a/core/network/mux.go b/core/network/mux.go index d12e2ea34b..be61ccf62a 100644 --- a/core/network/mux.go +++ b/core/network/mux.go @@ -3,6 +3,7 @@ package network import ( "context" "errors" + "fmt" "io" "net" "time" @@ -11,6 +12,49 @@ import ( // ErrReset is returned when reading or writing on a reset stream. var ErrReset = errors.New("stream reset") +type StreamErrorCode uint32 + +type StreamError struct { + ErrorCode StreamErrorCode + Remote bool + TransportError error +} + +func (s *StreamError) Error() string { + side := "local" + if s.Remote { + side = "remote" + } + if s.TransportError != nil { + return fmt.Sprintf("stream reset (%s): code: 0x%x: transport error: %s", side, s.ErrorCode, s.TransportError) + } + return fmt.Sprintf("stream reset (%s): code: 0x%x", side, s.ErrorCode) +} + +func (s *StreamError) Is(target error) bool { + if tse, ok := target.(*StreamError); ok { + return tse.ErrorCode == s.ErrorCode && tse.Remote == s.Remote + } + return false +} + +func (s *StreamError) Unwrap() []error { + return []error{ErrReset, s.TransportError} +} + +const ( + StreamNoError StreamErrorCode = 0 + StreamProtocolNegotiationFailed StreamErrorCode = 0x1001 + StreamResourceLimitExceeded StreamErrorCode = 0x1002 + StreamRateLimited StreamErrorCode = 0x1003 + StreamProtocolViolation StreamErrorCode = 0x1004 + StreamSupplanted StreamErrorCode = 0x1005 + StreamGarbageCollected StreamErrorCode = 0x1006 + StreamShutdown StreamErrorCode = 0x1007 + StreamGated StreamErrorCode = 0x1008 + StreamCodeOutOfRange StreamErrorCode = 0x1009 +) + // MuxedStream is a bidirectional io pipe within a connection. type MuxedStream interface { io.Reader @@ -56,6 +100,11 @@ type MuxedStream interface { // side to hang up and go away. Reset() error + // ResetWithError aborts both ends of the stream with `errCode`. `errCode` is sent + // to the peer on a best effort basis. For transports that do not support sending + // error codes to remote peer, the behavior is identical to calling Reset + ResetWithError(errCode StreamErrorCode) error + SetDeadline(time.Time) error SetReadDeadline(time.Time) error SetWriteDeadline(time.Time) error @@ -75,6 +124,10 @@ type MuxedConn interface { // Close closes the stream muxer and the the underlying net.Conn. io.Closer + // CloseWithError closes the connection with errCode. The errCode is sent + // to the peer. + CloseWithError(errCode ConnErrorCode) error + // IsClosed returns whether a connection is fully closed, so it can // be garbage collected. IsClosed() bool diff --git a/core/network/stream.go b/core/network/stream.go index 62e230034c..f2b6cbcb88 100644 --- a/core/network/stream.go +++ b/core/network/stream.go @@ -27,4 +27,8 @@ type Stream interface { // Scope returns the user's view of this stream's resource scope Scope() StreamScope + + // ResetWithError closes both ends of the stream with errCode. The errCode is sent + // to the peer. + ResetWithError(errCode StreamErrorCode) error } diff --git a/go.mod b/go.mod index 620db8e9ac..d132eee35b 100644 --- a/go.mod +++ b/go.mod @@ -30,7 +30,7 @@ require ( github.com/libp2p/go-nat v0.2.0 github.com/libp2p/go-netroute v0.2.2 github.com/libp2p/go-reuseport v0.4.0 - github.com/libp2p/go-yamux/v4 v4.0.2 + github.com/libp2p/go-yamux/v5 v5.0.0 github.com/libp2p/zeroconf/v2 v2.2.0 github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd github.com/mikioh/tcpinfo v0.0.0-20190314235526-30a79bb1804b diff --git a/go.sum b/go.sum index 5fc39d583a..07dcb34476 100644 --- a/go.sum +++ b/go.sum @@ -193,8 +193,8 @@ github.com/libp2p/go-netroute v0.2.2 h1:Dejd8cQ47Qx2kRABg6lPwknU7+nBnFRpko45/fFP github.com/libp2p/go-netroute v0.2.2/go.mod h1:Rntq6jUAH0l9Gg17w5bFGhcC9a+vk4KNXs6s7IljKYE= github.com/libp2p/go-reuseport v0.4.0 h1:nR5KU7hD0WxXCJbmw7r2rhRYruNRl2koHw8fQscQm2s= github.com/libp2p/go-reuseport v0.4.0/go.mod h1:ZtI03j/wO5hZVDFo2jKywN6bYKWLOy8Se6DrI2E1cLU= -github.com/libp2p/go-yamux/v4 v4.0.2 h1:nrLh89LN/LEiqcFiqdKDRHjGstN300C1269K/EX0CPU= -github.com/libp2p/go-yamux/v4 v4.0.2/go.mod h1:C808cCRgOs1iBwY4S71T5oxgMxgLmqUw56qh4AeBW2o= +github.com/libp2p/go-yamux/v5 v5.0.0 h1:2djUh96d3Jiac/JpGkKs4TO49YhsfLopAoryfPmf+Po= +github.com/libp2p/go-yamux/v5 v5.0.0/go.mod h1:en+3cdX51U0ZslwRdRLrvQsdayFt3TSUKvBGErzpWbU= github.com/libp2p/zeroconf/v2 v2.2.0 h1:Cup06Jv6u81HLhIj1KasuNM/RHHrJ8T7wOTS4+Tv53Q= github.com/libp2p/zeroconf/v2 v2.2.0/go.mod h1:fuJqLnUwZTshS3U/bMRJ3+ow/v9oid1n0DmyYyNO1Xs= github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index a85d1978d7..5abd305956 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -464,7 +464,7 @@ func (h *BasicHost) newStreamHandler(s network.Stream) { } else { log.Debugf("protocol mux failed: %s (took %s, id:%s, remote peer:%s, remote addr:%v)", err, took, s.ID(), s.Conn().RemotePeer(), s.Conn().RemoteMultiaddr()) } - s.Reset() + s.ResetWithError(network.StreamProtocolNegotiationFailed) return } @@ -478,7 +478,7 @@ func (h *BasicHost) newStreamHandler(s network.Stream) { if err := s.SetProtocol(protoID); err != nil { log.Debugf("error setting stream protocol: %s", err) - s.Reset() + s.ResetWithError(network.StreamResourceLimitExceeded) return } @@ -717,7 +717,7 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I } defer func() { if strErr != nil && s != nil { - s.Reset() + s.ResetWithError(network.StreamProtocolNegotiationFailed) } }() @@ -761,13 +761,14 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I return nil, fmt.Errorf("failed to negotiate protocol: %w", err) } case <-ctx.Done(): - s.Reset() + s.ResetWithError(network.StreamProtocolNegotiationFailed) // wait for `SelectOneOf` to error out because of resetting the stream. <-errCh return nil, fmt.Errorf("failed to negotiate protocol: %w", ctx.Err()) } if err := s.SetProtocol(selected); err != nil { + s.ResetWithError(network.StreamResourceLimitExceeded) return nil, err } _ = h.Peerstore().AddProtocols(p, selected) // adding the protocol to the peerstore isn't critical diff --git a/p2p/muxer/testsuite/mux.go b/p2p/muxer/testsuite/mux.go index 5b47117fd6..93d24785ea 100644 --- a/p2p/muxer/testsuite/mux.go +++ b/p2p/muxer/testsuite/mux.go @@ -4,6 +4,7 @@ import ( "bytes" "context" crand "crypto/rand" + "errors" "fmt" "io" mrand "math/rand" @@ -462,7 +463,7 @@ func SubtestStreamReset(t *testing.T, tr network.Multiplexer) { time.Sleep(time.Millisecond * 50) _, err = s.Write([]byte("foo")) - if err != network.ErrReset { + if !errors.Is(err, network.ErrReset) { t.Error("should have been stream reset") } s.Close() diff --git a/p2p/muxer/yamux/conn.go b/p2p/muxer/yamux/conn.go index 40c4af4052..54a856e58c 100644 --- a/p2p/muxer/yamux/conn.go +++ b/p2p/muxer/yamux/conn.go @@ -5,7 +5,7 @@ import ( "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-yamux/v4" + "github.com/libp2p/go-yamux/v5" ) // conn implements mux.MuxedConn over yamux.Session. @@ -23,6 +23,10 @@ func (c *conn) Close() error { return c.yamux().Close() } +func (c *conn) CloseWithError(errCode network.ConnErrorCode) error { + return c.yamux().CloseWithError(uint32(errCode)) +} + // IsClosed checks if yamux.Session is in closed state. func (c *conn) IsClosed() bool { return c.yamux().IsClosed() @@ -32,7 +36,7 @@ func (c *conn) IsClosed() bool { func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { s, err := c.yamux().OpenStream(ctx) if err != nil { - return nil, err + return nil, parseError(err) } return (*stream)(s), nil @@ -41,7 +45,7 @@ func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { // AcceptStream accepts a stream opened by the other side. func (c *conn) AcceptStream() (network.MuxedStream, error) { s, err := c.yamux().AcceptStream() - return (*stream)(s), err + return (*stream)(s), parseError(err) } func (c *conn) yamux() *yamux.Session { diff --git a/p2p/muxer/yamux/stream.go b/p2p/muxer/yamux/stream.go index b50bc0bb87..450bdec479 100644 --- a/p2p/muxer/yamux/stream.go +++ b/p2p/muxer/yamux/stream.go @@ -1,11 +1,13 @@ package yamux import ( + "errors" + "fmt" "time" "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-yamux/v4" + "github.com/libp2p/go-yamux/v5" ) // stream implements mux.MuxedStream over yamux.Stream. @@ -13,22 +15,32 @@ type stream yamux.Stream var _ network.MuxedStream = &stream{} -func (s *stream) Read(b []byte) (n int, err error) { - n, err = s.yamux().Read(b) - if err == yamux.ErrStreamReset { - err = network.ErrReset +func parseError(err error) error { + if err == nil { + return err + } + se := &yamux.StreamError{} + if errors.As(err, &se) { + return &network.StreamError{Remote: se.Remote, ErrorCode: network.StreamErrorCode(se.ErrorCode), TransportError: err} } + ce := &yamux.GoAwayError{} + if errors.As(err, &ce) { + return &network.ConnError{Remote: ce.Remote, ErrorCode: network.ConnErrorCode(ce.ErrorCode), TransportError: err} + } + if errors.Is(err, yamux.ErrStreamReset) { + return fmt.Errorf("%w: %w", network.ErrReset, err) + } + return err +} - return n, err +func (s *stream) Read(b []byte) (n int, err error) { + n, err = s.yamux().Read(b) + return n, parseError(err) } func (s *stream) Write(b []byte) (n int, err error) { n, err = s.yamux().Write(b) - if err == yamux.ErrStreamReset { - err = network.ErrReset - } - - return n, err + return n, parseError(err) } func (s *stream) Close() error { @@ -39,6 +51,10 @@ func (s *stream) Reset() error { return s.yamux().Reset() } +func (s *stream) ResetWithError(errCode network.StreamErrorCode) error { + return s.yamux().ResetWithError(uint32(errCode)) +} + func (s *stream) CloseRead() error { return s.yamux().CloseRead() } diff --git a/p2p/muxer/yamux/transport.go b/p2p/muxer/yamux/transport.go index 3273836331..8350abdd86 100644 --- a/p2p/muxer/yamux/transport.go +++ b/p2p/muxer/yamux/transport.go @@ -7,7 +7,7 @@ import ( "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-yamux/v4" + "github.com/libp2p/go-yamux/v5" ) var DefaultTransport *Transport diff --git a/p2p/net/connmgr/connmgr.go b/p2p/net/connmgr/connmgr.go index 5033538e3b..c2cd307259 100644 --- a/p2p/net/connmgr/connmgr.go +++ b/p2p/net/connmgr/connmgr.go @@ -175,7 +175,8 @@ func (cm *BasicConnMgr) memoryEmergency() { // Trim connections without paying attention to the silence period. for _, c := range cm.getConnsToCloseEmergency(target) { log.Infow("low on memory. closing conn", "peer", c.RemotePeer()) - c.Close() + + c.CloseWithError(network.ConnGarbageCollected) } // finally, update the last trim time. @@ -388,7 +389,7 @@ func (cm *BasicConnMgr) trim() { // do the actual trim. for _, c := range cm.getConnsToClose() { log.Debugw("closing conn", "peer", c.RemotePeer()) - c.Close() + c.CloseWithError(network.ConnGarbageCollected) } } diff --git a/p2p/net/connmgr/connmgr_test.go b/p2p/net/connmgr/connmgr_test.go index 2c657255f0..f47557b02c 100644 --- a/p2p/net/connmgr/connmgr_test.go +++ b/p2p/net/connmgr/connmgr_test.go @@ -11,8 +11,11 @@ import ( "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/peerstore" tu "github.com/libp2p/go-libp2p/core/test" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" ) @@ -33,6 +36,14 @@ func (c *tconn) Close() error { return nil } +func (c *tconn) CloseWithError(code network.ConnErrorCode) error { + atomic.StoreUint32(&c.closed, 1) + if c.disconnectNotify != nil { + c.disconnectNotify(nil, c) + } + return nil +} + func (c *tconn) isClosed() bool { return atomic.LoadUint32(&c.closed) == 1 } @@ -794,6 +805,7 @@ type mockConn struct { } func (m mockConn) Close() error { panic("implement me") } +func (m mockConn) CloseWithError(errCode network.ConnErrorCode) error { panic("implement me") } func (m mockConn) LocalPeer() peer.ID { panic("implement me") } func (m mockConn) RemotePeer() peer.ID { panic("implement me") } func (m mockConn) RemotePublicKey() crypto.PubKey { panic("implement me") } @@ -986,3 +998,79 @@ type testLimitGetter struct { func (g testLimitGetter) GetConnLimit() int { return g.limit } + +func TestErrorCode(t *testing.T) { + sw1, sw2, sw3 := swarmt.GenSwarm(t), swarmt.GenSwarm(t), swarmt.GenSwarm(t) + defer sw1.Close() + defer sw2.Close() + defer sw3.Close() + + cm, err := NewConnManager(1, 1, WithGracePeriod(0), WithSilencePeriod(10)) + require.NoError(t, err) + defer cm.Close() + + sw1.Peerstore().AddAddrs(sw2.LocalPeer(), sw2.ListenAddresses(), peerstore.PermanentAddrTTL) + sw1.Peerstore().AddAddrs(sw3.LocalPeer(), sw3.ListenAddresses(), peerstore.PermanentAddrTTL) + + c12, err := sw1.DialPeer(context.Background(), sw2.LocalPeer()) + require.NoError(t, err) + + var c21 network.Conn + require.Eventually(t, func() bool { + conns := sw2.ConnsToPeer(sw1.LocalPeer()) + if len(conns) == 0 { + return false + } + c21 = conns[0] + return true + }, 10*time.Second, 100*time.Millisecond) + + c13, err := sw1.DialPeer(context.Background(), sw3.LocalPeer()) + require.NoError(t, err) + + var c31 network.Conn + require.Eventually(t, func() bool { + conns := sw3.ConnsToPeer(sw1.LocalPeer()) + if len(conns) == 0 { + return false + } + c31 = conns[0] + return true + }, 10*time.Second, 100*time.Millisecond) + + not := cm.Notifee() + not.Connected(sw1, c12) + not.Connected(sw1, c13) + + cm.TrimOpenConns(context.Background()) + + require.True(t, c12.IsClosed() || c13.IsClosed()) + var c, cr network.Conn + if c12.IsClosed() { + c = c12 + require.Eventually(t, func() bool { + conns := sw2.ConnsToPeer(sw1.LocalPeer()) + if len(conns) == 0 { + cr = c21 + return true + } + return false + }, 5*time.Second, 100*time.Millisecond) + } else { + c = c13 + require.Eventually(t, func() bool { + conns := sw3.ConnsToPeer(sw1.LocalPeer()) + if len(conns) == 0 { + cr = c31 + return true + } + return false + }, 5*time.Second, 100*time.Millisecond) + } + + _, err = c.NewStream(context.Background()) + require.ErrorIs(t, err, &network.ConnError{ErrorCode: network.ConnGarbageCollected, Remote: false}) + + _, err = cr.NewStream(context.Background()) + require.ErrorIs(t, err, &network.ConnError{ErrorCode: network.ConnGarbageCollected, Remote: true}) +} diff --git a/p2p/net/mock/mock_conn.go b/p2p/net/mock/mock_conn.go index 8c3dc87299..fc4e0ad670 100644 --- a/p2p/net/mock/mock_conn.go +++ b/p2p/net/mock/mock_conn.go @@ -185,3 +185,7 @@ func (c *conn) Stat() network.ConnStats { func (c *conn) Scope() network.ConnScope { return &network.NullScope{} } + +func (c *conn) CloseWithError(_ network.ConnErrorCode) error { + return c.Close() +} diff --git a/p2p/net/mock/mock_stream.go b/p2p/net/mock/mock_stream.go index c85cca544d..27e32d9e9e 100644 --- a/p2p/net/mock/mock_stream.go +++ b/p2p/net/mock/mock_stream.go @@ -144,6 +144,24 @@ func (s *stream) Reset() error { return nil } +// ResetWithError resets the stream. It ignores the provided error code. +// TODO: Implement error code support. +func (s *stream) ResetWithError(_ network.StreamErrorCode) error { + // Cancel any pending reads/writes with an error. + + s.write.CloseWithError(network.ErrReset) + s.read.CloseWithError(network.ErrReset) + + select { + case s.reset <- struct{}{}: + default: + } + <-s.closed + + // No meaningful error case here. + return nil +} + func (s *stream) teardown() { // at this point, no streams are writing. s.conn.removeStream(s) diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index eb6abbcd84..b5ca654462 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -385,8 +385,7 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, // If we do this in the Upgrader, we will not be able to do this. if s.gater != nil { if allow, _ := s.gater.InterceptUpgraded(c); !allow { - // TODO Send disconnect with reason here - err := tc.Close() + err := tc.CloseWithError(network.ConnGated) if err != nil { log.Warnf("failed to close connection with peer %s and addr %s; err: %s", p, addr, err) } @@ -845,6 +844,14 @@ func (c *connWithMetrics) Close() error { return c.closeErr } +func (c *connWithMetrics) CloseWithError(errCode network.ConnErrorCode) error { + c.once.Do(func() { + c.metricsTracer.ClosedConnection(c.dir, time.Since(c.opened), c.ConnState(), c.LocalMultiaddr()) + c.closeErr = c.CapableConn.CloseWithError(errCode) + }) + return c.closeErr +} + func (c *connWithMetrics) Stat() network.ConnStats { if cs, ok := c.CapableConn.(network.ConnStat); ok { return cs.Stat() diff --git a/p2p/net/swarm/swarm_conn.go b/p2p/net/swarm/swarm_conn.go index 5fd41c8d9f..1d6cf96b4e 100644 --- a/p2p/net/swarm/swarm_conn.go +++ b/p2p/net/swarm/swarm_conn.go @@ -58,11 +58,20 @@ func (c *Conn) ID() string { // open notifications must finish before we can fire off the close // notifications). func (c *Conn) Close() error { - c.closeOnce.Do(c.doClose) + c.closeOnce.Do(func() { + c.doClose(0) + }) return c.err } -func (c *Conn) doClose() { +func (c *Conn) CloseWithError(errCode network.ConnErrorCode) error { + c.closeOnce.Do(func() { + c.doClose(errCode) + }) + return c.err +} + +func (c *Conn) doClose(errCode network.ConnErrorCode) { c.swarm.removeConn(c) // Prevent new streams from opening. @@ -71,7 +80,11 @@ func (c *Conn) doClose() { c.streams.m = nil c.streams.Unlock() - c.err = c.conn.Close() + if errCode != 0 { + c.err = c.conn.CloseWithError(errCode) + } else { + c.err = c.conn.Close() + } // Send the connectedness event after closing the connection. // This ensures that both remote connection close and local connection @@ -121,7 +134,7 @@ func (c *Conn) start() { } scope, err := c.swarm.ResourceManager().OpenStream(c.RemotePeer(), network.DirInbound) if err != nil { - ts.Reset() + ts.ResetWithError(network.StreamResourceLimitExceeded) continue } c.swarm.refs.Add(1) diff --git a/p2p/net/swarm/swarm_stream.go b/p2p/net/swarm/swarm_stream.go index b7846adec2..4fee368250 100644 --- a/p2p/net/swarm/swarm_stream.go +++ b/p2p/net/swarm/swarm_stream.go @@ -91,6 +91,12 @@ func (s *Stream) Reset() error { return err } +func (s *Stream) ResetWithError(errCode network.StreamErrorCode) error { + err := s.stream.ResetWithError(errCode) + s.closeAndRemoveStream() + return err +} + func (s *Stream) closeAndRemoveStream() { s.closeMx.Lock() defer s.closeMx.Unlock() diff --git a/p2p/net/swarm/swarm_test.go b/p2p/net/swarm/swarm_test.go index 3d92690b98..496236f826 100644 --- a/p2p/net/swarm/swarm_test.go +++ b/p2p/net/swarm/swarm_test.go @@ -538,7 +538,7 @@ func TestResourceManagerAcceptStream(t *testing.T) { if err == nil { _, err = str.Read([]byte{0}) } - require.EqualError(t, err, "stream reset") + require.ErrorContains(t, err, "stream reset") } func TestListenCloseCount(t *testing.T) { diff --git a/p2p/net/upgrader/conn.go b/p2p/net/upgrader/conn.go index 1c23a01aed..2cc4dcfbb6 100644 --- a/p2p/net/upgrader/conn.go +++ b/p2p/net/upgrader/conn.go @@ -63,3 +63,8 @@ func (t *transportConn) ConnState() network.ConnectionState { UsedEarlyMuxerNegotiation: t.usedEarlyMuxerNegotiation, } } + +func (t *transportConn) CloseWithError(errCode network.ConnErrorCode) error { + defer t.scope.Done() + return t.MuxedConn.CloseWithError(errCode) +} diff --git a/p2p/net/upgrader/listener.go b/p2p/net/upgrader/listener.go index c2e81d2e93..55783f0154 100644 --- a/p2p/net/upgrader/listener.go +++ b/p2p/net/upgrader/listener.go @@ -162,7 +162,7 @@ func (l *listener) handleIncoming() { // if we stop accepting connections for some reason, // we'll eventually close all the open ones // instead of hanging onto them. - conn.Close() + conn.CloseWithError(network.ConnRateLimited) } }() } diff --git a/p2p/protocol/circuitv2/relay/relay_test.go b/p2p/protocol/circuitv2/relay/relay_test.go index f6b63e32de..7c5ec927df 100644 --- a/p2p/protocol/circuitv2/relay/relay_test.go +++ b/p2p/protocol/circuitv2/relay/relay_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "crypto/rand" + "errors" "fmt" "io" "testing" @@ -267,12 +268,12 @@ func TestRelayLimitTime(t *testing.T) { if n > 0 { t.Fatalf("expected to write 0 bytes, wrote %d", n) } - if err != network.ErrReset { + if !errors.Is(err, network.ErrReset) { t.Fatalf("expected reset, but got %s", err) } err = <-rch - if err != network.ErrReset { + if !errors.Is(err, network.ErrReset) { t.Fatalf("expected reset, but got %s", err) } } @@ -300,7 +301,7 @@ func TestRelayLimitData(t *testing.T) { } n, err := s.Read(buf) - if err != network.ErrReset { + if !errors.Is(err, network.ErrReset) { t.Fatalf("expected reset but got %s", err) } rch <- n diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index 4984419dce..9936463e23 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -36,6 +36,7 @@ import ( "go.uber.org/mock/gomock" ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -866,3 +867,170 @@ func TestConnClosedWhenRemoteCloses(t *testing.T) { }) } } + +func TestErrorCodes(t *testing.T) { + assertStreamErrors := func(s network.Stream, expectedError error) { + buf := make([]byte, 10) + _, err := s.Read(buf) + require.ErrorIs(t, err, expectedError) + + _, err = s.Write(buf) + require.ErrorIs(t, err, expectedError) + } + + for _, tc := range transportsToTest { + t.Run(tc.Name, func(t *testing.T) { + server := tc.HostGenerator(t, TransportTestCaseOpts{}) + client := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) + defer server.Close() + defer client.Close() + + client.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL) + + // setup stream handler + remoteStreamQ := make(chan network.Stream) + server.SetStreamHandler("/test", func(s network.Stream) { + b := make([]byte, 10) + n, err := s.Read(b) + if !assert.NoError(t, err) { + return + } + _, err = s.Write(b[:n]) + if !assert.NoError(t, err) { + return + } + remoteStreamQ <- s + }) + + // pingPong writes and reads "hello" on the stream + pingPong := func(s network.Stream) { + buf := []byte("hello") + _, err := s.Write(buf) + require.NoError(t, err) + + _, err = s.Read(buf) + require.NoError(t, err) + require.Equal(t, buf, []byte("hello")) + } + + t.Run("StreamResetWithError", func(t *testing.T) { + if tc.Name == "WebTransport" { + t.Skipf("skipping: %s, not implemented", tc.Name) + return + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := client.NewStream(ctx, server.ID(), "/test") + require.NoError(t, err) + pingPong(s) + + remoteStream := <-remoteStreamQ + defer remoteStream.Reset() + + err = s.ResetWithError(42) + require.NoError(t, err) + assertStreamErrors(s, &network.StreamError{ + ErrorCode: 42, + Remote: false, + }) + + assertStreamErrors(remoteStream, &network.StreamError{ + ErrorCode: 42, + Remote: true, + }) + }) + t.Run("StreamResetWithErrorByRemote", func(t *testing.T) { + if tc.Name == "WebTransport" { + t.Skipf("skipping: %s, not implemented", tc.Name) + return + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := client.NewStream(ctx, server.ID(), "/test") + require.NoError(t, err) + pingPong(s) + + remoteStream := <-remoteStreamQ + + err = remoteStream.ResetWithError(42) + require.NoError(t, err) + + assertStreamErrors(s, &network.StreamError{ + ErrorCode: 42, + Remote: true, + }) + + assertStreamErrors(remoteStream, &network.StreamError{ + ErrorCode: 42, + Remote: false, + }) + }) + + t.Run("StreamResetByConnCloseWithError", func(t *testing.T) { + if tc.Name == "WebTransport" || tc.Name == "WebRTC" { + t.Skipf("skipping: %s, not implemented", tc.Name) + return + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := client.NewStream(ctx, server.ID(), "/test") + require.NoError(t, err) + pingPong(s) + + remoteStream := <-remoteStreamQ + defer remoteStream.Reset() + + err = s.Conn().CloseWithError(42) + require.NoError(t, err) + + assertStreamErrors(s, &network.ConnError{ + ErrorCode: 42, + Remote: false, + }) + + assertStreamErrors(remoteStream, &network.ConnError{ + ErrorCode: 42, + Remote: true, + }) + }) + + t.Run("NewStreamErrorByConnCloseWithError", func(t *testing.T) { + if tc.Name == "WebTransport" || tc.Name == "WebRTC" { + t.Skipf("skipping: %s, not implemented", tc.Name) + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := client.NewStream(ctx, server.ID(), "/test") + require.NoError(t, err) + pingPong(s) + + err = s.Conn().CloseWithError(42) + require.NoError(t, err) + + remoteStream := <-remoteStreamQ + defer remoteStream.Reset() + + localErr := &network.ConnError{ + ErrorCode: 42, + Remote: false, + } + + remoteErr := &network.ConnError{ + ErrorCode: 42, + Remote: true, + } + + // assert these first to ensure that remote has closed the connection + assertStreamErrors(remoteStream, remoteErr) + + _, err = s.Conn().NewStream(ctx) + require.ErrorIs(t, err, localErr) + + _, err = remoteStream.Conn().NewStream(ctx) + require.ErrorIs(t, err, remoteErr) + }) + }) + } +} diff --git a/p2p/transport/quic/conn.go b/p2p/transport/quic/conn.go index a2da81eb34..8b381d8eda 100644 --- a/p2p/transport/quic/conn.go +++ b/p2p/transport/quic/conn.go @@ -34,6 +34,13 @@ func (c *conn) Close() error { return c.closeWithError(0, "") } +// CloseWithError closes the connection +// It must be called even if the peer closed the connection in order for +// garbage collection to properly work in this package. +func (c *conn) CloseWithError(errCode network.ConnErrorCode) error { + return c.closeWithError(quic.ApplicationErrorCode(errCode), "") +} + func (c *conn) closeWithError(errCode quic.ApplicationErrorCode, errString string) error { c.transport.removeConn(c.quicConn) err := c.quicConn.CloseWithError(errCode, errString) @@ -53,13 +60,19 @@ func (c *conn) allowWindowIncrease(size uint64) bool { // OpenStream creates a new stream. func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { qstr, err := c.quicConn.OpenStreamSync(ctx) - return &stream{Stream: qstr}, err + if err != nil { + return nil, parseStreamError(err) + } + return &stream{Stream: qstr}, nil } // AcceptStream accepts a stream opened by the other side. func (c *conn) AcceptStream() (network.MuxedStream, error) { qstr, err := c.quicConn.AcceptStream(context.Background()) - return &stream{Stream: qstr}, err + if err != nil { + return nil, parseStreamError(err) + } + return &stream{Stream: qstr}, nil } // LocalPeer returns our peer ID diff --git a/p2p/transport/quic/conn_test.go b/p2p/transport/quic/conn_test.go index d3e27a7e16..bf3f7b0751 100644 --- a/p2p/transport/quic/conn_test.go +++ b/p2p/transport/quic/conn_test.go @@ -270,6 +270,9 @@ func TestStreams(t *testing.T) { t.Run(tc.Name, func(t *testing.T) { testStreams(t, tc) }) + t.Run(tc.Name, func(t *testing.T) { + testStreamsErrorCode(t, tc) + }) } } @@ -305,6 +308,45 @@ func testStreams(t *testing.T, tc *connTestCase) { require.Equal(t, data, []byte("foobar")) } +func testStreamsErrorCode(t *testing.T, tc *connTestCase) { + serverID, serverKey := createPeer(t) + _, clientKey := createPeer(t) + + serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil) + require.NoError(t, err) + defer serverTransport.(io.Closer).Close() + ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1") + defer ln.Close() + + clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil) + require.NoError(t, err) + defer clientTransport.(io.Closer).Close() + conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) + require.NoError(t, err) + defer conn.Close() + serverConn, err := ln.Accept() + require.NoError(t, err) + defer serverConn.Close() + + str, err := conn.OpenStream(context.Background()) + require.NoError(t, err) + err = str.ResetWithError(42) + require.NoError(t, err) + + sstr, err := serverConn.AcceptStream() + require.NoError(t, err) + _, err = io.ReadAll(sstr) + require.Error(t, err) + se := &network.StreamError{} + if errors.As(err, &se) { + require.Equal(t, se.ErrorCode, network.StreamErrorCode(42)) + require.True(t, se.Remote) + } else { + t.Fatalf("expected error to be of network.StreamError type, got %T, %v", err, err) + } + +} + func TestHandshakeFailPeerIDMismatch(t *testing.T) { for _, tc := range connTestCases { t.Run(tc.Name, func(t *testing.T) { diff --git a/p2p/transport/quic/listener.go b/p2p/transport/quic/listener.go index f90bdf53f0..30868e49eb 100644 --- a/p2p/transport/quic/listener.go +++ b/p2p/transport/quic/listener.go @@ -11,7 +11,6 @@ import ( tpt "github.com/libp2p/go-libp2p/core/transport" p2ptls "github.com/libp2p/go-libp2p/p2p/security/tls" "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" - ma "github.com/multiformats/go-multiaddr" "github.com/quic-go/quic-go" ) @@ -54,12 +53,12 @@ func (l *listener) Accept() (tpt.CapableConn, error) { c, err := l.wrapConn(qconn) if err != nil { log.Debugf("failed to setup connection: %s", err) - qconn.CloseWithError(1, "") + qconn.CloseWithError(quic.ApplicationErrorCode(network.ConnResourceLimitExceeded), "") continue } l.transport.addConn(qconn, c) if l.transport.gater != nil && !(l.transport.gater.InterceptAccept(c) && l.transport.gater.InterceptSecured(network.DirInbound, c.remotePeerID, c)) { - c.closeWithError(errorCodeConnectionGating, "connection gated") + c.closeWithError(quic.ApplicationErrorCode(network.ConnGated), "connection gated") continue } diff --git a/p2p/transport/quic/listener_test.go b/p2p/transport/quic/listener_test.go index dbd6d810e4..53d6001d35 100644 --- a/p2p/transport/quic/listener_test.go +++ b/p2p/transport/quic/listener_test.go @@ -159,10 +159,11 @@ func TestCleanupConnWhenBlocked(t *testing.T) { s.SetReadDeadline(time.Now().Add(10 * time.Second)) b := [1]byte{} _, err = s.Read(b[:]) - if err != nil && errors.As(err, &quicErr) { + connError := &network.ConnError{} + if err != nil && errors.As(err, &connError) { // We hit our expected application error return } - t.Fatalf("expected application error, got %v", err) + t.Fatalf("expected network.ConnError, got %v", err) } diff --git a/p2p/transport/quic/stream.go b/p2p/transport/quic/stream.go index ee21babe68..1de4770dce 100644 --- a/p2p/transport/quic/stream.go +++ b/p2p/transport/quic/stream.go @@ -2,6 +2,7 @@ package libp2pquic import ( "errors" + "math" "github.com/libp2p/go-libp2p/core/network" @@ -18,24 +19,49 @@ type stream struct { var _ network.MuxedStream = &stream{} -func (s *stream) Read(b []byte) (n int, err error) { - var streamErr *quic.StreamError +func parseStreamError(err error) error { + if err == nil { + return err + } + se := &quic.StreamError{} + if errors.As(err, &se) { + var code network.StreamErrorCode + if se.ErrorCode > math.MaxUint32 { + code = network.StreamCodeOutOfRange + } else { + code = network.StreamErrorCode(se.ErrorCode) + } + err = &network.StreamError{ + ErrorCode: code, + Remote: se.Remote, + TransportError: se, + } + } + ae := &quic.ApplicationError{} + if errors.As(err, &ae) { + var code network.ConnErrorCode + if ae.ErrorCode > math.MaxUint32 { + code = network.ConnCodeOutOfRange + } else { + code = network.ConnErrorCode(ae.ErrorCode) + } + err = &network.ConnError{ + ErrorCode: code, + Remote: ae.Remote, + TransportError: ae, + } + } + return err +} +func (s *stream) Read(b []byte) (n int, err error) { n, err = s.Stream.Read(b) - if err != nil && errors.As(err, &streamErr) { - err = network.ErrReset - } - return n, err + return n, parseStreamError(err) } func (s *stream) Write(b []byte) (n int, err error) { - var streamErr *quic.StreamError - n, err = s.Stream.Write(b) - if err != nil && errors.As(err, &streamErr) { - err = network.ErrReset - } - return n, err + return n, parseStreamError(err) } func (s *stream) Reset() error { @@ -44,6 +70,12 @@ func (s *stream) Reset() error { return nil } +func (s *stream) ResetWithError(errCode network.StreamErrorCode) error { + s.Stream.CancelRead(quic.StreamErrorCode(errCode)) + s.Stream.CancelWrite(quic.StreamErrorCode(errCode)) + return nil +} + func (s *stream) Close() error { s.Stream.CancelRead(reset) return s.Stream.Close() diff --git a/p2p/transport/quic/transport.go b/p2p/transport/quic/transport.go index 4d3d9e551d..62d31a8d2a 100644 --- a/p2p/transport/quic/transport.go +++ b/p2p/transport/quic/transport.go @@ -34,8 +34,6 @@ var ErrHolePunching = errors.New("hole punching attempted; no active dial") var HolePunchTimeout = 5 * time.Second -const errorCodeConnectionGating = 0x47415445 // GATE in ASCII - // The Transport implements the tpt.Transport interface for QUIC connections. type transport struct { privKey ic.PrivKey @@ -169,7 +167,7 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee remoteMultiaddr: raddr, } if t.gater != nil && !t.gater.InterceptSecured(network.DirOutbound, p, c) { - pconn.CloseWithError(errorCodeConnectionGating, "connection gated") + pconn.CloseWithError(quic.ApplicationErrorCode(network.ConnGated), "connection gated") return nil, fmt.Errorf("secured connection gated") } t.addConn(pconn, c) diff --git a/p2p/transport/quic/virtuallistener.go b/p2p/transport/quic/virtuallistener.go index 7927225567..5b23e4c507 100644 --- a/p2p/transport/quic/virtuallistener.go +++ b/p2p/transport/quic/virtuallistener.go @@ -3,6 +3,7 @@ package libp2pquic import ( "sync" + "github.com/libp2p/go-libp2p/core/network" tpt "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" @@ -142,8 +143,8 @@ func (r *acceptLoopRunner) innerAccept(l *listener, expectedVersion quic.Version select { case ch <- acceptVal{conn: conn}: default: + conn.CloseWithError(network.ConnRateLimited) // accept queue filled up, drop the connection - conn.Close() log.Warn("Accept queue filled. Dropping connection.") } diff --git a/p2p/transport/quicreuse/listener.go b/p2p/transport/quicreuse/listener.go index 42f1d00cef..44028197d8 100644 --- a/p2p/transport/quicreuse/listener.go +++ b/p2p/transport/quicreuse/listener.go @@ -10,6 +10,7 @@ import ( "strings" "sync" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/transport" ma "github.com/multiformats/go-multiaddr" "github.com/quic-go/quic-go" @@ -212,7 +213,7 @@ func (l *listener) Close() error { close(l.queue) // drain the queue for conn := range l.queue { - conn.CloseWithError(1, "closing") + conn.CloseWithError(quic.ApplicationErrorCode(network.ConnShutdown), "closing") } }) return nil diff --git a/p2p/transport/webrtc/connection.go b/p2p/transport/webrtc/connection.go index b614b2c45d..77b293fadb 100644 --- a/p2p/transport/webrtc/connection.go +++ b/p2p/transport/webrtc/connection.go @@ -132,6 +132,12 @@ func (c *connection) Close() error { return nil } +// CloseWithError closes the connection ignoring the error code. As there's no way to signal +// the remote peer on closing the underlying peerconnection, we ignore the error code. +func (c *connection) CloseWithError(_ network.ConnErrorCode) error { + return c.Close() +} + // closeWithError is used to Close the connection when the underlying DTLS connection fails func (c *connection) closeWithError(err error) { c.closeOnce.Do(func() { diff --git a/p2p/transport/webrtc/pb/message.pb.go b/p2p/transport/webrtc/pb/message.pb.go index 8a3e788651..6fc068b560 100644 --- a/p2p/transport/webrtc/pb/message.pb.go +++ b/p2p/transport/webrtc/pb/message.pb.go @@ -95,6 +95,7 @@ type Message struct { state protoimpl.MessageState `protogen:"open.v1"` Flag *Message_Flag `protobuf:"varint,1,opt,name=flag,enum=Message_Flag" json:"flag,omitempty"` Message []byte `protobuf:"bytes,2,opt,name=message" json:"message,omitempty"` + ErrorCode *uint32 `protobuf:"varint,3,opt,name=errorCode" json:"errorCode,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -143,24 +144,32 @@ func (x *Message) GetMessage() []byte { return nil } +func (x *Message) GetErrorCode() uint32 { + if x != nil && x.ErrorCode != nil { + return *x.ErrorCode + } + return 0 +} + var File_p2p_transport_webrtc_pb_message_proto protoreflect.FileDescriptor var file_p2p_transport_webrtc_pb_message_proto_rawDesc = string([]byte{ 0x0a, 0x25, 0x70, 0x32, 0x70, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2f, 0x70, 0x62, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x81, 0x01, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, + 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x9f, 0x01, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x21, 0x0a, 0x04, 0x66, 0x6c, 0x61, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0d, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x2e, 0x46, 0x6c, 0x61, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x61, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, - 0x22, 0x39, 0x0a, 0x04, 0x46, 0x6c, 0x61, 0x67, 0x12, 0x07, 0x0a, 0x03, 0x46, 0x49, 0x4e, 0x10, - 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x53, 0x54, 0x4f, 0x50, 0x5f, 0x53, 0x45, 0x4e, 0x44, 0x49, 0x4e, - 0x47, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x52, 0x45, 0x53, 0x45, 0x54, 0x10, 0x02, 0x12, 0x0b, - 0x0a, 0x07, 0x46, 0x49, 0x4e, 0x5f, 0x41, 0x43, 0x4b, 0x10, 0x03, 0x42, 0x35, 0x5a, 0x33, 0x67, - 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, - 0x2f, 0x67, 0x6f, 0x2d, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2f, 0x70, 0x32, 0x70, 0x2f, 0x74, - 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2f, - 0x70, 0x62, + 0x12, 0x1c, 0x0a, 0x09, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x0d, 0x52, 0x09, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x22, 0x39, + 0x0a, 0x04, 0x46, 0x6c, 0x61, 0x67, 0x12, 0x07, 0x0a, 0x03, 0x46, 0x49, 0x4e, 0x10, 0x00, 0x12, + 0x10, 0x0a, 0x0c, 0x53, 0x54, 0x4f, 0x50, 0x5f, 0x53, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, + 0x01, 0x12, 0x09, 0x0a, 0x05, 0x52, 0x45, 0x53, 0x45, 0x54, 0x10, 0x02, 0x12, 0x0b, 0x0a, 0x07, + 0x46, 0x49, 0x4e, 0x5f, 0x41, 0x43, 0x4b, 0x10, 0x03, 0x42, 0x35, 0x5a, 0x33, 0x67, 0x69, 0x74, + 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2f, 0x67, + 0x6f, 0x2d, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2f, 0x70, 0x32, 0x70, 0x2f, 0x74, 0x72, 0x61, + 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2f, 0x70, 0x62, }) var ( diff --git a/p2p/transport/webrtc/pb/message.proto b/p2p/transport/webrtc/pb/message.proto index aab885b0da..2401f7c4d2 100644 --- a/p2p/transport/webrtc/pb/message.proto +++ b/p2p/transport/webrtc/pb/message.proto @@ -21,4 +21,6 @@ message Message { optional Flag flag=1; optional bytes message = 2; + + optional uint32 errorCode = 3; } diff --git a/p2p/transport/webrtc/stream.go b/p2p/transport/webrtc/stream.go index bce2f3f2e3..39873c9d66 100644 --- a/p2p/transport/webrtc/stream.go +++ b/p2p/transport/webrtc/stream.go @@ -69,8 +69,9 @@ type stream struct { // readerMx ensures that only a single goroutine reads from the reader. Read is not threadsafe // But we may need to read from reader for control messages from a different goroutine. - readerMx sync.Mutex - reader pbio.Reader + readerMx sync.Mutex + reader pbio.Reader + readError error // this buffer is limited up to a single message. Reason we need it // is because a reader might read a message midway, and so we need a @@ -82,6 +83,7 @@ type stream struct { writeStateChanged chan struct{} sendState sendState writeDeadline time.Time + writeError error controlMessageReaderOnce sync.Once // controlMessageReaderEndTime is the end time for reading FIN_ACK from the control @@ -146,6 +148,10 @@ func (s *stream) Close() error { } func (s *stream) Reset() error { + return s.ResetWithError(0) +} + +func (s *stream) ResetWithError(errCode network.StreamErrorCode) error { s.mx.Lock() isClosed := s.closeForShutdownErr != nil s.mx.Unlock() @@ -154,8 +160,8 @@ func (s *stream) Reset() error { } defer s.cleanup() - cancelWriteErr := s.cancelWrite() - closeReadErr := s.CloseRead() + cancelWriteErr := s.cancelWrite(errCode) + closeReadErr := s.closeRead(errCode, false) s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour)) return errors.Join(closeReadErr, cancelWriteErr) } @@ -175,19 +181,20 @@ func (s *stream) SetDeadline(t time.Time) error { return s.SetWriteDeadline(t) } -// processIncomingFlag process the flag on an incoming message +// processIncomingFlag processes the flag(FIN/RST/etc) on msg. // It needs to be called while the mutex is locked. -func (s *stream) processIncomingFlag(flag *pb.Message_Flag) { - if flag == nil { +func (s *stream) processIncomingFlag(msg *pb.Message) { + if msg.Flag == nil { return } - switch *flag { + switch msg.GetFlag() { case pb.Message_STOP_SENDING: // We must process STOP_SENDING after sending a FIN(sendStateDataSent). Remote peer // may not send a FIN_ACK once it has sent a STOP_SENDING if s.sendState == sendStateSending || s.sendState == sendStateDataSent { s.sendState = sendStateReset + s.writeError = &network.StreamError{Remote: true, ErrorCode: network.StreamErrorCode(msg.GetErrorCode())} } s.notifyWriteStateChanged() case pb.Message_FIN_ACK: @@ -206,6 +213,11 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) { case pb.Message_RESET: if s.receiveState == receiveStateReceiving { s.receiveState = receiveStateReset + s.readError = &network.StreamError{Remote: true, ErrorCode: network.StreamErrorCode(msg.GetErrorCode())} + } + if s.sendState == sendStateSending || s.sendState == sendStateDataSent { + s.sendState = sendStateReset + s.writeError = &network.StreamError{Remote: true, ErrorCode: network.StreamErrorCode(msg.GetErrorCode())} } s.spawnControlMessageReader() } @@ -235,7 +247,7 @@ func (s *stream) spawnControlMessageReader() { s.readerMx.Unlock() if s.nextMessage != nil { - s.processIncomingFlag(s.nextMessage.Flag) + s.processIncomingFlag(s.nextMessage) s.nextMessage = nil } var msg pb.Message @@ -266,7 +278,7 @@ func (s *stream) spawnControlMessageReader() { } return } - s.processIncomingFlag(msg.Flag) + s.processIncomingFlag(&msg) } }() }) diff --git a/p2p/transport/webrtc/stream_read.go b/p2p/transport/webrtc/stream_read.go index 80d99ea91c..003d5f563e 100644 --- a/p2p/transport/webrtc/stream_read.go +++ b/p2p/transport/webrtc/stream_read.go @@ -22,7 +22,7 @@ func (s *stream) Read(b []byte) (int, error) { case receiveStateDataRead: return 0, io.EOF case receiveStateReset: - return 0, network.ErrReset + return 0, s.readError } if len(b) == 0 { @@ -52,10 +52,11 @@ func (s *stream) Read(b []byte) (int, error) { // datachannel. For these implementations a stream reset will be observed as an // abrupt closing of the datachannel. s.receiveState = receiveStateReset - return 0, network.ErrReset + s.readError = &network.StreamError{Remote: true} + return 0, s.readError } if s.receiveState == receiveStateReset { - return 0, network.ErrReset + return 0, s.readError } if s.receiveState == receiveStateDataRead { return 0, io.EOF @@ -73,7 +74,7 @@ func (s *stream) Read(b []byte) (int, error) { } // process flags on the message after reading all the data - s.processIncomingFlag(s.nextMessage.Flag) + s.processIncomingFlag(s.nextMessage) s.nextMessage = nil if s.closeForShutdownErr != nil { return read, s.closeForShutdownErr @@ -82,7 +83,7 @@ func (s *stream) Read(b []byte) (int, error) { case receiveStateDataRead: return read, io.EOF case receiveStateReset: - return read, network.ErrReset + return read, s.readError } } } @@ -101,12 +102,18 @@ func (s *stream) setDataChannelReadDeadline(t time.Time) error { } func (s *stream) CloseRead() error { + return s.closeRead(0, false) +} + +func (s *stream) closeRead(errCode network.StreamErrorCode, remote bool) error { s.mx.Lock() defer s.mx.Unlock() var err error if s.receiveState == receiveStateReceiving && s.closeForShutdownErr == nil { - err = s.writer.WriteMsg(&pb.Message{Flag: pb.Message_STOP_SENDING.Enum()}) + code := uint32(errCode) + err = s.writer.WriteMsg(&pb.Message{Flag: pb.Message_STOP_SENDING.Enum(), ErrorCode: &code}) s.receiveState = receiveStateReset + s.readError = &network.StreamError{Remote: remote, ErrorCode: errCode} } s.spawnControlMessageReader() return err diff --git a/p2p/transport/webrtc/stream_write.go b/p2p/transport/webrtc/stream_write.go index 534a8d8e60..01fddac331 100644 --- a/p2p/transport/webrtc/stream_write.go +++ b/p2p/transport/webrtc/stream_write.go @@ -24,7 +24,7 @@ func (s *stream) Write(b []byte) (int, error) { } switch s.sendState { case sendStateReset: - return 0, network.ErrReset + return 0, s.writeError case sendStateDataSent, sendStateDataReceived: return 0, errWriteAfterClose } @@ -48,7 +48,7 @@ func (s *stream) Write(b []byte) (int, error) { } switch s.sendState { case sendStateReset: - return n, network.ErrReset + return n, s.writeError case sendStateDataSent, sendStateDataReceived: return n, errWriteAfterClose } @@ -119,7 +119,7 @@ func (s *stream) availableSendSpace() int { return availableSpace } -func (s *stream) cancelWrite() error { +func (s *stream) cancelWrite(errCode network.StreamErrorCode) error { s.mx.Lock() defer s.mx.Unlock() @@ -129,10 +129,12 @@ func (s *stream) cancelWrite() error { return nil } s.sendState = sendStateReset + s.writeError = &network.StreamError{Remote: false, ErrorCode: errCode} // Remove reference to this stream from data channel s.dataChannel.OnBufferedAmountLow(nil) s.notifyWriteStateChanged() - return s.writer.WriteMsg(&pb.Message{Flag: pb.Message_RESET.Enum()}) + code := uint32(errCode) + return s.writer.WriteMsg(&pb.Message{Flag: pb.Message_RESET.Enum(), ErrorCode: &code}) } func (s *stream) CloseWrite() error { diff --git a/p2p/transport/webtransport/conn.go b/p2p/transport/webtransport/conn.go index d914398e0e..f76ad10438 100644 --- a/p2p/transport/webtransport/conn.go +++ b/p2p/transport/webtransport/conn.go @@ -78,6 +78,10 @@ func (c *conn) Close() error { return err } +func (c *conn) CloseWithError(_ network.ConnErrorCode) error { + return c.Close() +} + func (c *conn) IsClosed() bool { return c.session.Context().Err() != nil } func (c *conn) Scope() network.ConnScope { return c.scope } func (c *conn) Transport() tpt.Transport { return c.transport } diff --git a/p2p/transport/webtransport/stream.go b/p2p/transport/webtransport/stream.go index 0849fc9f38..83ee52a5d1 100644 --- a/p2p/transport/webtransport/stream.go +++ b/p2p/transport/webtransport/stream.go @@ -56,6 +56,17 @@ func (s *stream) Reset() error { return nil } +// ResetWithError resets the stream ignoring the error code. Error codes aren't +// specified for WebTransport as the current implementation of WebTransport in +// browsers(https://www.ietf.org/archive/id/draft-kinnear-webtransport-http2-02.html) +// only supports 1 byte error codes. For more details, see +// https://github.com/libp2p/specs/blob/4eca305185c7aef219e936bef76c48b1ab0a8b43/error-codes/README.md?plain=1#L84 +func (s *stream) ResetWithError(_ network.StreamErrorCode) error { + s.Stream.CancelRead(reset) + s.Stream.CancelWrite(reset) + return nil +} + func (s *stream) Close() error { s.Stream.CancelRead(reset) return s.Stream.Close() diff --git a/test-plans/go.mod b/test-plans/go.mod index f7fe7cfada..a94915f5f0 100644 --- a/test-plans/go.mod +++ b/test-plans/go.mod @@ -43,7 +43,7 @@ require ( github.com/libp2p/go-nat v0.2.0 // indirect github.com/libp2p/go-netroute v0.2.2 // indirect github.com/libp2p/go-reuseport v0.4.0 // indirect - github.com/libp2p/go-yamux/v4 v4.0.2 // indirect + github.com/libp2p/go-yamux/v5 v5.0.0 // indirect github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/miekg/dns v1.1.63 // indirect diff --git a/test-plans/go.sum b/test-plans/go.sum index 84205c0ae0..696e98aea6 100644 --- a/test-plans/go.sum +++ b/test-plans/go.sum @@ -149,8 +149,8 @@ github.com/libp2p/go-netroute v0.2.2 h1:Dejd8cQ47Qx2kRABg6lPwknU7+nBnFRpko45/fFP github.com/libp2p/go-netroute v0.2.2/go.mod h1:Rntq6jUAH0l9Gg17w5bFGhcC9a+vk4KNXs6s7IljKYE= github.com/libp2p/go-reuseport v0.4.0 h1:nR5KU7hD0WxXCJbmw7r2rhRYruNRl2koHw8fQscQm2s= github.com/libp2p/go-reuseport v0.4.0/go.mod h1:ZtI03j/wO5hZVDFo2jKywN6bYKWLOy8Se6DrI2E1cLU= -github.com/libp2p/go-yamux/v4 v4.0.2 h1:nrLh89LN/LEiqcFiqdKDRHjGstN300C1269K/EX0CPU= -github.com/libp2p/go-yamux/v4 v4.0.2/go.mod h1:C808cCRgOs1iBwY4S71T5oxgMxgLmqUw56qh4AeBW2o= +github.com/libp2p/go-yamux/v5 v5.0.0 h1:2djUh96d3Jiac/JpGkKs4TO49YhsfLopAoryfPmf+Po= +github.com/libp2p/go-yamux/v5 v5.0.0/go.mod h1:en+3cdX51U0ZslwRdRLrvQsdayFt3TSUKvBGErzpWbU= github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd h1:br0buuQ854V8u83wA0rVZ8ttrq5CpaPZdvrK0LP2lOk=