Skip to content

Commit

Permalink
net/http: don't treat an alternate protocol as a known round tripper
Browse files Browse the repository at this point in the history
As of CL 175857, the client code checks for known round tripper
implementations, and uses simpler cancellation code when it finds one.
However, this code was not considering the case of a request that uses
a user-defined protocol, where the user-defined protocol was
registered with the transport to use a different round tripper.
The effect was that round trippers that worked with earlier
releases would not see the expected cancellation semantics with tip.

Fixes golang#36820

Change-Id: I60e75b5d0badcfb9fde9d73a966ba1d3f7aa42b1
Reviewed-on: https://go-review.googlesource.com/c/go/+/216618
Run-TryBot: Ian Lance Taylor <[email protected]>
TryBot-Result: Gobot Gobot <[email protected]>
Reviewed-by: Brad Fitzpatrick <[email protected]>
  • Loading branch information
ianlancetaylor committed Jan 29, 2020
1 parent a6701d8 commit c436ead
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 11 deletions.
17 changes: 12 additions & 5 deletions src/net/http/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,17 @@ func timeBeforeContextDeadline(t time.Time, ctx context.Context) bool {

// knownRoundTripperImpl reports whether rt is a RoundTripper that's
// maintained by the Go team and known to implement the latest
// optional semantics (notably contexts).
func knownRoundTripperImpl(rt RoundTripper) bool {
switch rt.(type) {
case *Transport, *http2Transport:
// optional semantics (notably contexts). The Request is used
// to check whether this particular request is using an alternate protocol,
// in which case we need to check the RoundTripper for that protocol.
func knownRoundTripperImpl(rt RoundTripper, req *Request) bool {
switch t := rt.(type) {
case *Transport:
if altRT := t.alternateRoundTripper(req); altRT != nil {
return knownRoundTripperImpl(altRT, req)
}
return true
case *http2Transport, http2noDialH2RoundTripper:
return true
}
// There's a very minor chance of a false positive with this.
Expand Down Expand Up @@ -319,7 +326,7 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi
if deadline.IsZero() {
return nop, alwaysFalse
}
knownTransport := knownRoundTripperImpl(rt)
knownTransport := knownRoundTripperImpl(rt, req)
oldCtx := req.Context()

if req.Cancel == nil && knownTransport {
Expand Down
4 changes: 4 additions & 0 deletions src/net/http/omithttp2.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ type http2erringRoundTripper struct{}

func (http2erringRoundTripper) RoundTrip(*Request) (*Response, error) { panic(noHTTP2) }

type http2noDialH2RoundTripper struct{}

func (http2noDialH2RoundTripper) RoundTrip(*Request) (*Response, error) { panic(noHTTP2) }

type http2noDialClientConnPool struct {
http2clientConnPool http2clientConnPool
}
Expand Down
20 changes: 14 additions & 6 deletions src/net/http/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,17 @@ func (t *Transport) useRegisteredProtocol(req *Request) bool {
return true
}

// alternateRoundTripper returns the alternate RoundTripper to use
// for this request if the Request's URL scheme requires one,
// or nil for the normal case of using the Transport.
func (t *Transport) alternateRoundTripper(req *Request) RoundTripper {
if !t.useRegisteredProtocol(req) {
return nil
}
altProto, _ := t.altProto.Load().(map[string]RoundTripper)
return altProto[req.URL.Scheme]
}

// roundTrip implements a RoundTripper over HTTP.
func (t *Transport) roundTrip(req *Request) (*Response, error) {
t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
Expand Down Expand Up @@ -500,12 +511,9 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
}
}

if t.useRegisteredProtocol(req) {
altProto, _ := t.altProto.Load().(map[string]RoundTripper)
if altRT := altProto[scheme]; altRT != nil {
if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol {
return resp, err
}
if altRT := t.alternateRoundTripper(req); altRT != nil {
if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol {
return resp, err
}
}
if !isHTTP {
Expand Down
32 changes: 32 additions & 0 deletions src/net/http/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6143,3 +6143,35 @@ func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
t.Errorf("error occurred: %v", err)
}
}

// Issue 36820
// Test that we use the older backward compatible cancellation protocol
// when a RoundTripper is registered via RegisterProtocol.
func TestAltProtoCancellation(t *testing.T) {
defer afterTest(t)
tr := &Transport{}
c := &Client{
Transport: tr,
Timeout: time.Millisecond,
}
tr.RegisterProtocol("timeout", timeoutProto{})
_, err := c.Get("timeout://bar.com/path")
if err == nil {
t.Error("request unexpectedly succeeded")
} else if !strings.Contains(err.Error(), timeoutProtoErr.Error()) {
t.Errorf("got error %q, does not contain expected string %q", err, timeoutProtoErr)
}
}

var timeoutProtoErr = errors.New("canceled as expected")

type timeoutProto struct{}

func (timeoutProto) RoundTrip(req *Request) (*Response, error) {
select {
case <-req.Cancel:
return nil, timeoutProtoErr
case <-time.After(5 * time.Second):
return nil, errors.New("request was not canceled")
}
}

0 comments on commit c436ead

Please sign in to comment.