diff --git a/http.go b/http.go index 14a0f02..c97ef5d 100644 --- a/http.go +++ b/http.go @@ -25,14 +25,14 @@ var httpClient = &http.Client{ func doRequest(r request) response { - req, err := http.NewRequest(r.method, r.url.String(), nil) + req, err := http.NewRequest(r.method, r.URL(), nil) if err != nil { return response{request: r, err: err} } req.Close = true // add the host header to the request manually so it shows up in the output - r.headers = append(r.headers, fmt.Sprintf("Host: %s", r.url.Hostname())) + r.headers = append(r.headers, fmt.Sprintf("Host: %s", r.Hostname())) for _, h := range r.headers { parts := strings.SplitN(h, ":", 2) diff --git a/main.go b/main.go index 9c8f24d..d60df7b 100644 --- a/main.go +++ b/main.go @@ -3,32 +3,12 @@ package main import ( "bufio" "fmt" - "net/url" "os" "path/filepath" "sync" "time" ) -// a request is a wrapper for a URL that we want to request -type request struct { - method string - url *url.URL - headers []string -} - -// a response is a wrapper around an HTTP response; -// it contains the request value for context. -type response struct { - request request - - status string - statusCode int - headers []string - body []byte - err error -} - // a requester is a function that makes HTTP requests type requester func(request) response @@ -74,10 +54,7 @@ func main() { } // set up a rate limiter - rl := &rateLimiter{ - delay: time.Duration(c.delay * 1000000), - reqs: make(map[string]time.Time), - } + rl := newRateLimiter(time.Duration(c.delay * 1000000)) // the request and response channels for // the worker pool @@ -91,7 +68,7 @@ func main() { go func() { for req := range requests { - rl.Block(req.url) + rl.Block(req.Hostname()) responses <- doRequest(req) } wg.Done() @@ -113,7 +90,7 @@ func main() { fmt.Fprintf(os.Stderr, "failed to save file: %s\n", err) } - line := fmt.Sprintf("%s %s (%s)\n", path, res.request.url, res.status) + line := fmt.Sprintf("%s %s (%s)\n", path, res.request.URL(), res.status) fmt.Fprintf(index, "%s", line) if c.verbose { fmt.Printf("%s", line) @@ -125,12 +102,13 @@ func main() { // send requests for each suffix for every prefix for _, suffix := range suffixes { for _, prefix := range prefixes { - u, err := url.Parse(prefix + suffix) - if err != nil { - fmt.Printf("failed to parse url: %s\n", err) - continue + + requests <- request{ + method: c.method, + prefix: prefix, + suffix: suffix, + headers: c.headers, } - requests <- request{method: c.method, url: u, headers: c.headers} } } diff --git a/ratelimit.go b/ratelimit.go index 1bdbcdf..5925082 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -1,44 +1,56 @@ package main import ( - "net/url" "sync" "time" ) +// a rateLimiter allows you to delay operations +// on a per-key basis. I.e. only one operation for +// a given key can be done within the delay time type rateLimiter struct { sync.Mutex delay time.Duration - reqs map[string]time.Time + ops map[string]time.Time } -func (r *rateLimiter) Block(u *url.URL) { +// newRateLimiter returns a new *rateLimiter for the +// provided delay +func newRateLimiter(delay time.Duration) *rateLimiter { + return &rateLimiter{ + delay: delay, + ops: make(map[string]time.Time), + } +} + +// Block blocks until an operation for key is +// allowed to proceed +func (r *rateLimiter) Block(key string) { now := time.Now() - key := u.Hostname() r.Lock() // if there's nothing in the map we can // return straight away - if _, ok := r.reqs[key]; !ok { - r.reqs[key] = now + if _, ok := r.ops[key]; !ok { + r.ops[key] = now r.Unlock() return } // if time is up we can return straight away - t := r.reqs[key] + t := r.ops[key] deadline := t.Add(r.delay) if now.After(deadline) { - r.reqs[key] = now + r.ops[key] = now r.Unlock() return } remaining := deadline.Sub(now) - // Set the time of the request - r.reqs[key] = now.Add(remaining) + // Set the time of the operation + r.ops[key] = now.Add(remaining) r.Unlock() // Block for the remaining time diff --git a/request.go b/request.go new file mode 100644 index 0000000..75a73ef --- /dev/null +++ b/request.go @@ -0,0 +1,28 @@ +package main + +import "net/url" + +// a request is a wrapper for a URL that we want to request +type request struct { + method string + prefix string + suffix string + headers []string +} + +// Hostname returns the hostname part of the request +func (r request) Hostname() string { + u, err := url.Parse(r.prefix) + + // the hostname part is used only for the rate + // limiting and the + if err != nil { + return "unknown" + } + return u.Hostname() +} + +// URL returns the full URL to request +func (r request) URL() string { + return r.prefix + r.suffix +} diff --git a/response.go b/response.go index 2a0280a..4975b4d 100644 --- a/response.go +++ b/response.go @@ -9,18 +9,26 @@ import ( "path" ) +// a response is a wrapper around an HTTP response; +// it contains the request value for context. +type response struct { + request request + + status string + statusCode int + headers []string + body []byte + err error +} + +// String returns a string representation of the request and response func (r response) String() string { b := &bytes.Buffer{} - b.WriteString(r.request.url.String()) + b.WriteString(r.request.URL()) b.WriteString("\n\n") - qs := "" - if len(r.request.url.Query()) > 0 { - qs = "?" + r.request.url.Query().Encode() - } - - b.WriteString(fmt.Sprintf("> %s %s%s HTTP/1.1\n", r.request.method, r.request.url.EscapedPath(), qs)) + b.WriteString(fmt.Sprintf("> %s %s HTTP/1.1\n", r.request.method, r.request.suffix)) // request headers for _, h := range r.request.headers { @@ -43,13 +51,14 @@ func (r response) String() string { return b.String() } +// save write a request and response output to disk func (r response) save(pathPrefix string) (string, error) { content := []byte(r.String()) checksum := sha1.Sum(content) parts := []string{pathPrefix} - parts = append(parts, r.request.url.Hostname()) + parts = append(parts, r.request.Hostname()) parts = append(parts, fmt.Sprintf("%x", checksum)) p := path.Join(parts...)