From 72c8512091621fabb34620ba364970ef6c087110 Mon Sep 17 00:00:00 2001 From: obanby Date: Tue, 15 Oct 2024 11:54:03 -0400 Subject: [PATCH] Added validation to avoid partial configuration for sigv4 --- pkg/remotewrite/config.go | 18 ++++++++++++++ pkg/sigv4/tripper.go | 29 ++++++++++++----------- pkg/sigv4/tripper_test.go | 49 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 82 insertions(+), 14 deletions(-) diff --git a/pkg/remotewrite/config.go b/pkg/remotewrite/config.go index 46c7f56..98360da 100644 --- a/pkg/remotewrite/config.go +++ b/pkg/remotewrite/config.go @@ -3,6 +3,7 @@ package remotewrite import ( "crypto/tls" "encoding/json" + "errors" "fmt" "github.com/grafana/xk6-output-prometheus-remote/pkg/sigv4" "net/http" @@ -123,6 +124,14 @@ func (conf Config) RemoteConfig() (*remote.HTTPConfig, error) { hc.TLSConfig.Certificates = []tls.Certificate{cert} } + if isSigV4PartiallyConfigured(conf.SigV4Region, conf.SigV4AccessKey, conf.SigV4SecretKey) { + return nil, errors.New( + "sigv4 seems to be partially configured. All of " + + "K6_PROMETHEUS_RW_SIGV4_REGION, K6_PROMETHEUS_RW_SIGV4_ACCESS_KEY, K6_PROMETHEUS_RW_SIGV4_SECRET_KEY " + + "must all be set. Unset all to bypass sigv4", + ) + } + if conf.SigV4Region.Valid && conf.SigV4AccessKey.Valid && conf.SigV4SecretKey.Valid { hc.SigV4 = &sigv4.Config{ Region: conf.SigV4Region.String, @@ -429,3 +438,12 @@ func parseArg(text string) (Config, error) { return c, nil } + +func isSigV4PartiallyConfigured(region, accessKey, secretKey null.String) bool { + hasRegion := region.Valid && len(strings.TrimSpace(region.String)) != 0 + hasAccessID := accessKey.Valid && len(strings.TrimSpace(accessKey.String)) != 0 + hasSecretAccessKey := secretKey.Valid && len(strings.TrimSpace(secretKey.String)) != 0 + // either they are all set, or all not set. False if partial + isComplete := (hasRegion && hasAccessID && hasSecretAccessKey) || (!hasRegion && !hasAccessID && !hasSecretAccessKey) + return !isComplete +} diff --git a/pkg/sigv4/tripper.go b/pkg/sigv4/tripper.go index f7ed9cc..2874a56 100644 --- a/pkg/sigv4/tripper.go +++ b/pkg/sigv4/tripper.go @@ -14,31 +14,32 @@ type Tripper struct { type Config struct { Region string - AwsSecretAccessKey string AwsAccessKeyID string + AwsSecretAccessKey string } -func NewRoundTripper(config *Config, next http.RoundTripper) (*Tripper, error) { - if config == nil { - return nil, errors.New("can't initialize a sigv4 round tripper with nil config") - } - - if len(strings.TrimSpace(config.Region)) == 0 { - return nil, errors.New("sigV4 config `Region` must be set") +func (c *Config) validate() error { + if c == nil { + return errors.New("config should not be nil") } - - if len(strings.TrimSpace(config.AwsSecretAccessKey)) == 0 { - return nil, errors.New("sigV4 config `AwsSecretAccessKey` must be set") + hasRegion := len(strings.TrimSpace(c.Region)) != 0 + hasAccessID := len(strings.TrimSpace(c.AwsAccessKeyID)) != 0 + hasSecretAccessKey := len(strings.TrimSpace(c.AwsSecretAccessKey)) != 0 + if !hasRegion || !hasAccessID || !hasSecretAccessKey { + return errors.New("sigV4 config `Region`, `AwsAccessKeyID`, `AwsSecretAccessKey` must all be set") } + return nil +} - if len(strings.TrimSpace(config.AwsAccessKeyID)) == 0 { - return nil, errors.New("sigV4 config `AwsAccessKeyID` must be set") +func NewRoundTripper(config *Config, next http.RoundTripper) (*Tripper, error) { + if err := config.validate(); err != nil { + return nil, err } if next == nil { next = http.DefaultTransport } - + tripper := &Tripper{ config: config, next: next, diff --git a/pkg/sigv4/tripper_test.go b/pkg/sigv4/tripper_test.go index 7bbd77d..c4aa883 100644 --- a/pkg/sigv4/tripper_test.go +++ b/pkg/sigv4/tripper_test.go @@ -43,3 +43,52 @@ func TestTripper_request_includes_required_headers(t *testing.T) { client.Do(req) } + +func TestConfig_Validation(t *testing.T) { + testCases := []struct { + shouldError bool + arg *Config + }{ + { + shouldError: false, + arg: &Config{ + Region: "us-east1", + AwsAccessKeyID: "someAccessKey", + AwsSecretAccessKey: "someSecretKey", + }, + }, + { + shouldError: true, + arg: nil, + }, + { + shouldError: true, + arg: &Config{ + Region: "us-east1", + }, + }, + { + shouldError: true, + arg: &Config{ + Region: "us-east1", + AwsAccessKeyID: "someAccessKeyId", + }, + }, + { + shouldError: true, + arg: &Config{ + AwsAccessKeyID: "SomeAccessKey", + AwsSecretAccessKey: "SomeSecretKey", + }, + }, + } + + for _, tc := range testCases { + got := tc.arg.validate() + if tc.shouldError { + assert.Error(t, got) + continue + } + assert.NoError(t, got) + } +}