diff --git a/http2/transport.go b/http2/transport.go index f965579f7..ac90a2631 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -1266,6 +1266,27 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { return res, nil } + cancelRequest := func(cs *clientStream, err error) error { + cs.cc.mu.Lock() + defer cs.cc.mu.Unlock() + cs.abortStreamLocked(err) + if cs.ID != 0 { + // This request may have failed because of a problem with the connection, + // or for some unrelated reason. (For example, the user might have canceled + // the request without waiting for a response.) Mark the connection as + // not reusable, since trying to reuse a dead connection is worse than + // unnecessarily creating a new one. + // + // If cs.ID is 0, then the request was never allocated a stream ID and + // whatever went wrong was unrelated to the connection. We might have + // timed out waiting for a stream slot when StrictMaxConcurrentStreams + // is set, for example, in which case retrying on a different connection + // will not help. + cs.cc.doNotReuse = true + } + return err + } + for { select { case <-cs.respHeaderRecv: @@ -1280,15 +1301,12 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { return handleResponseHeaders() default: waitDone() - return nil, cs.abortErr + return nil, cancelRequest(cs, cs.abortErr) } case <-ctx.Done(): - err := ctx.Err() - cs.abortStream(err) - return nil, err + return nil, cancelRequest(cs, ctx.Err()) case <-cs.reqCancel: - cs.abortStream(errRequestCanceled) - return nil, errRequestCanceled + return nil, cancelRequest(cs, errRequestCanceled) } } } diff --git a/http2/transport_test.go b/http2/transport_test.go index 5adef4292..54d455148 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -775,7 +775,6 @@ func newClientTester(t *testing.T) *clientTester { cc, err := net.Dial("tcp", ln.Addr().String()) if err != nil { t.Fatal(err) - } sc, err := ln.Accept() if err != nil { @@ -1765,6 +1764,18 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) { defer tr.CloseIdleConnections() checkRoundTrip := func(req *http.Request, wantErr error, desc string) { + // Make an arbitrary request to ensure we get the server's + // settings frame and initialize peerMaxHeaderListSize. + req0, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + t.Fatalf("newRequest: NewRequest: %v", err) + } + res0, err := tr.RoundTrip(req0) + if err != nil { + t.Errorf("%v: Initial RoundTrip err = %v", desc, err) + } + res0.Body.Close() + res, err := tr.RoundTrip(req) if err != wantErr { if res != nil { @@ -1825,13 +1836,9 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) { return req } - // Make an arbitrary request to ensure we get the server's - // settings frame and initialize peerMaxHeaderListSize. + // Validate peerMaxHeaderListSize. req := newRequest() checkRoundTrip(req, nil, "Initial request") - - // Get the ClientConn associated with the request and validate - // peerMaxHeaderListSize. addr := authorityAddr(req.URL.Scheme, req.URL.Host) cc, err := tr.connPool().GetClientConn(req, addr) if err != nil { @@ -3738,35 +3745,33 @@ func testTransportPingWhenReading(t *testing.T, readIdleTimeout, deadline time.D ct.run() } -func TestTransportRetryAfterGOAWAY(t *testing.T) { - var dialer struct { - sync.Mutex - count int - } - ct1 := make(chan *clientTester) - ct2 := make(chan *clientTester) - +func testClientMultipleDials(t *testing.T, client func(*Transport), server func(int, *clientTester)) { ln := newLocalListener(t) defer ln.Close() + var ( + mu sync.Mutex + count int + conns []net.Conn + ) + var wg sync.WaitGroup tr := &Transport{ TLSClientConfig: tlsConfigInsecure, } tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) { - dialer.Lock() - defer dialer.Unlock() - dialer.count++ - if dialer.count == 3 { - return nil, errors.New("unexpected number of dials") - } + mu.Lock() + defer mu.Unlock() + count++ cc, err := net.Dial("tcp", ln.Addr().String()) if err != nil { return nil, fmt.Errorf("dial error: %v", err) } + conns = append(conns, cc) sc, err := ln.Accept() if err != nil { return nil, fmt.Errorf("accept error: %v", err) } + conns = append(conns, sc) ct := &clientTester{ t: t, tr: tr, @@ -3774,19 +3779,26 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) { sc: sc, fr: NewFramer(sc, sc), } - switch dialer.count { - case 1: - ct1 <- ct - case 2: - ct2 <- ct - } + wg.Add(1) + go func(count int) { + defer wg.Done() + server(count, ct) + sc.Close() + }(count) return cc, nil } - errs := make(chan error, 3) + client(tr) + tr.CloseIdleConnections() + ln.Close() + for _, c := range conns { + c.Close() + } + wg.Wait() +} - // Client. - go func() { +func TestTransportRetryAfterGOAWAY(t *testing.T) { + client := func(tr *Transport) { req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) res, err := tr.RoundTrip(req) if res != nil { @@ -3796,102 +3808,76 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) { } } if err != nil { - err = fmt.Errorf("RoundTrip: %v", err) - } - errs <- err - }() - - connToClose := make(chan io.Closer, 2) - - // Server for the first request. - go func() { - ct := <-ct1 - - connToClose <- ct.cc - ct.greet() - hf, err := ct.firstHeaders() - if err != nil { - errs <- fmt.Errorf("server1 failed reading HEADERS: %v", err) - return + t.Errorf("RoundTrip: %v", err) } - t.Logf("server1 got %v", hf) - if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil { - errs <- fmt.Errorf("server1 failed writing GOAWAY: %v", err) - return - } - errs <- nil - }() + } - // Server for the second request. - go func() { - ct := <-ct2 + server := func(count int, ct *clientTester) { + switch count { + case 1: + ct.greet() + hf, err := ct.firstHeaders() + if err != nil { + t.Errorf("server1 failed reading HEADERS: %v", err) + return + } + t.Logf("server1 got %v", hf) + if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil { + t.Errorf("server1 failed writing GOAWAY: %v", err) + return + } + case 2: + ct.greet() + hf, err := ct.firstHeaders() + if err != nil { + t.Errorf("server2 failed reading HEADERS: %v", err) + return + } + t.Logf("server2 got %v", hf) - connToClose <- ct.cc - ct.greet() - hf, err := ct.firstHeaders() - if err != nil { - errs <- fmt.Errorf("server2 failed reading HEADERS: %v", err) + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) + err = ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + if err != nil { + t.Errorf("server2 failed writing response HEADERS: %v", err) + } + default: + t.Errorf("unexpected number of dials") return } - t.Logf("server2 got %v", hf) - - var buf bytes.Buffer - enc := hpack.NewEncoder(&buf) - enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"}) - err = ct.fr.WriteHeaders(HeadersFrameParam{ - StreamID: hf.StreamID, - EndHeaders: true, - EndStream: false, - BlockFragment: buf.Bytes(), - }) - if err != nil { - errs <- fmt.Errorf("server2 failed writing response HEADERS: %v", err) - } else { - errs <- nil - } - }() - - for k := 0; k < 3; k++ { - err := <-errs - if err != nil { - t.Error(err) - } } - close(connToClose) - for c := range connToClose { - c.Close() - } + testClientMultipleDials(t, client, server) } func TestTransportRetryAfterRefusedStream(t *testing.T) { clientDone := make(chan struct{}) - ct := newClientTester(t) - ct.client = func() error { - defer ct.cc.(*net.TCPConn).CloseWrite() - if runtime.GOOS == "plan9" { - // CloseWrite not supported on Plan 9; Issue 17906 - defer ct.cc.(*net.TCPConn).Close() - } + client := func(tr *Transport) { defer close(clientDone) req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) - resp, err := ct.tr.RoundTrip(req) + resp, err := tr.RoundTrip(req) if err != nil { - return fmt.Errorf("RoundTrip: %v", err) + t.Errorf("RoundTrip: %v", err) + return } resp.Body.Close() if resp.StatusCode != 204 { - return fmt.Errorf("Status = %v; want 204", resp.StatusCode) + t.Errorf("Status = %v; want 204", resp.StatusCode) + return } - return nil } - ct.server = func() error { + + server := func(count int, ct *clientTester) { ct.greet() var buf bytes.Buffer enc := hpack.NewEncoder(&buf) - nreq := 0 - for { f, err := ct.fr.ReadFrame() if err != nil { @@ -3900,19 +3886,19 @@ func TestTransportRetryAfterRefusedStream(t *testing.T) { // If the client's done, it // will have reported any // errors on its side. - return nil default: - return err + t.Error(err) } + return } switch f := f.(type) { case *WindowUpdateFrame, *SettingsFrame: case *HeadersFrame: if !f.HeadersEnded() { - return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) + t.Errorf("headers should have END_HEADERS be ended: %v", f) + return } - nreq++ - if nreq == 1 { + if count == 1 { ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) } else { enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) @@ -3924,11 +3910,13 @@ func TestTransportRetryAfterRefusedStream(t *testing.T) { }) } default: - return fmt.Errorf("Unexpected client frame %v", f) + t.Errorf("Unexpected client frame %v", f) + return } } } - ct.run() + + testClientMultipleDials(t, client, server) } func TestTransportRetryHasLimit(t *testing.T) { @@ -4143,6 +4131,7 @@ func TestTransportRequestsStallAtServerLimit(t *testing.T) { greet := make(chan struct{}) // server sends initial SETTINGS frame gotRequest := make(chan struct{}) // server received a request clientDone := make(chan struct{}) + cancelClientRequest := make(chan struct{}) // Collect errors from goroutines. var wg sync.WaitGroup @@ -4221,9 +4210,8 @@ func TestTransportRequestsStallAtServerLimit(t *testing.T) { req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), body) if k == maxConcurrent { // This request will be canceled. - cancel := make(chan struct{}) - req.Cancel = cancel - close(cancel) + req.Cancel = cancelClientRequest + close(cancelClientRequest) _, err := ct.tr.RoundTrip(req) close(clientRequestCancelled) if err == nil { @@ -5986,14 +5974,21 @@ func TestTransportRetriesOnStreamProtocolError(t *testing.T) { } func TestClientConnReservations(t *testing.T) { - cc := &ClientConn{ - reqHeaderMu: make(chan struct{}, 1), - streams: make(map[uint32]*clientStream), - maxConcurrentStreams: initialMaxConcurrentStreams, - nextStreamID: 1, - t: &Transport{}, + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + }, func(s *Server) { + s.MaxConcurrentStreams = initialMaxConcurrentStreams + }) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + + cc, err := tr.newClientConn(st.cc, false) + if err != nil { + t.Fatal(err) } - cc.cond = sync.NewCond(&cc.mu) + + req, _ := http.NewRequest("GET", st.ts.URL, nil) n := 0 for n <= initialMaxConcurrentStreams && cc.ReserveNewRequest() { n++ @@ -6001,8 +5996,8 @@ func TestClientConnReservations(t *testing.T) { if n != initialMaxConcurrentStreams { t.Errorf("did %v reservations; want %v", n, initialMaxConcurrentStreams) } - if _, err := cc.RoundTrip(new(http.Request)); !errors.Is(err, errNilRequestURL) { - t.Fatalf("RoundTrip error = %v; want errNilRequestURL", err) + if _, err := cc.RoundTrip(req); err != nil { + t.Fatalf("RoundTrip error = %v", err) } n2 := 0 for n2 <= 5 && cc.ReserveNewRequest() { @@ -6014,7 +6009,7 @@ func TestClientConnReservations(t *testing.T) { // Use up all the reservations for i := 0; i < n; i++ { - cc.RoundTrip(new(http.Request)) + cc.RoundTrip(req) } n2 = 0 @@ -6370,3 +6365,95 @@ func TestTransportSlowClose(t *testing.T) { } res.Body.Close() } + +type blockReadConn struct { + net.Conn + blockc chan struct{} +} + +func (c *blockReadConn) Read(b []byte) (n int, err error) { + <-c.blockc + return c.Conn.Read(b) +} + +func TestTransportReuseAfterError(t *testing.T) { + serverReqc := make(chan struct{}, 3) + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + serverReqc <- struct{}{} + }, optOnlyServer) + defer st.Close() + + var ( + unblockOnce sync.Once + blockc = make(chan struct{}) + connCountMu sync.Mutex + connCount int + ) + tr := &Transport{ + TLSClientConfig: tlsConfigInsecure, + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + // The first connection dialed will block on reads until blockc is closed. + connCountMu.Lock() + defer connCountMu.Unlock() + connCount++ + conn, err := tls.Dial(network, addr, cfg) + if err != nil { + return nil, err + } + if connCount == 1 { + return &blockReadConn{ + Conn: conn, + blockc: blockc, + }, nil + } + return conn, nil + }, + } + defer tr.CloseIdleConnections() + defer unblockOnce.Do(func() { + // Ensure that reads on blockc are unblocked if we return early. + close(blockc) + }) + + req, _ := http.NewRequest("GET", st.ts.URL, nil) + + // Request 1 is made on conn 1. + // Reading the response will block. + // Wait until the server receives the request, and continue. + req1c := make(chan struct{}) + go func() { + defer close(req1c) + res1, err := tr.RoundTrip(req.Clone(context.Background())) + if err != nil { + t.Errorf("request 1: %v", err) + } else { + res1.Body.Close() + } + }() + <-serverReqc + + // Request 2 is also made on conn 1. + // Reading the response will block. + // The request fails when the context deadline expires. + // Conn 1 should now be flagged as unfit for reuse. + timeoutCtx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + _, err := tr.RoundTrip(req.Clone(timeoutCtx)) + if err == nil { + t.Errorf("request 2 unexpectedly succeeded (want timeout)") + } + time.Sleep(1 * time.Millisecond) + + // Request 3 is made on a new conn, and succeeds. + res3, err := tr.RoundTrip(req.Clone(context.Background())) + if err != nil { + t.Fatalf("request 3: %v", err) + } + res3.Body.Close() + + // Unblock conn 1, and verify that request 1 completes. + unblockOnce.Do(func() { + close(blockc) + }) + <-req1c +}