diff --git a/cmd/aws-sigv4-proxy/main.go b/cmd/aws-sigv4-proxy/main.go index c4817cd..7a62f70 100644 --- a/cmd/aws-sigv4-proxy/main.go +++ b/cmd/aws-sigv4-proxy/main.go @@ -39,7 +39,9 @@ var ( logFailedResponse = kingpin.Flag("log-failed-requests", "Log 4xx and 5xx response body").Bool() logSinging = kingpin.Flag("log-signing-process", "Log sigv4 signing process").Bool() port = kingpin.Flag("port", "Port to serve http on").Default(":8080").String() - strip = kingpin.Flag("strip", "Headers to strip from incoming request").Short('s').Strings() + stripDeprecated = kingpin.Flag("strip", "Headers to strip from incoming request").Hidden().Short('s').Strings() + stripHeaders = kingpin.Flag("strip-header", "Headers to strip from incoming request. Use wildcard suffix '*' to match prefix e.g. x-aws-*").Strings() + stripParams = kingpin.Flag("strip-param", "Query parameters to strip from incoming request. Use wildcard suffix '*' to match prefix e.g. x-aws-*").Strings() duplicateHeaders = kingpin.Flag("duplicate-headers", "Duplicate headers to an X-Original- prefix name").Strings() roleArn = kingpin.Flag("role-arn", "Amazon Resource Name (ARN) of the role to assume").String() signingNameOverride = kingpin.Flag("name", "AWS Service to sign for").String() @@ -119,8 +121,22 @@ func main() { }, } - log.WithFields(log.Fields{"StripHeaders": *strip}).Infof("Stripping headers %s", *strip) - log.WithFields(log.Fields{"DuplicateHeaders": *duplicateHeaders}).Infof("Duplicating headers %s", *duplicateHeaders) + if *stripDeprecated != nil { + log.Warn("Using deprecated flag 'strip' - use 'strip-header' instead") + if *stripHeaders != nil { + log.Fatal("Use either 'strip' or 'strip-header'") + } + stripHeaders = stripDeprecated + } + if *stripHeaders != nil { + log.WithFields(log.Fields{"StripHeaders": *stripHeaders}).Infof("Stripping headers %s", *stripHeaders) + } + if *stripParams != nil { + log.WithFields(log.Fields{"StripParams": *stripParams}).Infof("Stripping query parameters %s", *stripParams) + } + if *duplicateHeaders != nil { + log.WithFields(log.Fields{"DuplicateHeaders": *duplicateHeaders}).Infof("Duplicating headers %s", *duplicateHeaders) + } log.WithFields(log.Fields{"port": *port}).Infof("Listening on %s", *port) log.Fatal( @@ -128,7 +144,8 @@ func main() { ProxyClient: &handler.ProxyClient{ Signer: signer, Client: client, - StripRequestHeaders: *strip, + StripRequestHeaders: *stripHeaders, + StripRequestQueryParams: *stripParams, DuplicateRequestHeaders: *duplicateHeaders, SigningNameOverride: *signingNameOverride, SigningHostOverride: *signingHostOverride, diff --git a/handler/proxy_client.go b/handler/proxy_client.go index 492edf2..1b0705e 100644 --- a/handler/proxy_client.go +++ b/handler/proxy_client.go @@ -22,6 +22,7 @@ import ( "io/ioutil" "net/http" "net/http/httputil" + "strings" "time" "github.com/aws/aws-sdk-go/aws/endpoints" @@ -39,6 +40,7 @@ type ProxyClient struct { Signer *v4.Signer Client Client StripRequestHeaders []string + StripRequestQueryParams []string DuplicateRequestHeaders []string SigningNameOverride string SigningHostOverride string @@ -141,6 +143,28 @@ func (p *ProxyClient) Do(req *http.Request) (*http.Response, error) { proxyURL.Scheme = p.SchemeOverride } + // Remove any query params specified + if len(p.StripRequestQueryParams) > 0 { + query := req.URL.Query() + for _, stripParam := range p.StripRequestQueryParams { + // handle wildcard strip values + if strings.HasSuffix(stripParam, "*") { + prefix := strings.ToLower(stripParam[:len(stripParam)-1]) + for param := range query { + p := strings.ToLower(param) + if strings.HasPrefix(p, prefix) { + log.WithField("StripParam", string(param)).Debug("Stripping Param:") + query.Del(param) + } + } + } else { + log.WithField("StripParam", string(stripParam)).Debug("Stripping Param:") + query.Del(stripParam) + } + } + proxyURL.RawQuery = query.Encode() + } + if log.GetLevel() == log.DebugLevel { initialReqDump, err := httputil.DumpRequest(req, true) if err != nil { @@ -205,8 +229,20 @@ func (p *ProxyClient) Do(req *http.Request) (*http.Response, error) { // Remove any headers specified for _, header := range p.StripRequestHeaders { - log.WithField("StripHeader", string(header)).Debug("Stripping Header:") - req.Header.Del(header) + // handle wildcard strip values + if strings.HasSuffix(header, "*") { + prefix := strings.ToLower(header[:len(header)-1]) + for rHeader := range req.Header { + h := strings.ToLower(rHeader) + if strings.HasPrefix(h, prefix) { + log.WithField("StripHeader", string(h)).Debug("Stripping Header:") + req.Header.Del(h) + } + } + } else { + log.WithField("StripHeader", string(header)).Debug("Stripping Header:") + req.Header.Del(header) + } } // Duplicate the header value for any headers specified into a new header diff --git a/handler/proxy_client_test.go b/handler/proxy_client_test.go index edaf1e6..c2d7029 100644 --- a/handler/proxy_client_test.go +++ b/handler/proxy_client_test.go @@ -419,6 +419,65 @@ func TestProxyClient_Do(t *testing.T) { }, }, }, + { + name: "should strip request headers", + request: &http.Request{ + Method: "GET", + URL: &url.URL{}, + Host: "execute-api.us-west-2.amazonaws.com", + Header: http.Header{ + "X-Goodheader": []string{"x"}, + "X-Badheader": []string{"x"}, + "X-Badprefix-1": []string{"x"}, + "X-Badprefix-2": []string{"x"}, + }, + Body: nil, + }, + proxyClient: &ProxyClient{ + Signer: v4.NewSigner(credentials.NewCredentials(&mockProvider{})), + Client: &mockHTTPClient{}, + StripRequestHeaders: []string{ + "X-Badheader", + "X-Badprefix-*", + }, + }, + want: &want{ + resp: &http.Response{}, + err: nil, + request: &http.Request{ + Host: "execute-api.us-west-2.amazonaws.com", + Header: http.Header{ + "X-Goodheader": []string{"x"}, + // Ensure these headers are not present + "X-Badheader": nil, + "X-Badprefix-1": nil, + "X-Badprefix-2": nil, + }, + }, + }, + }, + { + name: "should strip request query params", + request: &http.Request{ + Method: "GET", + URL: &url.URL{RawQuery: "badparam=x&badprefix1=x&goodparam=x"}, + Host: "execute-api.us-west-2.amazonaws.com", + Body: nil, + }, + proxyClient: &ProxyClient{ + Signer: v4.NewSigner(credentials.NewCredentials(&mockProvider{})), + Client: &mockHTTPClient{}, + StripRequestQueryParams: []string{"badparam", "badprefix*"}, + }, + want: &want{ + resp: &http.Response{}, + err: nil, + request: &http.Request{ + Host: "execute-api.us-west-2.amazonaws.com", + URL: &url.URL{RawQuery: "goodparam=x"}, + }, + }, + }, } for _, tt := range tests { @@ -438,8 +497,16 @@ func TestProxyClient_Do(t *testing.T) { // Ensure specific headers are propagated (or not in certain cases) to the proxy request for kk, vv := range tt.want.request.Header { + if vv == nil && proxyRequest.Header.Get(kk) != "" { + t.Logf("got unexpected header key %q in proxy request", kk) + t.Fail() + } assert.Equal(t, vv, proxyRequest.Header[kk]) } + // Ensure query parameters match + if tt.want.request.URL != nil { + assert.Equal(t, tt.want.request.URL.Query(), proxyRequest.URL.Query()) + } // Ensure encoding is propagated to the proxy request. assert.Equal(t, chunked(tt.request.TransferEncoding), chunked(proxyRequest.TransferEncoding))