diff --git a/hey.go b/hey.go index f727e26b..0c17b3cb 100644 --- a/hey.go +++ b/hey.go @@ -16,6 +16,7 @@ package main import ( + "context" "flag" "fmt" "io/ioutil" @@ -191,7 +192,10 @@ func main() { } } - req, err := http.NewRequest(method, url, nil) + parentCtx, parentCtxCancel := context.WithCancel(context.Background()) + defer parentCtxCancel() + + req, err := http.NewRequestWithContext(parentCtx, method, url, nil) if err != nil { usageAndExit(err.Error()) } @@ -241,11 +245,13 @@ func main() { signal.Notify(c, os.Interrupt) go func() { <-c + parentCtxCancel() w.Stop() }() if dur > 0 { go func() { time.Sleep(dur) + parentCtxCancel() w.Stop() }() } diff --git a/requester/requester.go b/requester/requester.go index fd7277e7..6efacd55 100644 --- a/requester/requester.go +++ b/requester/requester.go @@ -53,10 +53,6 @@ type Work struct { RequestBody []byte - // RequestFunc is a function to generate requests. If it is nil, then - // Request and RequestData are cloned for each request. - RequestFunc func() *http.Request - // N is the total number of requests to make. N int @@ -150,12 +146,7 @@ func (b *Work) makeRequest(c *http.Client) { var code int var dnsStart, connStart, resStart, reqStart, delayStart time.Duration var dnsDuration, connDuration, resDuration, reqDuration, delayDuration time.Duration - var req *http.Request - if b.RequestFunc != nil { - req = b.RequestFunc() - } else { - req = cloneRequest(b.Request, b.RequestBody) - } + req := cloneRequest(b.Request, b.RequestBody) trace := &httptrace.ClientTrace{ DNSStart: func(info httptrace.DNSStartInfo) { dnsStart = now() @@ -265,9 +256,7 @@ func (b *Work) runWorkers() { // cloneRequest returns a clone of the provided *http.Request. // The clone is a shallow copy of the struct and its Header map. func cloneRequest(r *http.Request, body []byte) *http.Request { - // shallow copy of the struct - r2 := new(http.Request) - *r2 = *r + r2 := r.WithContext(r.Context()) // deep copy of the Header r2.Header = make(http.Header, len(r.Header)) for k, s := range r.Header {