diff --git a/echo.go b/echo.go index 60f7061d8..1f00b86d7 100644 --- a/echo.go +++ b/echo.go @@ -206,24 +206,28 @@ const ( // advertised as supported by the target resource. Returning an Allow header is mandatory // for status 405 (method not found) and useful for the OPTIONS method in responses. // See RFC 7231: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1 - HeaderAllow = "Allow" - HeaderAuthorization = "Authorization" - HeaderContentDisposition = "Content-Disposition" - HeaderContentEncoding = "Content-Encoding" - HeaderContentLength = "Content-Length" - HeaderContentType = "Content-Type" - HeaderCookie = "Cookie" - HeaderSetCookie = "Set-Cookie" - HeaderIfModifiedSince = "If-Modified-Since" - HeaderLastModified = "Last-Modified" - HeaderLocation = "Location" - HeaderRetryAfter = "Retry-After" - HeaderUpgrade = "Upgrade" - HeaderVary = "Vary" - HeaderWWWAuthenticate = "WWW-Authenticate" - HeaderXForwardedFor = "X-Forwarded-For" - HeaderXForwardedProto = "X-Forwarded-Proto" - HeaderXForwardedProtocol = "X-Forwarded-Protocol" + HeaderAllow = "Allow" + HeaderAuthorization = "Authorization" + HeaderContentDisposition = "Content-Disposition" + HeaderContentEncoding = "Content-Encoding" + HeaderContentLength = "Content-Length" + HeaderContentType = "Content-Type" + HeaderCookie = "Cookie" + HeaderForwarded = "Forwarded" + HeaderSetCookie = "Set-Cookie" + HeaderIfModifiedSince = "If-Modified-Since" + HeaderLastModified = "Last-Modified" + HeaderLocation = "Location" + HeaderRetryAfter = "Retry-After" + HeaderUpgrade = "Upgrade" + HeaderVary = "Vary" + HeaderWWWAuthenticate = "WWW-Authenticate" + HeaderXForwardedFor = "X-Forwarded-For" + HeaderXForwardedHost = "X-Forwarded-Host" + HeaderXForwardedPrefix = "X-Forwarded-Prefix" + HeaderXForwardedProto = "X-Forwarded-Proto" + HeaderXForwardedProtocol = "X-Forwarded-Protocol" + HeaderXForwardedSsl = "X-Forwarded-Ssl" HeaderXUrlScheme = "X-Url-Scheme" HeaderXHTTPMethodOverride = "X-HTTP-Method-Override" diff --git a/middleware/proxy_headers.go b/middleware/proxy_headers.go new file mode 100644 index 000000000..1353171b2 --- /dev/null +++ b/middleware/proxy_headers.go @@ -0,0 +1,58 @@ +package middleware + +import ( + "net/http" + "net/url" + "regexp" + "strings" + + "github.com/labstack/echo/v4" +) + +var ( + protoRegex = regexp.MustCompile(`(?i)(?:proto=)(https|http)`) + ipRegex = regexp.MustCompile("(?i)(?:for=)([^(;|,| )]+)") +) + +func ProxyHeaders() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if fwd := c.Request().Header.Get(echo.HeaderForwarded); fwd != "" { + if match := ipRegex.FindStringSubmatch(fwd); len(match) > 1 { + c.Request().RemoteAddr = strings.Trim(match[1], `"`) + } + } else if fwd := c.RealIP(); fwd != "" { + c.Request().RemoteAddr = fwd + } + + if scheme := getScheme(c.Request()); scheme != "" { + c.Request().URL.Scheme = scheme + } + + if c.Request().Header.Get(echo.HeaderXForwardedHost) != "" { + c.Request().Host = c.Request().Header.Get(echo.HeaderXForwardedHost) + } + + if prefix := c.Request().Header.Get(echo.HeaderXForwardedPrefix); prefix != "" { + c.Request().RequestURI, _ = url.JoinPath(prefix, c.Request().RequestURI) + c.Request().URL.Path, _ = url.JoinPath(prefix, c.Request().URL.Path) + } + return next(c) + } + } +} + +func getScheme(r *http.Request) string { + var scheme string + + if proto := r.Header.Get(echo.HeaderXForwardedProto); proto != "" { + scheme = strings.ToLower(proto) + } else if proto := r.Header.Get(echo.HeaderXForwardedProtocol); proto != "" { + scheme = strings.ToLower(proto) + } else if proto = r.Header.Get(echo.HeaderForwarded); proto != "" { + if match := protoRegex.FindStringSubmatch(proto); len(match) > 1 { + scheme = strings.ToLower(match[1]) + } + } + return scheme +} diff --git a/middleware/proxy_headers_test.go b/middleware/proxy_headers_test.go new file mode 100644 index 000000000..47d0f4c82 --- /dev/null +++ b/middleware/proxy_headers_test.go @@ -0,0 +1,144 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +func Test_getScheme(t *testing.T) { + tests := []struct { + name string + r *http.Request + headerName string + whenHeader string + want string + }{ + { + name: "test only X-Forwarded-Proto: https", + headerName: "X-Forwarded-Proto", + whenHeader: "https", + want: "https", + }, + { + name: "test only X-Forwarded-Proto: http", + headerName: "X-Forwarded-Proto", + whenHeader: "http", + want: "http", + }, + { + name: "test only X-Forwarded-Proto: HTTP", + headerName: "X-Forwarded-Proto", + whenHeader: "HTTP", + want: "http", + }, + { + name: "test only X-Forwarded-Protocol: https", + headerName: "X-Forwarded-Protocol", + whenHeader: "https", + want: "https", + }, + { + name: "test only X-Forwarded-Protocol: http", + headerName: "X-Forwarded-Protocol", + whenHeader: "http", + want: "http", + }, + { + name: "test only X-Forwarded-Protocol: HTTP", + headerName: "X-Forwarded-Protocol", + whenHeader: "HTTP", + want: "http", + }, + { + name: "test only Forwarded https", + headerName: "Forwarded", + whenHeader: "proto=https", + want: "https", + }, + { + name: "test only Forwarded: http", + headerName: "Forwarded", + whenHeader: "proto=http", + want: "http", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := &http.Request{ + Header: http.Header{ + tt.headerName: []string{tt.whenHeader}, + }, + } + + if got := getScheme(req); got != tt.want { + t.Errorf("getScheme() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestProxyHeaders(t *testing.T) { + tests := []struct { + name string + givenMW echo.MiddlewareFunc + whenMethod string + whenHeaders map[string]string + expectURL string + }{ + { + name: "Test X-Forwarded-Proto_HTTPS", + whenMethod: "GET", + whenHeaders: map[string]string{echo.HeaderXForwardedProto: "HTTPS"}, + expectURL: "https://srv.lan/tst/", + }, + { + name: "Test X-Forwarded-Prefix_TEST", + whenMethod: "GET", + whenHeaders: map[string]string{echo.HeaderXForwardedPrefix: "/test/"}, + expectURL: "http://srv.lan/test/tst/", + }, + { + name: "Test X-Forwarded-Prefix_TEST2", + whenMethod: "GET", + whenHeaders: map[string]string{echo.HeaderXForwardedPrefix: "/test2/"}, + expectURL: "http://srv.lan/test2/tst/", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + mw := ProxyHeaders() + if tc.givenMW != nil { + mw = tc.givenMW + } + + e.Use(mw) + + h := mw(func(c echo.Context) error { + return nil + }) + + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod + } + req := httptest.NewRequest(method, "http://srv.lan/tst/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + for k, v := range tc.whenHeaders { + req.Header.Set(k, v) + } + + err := h(c) + + assert.NoError(t, err) + url := c.Request().URL.String() + assert.Equal(t, tc.expectURL, url, "url: `%v` should be `%v`", tc.expectURL, url) + }) + } +}