From ef914ab8b59747d111ca2a2c80f80c1a061de35c Mon Sep 17 00:00:00 2001 From: Will McCutchen Date: Fri, 16 Jul 2021 14:28:43 -0400 Subject: [PATCH] Update /ip endpoint to return single origin IP --- httpbin/handlers_test.go | 64 ++++++++++++++++++++++++++++++++-------- httpbin/helpers.go | 9 +++--- 2 files changed, 56 insertions(+), 17 deletions(-) diff --git a/httpbin/handlers_test.go b/httpbin/handlers_test.go index 2bf8e050..ee703ddb 100644 --- a/httpbin/handlers_test.go +++ b/httpbin/handlers_test.go @@ -318,22 +318,60 @@ func TestCORS(t *testing.T) { } func TestIP(t *testing.T) { - r, _ := http.NewRequest("GET", "/ip", nil) - r.RemoteAddr = "192.168.0.100" - w := httptest.NewRecorder() - handler.ServeHTTP(w, r) + t.Parallel() - assertStatusCode(t, w, http.StatusOK) - assertContentType(t, w, jsonContentType) - - var resp *ipResponse - err := json.Unmarshal(w.Body.Bytes(), &resp) - if err != nil { - t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err) + testCases := map[string]struct { + remoteAddr string + headers map[string]string + wantOrigin string + }{ + "remote addr used if no x-forwarded-for": { + remoteAddr: "192.168.0.100", + wantOrigin: "192.168.0.100", + }, + "remote addr used if x-forwarded-for empty": { + remoteAddr: "192.168.0.100", + headers: map[string]string{"X-Forwarded-For": ""}, + wantOrigin: "192.168.0.100", + }, + "first entry in x-forwarded-for used if present": { + remoteAddr: "192.168.0.100", + headers: map[string]string{"X-Forwarded-For": "10.1.1.1, 10.2.2.2, 10.3.3.3"}, + wantOrigin: "10.1.1.1", + }, + "single entry x-forwarded-for ok": { + remoteAddr: "192.168.0.100", + headers: map[string]string{"X-Forwarded-For": "10.1.1.1"}, + wantOrigin: "10.1.1.1", + }, } - if resp.Origin != r.RemoteAddr { - t.Fatalf("%#v != %#v", resp.Origin, r.RemoteAddr) + for name, tc := range testCases { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + r, _ := http.NewRequest("GET", "/ip", nil) + r.RemoteAddr = tc.remoteAddr + for k, v := range tc.headers { + r.Header.Set(k, v) + } + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + assertStatusCode(t, w, http.StatusOK) + assertContentType(t, w, jsonContentType) + + var resp *ipResponse + err := json.Unmarshal(w.Body.Bytes(), &resp) + if err != nil { + t.Fatalf("failed to unmarshal body %s from JSON: %s", w.Body, err) + } + + if resp.Origin != tc.wantOrigin { + t.Fatalf("got %q, want %q", resp.Origin, tc.wantOrigin) + } + }) } } diff --git a/httpbin/helpers.go b/httpbin/helpers.go index 218386b4..366b8096 100644 --- a/httpbin/helpers.go +++ b/httpbin/helpers.go @@ -33,11 +33,12 @@ func getRequestHeaders(r *http.Request) http.Header { } func getOrigin(r *http.Request) string { - origin := r.Header.Get("X-Forwarded-For") - if origin == "" { - origin = r.RemoteAddr + forwardedFor := r.Header.Get("X-Forwarded-For") + if forwardedFor == "" { + return r.RemoteAddr } - return origin + // take the first entry in a comma-separated list of IP addrs + return strings.TrimSpace(strings.SplitN(forwardedFor, ",", 2)[0]) } func getURL(r *http.Request) *url.URL {