diff --git a/httpclient/config.go b/httpclient/config.go new file mode 100644 index 0000000..534b6c9 --- /dev/null +++ b/httpclient/config.go @@ -0,0 +1,400 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package httpclient + +import ( + "errors" + "github.com/acronis/go-appkit/config" + "github.com/acronis/go-appkit/retry" + "github.com/cenkalti/backoff/v4" + "time" +) + +const ( + // DefaultClientWaitTimeout is a default timeout for a client to wait for a request. + DefaultClientWaitTimeout = 10 * time.Second + + // RetryPolicyExponential is a policy for exponential retries. + RetryPolicyExponential = "exponential" + + // RetryPolicyConstant is a policy for constant retries. + RetryPolicyConstant = "constant" + + // configuration properties + cfgKeyRetriesEnabled = "retries.enabled" + cfgKeyRetriesMax = "retries.maxAttempts" + cfgKeyRetriesPolicyStrategy = "retries.policy.strategy" + cfgKeyRetriesPolicyExponentialInitialInterval = "retries.policy.exponentialBackoffInitialInterval" + cfgKeyRetriesPolicyExponentialMultiplier = "retries.policy.exponentialBackoffMultiplier" + cfgKeyRetriesPolicyConstantInternal = "retries.policy.constantBackoffInterval" + cfgKeyRateLimitsEnabled = "rateLimits.enabled" + cfgKeyRateLimitsLimit = "rateLimits.limit" + cfgKeyRateLimitsBurst = "rateLimits.burst" + cfgKeyRateLimitsWaitTimeout = "rateLimits.waitTimeout" + cfgKeyLoggerEnabled = "logger.enabled" + cfgKeyLoggerMode = "logger.mode" + cfgKeyLoggerSlowRequestThreshold = "logger.slowRequestThreshold" + cfgKeyMetricsEnabled = "metrics.enabled" + cfgKeyTimeout = "timeout" +) + +var _ config.Config = (*Config)(nil) +var _ config.KeyPrefixProvider = (*Config)(nil) + +// RateLimitConfig represents configuration options for HTTP client rate limits. +type RateLimitConfig struct { + // Enabled is a flag that enables rate limiting. + Enabled bool `mapstructure:"enabled"` + + // Limit is the maximum number of requests that can be made. + Limit int `mapstructure:"limit"` + + // Burst allow temporary spikes in request rate. + Burst int `mapstructure:"burst"` + + // WaitTimeout is the maximum time to wait for a request to be made. + WaitTimeout time.Duration `mapstructure:"waitTimeout"` +} + +// Set is part of config interface implementation. +func (c *RateLimitConfig) Set(dp config.DataProvider) (err error) { + enabled, err := dp.GetBool(cfgKeyRateLimitsEnabled) + if err != nil { + return err + } + c.Enabled = enabled + + if !c.Enabled { + return nil + } + + limit, err := dp.GetInt(cfgKeyRateLimitsLimit) + if err != nil { + return err + } + if limit <= 0 { + return errors.New("client rate limit must be positive") + } + c.Limit = limit + + burst, err := dp.GetInt(cfgKeyRateLimitsBurst) + if err != nil { + return err + } + if burst < 0 { + return errors.New("client burst must be positive") + } + c.Burst = burst + + waitTimeout, err := dp.GetDuration(cfgKeyRateLimitsWaitTimeout) + if err != nil { + return err + } + if waitTimeout < 0 { + return errors.New("client wait timeout must be positive") + } + c.WaitTimeout = waitTimeout + + return nil +} + +// SetProviderDefaults is part of config interface implementation. +func (c *RateLimitConfig) SetProviderDefaults(_ config.DataProvider) {} + +// TransportOpts returns transport options. +func (c *RateLimitConfig) TransportOpts() RateLimitingRoundTripperOpts { + return RateLimitingRoundTripperOpts{ + Burst: c.Burst, + WaitTimeout: c.WaitTimeout, + } +} + +// PolicyConfig represents configuration options for policy retry. +type PolicyConfig struct { + // Strategy is a strategy for retry policy. + Strategy string `mapstructure:"strategy"` + + // ExponentialBackoffInitialInterval is the initial interval for exponential backoff. + ExponentialBackoffInitialInterval time.Duration `mapstructure:"exponentialBackoffInitialInterval"` + + // ExponentialBackoffMultiplier is the multiplier for exponential backoff. + ExponentialBackoffMultiplier float64 `mapstructure:"exponentialBackoffMultiplier"` + + // ConstantBackoffInterval is the interval for constant backoff. + ConstantBackoffInterval time.Duration `mapstructure:"constantBackoffInterval"` +} + +// Set is part of config interface implementation. +func (c *PolicyConfig) Set(dp config.DataProvider) (err error) { + strategy, err := dp.GetString(cfgKeyRetriesPolicyStrategy) + if err != nil { + return err + } + c.Strategy = strategy + + if c.Strategy != "" && c.Strategy != RetryPolicyExponential && c.Strategy != RetryPolicyConstant { + return errors.New("client retry policy must be one of: [exponential, constant]") + } + + if c.Strategy == RetryPolicyExponential { + var interval time.Duration + interval, err = dp.GetDuration(cfgKeyRetriesPolicyExponentialInitialInterval) + if err != nil { + return nil + } + if interval < 0 { + return errors.New("client exponential backoff initial interval must be positive") + } + c.ExponentialBackoffInitialInterval = interval + + var multiplier float64 + multiplier, err = dp.GetFloat64(cfgKeyRetriesPolicyExponentialMultiplier) + if err != nil { + return err + } + if multiplier <= 1 { + return errors.New("client exponential backoff multiplier must be greater than 1") + } + c.ExponentialBackoffMultiplier = multiplier + + return nil + } else if c.Strategy == RetryPolicyConstant { + var interval time.Duration + interval, err = dp.GetDuration(cfgKeyRetriesPolicyConstantInternal) + if err != nil { + return err + } + if interval < 0 { + return errors.New("client constant backoff interval must be positive") + } + c.ConstantBackoffInterval = interval + } + + return nil +} + +// SetProviderDefaults is part of config interface implementation. +func (c *PolicyConfig) SetProviderDefaults(_ config.DataProvider) {} + +// RetriesConfig represents configuration options for HTTP client retries policy. +type RetriesConfig struct { + // Enabled is a flag that enables retries. + Enabled bool `mapstructure:"enabled"` + + // MaxAttempts is the maximum number of attempts to retry the request. + MaxAttempts int `mapstructure:"maxAttempts"` + + // Policy of a retry: [exponential, constant]. default is exponential. + Policy PolicyConfig `mapstructure:"policy"` +} + +// GetPolicy returns a retry policy based on strategy or nil if none is provided. +func (c *RetriesConfig) GetPolicy() retry.Policy { + if c.Policy.Strategy == RetryPolicyExponential { + return retry.PolicyFunc(func() backoff.BackOff { + bf := backoff.NewExponentialBackOff() + bf.InitialInterval = c.Policy.ExponentialBackoffInitialInterval + bf.Multiplier = c.Policy.ExponentialBackoffMultiplier + bf.Reset() + return bf + }) + } else if c.Policy.Strategy == RetryPolicyConstant { + return retry.PolicyFunc(func() backoff.BackOff { + bf := backoff.NewConstantBackOff(c.Policy.ConstantBackoffInterval) + bf.Reset() + return bf + }) + } + + return nil +} + +// Set is part of config interface implementation. +func (c *RetriesConfig) Set(dp config.DataProvider) error { + enabled, err := dp.GetBool(cfgKeyRetriesEnabled) + if err != nil { + return err + } + c.Enabled = enabled + + if !c.Enabled { + return nil + } + + maxAttempts, err := dp.GetInt(cfgKeyRetriesMax) + if err != nil { + return err + } + if maxAttempts < 0 { + return errors.New("client max retry attempts must be positive") + } + c.MaxAttempts = maxAttempts + + err = c.Policy.Set(config.NewKeyPrefixedDataProvider(dp, "")) + if err != nil { + return err + } + + return nil +} + +// SetProviderDefaults is part of config interface implementation. +func (c *RetriesConfig) SetProviderDefaults(_ config.DataProvider) {} + +// TransportOpts returns transport options. +func (c *RetriesConfig) TransportOpts() RetryableRoundTripperOpts { + return RetryableRoundTripperOpts{MaxRetryAttempts: c.MaxAttempts} +} + +// LoggerConfig represents configuration options for HTTP client logs. +type LoggerConfig struct { + // Enabled is a flag that enables logging. + Enabled bool `mapstructure:"enabled"` + + // SlowRequestThreshold is a threshold for slow requests. + SlowRequestThreshold time.Duration `mapstructure:"slowRequestThreshold"` + + // Mode of logging. + Mode string `mapstructure:"mode"` +} + +// Set is part of config interface implementation. +func (c *LoggerConfig) Set(dp config.DataProvider) error { + enabled, err := dp.GetBool(cfgKeyLoggerEnabled) + if err != nil { + return err + } + c.Enabled = enabled + + if !c.Enabled { + return nil + } + + slowRequestThreshold, err := dp.GetDuration(cfgKeyLoggerSlowRequestThreshold) + if err != nil { + return err + } + if slowRequestThreshold < 0 { + return errors.New("client logger slow request threshold can not be negative") + } + c.SlowRequestThreshold = slowRequestThreshold + + mode, err := dp.GetString(cfgKeyLoggerMode) + if err != nil { + return err + } + if !LoggerMode(mode).IsValid() { + return errors.New("client logger invalid mode, choose one of: [none, all, failed]") + } + c.Mode = mode + + return nil +} + +// SetProviderDefaults is part of config interface implementation. +func (c *LoggerConfig) SetProviderDefaults(_ config.DataProvider) {} + +// TransportOpts returns transport options. +func (c *LoggerConfig) TransportOpts() LoggingRoundTripperOpts { + return LoggingRoundTripperOpts{ + Mode: c.Mode, + SlowRequestThreshold: c.SlowRequestThreshold, + } +} + +// MetricsConfig represents configuration options for HTTP client logs. +type MetricsConfig struct { + // Enabled is a flag that enables metrics. + Enabled bool `mapstructure:"enabled"` +} + +// Set is part of config interface implementation. +func (c *MetricsConfig) Set(dp config.DataProvider) error { + enabled, err := dp.GetBool(cfgKeyMetricsEnabled) + if err != nil { + return err + } + c.Enabled = enabled + + return nil +} + +// SetProviderDefaults is part of config interface implementation. +func (c *MetricsConfig) SetProviderDefaults(_ config.DataProvider) {} + +// Config represents options for HTTP client configuration. +type Config struct { + // Retries is a configuration for HTTP client retries policy. + Retries RetriesConfig `mapstructure:"retries"` + + // RateLimits is a configuration for HTTP client rate limits. + RateLimits RateLimitConfig `mapstructure:"rateLimits"` + + // Logger is a configuration for HTTP client logs. + Logger LoggerConfig `mapstructure:"logger"` + + // Metrics is a configuration for HTTP client metrics. + Metrics MetricsConfig `mapstructure:"metrics"` + + // Timeout is the maximum time to wait for a request to be made. + Timeout time.Duration `mapstructure:"timeout"` + + // keyPrefix is a prefix for configuration parameters. + keyPrefix string +} + +// NewConfig creates a new instance of the Config. +func NewConfig() *Config { + return NewConfigWithKeyPrefix("") +} + +// NewConfigWithKeyPrefix creates a new instance of the Config. +// Allows specifying key prefix which will be used for parsing configuration parameters. +func NewConfigWithKeyPrefix(keyPrefix string) *Config { + return &Config{keyPrefix: keyPrefix} +} + +// KeyPrefix returns a key prefix with which all configuration parameters should be presented. +func (c *Config) KeyPrefix() string { + return c.keyPrefix +} + +// Set is part of config interface implementation. +func (c *Config) Set(dp config.DataProvider) error { + timeout, err := dp.GetDuration(cfgKeyTimeout) + if err != nil { + return err + } + c.Timeout = timeout + + err = c.Retries.Set(config.NewKeyPrefixedDataProvider(dp, c.keyPrefix)) + if err != nil { + return err + } + + err = c.RateLimits.Set(config.NewKeyPrefixedDataProvider(dp, c.keyPrefix)) + if err != nil { + return err + } + + err = c.Logger.Set(config.NewKeyPrefixedDataProvider(dp, c.keyPrefix)) + if err != nil { + return err + } + + err = c.Metrics.Set(config.NewKeyPrefixedDataProvider(dp, c.keyPrefix)) + if err != nil { + return err + } + + return nil +} + +// SetProviderDefaults is part of config interface implementation. +func (c *Config) SetProviderDefaults(_ config.DataProvider) { + c.Timeout = DefaultClientWaitTimeout +} diff --git a/httpclient/config_test.go b/httpclient/config_test.go new file mode 100644 index 0000000..3cbe31a --- /dev/null +++ b/httpclient/config_test.go @@ -0,0 +1,180 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package httpclient + +import ( + "bytes" + "github.com/acronis/go-appkit/config" + "github.com/acronis/go-appkit/retry" + "github.com/stretchr/testify/require" + "strings" + "testing" + "time" +) + +func TestConfigWithLoader(t *testing.T) { + yamlData := testYamlData(nil) + actualConfig := &Config{} + err := config.NewDefaultLoader("").LoadFromReader(bytes.NewReader(yamlData), config.DataTypeYAML, actualConfig) + require.NoError(t, err, "load configuration") + + expectedConfig := &Config{ + Retries: RetriesConfig{ + Enabled: true, + MaxAttempts: 30, + Policy: PolicyConfig{ + Strategy: RetryPolicyExponential, + ExponentialBackoffInitialInterval: 3 * time.Second, + ExponentialBackoffMultiplier: 2, + }, + }, + RateLimits: RateLimitConfig{ + Enabled: true, + Limit: 300, + Burst: 3000, + WaitTimeout: 3 * time.Second, + }, + Logger: LoggerConfig{ + Enabled: true, + SlowRequestThreshold: 5 * time.Second, + Mode: "all", + }, + Metrics: MetricsConfig{Enabled: true}, + Timeout: 30 * time.Second, + } + + require.Equal(t, expectedConfig, actualConfig, "configuration does not match expected") +} + +func TestConfigRateLimit(t *testing.T) { + yamlData := testYamlData([][]string{{"limit: 300", "limit: -300"}}) + actualConfig := &Config{} + err := config.NewDefaultLoader("").LoadFromReader(bytes.NewReader(yamlData), config.DataTypeYAML, actualConfig) + require.Error(t, err) + require.Equal(t, "client rate limit must be positive", err.Error()) + + yamlData = testYamlData([][]string{{"burst: 3000", "burst: -3"}}) + actualConfig = &Config{} + err = config.NewDefaultLoader("").LoadFromReader(bytes.NewReader(yamlData), config.DataTypeYAML, actualConfig) + require.Error(t, err) + require.Equal(t, "client burst must be positive", err.Error()) + + yamlData = testYamlData([][]string{{"waitTimeout: 3s", "waitTimeout: -3s"}}) + actualConfig = &Config{} + err = config.NewDefaultLoader("").LoadFromReader(bytes.NewReader(yamlData), config.DataTypeYAML, actualConfig) + require.Error(t, err) + require.Equal(t, "client wait timeout must be positive", err.Error()) +} + +func TestConfigRetries(t *testing.T) { + yamlData := testYamlData([][]string{{"maxAttempts: 30", "maxAttempts: -30"}}) + actualConfig := &Config{} + err := config.NewDefaultLoader("").LoadFromReader(bytes.NewReader(yamlData), config.DataTypeYAML, actualConfig) + require.Error(t, err) + require.Equal(t, "client max retry attempts must be positive", err.Error()) +} + +func TestConfigLogger(t *testing.T) { + yamlData := testYamlData([][]string{{"slowRequestThreshold: 5s", "slowRequestThreshold: -5s"}}) + actualConfig := &Config{} + err := config.NewDefaultLoader("").LoadFromReader(bytes.NewReader(yamlData), config.DataTypeYAML, actualConfig) + require.Error(t, err) + require.Equal(t, "client logger slow request threshold can not be negative", err.Error()) + + yamlData = testYamlData([][]string{{"mode: all", "mode: invalid"}}) + actualConfig = &Config{} + err = config.NewDefaultLoader("").LoadFromReader(bytes.NewReader(yamlData), config.DataTypeYAML, actualConfig) + require.Error(t, err) + require.Equal(t, "client logger invalid mode, choose one of: [none, all, failed]", err.Error()) +} + +func TestConfigRetriesPolicy(t *testing.T) { + yamlData := testYamlData([][]string{{"strategy: exponential", "strategy: invalid"}}) + actualConfig := &Config{} + err := config.NewDefaultLoader("").LoadFromReader(bytes.NewReader(yamlData), config.DataTypeYAML, actualConfig) + require.Error(t, err) + require.Equal(t, "client retry policy must be one of: [exponential, constant]", err.Error()) + + yamlData = testYamlData([][]string{ + {"exponentialBackoffInitialInterval: 3s", "exponentialBackoffInitialInterval: -1s"}, + }) + err = config.NewDefaultLoader("").LoadFromReader(bytes.NewReader(yamlData), config.DataTypeYAML, actualConfig) + require.Error(t, err) + require.Equal(t, "client exponential backoff initial interval must be positive", err.Error()) + + yamlData = testYamlData([][]string{{"exponentialBackoffMultiplier: 2", "exponentialBackoffMultiplier: 1"}}) + err = config.NewDefaultLoader("").LoadFromReader(bytes.NewReader(yamlData), config.DataTypeYAML, actualConfig) + require.Error(t, err) + require.Equal(t, "client exponential backoff multiplier must be greater than 1", err.Error()) + + yamlData = testYamlData([][]string{ + {"strategy: exponential", "strategy: constant"}, + {"constantBackoffInterval: 2s", "constantBackoffInterval: -3s"}, + }) + err = config.NewDefaultLoader("").LoadFromReader(bytes.NewReader(yamlData), config.DataTypeYAML, actualConfig) + require.Error(t, err) + require.Equal(t, "client constant backoff interval must be positive", err.Error()) + + yamlData = testYamlData([][]string{ + {"strategy: exponential", "strategy:"}, + }) + err = config.NewDefaultLoader("").LoadFromReader(bytes.NewReader(yamlData), config.DataTypeYAML, actualConfig) + require.NoError(t, err) + require.Nil(t, actualConfig.Retries.GetPolicy()) + + yamlData = testYamlData(nil) + err = config.NewDefaultLoader("").LoadFromReader(bytes.NewReader(yamlData), config.DataTypeYAML, actualConfig) + require.NoError(t, err) + require.Implements(t, (*retry.Policy)(nil), actualConfig.Retries.GetPolicy()) +} + +func TestConfigDisableWithLoader(t *testing.T) { + yamlData := []byte(` +retries: + enabled: false +rateLimits: + enabled: false +logger: + enabled: false +metrics: + enabled: false +timeout: 30s +`) + actualConfig := &Config{} + err := config.NewDefaultLoader("").LoadFromReader(bytes.NewReader(yamlData), config.DataTypeYAML, actualConfig) + require.NoError(t, err) +} + +func testYamlData(replacements [][]string) []byte { + yamlData := ` +retries: + enabled: true + maxAttempts: 30 + policy: + strategy: exponential + exponentialBackoffInitialInterval: 3s + exponentialBackoffMultiplier: 2 + constantBackoffInterval: 2s +rateLimits: + enabled: true + limit: 300 + burst: 3000 + waitTimeout: 3s +logger: + enabled: true + slowRequestThreshold: 5s + mode: all +metrics: + enabled: true +timeout: 30s +` + for i := range replacements { + yamlData = strings.Replace(yamlData, replacements[i][0], replacements[i][1], 1) + } + + return []byte(yamlData) +} diff --git a/httpclient/httpclient.go b/httpclient/httpclient.go index b293d7f..b41ca28 100644 --- a/httpclient/httpclient.go +++ b/httpclient/httpclient.go @@ -6,7 +6,12 @@ Released under MIT license. package httpclient -import "net/http" +import ( + "context" + "fmt" + "github.com/acronis/go-appkit/log" + "net/http" +) // CloneHTTPRequest creates a shallow copy of the request along with a deep copy of the Headers. func CloneHTTPRequest(req *http.Request) *http.Request { @@ -26,3 +31,113 @@ func CloneHTTPHeader(in http.Header) http.Header { } return out } + +// ClientProviders for further customization of the client logging and request id. +type ClientProviders struct { + // Logger is a function that provides a context-specific logger. + Logger func(ctx context.Context) log.FieldLogger + + // RequestID is a function that provides a request ID. + RequestID func(ctx context.Context) string +} + +// NewHTTPClient wraps delegate transports with logging, metrics, rate limiting, retryable, user agent, request id +// and returns an error if any occurs. +func NewHTTPClient( + cfg *Config, + userAgent string, + reqType string, + delegate http.RoundTripper, + providers ClientProviders, +) (*http.Client, error) { + var err error + + if delegate == nil { + delegate = http.DefaultTransport.(*http.Transport).Clone() + } + + if cfg.Logger.Enabled { + opts := cfg.Logger.TransportOpts() + opts.LoggerProvider = providers.Logger + delegate = NewLoggingRoundTripperWithOpts(delegate, reqType, opts) + } + + if cfg.Metrics.Enabled { + delegate = NewMetricsRoundTripper(delegate, reqType) + } + + if cfg.RateLimits.Enabled { + delegate, err = NewRateLimitingRoundTripperWithOpts( + delegate, cfg.RateLimits.Limit, cfg.RateLimits.TransportOpts(), + ) + if err != nil { + return nil, fmt.Errorf("create rate limiting round tripper: %w", err) + } + } + + if cfg.Retries.Enabled { + opts := cfg.Retries.TransportOpts() + opts.LoggerProvider = providers.Logger + opts.BackoffPolicy = cfg.Retries.GetPolicy() + delegate, err = NewRetryableRoundTripperWithOpts(delegate, opts) + if err != nil { + return nil, fmt.Errorf("create retryable round tripper: %w", err) + } + } + + delegate = NewUserAgentRoundTripper(delegate, userAgent) + delegate = NewRequestIDRoundTripperWithOpts(delegate, RequestIDRoundTripperOpts{ + RequestIDProvider: providers.RequestID, + }) + + return &http.Client{Transport: delegate, Timeout: cfg.Timeout}, nil +} + +// MustHTTPClient wraps delegate transports with logging, metrics, rate limiting, retryable, user agent, request id +// and panics if any error occurs. +func MustHTTPClient( + cfg *Config, + userAgent, + reqType string, + delegate http.RoundTripper, + providers ClientProviders, +) *http.Client { + client, err := NewHTTPClient(cfg, userAgent, reqType, delegate, providers) + if err != nil { + panic(err) + } + + return client +} + +// ClientOpts provides options for NewHTTPClientWithOpts and MustHTTPClientWithOpts functions. +type ClientOpts struct { + // Config is the configuration for the HTTP client. + Config Config + + // UserAgent is a user agent string. + UserAgent string + + // ReqType is a type of request. + ReqType string + + // Delegate is the next RoundTripper in the chain. + Delegate http.RoundTripper + + // Providers are the functions that provide a context-specific logger and request ID. + Providers ClientProviders +} + +// NewHTTPClientWithOpts wraps delegate transports with options +// logging, metrics, rate limiting, retryable, user agent, request id +// and returns an error if any occurs. +func NewHTTPClientWithOpts(opts ClientOpts) (*http.Client, error) { + return NewHTTPClient(&opts.Config, opts.UserAgent, opts.ReqType, opts.Delegate, opts.Providers) +} + +// MustHTTPClientWithOpts wraps delegate transports with options +// logging, metrics, rate limiting, retryable, user agent, request id +// and panics if any error occurs. +func MustHTTPClientWithOpts(opts ClientOpts) *http.Client { + return MustHTTPClient(&opts.Config, opts.UserAgent, opts.ReqType, opts.Delegate, opts.Providers) +} diff --git a/httpclient/httpclient_test.go b/httpclient/httpclient_test.go new file mode 100644 index 0000000..6f94e87 --- /dev/null +++ b/httpclient/httpclient_test.go @@ -0,0 +1,147 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package httpclient + +import ( + "context" + "github.com/acronis/go-appkit/httpserver/middleware" + "github.com/acronis/go-appkit/log/logtest" + "github.com/stretchr/testify/require" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestNewHTTPClientLoggingRoundTripper(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusTeapot) + })) + defer server.Close() + + logger := logtest.NewRecorder() + cfg := NewConfig() + cfg.Logger.Enabled = true + client, err := NewHTTPClient(cfg, "test-agent", "test-request", nil, ClientProviders{}) + require.NoError(t, err) + + ctx := middleware.NewContextWithLogger(context.Background(), logger) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, nil) + require.NoError(t, err) + + r, err := client.Do(req) + defer func() { _ = r.Body.Close() }() + require.NoError(t, err) + require.NotEmpty(t, logger.Entries()) + + loggerEntry := logger.Entries()[0] + require.Contains(t, loggerEntry.Text, "client http request POST "+server.URL+" req type test-request status code 418") +} + +func TestMustHTTPClientLoggingRoundTripper(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusTeapot) + })) + defer server.Close() + + logger := logtest.NewRecorder() + cfg := NewConfig() + cfg.Logger.Enabled = true + client := MustHTTPClient(cfg, "test-agent", "test-request", nil, ClientProviders{}) + ctx := middleware.NewContextWithLogger(context.Background(), logger) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, nil) + require.NoError(t, err) + + r, err := client.Do(req) + defer func() { _ = r.Body.Close() }() + require.NoError(t, err) + require.NotEmpty(t, logger.Entries()) + + loggerEntry := logger.Entries()[0] + require.Contains(t, loggerEntry.Text, "client http request POST "+server.URL+" req type test-request status code 418") +} + +func TestNewHTTPClientWithOptsRoundTripper(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusTeapot) + })) + defer server.Close() + + logger := logtest.NewRecorder() + cfg := NewConfig() + cfg.Logger.Enabled = true + client, err := NewHTTPClientWithOpts(ClientOpts{ + Config: *cfg, + UserAgent: "test-agent", + ReqType: "test-request", + Delegate: nil, + }) + require.NoError(t, err) + ctx := middleware.NewContextWithLogger(context.Background(), logger) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, nil) + require.NoError(t, err) + + r, err := client.Do(req) + defer func() { _ = r.Body.Close() }() + require.NoError(t, err) + require.NotEmpty(t, logger.Entries()) +} + +func TestMustHTTPClientWithOptsRoundTripper(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusTeapot) + })) + defer server.Close() + + logger := logtest.NewRecorder() + cfg := NewConfig() + cfg.Logger.Enabled = true + client := MustHTTPClientWithOpts(ClientOpts{ + Config: *cfg, + UserAgent: "test-agent", + ReqType: "test-request", + Delegate: nil, + }) + ctx := middleware.NewContextWithLogger(context.Background(), logger) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, nil) + require.NoError(t, err) + + r, err := client.Do(req) + defer func() { _ = r.Body.Close() }() + require.NoError(t, err) + require.NotEmpty(t, logger.Entries()) +} + +func TestMustHTTPClientWithOptsRoundTripperPolicy(t *testing.T) { + var retriesCount int + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + retriesCount++ + rw.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + cfg := NewConfig() + cfg.Retries.Enabled = true + cfg.Retries.MaxAttempts = 1 + cfg.Retries.Policy.Strategy = RetryPolicyExponential + cfg.Retries.Policy.ExponentialBackoffInitialInterval = 2 * time.Millisecond + cfg.Retries.Policy.ExponentialBackoffMultiplier = 1.1 + + client := MustHTTPClientWithOpts(ClientOpts{ + Config: *cfg, + UserAgent: "test-agent", + ReqType: "test-request", + Delegate: nil, + }) + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, server.URL, nil) + require.NoError(t, err) + + r, err := client.Do(req) + defer func() { _ = r.Body.Close() }() + require.NoError(t, err) + require.Equal(t, 2, retriesCount) +} diff --git a/httpclient/logging_round_tripper.go b/httpclient/logging_round_tripper.go new file mode 100644 index 0000000..1fbf661 --- /dev/null +++ b/httpclient/logging_round_tripper.go @@ -0,0 +1,131 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package httpclient + +import ( + "context" + "fmt" + "github.com/acronis/go-appkit/httpserver/middleware" + "github.com/acronis/go-appkit/log" + "net/http" + "time" +) + +// LoggerMode represents a mode of logging. +type LoggerMode string + +const ( + LoggerModeNone LoggerMode = "none" + LoggerModeAll LoggerMode = "all" + LoggerModeFailed LoggerMode = "failed" +) + +// IsValid checks if the logger mode is valid. +func (lm LoggerMode) IsValid() bool { + switch lm { + case LoggerModeNone, LoggerModeAll, LoggerModeFailed: + return true + } + return false +} + +// LoggingRoundTripper implements http.RoundTripper for logging requests. +type LoggingRoundTripper struct { + // Delegate is the next RoundTripper in the chain. + Delegate http.RoundTripper + + // ReqType is a type of request. + ReqType string + + // Opts are the options for the logging round tripper. + Opts LoggingRoundTripperOpts +} + +// LoggingRoundTripperOpts represents an options for LoggingRoundTripper. +type LoggingRoundTripperOpts struct { + // LoggerProvider is a function that provides a context-specific logger. + LoggerProvider func(ctx context.Context) log.FieldLogger + + // Mode of logging: none, all, failed. + Mode string + + // SlowRequestThreshold is a threshold for slow requests. + SlowRequestThreshold time.Duration +} + +// NewLoggingRoundTripper creates an HTTP transport that log requests. +func NewLoggingRoundTripper(delegate http.RoundTripper, reqType string) http.RoundTripper { + return &LoggingRoundTripper{ + Delegate: delegate, + ReqType: reqType, + Opts: LoggingRoundTripperOpts{}, + } +} + +// NewLoggingRoundTripperWithOpts creates an HTTP transport that log requests with options. +func NewLoggingRoundTripperWithOpts( + delegate http.RoundTripper, reqType string, opts LoggingRoundTripperOpts, +) http.RoundTripper { + return &LoggingRoundTripper{ + Delegate: delegate, + ReqType: reqType, + Opts: opts, + } +} + +// getLogger returns a logger from the context or from the options. +func (rt *LoggingRoundTripper) getLogger(ctx context.Context) log.FieldLogger { + if rt.Opts.LoggerProvider != nil { + return rt.Opts.LoggerProvider(ctx) + } + + return middleware.GetLoggerFromContext(ctx) +} + +// RoundTrip adds logging capabilities to the HTTP transport. +func (rt *LoggingRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { + if rt.Opts.Mode == string(LoggerModeNone) { + return rt.Delegate.RoundTrip(r) + } + + ctx := r.Context() + logger := rt.getLogger(ctx) + start := time.Now() + + resp, err := rt.Delegate.RoundTrip(r) + elapsed := time.Since(start) + if logger != nil && elapsed >= rt.Opts.SlowRequestThreshold { + common := "client http request %s %s req type %s " + args := []interface{}{r.Method, r.URL.String(), rt.ReqType, elapsed.Seconds(), err} + message := common + "time taken %.3f, err %+v" + loggerAtLevel := logger.Infof + shouldModeLog := true + + if resp != nil { + if rt.Opts.Mode == string(LoggerModeFailed) && resp.StatusCode < http.StatusBadRequest { + shouldModeLog = false + } + + args = []interface{}{r.Method, r.URL.String(), rt.ReqType, resp.StatusCode, elapsed.Seconds(), err} + message = common + "status code %d, time taken %.3f, err %+v" + } + + if err != nil { + loggerAtLevel = logger.Errorf + } + + if shouldModeLog { + loggerAtLevel(message, args...) + loggingParams := middleware.GetLoggingParamsFromContext(ctx) + if loggingParams != nil { + loggingParams.AddTimeSlotDurationInMs(fmt.Sprintf("external_request_%s_ms", rt.ReqType), elapsed) + } + } + } + + return resp, err +} diff --git a/httpclient/logging_round_tripper_test.go b/httpclient/logging_round_tripper_test.go new file mode 100644 index 0000000..8dc4813 --- /dev/null +++ b/httpclient/logging_round_tripper_test.go @@ -0,0 +1,181 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package httpclient + +import ( + "context" + "github.com/acronis/go-appkit/httpserver/middleware" + "github.com/acronis/go-appkit/log" + "github.com/acronis/go-appkit/log/logtest" + "github.com/stretchr/testify/require" + "net" + "net/http" + "net/http/httptest" + "testing" +) + +func TestNewLoggingRoundTripper(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusTeapot) + })) + defer server.Close() + + logger := logtest.NewRecorder() + loggerRoundTripper := NewLoggingRoundTripper(http.DefaultTransport, "test-request") + client := &http.Client{Transport: loggerRoundTripper} + ctx := middleware.NewContextWithLogger(context.Background(), logger) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, nil) + require.NoError(t, err) + + r, err := client.Do(req) + defer func() { _ = r.Body.Close() }() + require.NoError(t, err) + require.NotEmpty(t, logger.Entries()) + + loggerEntry := logger.Entries()[0] + require.Contains(t, loggerEntry.Text, "client http request POST "+server.URL+" req type test-request status code 418") +} + +func TestMustHTTPClientLoggingRoundTripperError(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + serverURL := "http://" + ln.Addr().String() + _ = ln.Close() + + logger := logtest.NewRecorder() + cfg := NewConfig() + cfg.Logger.Enabled = true + client := MustHTTPClient(cfg, "test-agent", "test-request", nil, ClientProviders{}) + ctx := middleware.NewContextWithLogger(context.Background(), logger) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, serverURL, nil) + require.NoError(t, err) + + r, err := client.Do(req) + require.Error(t, err) + require.Nil(t, r) + require.NotEmpty(t, logger.Entries()) + + loggerEntry := logger.Entries()[0] + require.Contains(t, loggerEntry.Text, "client http request POST "+serverURL+" req type test-request") + require.Contains(t, loggerEntry.Text, "err dial tcp "+ln.Addr().String()+": connect: connection refused") + require.NotContains(t, loggerEntry.Text, "status code") +} + +func TestMustHTTPClientLoggingRoundTripperDisabled(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + serverURL := "http://" + ln.Addr().String() + _ = ln.Close() + + logger := logtest.NewRecorder() + cfg := NewConfig() + client := MustHTTPClient(cfg, "test-agent", "test-request", nil, ClientProviders{}) + ctx := middleware.NewContextWithLogger(context.Background(), logger) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, serverURL, nil) + require.NoError(t, err) + + r, err := client.Do(req) + require.Error(t, err) + require.Nil(t, r) + require.Empty(t, logger.Entries()) +} + +func TestNewLoggingRoundTripperModes(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + rw.WriteHeader(http.StatusBadRequest) + } else { + rw.WriteHeader(http.StatusOK) + } + })) + defer server.Close() + + tests := []struct { + name string + method string + mode string + want int + }{ + { + name: "no requests are logged because of 'none' mode", + method: http.MethodGet, + mode: "none", + }, + { + name: "4xx and 5xx requests are logged because of 'failed' mode", + method: http.MethodPost, + mode: "failed", + want: 1, + }, + { + name: "only 4xx and 5xx requests so no logs because of 'failed' mode for 2xx", + method: http.MethodGet, + mode: "failed", + }, + { + name: "success requests are logged because of 'all' mode", + method: http.MethodGet, + mode: "all", + want: 1, + }, + { + name: "failed requests are logged because of 'all' mode", + method: http.MethodPost, + mode: "all", + want: 1, + }, + } + + for i := range tests { + tt := tests[i] + t.Run(tt.name, func(t *testing.T) { + logger := logtest.NewRecorder() + loggerRoundTripper := NewLoggingRoundTripperWithOpts( + http.DefaultTransport, + "test-request", + LoggingRoundTripperOpts{Mode: tt.mode}, + ) + client := &http.Client{Transport: loggerRoundTripper} + ctx := middleware.NewContextWithLogger(context.Background(), logger) + req, err := http.NewRequestWithContext(ctx, tt.method, server.URL, nil) + require.NoError(t, err) + + r, err := client.Do(req) + defer func() { _ = r.Body.Close() }() + require.NoError(t, err) + require.Len(t, logger.Entries(), tt.want) + }) + } +} + +func TestMustHTTPClientLoggingRoundTripperOpts(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + serverURL := "http://" + ln.Addr().String() + _ = ln.Close() + + logger := logtest.NewRecorder() + cfg := NewConfig() + cfg.Logger.Enabled = true + + var loggerProviderCalled bool + client := MustHTTPClient(cfg, "test-agent", "test-request", nil, ClientProviders{ + Logger: func(ctx context.Context) log.FieldLogger { + loggerProviderCalled = true + return logger + }, + }) + ctx := middleware.NewContextWithLogger(context.Background(), nil) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, serverURL, nil) + require.NoError(t, err) + + r, err := client.Do(req) + require.Error(t, err) + require.Nil(t, r) + require.True(t, loggerProviderCalled) + require.Len(t, logger.Entries(), 1) +} diff --git a/httpclient/metrics_round_tripper.go b/httpclient/metrics_round_tripper.go new file mode 100644 index 0000000..faf27dd --- /dev/null +++ b/httpclient/metrics_round_tripper.go @@ -0,0 +1,86 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package httpclient + +import ( + "fmt" + "github.com/prometheus/client_golang/prometheus" + "net/http" + "strconv" + "time" +) + +var ( + ClientHTTPRequestDuration *prometheus.HistogramVec + ClassifyRequest func(r *http.Request, reqType string) string +) + +// MustInitAndRegisterMetrics initializes and registers external HTTP request duration metric. +// Panic will be raised in case of error. +func MustInitAndRegisterMetrics(namespace string) { + ClientHTTPRequestDuration = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: namespace, + Name: "http_client_request_duration_seconds", + Help: "A histogram of the http client requests durations.", + Buckets: []float64{0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30, 60, 150, 300, 600}, + }, + []string{"type", "remote_address", "summary", "status"}, + ) + prometheus.MustRegister(ClientHTTPRequestDuration) +} + +// UnregisterMetrics unregisters external HTTP request duration metric. +func UnregisterMetrics() { + if ClientHTTPRequestDuration != nil { + prometheus.Unregister(ClientHTTPRequestDuration) + } +} + +// MetricsRoundTripper is an HTTP transport that measures requests done. +type MetricsRoundTripper struct { + // Delegate is the next RoundTripper in the chain. + Delegate http.RoundTripper + + // ReqType is a type of request. + ReqType string +} + +// NewMetricsRoundTripper creates an HTTP transport that measures requests done. +func NewMetricsRoundTripper(delegate http.RoundTripper, reqType string) http.RoundTripper { + return &MetricsRoundTripper{Delegate: delegate, ReqType: reqType} +} + +// RoundTrip measures external requests done. +func (rt *MetricsRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { + if ClientHTTPRequestDuration == nil { + return rt.Delegate.RoundTrip(r) + } + + status := "0" + start := time.Now() + + resp, err := rt.Delegate.RoundTrip(r) + if err == nil && resp != nil { + status = strconv.Itoa(resp.StatusCode) + } + + ClientHTTPRequestDuration.WithLabelValues( + rt.ReqType, r.Host, requestSummary(r, rt.ReqType), status, + ).Observe(time.Since(start).Seconds()) + + return resp, err +} + +// requestSummary does request classification, producing non-parameterized summary for given request. +func requestSummary(r *http.Request, reqType string) string { + if ClassifyRequest != nil { + return ClassifyRequest(r, reqType) + } + + return fmt.Sprintf("%s %s", r.Method, reqType) +} diff --git a/httpclient/metrics_round_tripper_test.go b/httpclient/metrics_round_tripper_test.go new file mode 100644 index 0000000..a532522 --- /dev/null +++ b/httpclient/metrics_round_tripper_test.go @@ -0,0 +1,48 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package httpclient + +import ( + "context" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/require" + "net/http" + "net/http/httptest" + "testing" +) + +func TestNewMetricsRoundTripper(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusTeapot) + })) + defer server.Close() + + MustInitAndRegisterMetrics("") + defer UnregisterMetrics() + + metricsRoundTripper := NewMetricsRoundTripper(http.DefaultTransport, "test-request") + client := &http.Client{Transport: metricsRoundTripper} + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, server.URL, nil) + require.NoError(t, err) + + r, err := client.Do(req) + defer func() { _ = r.Body.Close() }() + require.NoError(t, err) + + ch := make(chan prometheus.Metric, 1) + go func() { + ClientHTTPRequestDuration.Collect(ch) + close(ch) + }() + + var metricCount int + for range ch { + metricCount++ + } + + require.Equal(t, metricCount, 1) +} diff --git a/httpclient/request_id_round_tripper.go b/httpclient/request_id_round_tripper.go new file mode 100644 index 0000000..6e51e4d --- /dev/null +++ b/httpclient/request_id_round_tripper.go @@ -0,0 +1,64 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package httpclient + +import ( + "context" + "github.com/acronis/go-appkit/httpserver/middleware" + "net/http" +) + +// RequestIDRoundTripper for X-Request-ID header to the request. +type RequestIDRoundTripper struct { + // Delegate is the next RoundTripper in the chain. + Delegate http.RoundTripper + + // Opts are the options for the request ID round tripper. + Opts RequestIDRoundTripperOpts +} + +// RequestIDRoundTripperOpts for X-Request-ID header to the request options. +type RequestIDRoundTripperOpts struct { + // RequestIDProvider is a function that provides a request ID. + RequestIDProvider func(ctx context.Context) string +} + +// NewRequestIDRoundTripper creates an HTTP transport with X-Request-ID header support. +func NewRequestIDRoundTripper(delegate http.RoundTripper) http.RoundTripper { + return &RequestIDRoundTripper{ + Delegate: delegate, + } +} + +// NewRequestIDRoundTripperWithOpts creates an HTTP transport with X-Request-ID header support with options. +func NewRequestIDRoundTripperWithOpts(delegate http.RoundTripper, opts RequestIDRoundTripperOpts) http.RoundTripper { + return &RequestIDRoundTripper{ + Delegate: delegate, + Opts: opts, + } +} + +// getRequestIDProvider returns a function with the request ID provider. +func (rt *RequestIDRoundTripper) getRequestIDProvider() func(ctx context.Context) string { + if rt.Opts.RequestIDProvider != nil { + return rt.Opts.RequestIDProvider + } + + return middleware.GetRequestIDFromContext +} + +// RoundTrip adds X-Request-ID header to the request. +func (rt *RequestIDRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { + requestID := rt.getRequestIDProvider()(r.Context()) + if r.Header.Get("X-Request-ID") != "" || requestID == "" { + return rt.Delegate.RoundTrip(r) + } + + r = CloneHTTPRequest(r) + r.Header.Set("X-Request-ID", requestID) + return rt.Delegate.RoundTrip(r) +} diff --git a/httpclient/request_id_round_tripper_test.go b/httpclient/request_id_round_tripper_test.go new file mode 100644 index 0000000..438c924 --- /dev/null +++ b/httpclient/request_id_round_tripper_test.go @@ -0,0 +1,60 @@ +/* +Copyright © 2024 Acronis International GmbH. + +Released under MIT license. +*/ + +package httpclient + +import ( + "context" + "github.com/acronis/go-appkit/httpserver/middleware" + "github.com/stretchr/testify/require" + "net/http" + "net/http/httptest" + "testing" +) + +func TestNewRequestIDRoundTripper(t *testing.T) { + requestID := "12345" + + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + require.Equal(t, requestID, r.Header.Get("X-Request-ID")) + rw.WriteHeader(http.StatusTeapot) + })) + defer server.Close() + + requestIDRoundTripper := NewRequestIDRoundTripper(http.DefaultTransport) + client := &http.Client{Transport: requestIDRoundTripper} + ctx := middleware.NewContextWithRequestID(context.Background(), requestID) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, nil) + require.NoError(t, err) + + r, err := client.Do(req) + defer func() { _ = r.Body.Close() }() + require.NoError(t, err) +} + +func TestNewRequestIDRoundTripperWithOpts(t *testing.T) { + requestID := "12345" + prefix := "my_custom_request_provider" + + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + require.Equal(t, prefix+requestID, r.Header.Get("X-Request-ID")) + rw.WriteHeader(http.StatusTeapot) + })) + defer server.Close() + + requestIDRoundTripper := NewRequestIDRoundTripperWithOpts(http.DefaultTransport, RequestIDRoundTripperOpts{ + RequestIDProvider: func(ctx context.Context) string { + return prefix + requestID + }, + }) + client := &http.Client{Transport: requestIDRoundTripper} + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, server.URL, nil) + require.NoError(t, err) + + r, err := client.Do(req) + defer func() { _ = r.Body.Close() }() + require.NoError(t, err) +}