From 33e9b7754246a89b809f78988016fb316e13f269 Mon Sep 17 00:00:00 2001 From: jo Date: Wed, 19 Jun 2024 15:24:16 +0200 Subject: [PATCH] refactor: deduplicate clone request code --- hcloud/client_handler.go | 15 +++++++++++++++ hcloud/client_handler_debug.go | 10 +++------- hcloud/client_handler_retry.go | 10 +++------- hcloud/client_handler_test.go | 28 ++++++++++++++++++++++++++++ 4 files changed, 49 insertions(+), 14 deletions(-) diff --git a/hcloud/client_handler.go b/hcloud/client_handler.go index db4e56b9..5e7321e9 100644 --- a/hcloud/client_handler.go +++ b/hcloud/client_handler.go @@ -1,6 +1,7 @@ package hcloud import ( + "context" "net/http" ) @@ -39,3 +40,17 @@ func assembleHandlerChain(client *Client) handler { return h } + +// cloneRequest clones both the request and the request body. +func cloneRequest(req *http.Request, ctx context.Context) (cloned *http.Request, err error) { //revive:disable:context-as-argument + cloned = req.Clone(ctx) + + if req.ContentLength > 0 { + cloned.Body, err = req.GetBody() + if err != nil { + return nil, err + } + } + + return cloned, nil +} diff --git a/hcloud/client_handler_debug.go b/hcloud/client_handler_debug.go index d2459947..4aa867db 100644 --- a/hcloud/client_handler_debug.go +++ b/hcloud/client_handler_debug.go @@ -20,13 +20,9 @@ type debugHandler struct { func (h *debugHandler) Do(req *http.Request, v any) (resp *Response, err error) { // Clone the request, so we can redact the auth header, read the body // and use a new context. - cloned := req.Clone(context.Background()) - - if req.ContentLength > 0 { - cloned.Body, err = req.GetBody() - if err != nil { - return nil, err - } + cloned, err := cloneRequest(req, context.Background()) + if err != nil { + return nil, err } cloned.Header.Set("Authorization", "REDACTED") diff --git a/hcloud/client_handler_retry.go b/hcloud/client_handler_retry.go index a6317ec4..52347eb6 100644 --- a/hcloud/client_handler_retry.go +++ b/hcloud/client_handler_retry.go @@ -19,13 +19,9 @@ func (h *retryHandler) Do(req *http.Request, v any) (resp *Response, err error) for { // Clone the request using the original context - cloned := req.Clone(req.Context()) - - if req.ContentLength > 0 { - cloned.Body, err = req.GetBody() - if err != nil { - return nil, err - } + cloned, err := cloneRequest(req, req.Context()) + if err != nil { + return nil, err } resp, err = h.handler.Do(cloned, v) diff --git a/hcloud/client_handler_test.go b/hcloud/client_handler_test.go index 37c0c4eb..9a9ca83d 100644 --- a/hcloud/client_handler_test.go +++ b/hcloud/client_handler_test.go @@ -1,10 +1,14 @@ package hcloud import ( + "bytes" + "context" + "io" "net/http" "net/http/httptest" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -33,3 +37,27 @@ func fakeResponse(t *testing.T, statusCode int, body string, json bool) *Respons return resp } + +func TestCloneRequest(t *testing.T) { + ctx := context.Background() + + req, err := http.NewRequest("GET", "/", bytes.NewBufferString("Hello")) + require.NoError(t, err) + req.Header.Set("Authorization", "sensitive") + + cloned, err := cloneRequest(req, ctx) + require.NoError(t, err) + cloned.Header.Set("Authorization", "REDACTED") + cloned.Body = io.NopCloser(bytes.NewBufferString("Changed")) + + // Check context + assert.Equal(t, req.Context(), cloned.Context()) + + // Check headers + assert.Equal(t, req.Header.Get("Authorization"), "sensitive") + + // Check body + reqBody, err := io.ReadAll(req.Body) + require.NoError(t, err) + assert.Equal(t, string(reqBody), "Hello") +}