diff --git a/.gitignore b/.gitignore index 3259e0b3..2f128a51 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ tests-offsets testdata .idea/ +.vscode .DS_Store diff --git a/plugin/output/elasticsearch/elasticsearch.go b/plugin/output/elasticsearch/elasticsearch.go index bfdc6de3..c43381d3 100644 --- a/plugin/output/elasticsearch/elasticsearch.go +++ b/plugin/output/elasticsearch/elasticsearch.go @@ -4,19 +4,16 @@ import ( "context" "encoding/base64" "fmt" - "math/rand" "net/http" "sync" "time" "github.com/ozontech/file.d/cfg" "github.com/ozontech/file.d/fd" - "github.com/ozontech/file.d/logger" "github.com/ozontech/file.d/metric" "github.com/ozontech/file.d/pipeline" - "github.com/ozontech/file.d/xtls" + "github.com/ozontech/file.d/xhttp" "github.com/prometheus/client_golang/prometheus" - "github.com/valyala/fasthttp" insaneJSON "github.com/vitkovskii/insane-json" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -30,45 +27,13 @@ If a network error occurs, the batch will infinitely try to be delivered to the const ( outPluginType = "elasticsearch" - NDJSONContentType = "application/x-ndjson" - gzipContentEncoding = "gzip" -) - -type gzipCompressionLevel int - -const ( - gzipCompressionLevelDefault gzipCompressionLevel = iota - gzipCompressionLevelNo - gzipCompressionLevelBestSpeed - gzipCompressionLevelBestCompression - gzipCompressionLevelHuffmanOnly -) - -func (l gzipCompressionLevel) toFastHTTP() int { - switch l { - case gzipCompressionLevelNo: - return fasthttp.CompressNoCompression - case gzipCompressionLevelBestSpeed: - return fasthttp.CompressBestSpeed - case gzipCompressionLevelBestCompression: - return fasthttp.CompressBestCompression - case gzipCompressionLevelHuffmanOnly: - return fasthttp.CompressHuffmanOnly - default: - return fasthttp.CompressDefaultCompression - } -} - -var ( - strAuthorization = []byte(fasthttp.HeaderAuthorization) + NDJSONContentType = "application/x-ndjson" ) type Plugin struct { config *Config - client *fasthttp.Client - endpoints []*fasthttp.URI - authHeader []byte + client *xhttp.Client logger *zap.Logger controller pipeline.OutputPluginController @@ -102,8 +67,7 @@ type Config struct { // > @3@4@5@6 // > // > Gzip compression level. Used if `use_gzip=true`. - GzipCompressionLevel string `json:"gzip_compression_level" default:"default" options:"default|no|best-speed|best-compression|huffman-only"` // * - GzipCompressionLevel_ gzipCompressionLevel + GzipCompressionLevel string `json:"gzip_compression_level" default:"default" options:"default|no|best-speed|best-compression|huffman-only"` // * // > @3@4@5@6 // > @@ -322,37 +286,46 @@ func (p *Plugin) registerMetrics(ctl *metric.Ctl) { } func (p *Plugin) prepareClient() { - p.client = &fasthttp.Client{ - ReadTimeout: p.config.ConnectionTimeout_ * 2, - WriteTimeout: p.config.ConnectionTimeout_ * 2, - - MaxIdleConnDuration: p.config.KeepAlive.MaxIdleConnDuration_, - MaxConnDuration: p.config.KeepAlive.MaxConnDuration_, + config := &xhttp.ClientConfig{ + Endpoints: prepareEndpoints(p.config.Endpoints), + ConnectionTimeout: p.config.ConnectionTimeout_ * 2, + AuthHeader: p.getAuthHeader(), + KeepAlive: &xhttp.ClientKeepAliveConfig{ + MaxConnDuration: p.config.KeepAlive.MaxConnDuration_, + MaxIdleConnDuration: p.config.KeepAlive.MaxIdleConnDuration_, + }, } if p.config.CACert != "" { - b := xtls.NewConfigBuilder() - err := b.AppendCARoot(p.config.CACert) - if err != nil { - p.logger.Fatal("can't append CA root", zap.Error(err)) + config.TLS = &xhttp.ClientTLSConfig{ + CACert: p.config.CACert, } - - p.client.TLSConfig = b.Build() + } + if p.config.UseGzip { + config.GzipCompressionLevel = p.config.GzipCompressionLevel } - for _, endpoint := range p.config.Endpoints { - if endpoint[len(endpoint)-1] == '/' { - endpoint = endpoint[:len(endpoint)-1] - } + var err error + p.client, err = xhttp.NewClient(config) + if err != nil { + p.logger.Fatal("can't create http client", zap.Error(err)) + } +} - uri := &fasthttp.URI{} - if err := uri.Parse(nil, []byte(endpoint+"/_bulk?_source=false")); err != nil { - logger.Fatalf("can't parse ES endpoint %s: %s", endpoint, err.Error()) +func prepareEndpoints(endpoints []string) []string { + res := make([]string, 0, len(endpoints)) + for _, e := range endpoints { + if e[len(e)-1] == '/' { + e = e[:len(e)-1] } - - p.endpoints = append(p.endpoints, uri) + res = append(res, e+"/_bulk?_source=false") } + return res +} - p.authHeader = p.getAuthHeader() +func (p *Plugin) maintenance(_ *pipeline.WorkerData) { + p.mu.Lock() + p.time = time.Now().Format(p.config.TimeFormat) + p.mu.Unlock() } func (p *Plugin) out(workerData *pipeline.WorkerData, batch *pipeline.Batch) error { @@ -373,63 +346,15 @@ func (p *Plugin) out(workerData *pipeline.WorkerData, batch *pipeline.Batch) err data.outBuf = p.appendEvent(data.outBuf, event) }) - err := p.send(data.outBuf) + _, err := p.client.DoTimeout(http.MethodPost, NDJSONContentType, data.outBuf, + p.config.ConnectionTimeout_, p.reportESErrors) + if err != nil { p.sendErrorMetric.Inc() p.logger.Error("can't send to the elastic, will try other endpoint", zap.Error(err)) } - return err -} - -func (p *Plugin) send(body []byte) error { - req := fasthttp.AcquireRequest() - defer fasthttp.ReleaseRequest(req) - resp := fasthttp.AcquireResponse() - defer fasthttp.ReleaseResponse(resp) - - endpoint := p.endpoints[rand.Int()%len(p.endpoints)] - p.prepareRequest(req, endpoint, body) - - if err := p.client.DoTimeout(req, resp, p.config.ConnectionTimeout_); err != nil { - return fmt.Errorf("can't send batch to %s: %s", endpoint.String(), err.Error()) - } - - respContent := resp.Body() - - if statusCode := resp.Header.StatusCode(); statusCode < http.StatusOK || statusCode > http.StatusAccepted { - return fmt.Errorf("response status from %s isn't OK: status=%d, body=%s", endpoint.String(), statusCode, string(respContent)) - } - - root, err := insaneJSON.DecodeBytes(respContent) - if err != nil { - return fmt.Errorf("wrong response from %s: %s", endpoint.String(), err.Error()) - } - defer insaneJSON.Release(root) - - p.reportESErrors(root) - - return nil -} - -func (p *Plugin) prepareRequest(req *fasthttp.Request, endpoint *fasthttp.URI, body []byte) { - req.SetURI(endpoint) - - req.Header.SetMethod(fasthttp.MethodPost) - req.Header.SetContentType(NDJSONContentType) - - if p.authHeader != nil { - req.Header.SetBytesKV(strAuthorization, p.authHeader) - } - if p.config.UseGzip { - if _, err := fasthttp.WriteGzipLevel(req.BodyWriter(), body, p.config.GzipCompressionLevel_.toFastHTTP()); err != nil { - req.SetBodyRaw(body) - } else { - req.Header.SetContentEncoding(gzipContentEncoding) - } - } else { - req.SetBodyRaw(body) - } + return err } func (p *Plugin) appendEvent(outBuf []byte, event *pipeline.Event) []byte { @@ -473,23 +398,15 @@ func (p *Plugin) appendIndexName(outBuf []byte, event *pipeline.Event) []byte { return outBuf } -func (p *Plugin) maintenance(_ *pipeline.WorkerData) { - p.mu.Lock() - p.time = time.Now().Format(p.config.TimeFormat) - p.mu.Unlock() -} - -func (p *Plugin) getAuthHeader() []byte { +func (p *Plugin) getAuthHeader() string { if p.config.APIKey != "" { - return []byte("ApiKey " + p.config.APIKey) + return "ApiKey " + p.config.APIKey } if p.config.Username != "" && p.config.Password != "" { credentials := []byte(p.config.Username + ":" + p.config.Password) - buf := make([]byte, base64.StdEncoding.EncodedLen(len(credentials))) - base64.StdEncoding.Encode(buf, credentials) - return append([]byte("Basic "), buf...) + return "Basic " + base64.StdEncoding.EncodeToString(credentials) } - return nil + return "" } // example of an ElasticSearch response that returned an indexing error for the first log: @@ -533,9 +450,15 @@ func (p *Plugin) getAuthHeader() []byte { // } // ] // } -func (p *Plugin) reportESErrors(root *insaneJSON.Root) { +func (p *Plugin) reportESErrors(data []byte) error { + root, err := insaneJSON.DecodeBytes(data) + defer insaneJSON.Release(root) + if err != nil { + return fmt.Errorf("can't decode response: %w", err) + } + if !root.Dig("errors").AsBool() { - return + return nil } items := root.Dig("items").AsArray() @@ -543,7 +466,7 @@ func (p *Plugin) reportESErrors(root *insaneJSON.Root) { p.logger.Error("unknown elasticsearch error, 'items' field in the response is empty", zap.String("response", root.EncodeToString()), ) - return + return nil } indexingErrors := 0 @@ -575,4 +498,5 @@ func (p *Plugin) reportESErrors(root *insaneJSON.Root) { } p.logger.Error("some events from batch aren't written, check previous logs for more information") + return nil } diff --git a/plugin/output/elasticsearch/elasticsearch_test.go b/plugin/output/elasticsearch/elasticsearch_test.go index 6b376310..1d1f0f11 100644 --- a/plugin/output/elasticsearch/elasticsearch_test.go +++ b/plugin/output/elasticsearch/elasticsearch_test.go @@ -6,7 +6,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/valyala/fasthttp" insaneJSON "github.com/vitkovskii/insane-json" "github.com/ozontech/file.d/pipeline" @@ -81,112 +80,24 @@ func TestAppendEventWithCreateOpType(t *testing.T) { assert.Equal(t, expected, string(result), "wrong request content") } -func TestConfig(t *testing.T) { - p := &Plugin{} - config := &Config{ - IndexFormat: "test-%", - Endpoints: []string{ - "http://endpoint_1:9000", - "http://endpoint_2:9000/", - "https://endpoint_3:9000", - "https://endpoint_4:9000/", - }, - BatchSize: "1", +func TestPrepareEndpoints(t *testing.T) { + in := []string{ + "http://endpoint_1:9000", + "http://endpoint_2:9000/", + "https://endpoint_3:9000", + "https://endpoint_4:9000/", } - test.NewConfig(config, map[string]int{"gomaxprocs": 1}) - - p.Start(config, test.NewEmptyOutputPluginParams()) - - results := []string{ + want := []string{ "http://endpoint_1:9000/_bulk?_source=false", "http://endpoint_2:9000/_bulk?_source=false", "https://endpoint_3:9000/_bulk?_source=false", "https://endpoint_4:9000/_bulk?_source=false", } - require.Len(t, p.endpoints, len(results)) - for i := range results { - assert.Equal(t, results[i], p.endpoints[i].String()) - } -} + got := prepareEndpoints(in) -func TestPrepareRequest(t *testing.T) { - type wantData struct { - uri string - method []byte - contentType []byte - contentEncoding []byte - auth []byte - body []byte - } - - cases := []struct { - name string - config *Config - - body string - want wantData - }{ - { - name: "raw", - config: &Config{ - Endpoints: []string{"http://endpoint:9000"}, - APIKey: "test", - }, - body: "test", - want: wantData{ - uri: "http://endpoint:9000/_bulk?_source=false", - method: []byte(fasthttp.MethodPost), - contentType: []byte(NDJSONContentType), - auth: []byte("ApiKey test"), - body: []byte("test"), - }, - }, - { - name: "gzip", - config: &Config{ - Endpoints: []string{"http://endpoint:9000"}, - UseGzip: true, - GzipCompressionLevel_: gzipCompressionLevelBestSpeed, - }, - body: "test", - want: wantData{ - uri: "http://endpoint:9000/_bulk?_source=false", - method: []byte(fasthttp.MethodPost), - contentType: []byte(NDJSONContentType), - contentEncoding: []byte(gzipContentEncoding), - body: []byte("test"), - }, - }, - } - for _, tt := range cases { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - p := Plugin{ - config: tt.config, - } - p.prepareClient() - - req := fasthttp.AcquireRequest() - defer fasthttp.ReleaseRequest(req) - - p.prepareRequest(req, p.endpoints[0], []byte(tt.body)) - - require.Equal(t, tt.want.uri, req.URI().String(), "wrong uri") - require.Equal(t, tt.want.method, req.Header.Method(), "wrong method") - require.Equal(t, tt.want.contentType, req.Header.ContentType(), "wrong content type") - require.Equal(t, tt.want.contentEncoding, req.Header.ContentEncoding(), "wrong content encoding") - require.Equal(t, tt.want.auth, req.Header.PeekBytes(strAuthorization), "wrong auth") - - var body []byte - if tt.config.UseGzip { - body, _ = req.BodyUncompressed() - } else { - body = req.Body() - } - require.Equal(t, tt.want.body, body, "wrong body") - }) + require.Len(t, got, len(want)) + for i := range got { + assert.Equal(t, want[i], got[i]) } } diff --git a/plugin/output/splunk/splunk.go b/plugin/output/splunk/splunk.go index d5410a3e..325fa66b 100644 --- a/plugin/output/splunk/splunk.go +++ b/plugin/output/splunk/splunk.go @@ -3,7 +3,6 @@ package splunk import ( "context" - "crypto/tls" "fmt" "net/http" "strconv" @@ -14,8 +13,8 @@ import ( "github.com/ozontech/file.d/fd" "github.com/ozontech/file.d/metric" "github.com/ozontech/file.d/pipeline" + "github.com/ozontech/file.d/xhttp" "github.com/prometheus/client_golang/prometheus" - "github.com/valyala/fasthttp" insaneJSON "github.com/vitkovskii/insane-json" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -73,35 +72,8 @@ Out: const ( outPluginType = "splunk" - - gzipContentEncoding = "gzip" ) -type gzipCompressionLevel int - -const ( - gzipCompressionLevelDefault gzipCompressionLevel = iota - gzipCompressionLevelNo - gzipCompressionLevelBestSpeed - gzipCompressionLevelBestCompression - gzipCompressionLevelHuffmanOnly -) - -func (l gzipCompressionLevel) toFastHTTP() int { - switch l { - case gzipCompressionLevelNo: - return fasthttp.CompressNoCompression - case gzipCompressionLevelBestSpeed: - return fasthttp.CompressBestSpeed - case gzipCompressionLevelBestCompression: - return fasthttp.CompressBestCompression - case gzipCompressionLevelHuffmanOnly: - return fasthttp.CompressHuffmanOnly - default: - return fasthttp.CompressDefaultCompression - } -} - type copyFieldPaths struct { fromPath []string toPath []string @@ -110,9 +82,7 @@ type copyFieldPaths struct { type Plugin struct { config *Config - client *fasthttp.Client - endpoint *fasthttp.URI - authHeader string + client *xhttp.Client copyFieldsPaths []copyFieldPaths @@ -149,8 +119,7 @@ type Config struct { // > @3@4@5@6 // > // > Gzip compression level. Used if `use_gzip=true`. - GzipCompressionLevel string `json:"gzip_compression_level" default:"default" options:"default|no|best-speed|best-compression|huffman-only"` // * - GzipCompressionLevel_ gzipCompressionLevel + GzipCompressionLevel string `json:"gzip_compression_level" default:"default" options:"default|no|best-speed|best-compression|huffman-only"` // * // > @3@4@5@6 // > @@ -325,6 +294,15 @@ func (p *Plugin) Start(config pipeline.AnyConfig, params *pipeline.OutputPluginP p.batcher.Start(ctx) } +func (p *Plugin) Stop() { + p.batcher.Stop() + p.cancel() +} + +func (p *Plugin) Out(event *pipeline.Event) { + p.batcher.Add(event) +} + func (p *Plugin) registerMetrics(ctl *metric.Ctl) { p.sendErrorMetric = ctl.RegisterCounterVec( "output_splunk_send_error", @@ -334,34 +312,28 @@ func (p *Plugin) registerMetrics(ctl *metric.Ctl) { } func (p *Plugin) prepareClient() { - p.client = &fasthttp.Client{ - ReadTimeout: p.config.RequestTimeout_, - WriteTimeout: p.config.RequestTimeout_, - - MaxIdleConnDuration: p.config.KeepAlive.MaxIdleConnDuration_, - MaxConnDuration: p.config.KeepAlive.MaxConnDuration_, - - TLSConfig: &tls.Config{ + config := &xhttp.ClientConfig{ + Endpoints: []string{p.config.Endpoint}, + ConnectionTimeout: p.config.RequestTimeout_, + AuthHeader: "Splunk " + p.config.Token, + KeepAlive: &xhttp.ClientKeepAliveConfig{ + MaxConnDuration: p.config.KeepAlive.MaxConnDuration_, + MaxIdleConnDuration: p.config.KeepAlive.MaxIdleConnDuration_, + }, + TLS: &xhttp.ClientTLSConfig{ // TODO: make this configuration option and false by default InsecureSkipVerify: true, }, } - - p.endpoint = &fasthttp.URI{} - if err := p.endpoint.Parse(nil, []byte(p.config.Endpoint)); err != nil { - p.logger.Fatalf("can't parse splunk endpoint %s: %s", p.config.Endpoint, err.Error()) + if p.config.UseGzip { + config.GzipCompressionLevel = p.config.GzipCompressionLevel } - p.authHeader = "Splunk " + p.config.Token -} - -func (p *Plugin) Stop() { - p.batcher.Stop() - p.cancel() -} - -func (p *Plugin) Out(event *pipeline.Event) { - p.batcher.Add(event) + var err error + p.client, err = xhttp.NewClient(config) + if err != nil { + p.logger.Fatal("can't create http client", zap.Error(err)) + } } func (p *Plugin) out(workerData *pipeline.WorkerData, batch *pipeline.Batch) error { @@ -399,7 +371,9 @@ func (p *Plugin) out(workerData *pipeline.WorkerData, batch *pipeline.Batch) err p.logger.Debugf("trying to send: %s", outBuf) - code, err := p.send(outBuf) + code, err := p.client.DoTimeout(http.MethodPost, "", outBuf, + p.config.RequestTimeout_, parseSplunkError) + if err != nil { p.sendErrorMetric.WithLabelValues(strconv.Itoa(code)).Inc() p.logger.Errorf("can't send data to splunk address=%s: %s", p.config.Endpoint, err.Error()) @@ -417,45 +391,6 @@ func (p *Plugin) out(workerData *pipeline.WorkerData, batch *pipeline.Batch) err func (p *Plugin) maintenance(_ *pipeline.WorkerData) {} -func (p *Plugin) send(data []byte) (int, error) { - req := fasthttp.AcquireRequest() - defer fasthttp.ReleaseRequest(req) - resp := fasthttp.AcquireResponse() - defer fasthttp.ReleaseResponse(resp) - - p.prepareRequest(req, data) - - if err := p.client.DoTimeout(req, resp, p.config.RequestTimeout_); err != nil { - return 0, fmt.Errorf("can't send request: %w", err) - } - - respBody := resp.Body() - - var statusCode int - if statusCode = resp.Header.StatusCode(); statusCode != http.StatusOK { - return statusCode, fmt.Errorf("bad response: code=%s, body=%s", resp.Header.StatusMessage(), respBody) - } - - return statusCode, parseSplunkError(respBody) -} - -func (p *Plugin) prepareRequest(req *fasthttp.Request, body []byte) { - req.SetURI(p.endpoint) - - req.Header.SetMethod(fasthttp.MethodPost) - req.Header.Set(fasthttp.HeaderAuthorization, p.authHeader) - - if p.config.UseGzip { - if _, err := fasthttp.WriteGzipLevel(req.BodyWriter(), body, p.config.GzipCompressionLevel_.toFastHTTP()); err != nil { - req.SetBodyRaw(body) - } else { - req.Header.SetContentEncoding(gzipContentEncoding) - } - } else { - req.SetBodyRaw(body) - } -} - func parseSplunkError(data []byte) error { root, err := insaneJSON.DecodeBytes(data) defer insaneJSON.Release(root) diff --git a/plugin/output/splunk/splunk_test.go b/plugin/output/splunk/splunk_test.go index a40e0f6b..bbd91867 100644 --- a/plugin/output/splunk/splunk_test.go +++ b/plugin/output/splunk/splunk_test.go @@ -9,7 +9,6 @@ import ( "github.com/ozontech/file.d/cfg" "github.com/ozontech/file.d/pipeline" "github.com/stretchr/testify/assert" - "github.com/valyala/fasthttp" insaneJSON "github.com/vitkovskii/insane-json" "go.uber.org/zap" ) @@ -70,85 +69,6 @@ func TestSplunk(t *testing.T) { } } -func TestPrepareRequest(t *testing.T) { - type wantData struct { - uri string - method []byte - contentEncoding []byte - auth []byte - body []byte - } - - cases := []struct { - name string - config *Config - - body string - want wantData - }{ - { - name: "raw", - config: &Config{ - Endpoint: "http://endpoint:9000", - Token: "test", - }, - body: "test", - want: wantData{ - uri: "http://endpoint:9000/", - method: []byte(fasthttp.MethodPost), - auth: []byte("Splunk test"), - body: []byte("test"), - }, - }, - { - name: "gzip", - config: &Config{ - Endpoint: "http://endpoint:9000", - Token: "test", - UseGzip: true, - GzipCompressionLevel_: gzipCompressionLevelBestCompression, - }, - body: "test", - want: wantData{ - uri: "http://endpoint:9000/", - method: []byte(fasthttp.MethodPost), - contentEncoding: []byte(gzipContentEncoding), - auth: []byte("Splunk test"), - body: []byte("test"), - }, - }, - } - for _, tt := range cases { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - p := Plugin{ - config: tt.config, - } - p.prepareClient() - - req := fasthttp.AcquireRequest() - defer fasthttp.ReleaseRequest(req) - - p.prepareRequest(req, []byte(tt.body)) - - assert.Equal(t, tt.want.uri, req.URI().String(), "wrong uri") - assert.Equal(t, tt.want.method, req.Header.Method(), "wrong method") - assert.Equal(t, tt.want.contentEncoding, req.Header.ContentEncoding(), "wrong content encoding") - assert.Equal(t, tt.want.auth, req.Header.Peek(fasthttp.HeaderAuthorization), "wrong auth") - - var body []byte - if tt.config.UseGzip { - body, _ = req.BodyUncompressed() - } else { - body = req.Body() - } - assert.Equal(t, tt.want.body, body, "wrong body") - }) - } -} - func TestParseSplunkError(t *testing.T) { cases := []struct { name string diff --git a/xhttp/client.go b/xhttp/client.go new file mode 100644 index 00000000..314a9760 --- /dev/null +++ b/xhttp/client.go @@ -0,0 +1,162 @@ +package xhttp + +import ( + "fmt" + "math/rand" + "net/http" + "time" + + "github.com/ozontech/file.d/xtls" + "github.com/valyala/fasthttp" +) + +const gzipContentEncoding = "gzip" + +type ClientTLSConfig struct { + CACert string + InsecureSkipVerify bool +} + +type ClientKeepAliveConfig struct { + MaxConnDuration time.Duration + MaxIdleConnDuration time.Duration +} + +type ClientConfig struct { + Endpoints []string + ConnectionTimeout time.Duration + AuthHeader string + GzipCompressionLevel string + TLS *ClientTLSConfig + KeepAlive *ClientKeepAliveConfig +} + +type Client struct { + client *fasthttp.Client + endpoints []*fasthttp.URI + authHeader string + gzipCompressionLevel int +} + +func NewClient(cfg *ClientConfig) (*Client, error) { + client := &fasthttp.Client{ + ReadTimeout: cfg.ConnectionTimeout, + WriteTimeout: cfg.ConnectionTimeout, + } + + if cfg.KeepAlive != nil { + client.MaxConnDuration = cfg.KeepAlive.MaxConnDuration + client.MaxIdleConnDuration = cfg.KeepAlive.MaxIdleConnDuration + } + + if cfg.TLS != nil { + b := xtls.NewConfigBuilder() + if cfg.TLS.CACert != "" { + err := b.AppendCARoot(cfg.TLS.CACert) + if err != nil { + return nil, fmt.Errorf("can't append CA root: %w", err) + } + } + b.SetSkipVerify(cfg.TLS.InsecureSkipVerify) + + client.TLSConfig = b.Build() + } + + endpoints, err := parseEndpoints(cfg.Endpoints) + if err != nil { + return nil, err + } + + return &Client{ + client: client, + endpoints: endpoints, + authHeader: cfg.AuthHeader, + gzipCompressionLevel: parseGzipCompressionLevel(cfg.GzipCompressionLevel), + }, nil +} + +func (c *Client) DoTimeout( + method, contentType string, + body []byte, + timeout time.Duration, + processResponse func([]byte) error, +) (int, error) { + req := fasthttp.AcquireRequest() + defer fasthttp.ReleaseRequest(req) + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseResponse(resp) + + var endpoint *fasthttp.URI + if len(c.endpoints) == 1 { + endpoint = c.endpoints[0] + } else { + endpoint = c.endpoints[rand.Int()%len(c.endpoints)] + } + + c.prepareRequest(req, endpoint, method, contentType, body) + + if err := c.client.DoTimeout(req, resp, timeout); err != nil { + return 0, fmt.Errorf("can't send request to %s: %w", endpoint.String(), err) + } + + respContent := resp.Body() + statusCode := resp.Header.StatusCode() + + if statusCode < http.StatusOK || statusCode > http.StatusAccepted { + return statusCode, fmt.Errorf("response status from %s isn't OK: status=%d, body=%s", endpoint.String(), statusCode, string(respContent)) + } + + if processResponse != nil { + return statusCode, processResponse(respContent) + } + return statusCode, nil +} + +func (c *Client) prepareRequest(req *fasthttp.Request, endpoint *fasthttp.URI, method, contentType string, body []byte) { + req.SetURI(endpoint) + req.Header.SetMethod(method) + if contentType != "" { + req.Header.SetContentType(contentType) + } + if c.authHeader != "" { + req.Header.Set(fasthttp.HeaderAuthorization, c.authHeader) + } + if c.gzipCompressionLevel != -1 { + if _, err := fasthttp.WriteGzipLevel(req.BodyWriter(), body, c.gzipCompressionLevel); err != nil { + req.SetBodyRaw(body) + } else { + req.Header.SetContentEncoding(gzipContentEncoding) + } + } else { + req.SetBodyRaw(body) + } +} + +func parseEndpoints(endpoints []string) ([]*fasthttp.URI, error) { + res := make([]*fasthttp.URI, 0, len(endpoints)) + for _, e := range endpoints { + uri := &fasthttp.URI{} + if err := uri.Parse(nil, []byte(e)); err != nil { + return nil, fmt.Errorf("can't parse endpoint %s: %w", e, err) + } + res = append(res, uri) + } + return res, nil +} + +func parseGzipCompressionLevel(level string) int { + switch level { + case "default": + return fasthttp.CompressDefaultCompression + case "no": + return fasthttp.CompressNoCompression + case "best-speed": + return fasthttp.CompressBestSpeed + case "best-compression": + return fasthttp.CompressBestCompression + case "huffman-only": + return fasthttp.CompressHuffmanOnly + default: + return -1 + } +} diff --git a/xhttp/client_test.go b/xhttp/client_test.go new file mode 100644 index 00000000..0966460c --- /dev/null +++ b/xhttp/client_test.go @@ -0,0 +1,118 @@ +package xhttp + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" +) + +func TestPrepareRequest(t *testing.T) { + type inputData struct { + endpoint string + method string + contentType string + body string + authHeader string + gzipCompressionLevel int + } + + type wantData struct { + uri string + method []byte + contentType []byte + contentEncoding []byte + body []byte + auth []byte + } + + cases := []struct { + name string + in inputData + want wantData + }{ + { + name: "simple", + in: inputData{ + endpoint: "http://endpoint:1", + method: fasthttp.MethodPost, + contentType: "application/json", + body: "test simple", + gzipCompressionLevel: -1, + }, + want: wantData{ + uri: "http://endpoint:1/", + method: []byte(fasthttp.MethodPost), + contentType: []byte("application/json"), + body: []byte("test simple"), + }, + }, + { + name: "auth", + in: inputData{ + endpoint: "http://endpoint:3", + method: fasthttp.MethodPost, + contentType: "application/json", + body: "test auth", + authHeader: "Auth Header", + gzipCompressionLevel: -1, + }, + want: wantData{ + uri: "http://endpoint:3/", + method: []byte(fasthttp.MethodPost), + contentType: []byte("application/json"), + body: []byte("test auth"), + auth: []byte("Auth Header"), + }, + }, + { + name: "gzip", + in: inputData{ + endpoint: "http://endpoint:4", + method: fasthttp.MethodPost, + contentType: "application/json", + body: "test gzip", + gzipCompressionLevel: 1, + }, + want: wantData{ + uri: "http://endpoint:4/", + method: []byte(fasthttp.MethodPost), + contentType: []byte("application/json"), + contentEncoding: []byte(gzipContentEncoding), + body: []byte("test gzip"), + }, + }, + } + for _, tt := range cases { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + c := &Client{ + authHeader: tt.in.authHeader, + gzipCompressionLevel: tt.in.gzipCompressionLevel, + } + + req := fasthttp.AcquireRequest() + defer fasthttp.ReleaseRequest(req) + + endpoints, _ := parseEndpoints([]string{tt.in.endpoint}) + + c.prepareRequest(req, endpoints[0], tt.in.method, tt.in.contentType, []byte(tt.in.body)) + + require.Equal(t, tt.want.uri, req.URI().String(), "wrong uri") + require.Equal(t, tt.want.method, req.Header.Method(), "wrong method") + require.Equal(t, tt.want.contentType, req.Header.ContentType(), "wrong content type") + require.Equal(t, tt.want.contentEncoding, req.Header.ContentEncoding(), "wrong content encoding") + require.Equal(t, tt.want.auth, req.Header.Peek(fasthttp.HeaderAuthorization), "wrong auth") + + var body []byte + if tt.in.gzipCompressionLevel != -1 { + body, _ = req.BodyUncompressed() + } else { + body = req.Body() + } + require.Equal(t, tt.want.body, body, "wrong body") + }) + } +}