diff --git a/config.dgate.yaml b/config.dgate.yaml index d3fa32b..dc946f1 100644 --- a/config.dgate.yaml +++ b/config.dgate.yaml @@ -1,6 +1,6 @@ version: v1 debug: true -log_level: ${LOG_LEVEL:-debug} +log_level: ${LOG_LEVEL:-info} disable_default_namespace: true tags: [debug, local, test] storage: @@ -15,7 +15,8 @@ proxy: port: ${PORT:-80} host: 0.0.0.0 console_log_level: info - transport: + client_transport: + disable_private_ips: false dns_prefer_go: true init_resources: namespaces: diff --git a/internal/config/config.go b/internal/config/config.go index b95e427..e9c6441 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -29,13 +29,7 @@ type ( } LoggingConfig struct { - ZapConfig *zap.Config `koanf:",squash"` - LogOutputs []*LogOutput `koanf:"log_outputs"` - } - - LogOutput struct { - Name string `koanf:"name"` - Config map[string]any `koanf:",remain"` + ZapConfig *zap.Config `koanf:",squash"` } DGateProxyConfig struct { @@ -176,6 +170,7 @@ type ( KeepAlive time.Duration `koanf:"keep_alive"` ResponseHeaderTimeout time.Duration `koanf:"response_header_timeout"` DialTimeout time.Duration `koanf:"dial_timeout"` + DisablePrivateIPs bool `koanf:"disable_private_ips"` } DGateStorageConfig struct { @@ -255,6 +250,7 @@ func (conf *DGateConfig) GetLogger() (*zap.Logger, error) { if logger, err := conf.Logging.ZapConfig.Build(); err != nil { return nil, err } else { + zap.ReplaceGlobals(logger) return logger, nil } } diff --git a/internal/proxy/change_log.go b/internal/proxy/change_log.go index 69ef075..758e4d1 100644 --- a/internal/proxy/change_log.go +++ b/internal/proxy/change_log.go @@ -17,10 +17,12 @@ import ( func (ps *ProxyState) processChangeLog(cl *spec.ChangeLog, reload, store bool) (err error) { if reload { defer func(start time.Time) { - ps.logger.Debug("processing change log", - zap.String("id", cl.ID), - zap.Duration("duration", time.Since(start)), - ) + if err != nil { + ps.logger.Debug("processed change log", + zap.String("id", cl.ID), + zap.Duration("duration", time.Since(start)), + ) + } }(time.Now()) } ps.proxyLock.Lock() diff --git a/internal/proxy/dynamic_proxy.go b/internal/proxy/dynamic_proxy.go index de74200..89f5086 100644 --- a/internal/proxy/dynamic_proxy.go +++ b/internal/proxy/dynamic_proxy.go @@ -173,15 +173,19 @@ func (ps *ProxyState) setupRoutes( if len(rt.Modules) > 0 { modExtFunc := ps.createModuleExtractorFunc(rt) if modPool, err := NewModulePool( - 256, 1024, reqCtxProvider, modExtFunc, + 0, 1024, time.Minute*5, + reqCtxProvider, modExtFunc, ); err != nil { ps.logger.Error("Error creating module buffer", zap.Error(err)) return err } else { - reqCtxProvider.SetModulePool(modPool) + reqCtxProvider.UpdateModulePool(modPool) } } - ps.providers.Insert(rt.Namespace.Name+"/"+rt.Name, reqCtxProvider) + oldReqCtxProvider := ps.providers.Insert(rt.Namespace.Name+"/"+rt.Name, reqCtxProvider) + if oldReqCtxProvider != nil { + oldReqCtxProvider.Close() + } for _, path := range rt.Paths { if len(rt.Methods) > 0 && rt.Methods[0] == "*" { if len(rt.Methods) > 1 { diff --git a/internal/proxy/module_executor.go b/internal/proxy/module_executor.go index 229230e..8dbe2d7 100644 --- a/internal/proxy/module_executor.go +++ b/internal/proxy/module_executor.go @@ -2,6 +2,9 @@ package proxy import ( "context" + "time" + + "go.uber.org/zap" ) type ModulePool interface { @@ -11,44 +14,55 @@ type ModulePool interface { } type modulePool struct { - modExtBuffer chan ModuleExtractor - min, max int - - ctxCancel context.CancelFunc - ctx context.Context + modExtChan chan ModuleExtractor + min, max int + cancel context.CancelFunc + ctx context.Context - createModuleExtract func() (ModuleExtractor, error) + createModExt func() (ModuleExtractor, error) } func NewModulePool( minBuffers, maxBuffers int, + bufferTimeout time.Duration, reqCtxProvider *RequestContextProvider, createModExts ModuleExtractorFunc, ) (ModulePool, error) { - if minBuffers < 1 { - panic("module concurrency must be greater than 0") - } if maxBuffers < minBuffers { panic("maxBuffers must be greater than minBuffers") } - if _, err := createModExts(reqCtxProvider); err != nil { return nil, err } + modExtChan := make(chan ModuleExtractor, maxBuffers) mb := &modulePool{ - min: minBuffers, - max: maxBuffers, - modExtBuffer: make(chan ModuleExtractor, maxBuffers), + min: minBuffers, + max: maxBuffers, + modExtChan: modExtChan, + createModExt: func() (ModuleExtractor, error) { + return createModExts(reqCtxProvider) + }, } - mb.createModuleExtract = func() (ModuleExtractor, error) { - return createModExts(reqCtxProvider) - } - mb.ctx, mb.ctxCancel = context.WithCancel(reqCtxProvider.ctx) + mb.ctx, mb.cancel = context.WithCancel(reqCtxProvider.ctx) + + // add min module extractors to the pool + defer func() { + for i := 0; i < minBuffers; i++ { + me, err := mb.createModExt() + if err == nil { + mb.modExtChan <- me + } + } + }() + return mb, nil } func (mb *modulePool) Borrow() ModuleExtractor { - if mb == nil || mb.ctx == nil || mb.ctx.Err() != nil { + if mb == nil || mb.ctx.Err() != nil { + zap.L().Warn("stale use of module pool", + zap.Any("modPool", mb), + ) return nil } var ( @@ -56,12 +70,10 @@ func (mb *modulePool) Borrow() ModuleExtractor { err error ) select { - case me = <-mb.modExtBuffer: + case me = <-mb.modExtChan: break - // NOTE: important for performance default: - me, err = mb.createModuleExtract() - if err != nil { + if me, err = mb.createModExt(); err != nil { return nil } } @@ -72,18 +84,18 @@ func (mb *modulePool) Return(me ModuleExtractor) { // if context is canceled, do not return module extract if mb.ctx != nil && mb.ctx.Err() == nil { select { - case mb.modExtBuffer <- me: + case mb.modExtChan <- me: return default: // if buffer is full, discard module extract } } - me.Stop(true) + me.Stop(false) } func (mb *modulePool) Close() { - if mb.ctxCancel != nil { - mb.ctxCancel() + if mb.cancel != nil { + mb.cancel() } - close(mb.modExtBuffer) + close(mb.modExtChan) } diff --git a/internal/proxy/proxy_handler.go b/internal/proxy/proxy_handler.go index 547fe2e..a86b60d 100644 --- a/internal/proxy/proxy_handler.go +++ b/internal/proxy/proxy_handler.go @@ -41,7 +41,7 @@ func proxyHandler(ps *ProxyState, reqCtx *RequestContext) { event.Debug("Request log") }() - defer ps.metrics.MeasureProxyRequest(reqCtx, time.Now()) + defer ps.metrics.MeasureProxyRequest(reqCtx.ctx, reqCtx, time.Now()) var modExt ModuleExtractor if len(reqCtx.route.Modules) != 0 { @@ -53,7 +53,7 @@ func proxyHandler(ps *ProxyState, reqCtx *RequestContext) { } else { if modExt = modPool.Borrow(); modExt == nil { ps.metrics.MeasureModuleDuration( - reqCtx, "module_extract", runtimeStart, + reqCtx.ctx, reqCtx, "module_extract", runtimeStart, errors.New("error borrowing module"), ) ps.logger.Error("Error borrowing module") @@ -66,7 +66,7 @@ func proxyHandler(ps *ProxyState, reqCtx *RequestContext) { modExt.Start(reqCtx) defer modExt.Stop(true) ps.metrics.MeasureModuleDuration( - reqCtx, "module_extract", + reqCtx.ctx, reqCtx, "module_extract", runtimeStart, nil, ) } else { @@ -86,7 +86,7 @@ func handleServiceProxy(ps *ProxyState, reqCtx *RequestContext, modExt ModuleExt fetchUpstreamStart := time.Now() hostUrl, err := fetchUpstreamUrl(modExt.ModuleContext()) ps.metrics.MeasureModuleDuration( - reqCtx, "fetch_upstream", + reqCtx.ctx, reqCtx, "fetch_upstream", fetchUpstreamStart, err, ) if err != nil { @@ -147,7 +147,8 @@ func handleServiceProxy(ps *ProxyState, reqCtx *RequestContext, modExt ModuleExt resModifierStart := time.Now() err = responseModifier(modExt.ModuleContext(), res) ps.metrics.MeasureModuleDuration( - reqCtx, "response_modifier", + reqCtx.ctx, reqCtx, + "response_modifier", resModifierStart, err, ) if err != nil { @@ -164,7 +165,7 @@ func handleServiceProxy(ps *ProxyState, reqCtx *RequestContext, modExt ModuleExt }). ErrorHandler(func(w http.ResponseWriter, r *http.Request, reqErr error) { upstreamErr = reqErr - ps.logger.Debug("Error proxying request", + ps.logger.Error("Error proxying request", zap.String("error", reqErr.Error()), zap.String("route", reqCtx.route.Name), zap.String("service", reqCtx.route.Service.Name), @@ -178,7 +179,7 @@ func handleServiceProxy(ps *ProxyState, reqCtx *RequestContext, modExt ModuleExt errorHandlerStart := time.Now() err = errorHandler(modExt.ModuleContext(), reqErr) ps.metrics.MeasureModuleDuration( - reqCtx, "error_handler", + reqCtx.ctx, reqCtx, "error_handler", errorHandlerStart, err, ) if err != nil { @@ -193,12 +194,6 @@ func handleServiceProxy(ps *ProxyState, reqCtx *RequestContext, modExt ModuleExt } } if !reqCtx.rw.HeadersSent() && reqCtx.rw.BytesWritten() == 0 { - ps.logger.Error("Writing error response", - zap.String("error", reqErr.Error()), - zap.String("route", reqCtx.route.Name), - zap.String("service", reqCtx.route.Service.Name), - zap.String("namespace", reqCtx.route.Namespace.Name), - ) util.WriteStatusCodeError(reqCtx.rw, http.StatusBadGateway) } }) @@ -207,7 +202,8 @@ func handleServiceProxy(ps *ProxyState, reqCtx *RequestContext, modExt ModuleExt reqModifierStart := time.Now() err = requestModifier(modExt.ModuleContext()) ps.metrics.MeasureModuleDuration( - reqCtx, "request_modifier", + reqCtx.ctx, reqCtx, + "request_modifier", reqModifierStart, err, ) if err != nil { @@ -241,8 +237,10 @@ func handleServiceProxy(ps *ProxyState, reqCtx *RequestContext, modExt ModuleExt upstreamStart := time.Now() rp.ServeHTTP(reqCtx.rw, reqCtx.req) ps.metrics.MeasureUpstreamDuration( - reqCtx, upstreamStart, - upstreamUrl.String(), upstreamErr, + reqCtx.ctx, reqCtx, + upstreamStart, + upstreamUrl.String(), + upstreamErr, ) } @@ -252,7 +250,8 @@ func requestHandlerModule(ps *ProxyState, reqCtx *RequestContext, modExt ModuleE reqModifierStart := time.Now() err = requestModifier(modExt.ModuleContext()) ps.metrics.MeasureModuleDuration( - reqCtx, "request_modifier", + reqCtx.ctx, reqCtx, + "request_modifier", reqModifierStart, err, ) if err != nil { @@ -269,7 +268,8 @@ func requestHandlerModule(ps *ProxyState, reqCtx *RequestContext, modExt ModuleE requestHandlerStart := time.Now() err := requestHandler(modExt.ModuleContext()) defer ps.metrics.MeasureModuleDuration( - reqCtx, "request_handler", + reqCtx.ctx, reqCtx, + "request_handler", requestHandlerStart, err, ) if err != nil { @@ -283,7 +283,8 @@ func requestHandlerModule(ps *ProxyState, reqCtx *RequestContext, modExt ModuleE errorHandlerStart := time.Now() err = errorHandler(modExt.ModuleContext(), err) ps.metrics.MeasureModuleDuration( - reqCtx, "error_handler", + reqCtx.ctx, reqCtx, + "error_handler", errorHandlerStart, err, ) if err != nil { diff --git a/internal/proxy/proxy_handler_test.go b/internal/proxy/proxy_handler_test.go index de38c67..08e76ff 100644 --- a/internal/proxy/proxy_handler_test.go +++ b/internal/proxy/proxy_handler_test.go @@ -66,12 +66,12 @@ func TestProxyHandler_ReverseProxy(t *testing.T) { modBuf.On("Borrow").Return(modExt).Once() modBuf.On("Return", modExt).Return().Once() modBuf.On("Close").Return().Once() - reqCtxProvider.SetModulePool(modBuf) + reqCtxProvider.UpdateModulePool(modBuf) modPool := NewMockModulePool() modPool.On("Borrow").Return(modExt).Once() modPool.On("Return", modExt).Return().Once() - reqCtxProvider.SetModulePool(modPool) + reqCtxProvider.UpdateModulePool(modPool) ps.ProxyHandler(ps, reqCtx) wr.AssertExpectations(t) @@ -127,7 +127,7 @@ func TestProxyHandler_ProxyHandler(t *testing.T) { modPool := NewMockModulePool() modPool.On("Borrow").Return(modExt).Once() modPool.On("Return", modExt).Return().Once() - reqCtxProvider.SetModulePool(modPool) + reqCtxProvider.UpdateModulePool(modPool) reqCtx := reqCtxProvider.CreateRequestContext( context.Background(), wr, req, "/") @@ -180,7 +180,7 @@ func TestProxyHandler_ProxyHandlerError(t *testing.T) { modPool.On("Borrow").Return(modExt).Once() modPool.On("Return", modExt).Return().Once() reqCtxProvider := proxy.NewRequestContextProvider(rt, ps) - reqCtxProvider.SetModulePool(modPool) + reqCtxProvider.UpdateModulePool(modPool) reqCtx := reqCtxProvider.CreateRequestContext( context.Background(), wr, req, "/") ps.ProxyHandler(ps, reqCtx) diff --git a/internal/proxy/proxy_metrics.go b/internal/proxy/proxy_metrics.go index 3f7bd29..d07d767 100644 --- a/internal/proxy/proxy_metrics.go +++ b/internal/proxy/proxy_metrics.go @@ -60,7 +60,9 @@ func (pm *ProxyMetrics) Setup(config *config.DGateConfig) { } func (pm *ProxyMetrics) MeasureProxyRequest( - reqCtx *RequestContext, start time.Time, + ctx context.Context, + reqCtx *RequestContext, + start time.Time, ) { if pm.proxyDurInstrument == nil || pm.proxyCountInstrument == nil { return @@ -100,8 +102,8 @@ func (pm *ProxyMetrics) MeasureProxyRequest( } func (pm *ProxyMetrics) MeasureModuleDuration( - reqCtx *RequestContext, moduleFunc string, - start time.Time, err error, + ctx context.Context, reqCtx *RequestContext, + moduleFunc string, start time.Time, err error, ) { if pm.moduleDurInstrument == nil || pm.moduleRunCountInstrument == nil { return @@ -117,7 +119,7 @@ func (pm *ProxyMetrics) MeasureModuleDuration( attribute.String("pattern", reqCtx.pattern), attribute.String("host", reqCtx.req.Host), ) - pm.addError(moduleFunc, err, attrSet) + pm.addError(ctx, moduleFunc, err, attrSet) pm.moduleDurInstrument.Record(reqCtx.ctx, float64(elasped)/float64(time.Millisecond), @@ -128,8 +130,8 @@ func (pm *ProxyMetrics) MeasureModuleDuration( } func (pm *ProxyMetrics) MeasureUpstreamDuration( - reqCtx *RequestContext, start time.Time, - upstreamHost string, err error, + ctx context.Context, reqCtx *RequestContext, + start time.Time, upstreamHost string, err error, ) { if pm.upstreamDurInstrument == nil { return @@ -146,7 +148,7 @@ func (pm *ProxyMetrics) MeasureUpstreamDuration( attribute.String("service", reqCtx.route.Service.Name), attribute.String("upstream_host", upstreamHost), ) - pm.addError("upstream_request", err, attrSet) + pm.addError(ctx, "upstream_request", err, attrSet) pm.upstreamDurInstrument.Record(reqCtx.ctx, float64(elasped)/float64(time.Millisecond), @@ -154,7 +156,8 @@ func (pm *ProxyMetrics) MeasureUpstreamDuration( } func (pm *ProxyMetrics) MeasureNamespaceResolutionDuration( - start time.Time, host, namespace string, err error, + ctx context.Context, start time.Time, + host, namespace string, err error, ) { if pm.resolveNamespaceDurInstrument == nil { return @@ -164,7 +167,7 @@ func (pm *ProxyMetrics) MeasureNamespaceResolutionDuration( attribute.String("host", host), attribute.String("namespace", namespace), ) - pm.addError("namespace_resolution", err, attrSet) + pm.addError(ctx, "namespace_resolution", err, attrSet) pm.resolveNamespaceDurInstrument.Record(context.TODO(), float64(elasped)/float64(time.Microsecond), @@ -172,7 +175,8 @@ func (pm *ProxyMetrics) MeasureNamespaceResolutionDuration( } func (pm *ProxyMetrics) MeasureCertResolutionDuration( - start time.Time, host string, cache bool, err error, + ctx context.Context, start time.Time, + host string, cache bool, err error, ) { if pm.resolveCertDurInstrument == nil { return @@ -184,7 +188,7 @@ func (pm *ProxyMetrics) MeasureCertResolutionDuration( attribute.String("host", host), attribute.Bool("cache", cache), ) - pm.addError("cert_resolution", err, attrSet) + pm.addError(ctx, "cert_resolution", err, attrSet) pm.resolveCertDurInstrument.Record(context.TODO(), float64(elasped)/float64(time.Millisecond), @@ -192,6 +196,7 @@ func (pm *ProxyMetrics) MeasureCertResolutionDuration( } func (pm *ProxyMetrics) addError( + ctx context.Context, namespace string, err error, attrs ...attribute.Set, ) { @@ -210,5 +215,5 @@ func (pm *ProxyMetrics) addError( attrSets = append(attrSets, api.WithAttributeSet(attr)) } - pm.errorCountInstrument.Add(context.TODO(), 1, attrSets...) + pm.errorCountInstrument.Add(ctx, 1, attrSets...) } diff --git a/internal/proxy/proxy_state.go b/internal/proxy/proxy_state.go index 834c6a4..2bcda9b 100644 --- a/internal/proxy/proxy_state.go +++ b/internal/proxy/proxy_state.go @@ -396,9 +396,9 @@ func (ps *ProxyState) DynamicTLSConfig(certFile, keyFile string) *tls.Config { return &tls.Config{ GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { - if cert, err := ps.getDomainCertificate(info.ServerName); err != nil { + if cert, found, err := ps.getDomainCertificate(info.Context(), info.ServerName); err != nil { return nil, err - } else if cert == nil { + } else if !found { if fallbackCert != nil { return fallbackCert, nil } else { @@ -420,7 +420,9 @@ func loadCertFromFile(certFile, keyFile string) (*tls.Certificate, error) { return &cert, nil } -func (ps *ProxyState) getDomainCertificate(domain string) (*tls.Certificate, error) { +func (ps *ProxyState) getDomainCertificate( + ctx context.Context, domain string, +) (*tls.Certificate, bool, error) { start := time.Now() allowedDomains := ps.config.ProxyConfig.AllowedDomains domainAllowed := len(allowedDomains) == 0 @@ -430,7 +432,7 @@ func (ps *ProxyState) getDomainCertificate(domain string) (*tls.Certificate, err ps.logger.Error("Error checking domain match list", zap.Error(err), ) - return nil, err + return nil, false, err } domainAllowed = domainMatch } @@ -441,17 +443,19 @@ func (ps *ProxyState) getDomainCertificate(domain string) (*tls.Certificate, err ps.logger.Error("Error checking domain match list", zap.Error(err), ) - return nil, err + return nil, false, err } else if match && d.Cert != "" && d.Key != "" { var err error + var cached bool defer ps.metrics.MeasureCertResolutionDuration( - start, domain, false, err) - + ctx, start, domain,cached, err, + ) certBucket := ps.sharedCache.Bucket("certs") key := fmt.Sprintf("cert:%s:%s:%d", d.Namespace.Name, - d.Name, d.CreatedAt.UnixMilli()) + d.Name, d.UpdatedAt.Unix()) if cert, ok := certBucket.Get(key); ok { - return cert.(*tls.Certificate), nil + cached = true + return cert.(*tls.Certificate), true, nil } var serverCert tls.Certificate serverCert, err = tls.X509KeyPair( @@ -462,14 +466,14 @@ func (ps *ProxyState) getDomainCertificate(domain string) (*tls.Certificate, err zap.String("domain_name", d.Name), zap.String("namespace", d.Namespace.Name), ) - return nil, err + return nil, false, err } certBucket.Set(key, &serverCert) - return &serverCert, nil + return &serverCert, true, nil } } } - return nil, nil + return nil, false, nil } func (ps *ProxyState) initConfigResources(resources *config.DGateResources) error { @@ -610,7 +614,7 @@ func (ps *ProxyState) ServeHTTP(w http.ResponseWriter, r *http.Request) { err = nil } defer ps.metrics.MeasureNamespaceResolutionDuration( - start, host, ns.Name, err, + r.Context(), start, host, ns.Name, err, ) var ok bool if len(allowedDomains) > 0 { diff --git a/internal/proxy/proxy_state_test.go b/internal/proxy/proxy_state_test.go index f76cc2e..59922b1 100644 --- a/internal/proxy/proxy_state_test.go +++ b/internal/proxy/proxy_state_test.go @@ -48,7 +48,7 @@ func TestDynamicTLSConfig_DomainCertCache(t *testing.T) { } d := domains[0] key := fmt.Sprintf("cert:%s:%s:%d", d.Namespace.Name, - d.Name, d.CreatedAt.UnixMilli()) + d.Name, d.CreatedAt.Unix()) tlsConfig := ps.DynamicTLSConfig("", "") clientHello := &tls.ClientHelloInfo{ ServerName: "abc.test.com", diff --git a/internal/proxy/proxy_transport.go b/internal/proxy/proxy_transport.go index 397c010..1b4fb20 100644 --- a/internal/proxy/proxy_transport.go +++ b/internal/proxy/proxy_transport.go @@ -3,6 +3,7 @@ package proxy import ( "context" "crypto/tls" + "errors" "net" "net/http" @@ -10,28 +11,26 @@ import ( "golang.org/x/net/http2" ) +func validateAddress(c *config.DGateHttpTransportConfig, address string) error { + if c.DisablePrivateIPs { + ip, _, err := net.SplitHostPort(address) + if err != nil { + ip = address + } + if ipAddr := net.ParseIP(ip); ipAddr == nil { + return errors.New("could not parse IP: " + ip) + } else if ipAddr.IsLoopback() || ipAddr.IsPrivate() { + return errors.New("private IP address not allowed: " + ipAddr.String()) + } + } + return nil +} + func setupTranportsFromConfig( - c config.DGateHttpTransportConfig, + c *config.DGateHttpTransportConfig, modifyTransport func(*net.Dialer, *http.Transport), ) http.RoundTripper { t1 := http.DefaultTransport.(*http.Transport).Clone() - dailer := &net.Dialer{ - Timeout: c.DialTimeout, - KeepAlive: c.KeepAlive, - Resolver: &net.Resolver{ - PreferGo: c.DNSPreferGo, - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { - d := net.Dialer{ - Timeout: c.DNSTimeout, - } - return d.DialContext(ctx, network, c.DNSServer) - }, - }, - } - if t1.DisableKeepAlives { - dailer.KeepAlive = -1 - } - t1.DialContext = dailer.DialContext t1.MaxIdleConns = c.MaxIdleConns t1.IdleConnTimeout = c.IdleConnTimeout t1.TLSHandshakeTimeout = c.TLSHandshakeTimeout @@ -45,10 +44,38 @@ func setupTranportsFromConfig( t1.DisableCompression = c.DisableCompression t1.ForceAttemptHTTP2 = c.ForceAttemptHttp2 t1.ResponseHeaderTimeout = c.ResponseHeaderTimeout - if modifyTransport != nil { - modifyTransport(dailer, t1) + dailer := &net.Dialer{ + Timeout: c.DialTimeout, + KeepAlive: c.KeepAlive, } - + if t1.DisableKeepAlives { + dailer.KeepAlive = -1 + } + resolver := &net.Resolver{ + PreferGo: c.DNSPreferGo, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + if err := validateAddress(c, address); err != nil { + return nil, err + } + if c.DNSServer != "" { + address = c.DNSServer + } + return dailer.DialContext(ctx, network, address) + }, + } + dailer.Resolver = resolver + t1.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) { + conn, err := dailer.DialContext(ctx, network, address) + if err != nil { + return nil, err + } + if err := validateAddress(c, conn.RemoteAddr().String()); err != nil { + return nil, err + } + return conn, nil + } + t1.DialTLSContext = t1.DialContext + modifyTransport(dailer, t1) return newRoundTripper(t1) } diff --git a/internal/proxy/proxy_transport/proxy_transport.go b/internal/proxy/proxy_transport/proxy_transport.go index 1fef421..6c280c5 100644 --- a/internal/proxy/proxy_transport/proxy_transport.go +++ b/internal/proxy/proxy_transport/proxy_transport.go @@ -6,6 +6,8 @@ import ( "time" "errors" + + "github.com/dgate-io/dgate/internal/proxy/proxyerrors" ) type Builder interface { @@ -112,8 +114,13 @@ func (m *retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) defer cancel() } resp, err = m.transport.RoundTrip(req) + // Retry only on network errors or if the request is a PUT or POST if err == nil || req.Method == http.MethodPut || req.Method == http.MethodPost { break + } else if pxyErr := proxyerrors.GetProxyError(err); pxyErr != nil { + if !pxyErr.DisableRetry { + break + } } if m.retryTimeout != 0 { <-time.After(m.retryTimeout) diff --git a/internal/proxy/proxyerrors/proxyerrors.go b/internal/proxy/proxyerrors/proxyerrors.go new file mode 100644 index 0000000..3f48719 --- /dev/null +++ b/internal/proxy/proxyerrors/proxyerrors.go @@ -0,0 +1,27 @@ +package proxyerrors + +import "errors" + +type ProxyError struct { + DisableRetry bool + StatusCode int + Err error +} + +func NewProxyError(text string) error { + return &ProxyError{ + Err: errors.New(text), + } +} + +func (e *ProxyError) Error() string { + return e.Err.Error() +} + +func GetProxyError(err error) *ProxyError { + if err == nil { + return nil + } + pxyErr, _ := err.(*ProxyError) + return pxyErr +} diff --git a/internal/proxy/request_context.go b/internal/proxy/request_context.go index 0f01dd9..9589d48 100644 --- a/internal/proxy/request_context.go +++ b/internal/proxy/request_context.go @@ -16,6 +16,7 @@ type S string type RequestContextProvider struct { ctx context.Context + cancel context.CancelFunc route *spec.DGateRoute rpb reverse_proxy.Builder mtx *sync.Mutex @@ -42,7 +43,7 @@ func NewRequestContextProvider(route *spec.DGateRoute, ps *ProxyState) *RequestC if route.Service != nil { ctx = context.WithValue(ctx, spec.Name("service"), route.Service.Name) transport := setupTranportsFromConfig( - ps.config.ProxyConfig.Transport, + &ps.config.ProxyConfig.Transport, func(dialer *net.Dialer, t *http.Transport) { t.TLSClientConfig = &tls.Config{ InsecureSkipVerify: route.Service.TLSSkipVerify, @@ -68,18 +69,19 @@ func NewRequestContextProvider(route *spec.DGateRoute, ps *ProxyState) *RequestC route.Service.DisableQueryParams, ps.config.ProxyConfig.DisableXForwardedHeaders, ) - } + ctx, cancel := context.WithCancel(ctx) return &RequestContextProvider{ - ctx: ctx, - route: route, - rpb: rpb, - mtx: &sync.Mutex{}, + ctx: ctx, + cancel: cancel, + route: route, + rpb: rpb, + mtx: &sync.Mutex{}, } } -func (reqCtxProvider *RequestContextProvider) SetModulePool(mb ModulePool) { +func (reqCtxProvider *RequestContextProvider) UpdateModulePool(mb ModulePool) { reqCtxProvider.mtx.Lock() defer reqCtxProvider.mtx.Unlock() if reqCtxProvider.modBuf != nil { @@ -115,6 +117,16 @@ func (reqCtxProvider *RequestContextProvider) CreateRequestContext( } } +func (reqCtxProvider *RequestContextProvider) Close() { + reqCtxProvider.mtx.Lock() + defer reqCtxProvider.mtx.Unlock() + if reqCtxProvider.modBuf != nil { + reqCtxProvider.modBuf.Close() + reqCtxProvider.modBuf = nil + } + reqCtxProvider.cancel() +} + func (reqCtx *RequestContext) Context() context.Context { return reqCtx.ctx } diff --git a/performance-tests/long-perf-test.js b/performance-tests/long-perf-test.js index c835d21..4dc1c58 100644 --- a/performance-tests/long-perf-test.js +++ b/performance-tests/long-perf-test.js @@ -42,25 +42,6 @@ export let options = { env: { DGATE_PATH: "/svctest?wait=30ms" }, gracefulStop: '5s', }, - test_server_direct: { - executor: 'constant-vus', - vus: n, - duration: inc + 'm', - startTime: (curWait += inc) + 'm', - exec: 'dgatePath', - env: { DGATE_PATH: ":8888/direct" }, - gracefulStop: '5s', - }, - test_server_direct_wait: { - executor: 'constant-vus', - vus: n*5, - duration: inc + 'm', - startTime: (curWait += inc) + 'm', - exec: 'dgatePath', - env: { DGATE_PATH: ":8888/svctest?wait=30ms" }, - gracefulStop: '5s', - }, - }, // test_server_direct: { // executor: 'constant-vus', // vus: n, @@ -72,13 +53,14 @@ export let options = { // }, // test_server_direct_wait: { // executor: 'constant-vus', - // vus: n*3, + // vus: n*5, // duration: inc + 'm', // startTime: (curWait += inc) + 'm', // exec: 'dgatePath', // env: { DGATE_PATH: ":8888/svctest?wait=30ms" }, // gracefulStop: '5s', // }, + }, discardResponseBodies: true, }; diff --git a/pkg/resources/resource_manager.go b/pkg/resources/resource_manager.go index cee0d69..66bc76f 100644 --- a/pkg/resources/resource_manager.go +++ b/pkg/resources/resource_manager.go @@ -23,6 +23,9 @@ type ResourceManager struct { secrets avlTreeLinker[spec.DGateSecret] collections avlTreeLinker[spec.DGateCollection] mutex *keylock.KeyLock + + // sorting can be expensive, so we cache the sorted list of domains + priorityDomainCache []*spec.DGateDomain } type Options func(*ResourceManager) @@ -481,12 +484,18 @@ func (rm *ResourceManager) DomainCountEquals(target int) bool { // GetDomainsByPriority returns a list of all domains sorted by priority and name func (rm *ResourceManager) GetDomainsByPriority() []*spec.DGateDomain { defer rm.mutex.RLockMain()() + if rm.priorityDomainCache != nil { + return rm.priorityDomainCache + } + return rm.domainsByPriority() +} + +func (rm *ResourceManager) domainsByPriority() []*spec.DGateDomain { var domains []*spec.DGateDomain rm.domains.Each(func(_ string, lk *linker.Link[string, safe.Ref[spec.DGateDomain]]) bool { domains = append(domains, lk.Item().Read()) return true }) - sort.Slice(domains, func(i, j int) bool { d1, d2 := domains[j], domains[i] if d1.Priority == d2.Priority { @@ -494,7 +503,7 @@ func (rm *ResourceManager) GetDomainsByPriority() []*spec.DGateDomain { } return d1.Priority < d2.Priority }) - + rm.priorityDomainCache = domains return domains } @@ -519,6 +528,7 @@ func (rm *ResourceManager) AddDomain(domain *spec.Domain) (*spec.DGateDomain, er if err != nil { return nil, err } + defer func() { rm.priorityDomainCache = nil }() if dmLk, ok := rm.domains.Find(domain.Name + "/" + domain.NamespaceName); ok { dmLk.Item().Replace(dm) return dm, nil @@ -553,6 +563,7 @@ func (rm *ResourceManager) transformDomain(domain *spec.Domain) (*spec.DGateDoma func (rm *ResourceManager) RemoveDomain(name, namespace string) error { defer rm.mutex.Lock(namespace)() + defer func() { rm.priorityDomainCache = nil }() if dmLk, ok := rm.domains.Find(name + "/" + namespace); ok { if nsLk, ok := rm.namespaces.Find(namespace); ok { nsLk.UnlinkOneMany("domains", name) diff --git a/pkg/util/iplist/iplist.go b/pkg/util/iplist/iplist.go index d0bf119..fb5da35 100644 --- a/pkg/util/iplist/iplist.go +++ b/pkg/util/iplist/iplist.go @@ -59,6 +59,7 @@ func (l *IPList) AddIPString(ipstr string) error { func (l *IPList) Len() int { return l.v4s.Len() + l.v6s.Len() } + func (l *IPList) Contains(ipstr string) (bool, error) { if l.Len() == 0 { return false, nil