diff --git a/config.go b/config.go index 2ba90f7..e09cdc6 100644 --- a/config.go +++ b/config.go @@ -21,6 +21,7 @@ import ( "io" "io/ioutil" "net/http" + "net/http/httputil" "time" yaml "gopkg.in/yaml.v2" @@ -60,9 +61,9 @@ type httpConfig struct { Address string `yaml:"address"` // 127.0.0.1 XXX map[string]interface{} `yaml:",inline"` - tlsConfig *tls.Config - httpClient *http.Client - mcfg *moduleConfig + tlsConfig *tls.Config + mcfg *moduleConfig + *httputil.ReverseProxy } type execConfig struct { @@ -116,6 +117,8 @@ func checkModuleConfig(name string, cfg *moduleConfig) error { return fmt.Errorf("unknown module configuration fields: %v", cfg.XXX) } + cfg.name = name + switch cfg.Method { case "http": if len(cfg.HTTP.XXX) != 0 { @@ -144,10 +147,13 @@ func checkModuleConfig(name string, cfg *moduleConfig) error { return fmt.Errorf("could not create tls config, %w", err) } cfg.HTTP.tlsConfig = tlsConfig - cfg.HTTP.httpClient = &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: tlsConfig, - }, + cfg.HTTP.ReverseProxy = &httputil.ReverseProxy{ + Transport: &http.Transport{TLSClientConfig: tlsConfig}, + Director: cfg.getReverseProxyDirectorFunc(), + ErrorHandler: cfg.getReverseProxyErrorHandlerFunc(), + } + if *cfg.HTTP.Verify { + cfg.HTTP.ReverseProxy.ModifyResponse = cfg.getReverseProxyModifyResponseFunc() } case "exec": if len(cfg.Exec.XXX) != 0 { diff --git a/http.go b/http.go index 20c06a1..8218a65 100644 --- a/http.go +++ b/http.go @@ -14,60 +14,81 @@ package main import ( + "bytes" + "compress/gzip" "context" - "fmt" + "errors" "io" + "io/ioutil" "net" "net/http" - "net/http/httputil" - "net/url" "strconv" "strings" - "golang.org/x/net/context/ctxhttp" - - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promhttp" dto "github.com/prometheus/client_model/go" "github.com/prometheus/common/expfmt" log "github.com/sirupsen/logrus" ) -func (c httpConfig) GatherWithContext(ctx context.Context, r *http.Request) prometheus.GathererFunc { - return func() ([]*dto.MetricFamily, error) { - qvs := r.URL.Query() - qvs["module"] = qvs["module"][1:] +const ( + // Msg to send in response body when verification of proxied server + // response is failed + VerificationErrorMsg = "Internal Server Error: " + + "Response from proxied server failed verification. " + + "See server logs for details" +) - url, err := url.Parse(c.Path) - uvs := url.Query() - for k, vs := range uvs { - for _, v := range vs { - qvs.Add(k, v) - } - } +type VerifyError struct { + msg string + cause error +} - url.Host = net.JoinHostPort(c.Address, strconv.Itoa(c.Port)) - url.Scheme = c.Scheme - url.RawQuery = qvs.Encode() +func (e *VerifyError) Error() string { return e.msg + ": " + e.cause.Error() } +func (e *VerifyError) Unwrap() error { return e.cause } - resp, err := ctxhttp.Get(ctx, c.httpClient, url.String()) - if err != nil { - log.Errorf("http proxy for module %v failed %+v", c.mcfg.name, err) - proxyErrorCount.WithLabelValues(c.mcfg.name).Inc() - if err == context.DeadlineExceeded { - proxyTimeoutCount.WithLabelValues(c.mcfg.name).Inc() - } - return nil, err - } - defer resp.Body.Close() +func (cfg moduleConfig) getReverseProxyDirectorFunc() func(*http.Request) { + return func(r *http.Request) { + vs := r.URL.Query() + vs["module"] = vs["module"][1:] + r.URL.RawQuery = vs.Encode() + + r.URL.Scheme = cfg.HTTP.Scheme + r.URL.Host = net.JoinHostPort(cfg.HTTP.Address, strconv.Itoa(cfg.HTTP.Port)) + r.URL.Path = cfg.HTTP.Path + } +} +func (cfg moduleConfig) getReverseProxyModifyResponseFunc() func(*http.Response) error { + return func(resp *http.Response) error { if resp.StatusCode != 200 { - return nil, fmt.Errorf("server responded %v, %q", resp.StatusCode, resp.Status) + return nil + } + + var ( + err error + body bytes.Buffer + oldBody = resp.Body + ) + defer oldBody.Close() + + if _, err = body.ReadFrom(oldBody); err != nil { + return &VerifyError{"Failed to read body from proxied server", err} } - dec := expfmt.NewDecoder(resp.Body, expfmt.ResponseFormat(resp.Header)) + resp.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes())) + + var bodyReader io.ReadCloser + if resp.Header.Get("Content-Encoding") == "gzip" { + bodyReader, err = gzip.NewReader(bytes.NewReader(body.Bytes())) + if err != nil { + return &VerifyError{"Failed to decode gzipped response", err} + } + } else { + bodyReader = ioutil.NopCloser(bytes.NewReader(body.Bytes())) + } + defer bodyReader.Close() - result := []*dto.MetricFamily{} + dec := expfmt.NewDecoder(bodyReader, expfmt.ResponseFormat(resp.Header)) for { mf := dto.MetricFamily{} err := dec.Decode(&mf) @@ -75,49 +96,33 @@ func (c httpConfig) GatherWithContext(ctx context.Context, r *http.Request) prom break } if err != nil { - proxyMalformedCount.WithLabelValues(c.mcfg.name).Inc() - log.Errorf("err %+v", err) - return nil, err + proxyMalformedCount.WithLabelValues(cfg.name).Inc() + return &VerifyError{"Failed to decode metrics from proxied server", err} } - - result = append(result, &mf) } - return result, nil + return nil } } -func (c httpConfig) ServeHTTP(w http.ResponseWriter, r *http.Request) { - var h http.Handler - - if !(*c.Verify) { - // proxy directly - rt := &http.Transport{ - Dial: (&net.Dialer{ - Timeout: c.mcfg.Timeout, - }).Dial, - TLSHandshakeTimeout: c.mcfg.Timeout, - TLSClientConfig: c.tlsConfig, +func (cfg moduleConfig) getReverseProxyErrorHandlerFunc() func(http.ResponseWriter, *http.Request, error) { + return func(w http.ResponseWriter, r *http.Request, err error) { + var verifyError *VerifyError + if errors.As(err, &verifyError) { + log.Errorf("Verification for module '%s' failed: %v", cfg.name, err) + http.Error(w, VerificationErrorMsg, http.StatusInternalServerError) + return } - h = &httputil.ReverseProxy{ - Transport: rt, - Director: func(r *http.Request) { - vs := r.URL.Query() - vs["module"] = vs["module"][1:] - r.URL.RawQuery = vs.Encode() - - r.URL.Scheme = c.Scheme - r.URL.Host = net.JoinHostPort(c.Address, strconv.Itoa(c.Port)) - r.URL.Path = c.Path - }, + + if errors.Is(err, context.DeadlineExceeded) { + log.Errorf("Request time out for module '%s'", cfg.name) + http.Error(w, http.StatusText(http.StatusGatewayTimeout), http.StatusGatewayTimeout) + return } - } else { - ctx := r.Context() - g := c.GatherWithContext(ctx, r) - h = promhttp.HandlerFor(g, promhttp.HandlerOpts{}) - } - h.ServeHTTP(w, r) + log.Errorf("Proxy error for module '%s': %v", cfg.name, err) + http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway) + } } // BearerAuthMiddleware diff --git a/http_test.go b/http_test.go new file mode 100644 index 0000000..d7a5cae --- /dev/null +++ b/http_test.go @@ -0,0 +1,114 @@ +package main + +import ( + "bytes" + "fmt" + "io" + "math/rand" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "testing" + "time" + + dto "github.com/prometheus/client_model/go" + "github.com/prometheus/common/expfmt" +) + +func BenchmarkReverseProxyHandler(b *testing.B) { + body := genRandomMetricsResponse(10000, 10) + + test_exporter := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reader := bytes.NewReader(body.Bytes()) + io.Copy(w, reader) + })) + defer test_exporter.Close() + + URL, _ := url.Parse(test_exporter.URL) + verify := true + port, _ := strconv.ParseInt(URL.Port(), 0, 0) + modCfg := &moduleConfig{ + Method: "http", + Timeout: 5 * time.Second, + HTTP: httpConfig{ + Verify: &verify, + Scheme: URL.Scheme, + Address: URL.Hostname(), + Port: int(port), + Path: "/", + }, + } + + if err := checkModuleConfig("test", modCfg); err != nil { + b.Fatalf("Failed to check module config: %v", err) + } + + cfg := &config{ + Modules: map[string]*moduleConfig{ + "test": modCfg, + }, + } + + req := httptest.NewRequest("GET", "/proxy?module=test", nil) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + rr := httptest.NewRecorder() + cfg.doProxy(rr, req) + if rr.Code != http.StatusOK { + b.Fatalf("Bad response status %d", rr.Code) + } + if len(rr.Body.Bytes()) <= 0 { + b.Fatal("Response body is absent") + } + } +} + +// genRandomMetricsResponse generates http response body which contains random set of +// prometheus metrics. mf_num sets number of metric families in response which has +// metric names in format 'metric{random number}'. m_num controls number of metrics +// inside each metric family. Metrics inside metric families differ in values of +// label 'label'. +func genRandomMetricsResponse(mf_num int, m_num int) *bytes.Buffer { + rand.Seed(time.Now().UnixNano()) + helpMsg := "help msg" + labelName := "label" + metricFamilies := make([]*dto.MetricFamily, mf_num) + metricType := dto.MetricType_GAUGE + for i, _ := range metricFamilies { + metrics := make([]*dto.Metric, m_num) + for i, _ := range metrics { + labelValue := fmt.Sprint(rand.Int63()) + value := rand.Float64() + ts := time.Now().UnixNano() + metrics[i] = &dto.Metric{ + Label: []*dto.LabelPair{ + &dto.LabelPair{ + Name: &labelName, + Value: &labelValue, + }, + }, + Gauge: &dto.Gauge{ + Value: &value, + }, + TimestampMs: &ts, + } + } + metricName := fmt.Sprintf("metric%d", rand.Int63()) + metricFamilies[i] = &dto.MetricFamily{ + Name: &metricName, + Help: &helpMsg, + Type: &metricType, + Metric: metrics, + } + } + + buf := &bytes.Buffer{} + enc := expfmt.NewEncoder(buf, expfmt.FmtText) + for _, mf := range metricFamilies { + enc.Encode(mf) + } + + return buf +} diff --git a/main.go b/main.go index 93c09a2..a35d962 100644 --- a/main.go +++ b/main.go @@ -346,7 +346,6 @@ func (cfg *config) doProxy(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("unknown module %v\n", mod), http.StatusNotFound) return } else { - m.name = mod[0] h = m }