diff --git a/functional-tests/admin_tests/modify_response.ts b/functional-tests/admin_tests/modify_response.ts index 2c44ede..01e2357 100644 --- a/functional-tests/admin_tests/modify_response.ts +++ b/functional-tests/admin_tests/modify_response.ts @@ -4,9 +4,9 @@ export const responseModifier = async (ctx) => { const uuidData = await res.json(); console.log("INFO uuid", JSON.stringify(uuidData)); const resp = ctx.upstream(); - const results = await resp.getJson(); + const results = await resp.readJson(); results.data.uuid = uuidData.uuid; - return resp.setStatus(200).setJson(results); + return resp.status(200).writeJson(results); }); }; diff --git a/functional-tests/admin_tests/multi_module_test.sh b/functional-tests/admin_tests/multi_module_test.sh new file mode 100755 index 0000000..f5071f9 --- /dev/null +++ b/functional-tests/admin_tests/multi_module_test.sh @@ -0,0 +1,77 @@ +#!/bin/bash + +set -eo xtrace + +ADMIN_URL=${ADMIN_URL:-"http://localhost:9080"} +PROXY_URL=${PROXY_URL:-"http://localhost"} + +DIR="$( cd "$( dirname "$0" )" && pwd )" + +export DGATE_ADMIN_API=$ADMIN_URL + +dgate-cli namespace create \ + name=multimod-test-ns + +dgate-cli domain create \ + name=multimod-test-dm \ + patterns:='["multimod-test.com"]' \ + namespace=multimod-test-ns + +MOD_B64=$(base64 <<-END +export { + requestModifier, +} from './multimod2'; +import { + responseModifier as resMod, +} from './multimod2'; +const responseModifier = async (ctx) => { + console.log('responseModifier executed from multimod1') + return resMod(ctx); +}; +END + +) + +dgate-cli module create \ + name=multimod1 \ + payload="$MOD_B64" \ + namespace=multimod-test-ns + +MOD_B64=$(base64 <<-END +const reqMod = (ctx) => ctx.request().writeJson({a: 1}); +const resMod = async (ctx) => ctx.upstream()?.writeJson({ + upstream_body: await ctx.upstream()?.readJson(), + upstream_headers: ctx.upstream()?.headers, + upsteam_status: ctx.upstream()?.statusCode, + upstream_statusText: ctx.upstream()?.statusText, +}); +export { + reqMod as requestModifier, + resMod as responseModifier, +}; +END + +) + +dgate-cli module create name=multimod2 \ + payload="$MOD_B64" namespace=multimod-test-ns + +URL='http://localhost:8888' +dgate-cli service create name=base_svc \ + urls="$URL/a","$URL/b","$URL/c" \ + namespace=multimod-test-ns + +dgate-cli route create name=base_rt \ + paths=/,/multimod-test \ + methods:='["GET"]' \ + modules=multimod1,multimod2 \ + service=base_svc \ + stripPath:=true \ + preserveHost:=true \ + namespace=multimod-test-ns + + +curl -s --fail-with-body ${PROXY_URL}/ -H Host:multimod-test.com +curl -s --fail-with-body ${PROXY_URL}/multimod-test -H Host:multimod-test.com + +echo "Multi Module Test Passed" \ No newline at end of file diff --git a/internal/admin/admin_api.go b/internal/admin/admin_api.go index 4bb79fb..ef46f56 100644 --- a/internal/admin/admin_api.go +++ b/internal/admin/admin_api.go @@ -2,6 +2,7 @@ package admin import ( "fmt" + "io" "log" "net/http" "os" @@ -89,8 +90,11 @@ func StartAdminAPI(conf *config.DGateConfig, proxyState *proxy.ProxyState) { } respMap["method"] = r.Method respMap["path"] = r.URL.String() - respMap["remote_addr"] = r.RemoteAddr + if body, err := io.ReadAll(r.Body); err == nil { + respMap["body"] = string(body) + } respMap["host"] = r.Host + respMap["remote_addr"] = r.RemoteAddr respMap["req_headers"] = r.Header if conf.TestServerConfig.EnableEnvVars { respMap["env"] = os.Environ() diff --git a/internal/admin/admin_raft.go b/internal/admin/admin_raft.go index 0d29699..73a1893 100644 --- a/internal/admin/admin_raft.go +++ b/internal/admin/admin_raft.go @@ -77,25 +77,25 @@ func setupRaft(conf *config.DGateConfig, server *chi.Mux, ps *proxy.ProxyState) panic(fmt.Errorf("invalid scheme: %s", adminConfig.Replication.AdvertScheme)) } - trans := rafthttp.NewHTTPTransport(address, - http.DefaultClient, raftHttpLogger, - adminConfig.Replication.AdvertScheme+ - "://(address)/raft", + transport := rafthttp.NewHTTPTransport( + address, http.DefaultClient, raftHttpLogger, + adminConfig.Replication.AdvertScheme+"://(address)/raft", + ) + raftNode, err := raft.NewRaft( + raftConfig, newDGateAdminFSM(ps), + lstore, sstore, snapstore, transport, ) - raftNode, err := raft.NewRaft(raftConfig, newDGateAdminFSM(ps), - lstore, sstore, snapstore, trans) if err != nil { panic(err) } ps.SetupRaft(raftNode, raftConfig) // Setup raft handler - server.Handle("/raft/*", trans) + server.Handle("/raft/*", transport) raftAdminLogger := ps.Logger().With().Str("component", "raftAdmin").Logger() raftAdmin := raftadmin.NewRaftAdminHTTPServer( - raftNode, raftAdminLogger, - []raft.ServerAddress{address}, + raftNode, raftAdminLogger, []raft.ServerAddress{address}, ) // Setup handler raft diff --git a/internal/admin/routes/module_routes.go b/internal/admin/routes/module_routes.go index 02f09b6..99f7447 100644 --- a/internal/admin/routes/module_routes.go +++ b/internal/admin/routes/module_routes.go @@ -120,6 +120,6 @@ func ConfigureModuleAPI(server chi.Router, proxyState *proxy.ProxyState, appConf util.JsonError(w, http.StatusNotFound, "module not found") return } - util.JsonResponse(w, http.StatusOK, mod) + util.JsonResponse(w, http.StatusOK, spec.TransformDGateModule(mod)) }) } diff --git a/internal/proxy/change_log.go b/internal/proxy/change_log.go index 2e72599..b4e7474 100644 --- a/internal/proxy/change_log.go +++ b/internal/proxy/change_log.go @@ -14,11 +14,11 @@ import ( ) // processChangeLog - processes a change log and applies the change to the proxy state -func (ps *ProxyState) processChangeLog( - cl *spec.ChangeLog, reload, store bool, -) (err error) { +func (ps *ProxyState) processChangeLog(cl *spec.ChangeLog, reload, store bool) (err error) { if cl == nil { - cl = &spec.ChangeLog{Cmd: spec.NoopCommand} + cl = &spec.ChangeLog{ + Cmd: spec.NoopCommand, + } } else if !cl.Cmd.IsNoop() { switch cl.Cmd.Resource() { case spec.Namespaces: @@ -81,7 +81,7 @@ func (ps *ProxyState) processChangeLog( err = fmt.Errorf("unknown command: %s", cl.Cmd) } if err != nil { - ps.logger.Err(err).Msg("error processing change log") + ps.logger.Err(err).Msg("decoding or processing change log") return } } @@ -273,11 +273,7 @@ func (ps *ProxyState) applyChange(changeLog *spec.ChangeLog) <-chan error { return done } -func (ps *ProxyState) rollbackChange(changeLog *spec.ChangeLog) { - panic("not implemented") -} - -func (ps *ProxyState) restoreFromChangeLogs() error { +func (ps *ProxyState) restoreFromChangeLogs(directApply bool) error { // restore state change logs logs, err := ps.store.FetchChangeLogs() if err != nil { @@ -294,12 +290,23 @@ func (ps *ProxyState) restoreFromChangeLogs() error { Interface("changeLog: "+cl.Name, cl).Msgf("restoring change log index: %d", i) err = ps.processChangeLog(cl, false, false) if err != nil { + if ps.config.Debug { + ps.logger.Err(err). + Str("namespace", cl.Namespace). + Msg("error restorng from change logs") + continue + } return err } } - - if err = ps.processChangeLog(nil, true, false); err != nil { - return err + if !directApply { + if err = ps.processChangeLog(nil, true, false); err != nil { + return err + } + } else { + if err = ps.reconfigureState(false, nil); err != nil { + return nil + } } // TODO: optionally compact change logs through a flag in config? @@ -327,13 +334,11 @@ func (ps *ProxyState) compactChangeLogs(logs []*spec.ChangeLog) (int, error) { } /* -compactChangeLogsRemoveList - compacts a list of change logs by removing redundant logs - -TODO: perhaps add flag for compacting change logs on startup (mark as experimental) +compactChangeLogsRemoveList - compacts a list of change logs by removing redundant logs. compaction rules: -- if an add command is followed by a delete command with matching keys, remove both commands -- if an add command is followed by another add command with matching keys, remove the first add command + - if an add command is followed by a delete command with matching keys, remove both commands + - if an add command is followed by another add command with matching keys, remove the first add command */ func compactChangeLogsRemoveList(logger *zerolog.Logger, logs []*spec.ChangeLog) []*spec.ChangeLog { removeList := make([]*spec.ChangeLog, 0) @@ -348,11 +353,13 @@ START: if prevLog.Cmd.IsNoop() { removeList = append(removeList, prevLog) logs = append(logs[:i-1], logs[i:]...) - goto START + continue } + commonResource := prevLog.Cmd.Resource() == curLog.Cmd.Resource() if prevLog.Cmd.Action() == spec.Add && curLog.Cmd.Action() == spec.Delete && commonResource { - // Rule 1: if an add command is followed by a delete command with matching keys, remove both commands + // Rule 1: if an add command is followed by a delete + // command with matching keys, remove both commands if prevLog.Name == curLog.Name && prevLog.Namespace == curLog.Namespace { removeList = append(removeList, prevLog, curLog) logs = append(logs[:i-1], logs[i+1:]...) @@ -362,7 +369,8 @@ START: commonAction := prevLog.Cmd.Action() == curLog.Cmd.Action() if prevLog.Cmd.Action() == spec.Add && commonAction && commonResource { - // Rule 2: if an add command is followed by another add command with matching keys, remove the first add command + // Rule 2: if an add command is followed by another add + // command with matching keys, remove the first add command if prevLog.Name == curLog.Name && prevLog.Namespace == curLog.Namespace { removeList = append(removeList, prevLog) logs = append(logs[:i-1], logs[i:]...) diff --git a/internal/proxy/dynamic_proxy.go b/internal/proxy/dynamic_proxy.go index 2b6badf..d805ce8 100644 --- a/internal/proxy/dynamic_proxy.go +++ b/internal/proxy/dynamic_proxy.go @@ -14,7 +14,6 @@ import ( "github.com/dgate-io/dgate/pkg/modules/extractors" "github.com/dgate-io/dgate/pkg/spec" "github.com/dgate-io/dgate/pkg/typescript" - "github.com/dgate-io/dgate/pkg/util/sliceutil" "github.com/dgate-io/dgate/pkg/util/tree/avl" "github.com/dop251/goja" "github.com/rs/zerolog" @@ -23,7 +22,7 @@ import ( "golang.org/x/sync/errgroup" ) -func (state *ProxyState) reconfigureState(init bool, log *spec.ChangeLog) error { +func (state *ProxyState) reconfigureState(init bool, _ *spec.ChangeLog) error { start := time.Now() if err := state.setupModules(); err != nil { return err @@ -31,14 +30,16 @@ func (state *ProxyState) reconfigureState(init bool, log *spec.ChangeLog) error if err := state.setupRoutes(); err != nil { return err } - if !init && log != nil { + if !init { state.logger.Debug().Msgf( "State reloaded in %s", - time.Since(start)) - } else if init { + time.Since(start), + ) + } else { state.logger.Info().Msgf( "State initialized in %s", - time.Since(start)) + time.Since(start), + ) } return nil } @@ -46,57 +47,49 @@ func (state *ProxyState) reconfigureState(init bool, log *spec.ChangeLog) error func (ps *ProxyState) setupModules() error { ps.logger.Debug().Msg("Setting up modules") eg, _ := errgroup.WithContext(context.TODO()) - newModPrograms := avl.NewTree[string, *goja.Program]() for _, route := range ps.rm.GetRoutes() { route := route - for _, mod := range route.Modules { - mod := mod - eg.Go(func() error { - var ( - err error - program *goja.Program - modPayload string = mod.Payload - ) - start := time.Now() - if mod.Type == spec.ModuleTypeTypescript { - modPayload, err = typescript.Transpile(modPayload) - if err != nil { - ps.logger.Err(err).Msg("Error transpiling module: " + mod.Name) - return err - } - } - if mod.Type == spec.ModuleTypeJavascript || mod.Type == spec.ModuleTypeTypescript { - program, err = goja.Compile(mod.Name+".js", modPayload, true) - if err != nil { - ps.logger.Err(err).Msg("Error compiling module: " + mod.Name) - return err - } - } else { - return errors.New("invalid module type: " + mod.Type.String()) + if len(route.Modules) > 0 { + mod := route.Modules[0] + var ( + err error + program *goja.Program + modPayload string = mod.Payload + ) + start := time.Now() + if mod.Type == spec.ModuleTypeTypescript { + if modPayload, err = typescript.Transpile(modPayload); err != nil { + ps.logger.Err(err).Msg("Error transpiling module: " + mod.Name) + return err } - - testRtCtx := NewRuntimeContext(ps, route, mod) - defer testRtCtx.Clean() - err = extractors.SetupModuleEventLoop(ps.printer, testRtCtx) - if err != nil { - ps.logger.Err(err). - Msgf("Error applying module '%s' changes", mod.Name) + } + if mod.Type == spec.ModuleTypeJavascript || mod.Type == spec.ModuleTypeTypescript { + if program, err = goja.Compile(mod.Name, modPayload, true); err != nil { + ps.logger.Err(err).Msg("Error compiling module: " + mod.Name) return err } - newModPrograms.Insert(mod.Name+"/"+mod.Namespace.Name, program) - elapsed := time.Since(start) - ps.logger.Debug(). - Msgf("Module '%s/%s' changed applied in %s", mod.Name, mod.Namespace.Name, elapsed) - return nil - }) + } else { + return errors.New("invalid module type: " + mod.Type.String()) + } + + testRtCtx := NewRuntimeContext(ps, route, mod) + defer testRtCtx.Clean() + err = extractors.SetupModuleEventLoop(ps.printer, testRtCtx) + if err != nil { + ps.logger.Err(err). + Msgf("Error applying module '%s' changes", mod.Name) + return err + } + ps.modPrograms.Insert(mod.Name+"/"+mod.Namespace.Name, program) + elapsed := time.Since(start) + ps.logger.Debug(). + Msgf("Module '%s/%s' changed applied in %s", mod.Name, mod.Namespace.Name, elapsed) + return nil } } if err := eg.Wait(); err != nil { ps.logger.Err(err).Msg("Error setting up modules") return err - } else { - ps.modPrograms = newModPrograms - } return nil } @@ -159,51 +152,53 @@ func (ps *ProxyState) setupRoutes() (err error) { } func (ps *ProxyState) createModuleExtractorFunc(r *spec.DGateRoute) ModuleExtractorFunc { - return func(reqCtx *RequestContextProvider) ModuleExtractor { - programs := sliceutil.SliceMapper(r.Modules, func(m *spec.DGateModule) *goja.Program { - program, ok := ps.modPrograms.Find(m.Name + "/" + r.Namespace.Name) - if !ok { - ps.logger.Error().Msg("Error getting module program: invalid state") - panic("Error getting module program: invalid state") - } - return program - }) - rtCtx := NewRuntimeContext(ps, r, r.Modules...) - if err := extractors.SetupModuleEventLoop(ps.printer, rtCtx, programs...); err != nil { - ps.logger.Err(err).Msg("Error creating runtime for route: " + reqCtx.route.Name) - return nil + return func(reqCtx *RequestContextProvider) (_ ModuleExtractor, err error) { + if len(r.Modules) == 0 { + return nil, fmt.Errorf("no modules found for route: %s/%s", r.Name, r.Namespace.Name) + } + // TODO: Perhaps have some entrypoint flag to determine which module to use + m := r.Modules[0] + if program, ok := ps.modPrograms.Find(m.Name + "/" + r.Namespace.Name); !ok { + ps.logger.Error().Msg("Error getting module program: invalid state") + return nil, fmt.Errorf("cannot find module program: %s/%s", m.Name, r.Namespace.Name) } else { - loop := rtCtx.EventLoop() - errorHandler, err := extractors.ExtractErrorHandlerFunction(loop) - if err != nil { - ps.logger.Err(err).Msg("Error extracting error handler function") - return nil - } - fetchUpstream, err := extractors.ExtractFetchUpstreamFunction(loop) - if err != nil { - ps.logger.Err(err).Msg("Error extracting fetch upstream function") - return nil - } - reqModifier, err := extractors.ExtractRequestModifierFunction(loop) - if err != nil { - ps.logger.Err(err).Msg("Error extracting request modifier function") - return nil - } - resModifier, err := extractors.ExtractResponseModifierFunction(loop) - if err != nil { - ps.logger.Err(err).Msg("Error extracting response modifier function") - return nil - } - reqHandler, err := extractors.ExtractRequestHandlerFunction(loop) - if err != nil { - ps.logger.Err(err).Msg("Error extracting request handler function") - return nil + rtCtx := NewRuntimeContext(ps, r, r.Modules...) + if err := extractors.SetupModuleEventLoop(ps.printer, rtCtx, program); err != nil { + ps.logger.Err(err).Msg("Error creating runtime for route: " + reqCtx.route.Name) + return nil, err + } else { + loop := rtCtx.EventLoop() + errorHandler, err := extractors.ExtractErrorHandlerFunction(loop) + if err != nil { + ps.logger.Err(err).Msg("Error extracting error handler function") + return nil, err + } + fetchUpstream, err := extractors.ExtractFetchUpstreamFunction(loop) + if err != nil { + ps.logger.Err(err).Msg("Error extracting fetch upstream function") + return nil, err + } + reqModifier, err := extractors.ExtractRequestModifierFunction(loop) + if err != nil { + ps.logger.Err(err).Msg("Error extracting request modifier function") + return nil, err + } + resModifier, err := extractors.ExtractResponseModifierFunction(loop) + if err != nil { + ps.logger.Err(err).Msg("Error extracting response modifier function") + return nil, err + } + reqHandler, err := extractors.ExtractRequestHandlerFunction(loop) + if err != nil { + ps.logger.Err(err).Msg("Error extracting request handler function") + return nil, err + } + return NewModuleExtractor( + rtCtx, fetchUpstream, + reqModifier, resModifier, + errorHandler, reqHandler, + ), nil } - return NewModuleExtractor( - rtCtx, fetchUpstream, - reqModifier, resModifier, - errorHandler, reqHandler, - ) } } } @@ -219,11 +214,22 @@ func (ps *ProxyState) startChangeLoop() { for { log := <-ps.changeChan - if log.Cmd == spec.StopCommand { + switch log.Cmd { + case spec.ShutdownCommand: ps.logger.Warn(). - Msg("Stop command received, closing change loop") + Msg("Shutdown command received, closing change loop") log.PushError(nil) return + case spec.RestartCommand: + ps.logger.Warn(). + Msg("Restart command received, not supported") + // ps.logger.Warn(). + // Msg("Restart command received, restarting state") + // go ps.RestartState(func(err error) { + // ps.logger.Err(err). + // Msg("Error restarting state") + // os.Exit(1) + // }) } func() { @@ -234,7 +240,13 @@ func (ps *ProxyState) startChangeLoop() { if log.PushError(err); err != nil { ps.logger.Err(err). Msgf("Error reconfiguring state @namespace:%s", log.Namespace) - // ps.rollbackChange(log) + go ps.RestartState(func(err error) { + ps.logger.Err(err). + Msg("Error restarting state, exiting") + ps.changeChan <- &spec.ChangeLog{ + Cmd: spec.ShutdownCommand, + } + }) } }() } @@ -330,7 +342,7 @@ func (ps *ProxyState) Start() (err error) { } if !ps.replicationEnabled { - if err = ps.restoreFromChangeLogs(); err != nil { + if err = ps.restoreFromChangeLogs(false); err != nil { return err } } @@ -340,7 +352,7 @@ func (ps *ProxyState) Start() (err error) { func (ps *ProxyState) Stop() { cl := &spec.ChangeLog{ - Cmd: spec.StopCommand, + Cmd: spec.ShutdownCommand, } done := make(chan error, 1) cl.SetErrorChan(done) diff --git a/internal/proxy/module_executor.go b/internal/proxy/module_executor.go index c6ce81f..229230e 100644 --- a/internal/proxy/module_executor.go +++ b/internal/proxy/module_executor.go @@ -2,7 +2,6 @@ package proxy import ( "context" - "errors" ) type ModulePool interface { @@ -18,13 +17,13 @@ type modulePool struct { ctxCancel context.CancelFunc ctx context.Context - createModuleExtract func() ModuleExtractor + createModuleExtract func() (ModuleExtractor, error) } func NewModulePool( minBuffers, maxBuffers int, reqCtxProvider *RequestContextProvider, - createModExts func(*RequestContextProvider) ModuleExtractor, + createModExts ModuleExtractorFunc, ) (ModulePool, error) { if minBuffers < 1 { panic("module concurrency must be greater than 0") @@ -33,16 +32,15 @@ func NewModulePool( panic("maxBuffers must be greater than minBuffers") } - me := createModExts(reqCtxProvider) - if me == nil { - return nil, errors.New("could not load moduleExtract") + if _, err := createModExts(reqCtxProvider); err != nil { + return nil, err } mb := &modulePool{ min: minBuffers, max: maxBuffers, modExtBuffer: make(chan ModuleExtractor, maxBuffers), } - mb.createModuleExtract = func() ModuleExtractor { + mb.createModuleExtract = func() (ModuleExtractor, error) { return createModExts(reqCtxProvider) } mb.ctx, mb.ctxCancel = context.WithCancel(reqCtxProvider.ctx) @@ -53,13 +51,19 @@ func (mb *modulePool) Borrow() ModuleExtractor { if mb == nil || mb.ctx == nil || mb.ctx.Err() != nil { return nil } - var me ModuleExtractor + var ( + me ModuleExtractor + err error + ) select { case me = <-mb.modExtBuffer: break // NOTE: important for performance default: - me = mb.createModuleExtract() + me, err = mb.createModuleExtract() + if err != nil { + return nil + } } return me } diff --git a/internal/proxy/module_extractor.go b/internal/proxy/module_extractor.go index f3f344c..63ef3c0 100644 --- a/internal/proxy/module_extractor.go +++ b/internal/proxy/module_extractor.go @@ -100,8 +100,11 @@ func (me *moduleExtract) RequestHandlerFunc() (extractors.RequestHandlerFunc, bo return me.requestHandler, me.requestHandler != nil } -func NewEmptyModuleExtractor() ModuleExtractor { - return &moduleExtract{} +func NewDefaultModuleExtractor() ModuleExtractor { + return &moduleExtract{ + fetchUpstreamUrl: extractors.DefaultFetchUpstreamFunction(), + errorHandler: extractors.DefaultErrorHandlerFunction(), + } } -type ModuleExtractorFunc func(*RequestContextProvider) ModuleExtractor +type ModuleExtractorFunc func(*RequestContextProvider) (ModuleExtractor, error) diff --git a/internal/proxy/proxy_handler.go b/internal/proxy/proxy_handler.go index 34334e8..1ac31be 100644 --- a/internal/proxy/proxy_handler.go +++ b/internal/proxy/proxy_handler.go @@ -59,7 +59,7 @@ func proxyHandler(ps *ProxyState, reqCtx *RequestContext) { runtimeStart, nil, ) } else { - modExt = NewEmptyModuleExtractor() + modExt = NewDefaultModuleExtractor() } if reqCtx.route.Service != nil { @@ -192,7 +192,6 @@ func handleServiceProxy(ps *ProxyState, reqCtx *RequestContext, modExt ModuleExt func requestHandlerModule(ps *ProxyState, reqCtx *RequestContext, modExt ModuleExtractor) { var err error if requestModifier, ok := modExt.RequestModifierFunc(); ok { - // extract request modifier function from module reqModifierStart := time.Now() err = requestModifier(modExt.ModuleContext()) ps.metrics.MeasureModuleDuration( diff --git a/internal/proxy/proxy_state.go b/internal/proxy/proxy_state.go index 5785a58..94491a1 100644 --- a/internal/proxy/proxy_state.go +++ b/internal/proxy/proxy_state.go @@ -128,22 +128,29 @@ func NewProxyState(conf *config.DGateConfig) *ProxyState { WithComponentLogger("proxy_store"), WithDefaultLevel(zerolog.InfoLevel), ) + schedulerLogger := Logger(&logger, + WithComponentLogger("scheduler"), + WithDefaultLevel(zerolog.InfoLevel), + ) replicationEnabled := false if conf.AdminConfig != nil && conf.AdminConfig.Replication != nil { replicationEnabled = true } state := &ProxyState{ - version: "unknown", - startTime: time.Now(), - raftReady: atomic.Bool{}, - logger: logger, - debugMode: conf.Debug, - config: conf, - metrics: NewProxyMetrics(), - printer: printer, - routers: avl.NewTree[string, *router.DynamicRouter](), - changeChan: make(chan *spec.ChangeLog, 1), - rm: resources.NewManager(opt), + version: "unknown", + startTime: time.Now(), + raftReady: atomic.Bool{}, + logger: logger, + debugMode: conf.Debug, + config: conf, + metrics: NewProxyMetrics(), + printer: printer, + routers: avl.NewTree[string, *router.DynamicRouter](), + changeChan: make(chan *spec.ChangeLog, 1), + rm: resources.NewManager(opt), + skdr: scheduler.New(scheduler.Options{ + Logger: schedulerLogger, + }), providers: avl.NewTree[string, *RequestContextProvider](), modPrograms: avl.NewTree[string, *goja.Program](), proxyLock: new(sync.RWMutex), @@ -298,6 +305,37 @@ func (ps *ProxyState) SharedCache() cache.TCache { return ps.sharedCache } +// RestartState - restart state clears the state and reloads the configuration +// this is useful for rollbacks when broken changes are made. +func (ps *ProxyState) RestartState(errFn func(error)) { + ps.proxyLock.Lock() + defer ps.proxyLock.Unlock() + + ps.rm.Empty() + ps.modPrograms.Clear() + ps.providers.Clear() + ps.routers.Clear() + ps.sharedCache.Clear() + ps.Scheduler().Stop() + if err := ps.initConfigResources(ps.config.ProxyConfig.InitResources); err != nil { + errFn(err) + return + } + if ps.replicationEnabled { + raft := ps.Raft() + err := raft.ReloadConfig(raft.ReloadableConfig()) + if err != nil { + errFn(err) + return + } + } + if err := ps.restoreFromChangeLogs(true); err != nil { + errFn(err) + return + } + ps.logger.Info().Msg("State successfully restarted") +} + // ReloadState - reload state checks the change logs to see if a reload is required, // specifying check as false skips this step and automatically reloads func (ps *ProxyState) ReloadState(check bool, logs ...*spec.ChangeLog) error { @@ -319,7 +357,7 @@ func (ps *ProxyState) ReloadState(check bool, logs ...*spec.ChangeLog) error { func (ps *ProxyState) ProcessChangeLog(log *spec.ChangeLog, reload bool) error { err := ps.processChangeLog(log, reload, !ps.replicationEnabled) if err != nil { - ps.logger.Error().Err(err).Msg("error processing change log") + ps.logger.Error().Err(err).Msg("processing error") } return err } diff --git a/internal/proxy/runtime_context.go b/internal/proxy/runtime_context.go index a9e3bcf..0c901b8 100644 --- a/internal/proxy/runtime_context.go +++ b/internal/proxy/runtime_context.go @@ -2,12 +2,17 @@ package proxy import ( "context" + "errors" + "strings" + "time" "github.com/dgate-io/dgate/pkg/eventloop" "github.com/dgate-io/dgate/pkg/modules" "github.com/dgate-io/dgate/pkg/resources" "github.com/dgate-io/dgate/pkg/spec" + "github.com/dgate-io/dgate/pkg/typescript" "github.com/dop251/goja" + "github.com/dop251/goja_nodejs/require" ) // RuntimeContext is the context for the runtime. one per request @@ -32,40 +37,36 @@ func NewRuntimeContext( route: spec.TransformDGateRoute(route), } - // TODO: setup module import logic - // sort.Slice(rtCtx.modules, func(i, j int) bool { - // return rtCtx.modules[i].Name < rtCtx.modules[j].Name - // }) - // reg := require.NewRegistryWithLoader(func(path string) ([]byte, error) { - // requireMod := strings.Replace(path, "node_modules/", "", 1) - // // 'https://' - requires network permissions and must be enabled in the config - // // 'file://' - requires file system permissions and must be enabled in the config - // // 'module://' - requires a module lookup and module permissions - // if mod, ok := findInSortedWith(rtCtx.modules, requireMod, - // func(m *spec.Module) string { return m.Name }); !ok { - // return nil, errors.New(requireMod + " not found") - // } else { - // if mod.Type == spec.ModuleTypeJavascript { - // return []byte(mod.Payload), nil - // } - // var err error - // var key string - // transpileBucket := proxyState.sharedCache.Bucket("ts-transpile") - // if key, err = HashString(0, mod.Payload); err == nil { - // if code, ok := transpileBucket.Get(key); ok { - // return code.([]byte), nil - // } - // } - // payload, err := typescript.Transpile(mod.Payload) - // if err != nil { - // return nil, err - // } - // transpileBucket.SetWithTTL(key, []byte(payload), time.Minute*30) - // return []byte(payload), nil - // } - // }) + reg := require.NewRegistryWithLoader(func(path string) ([]byte, error) { + requireMod := strings.Replace(path, "node_modules/", "", 1) + // 'https://' - requires network permissions and must be enabled in the config + // 'file://' - requires file system permissions and must be enabled in the config + // 'module://' - requires a module lookup and module permissions + if mod, ok := findInSortedWith(modules, requireMod, + func(m *spec.DGateModule) string { return m.Name }); !ok { + return nil, errors.New(requireMod + " not found") + } else { + if mod.Type == spec.ModuleTypeJavascript { + return []byte(mod.Payload), nil + } + var err error + var key string + transpileBucket := proxyState.sharedCache.Bucket("ts-transpile") + if key, err = HashString(0, mod.Payload); err == nil { + if code, ok := transpileBucket.Get(key); ok { + return code.([]byte), nil + } + } + payload, err := typescript.Transpile(mod.Payload) + if err != nil { + return nil, err + } + transpileBucket.SetWithTTL(key, []byte(payload), time.Minute*30) + return []byte(payload), nil + } + }) rtCtx.loop = eventloop.NewEventLoop( - // eventloop.WithRegistry(reg), + eventloop.WithRegistry(reg), ) return rtCtx } diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index 3c7cd59..7e87275 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -14,6 +14,7 @@ import ( type TCache interface { Bucket(string) Bucket BucketWithOpts(string, BucketOptions) Bucket + Clear() } type Bucket interface { @@ -227,3 +228,17 @@ func (b *bucketImpl) delete(key string) bool { delete(b.items, key) return true } + +func (cache *cacheImpl) Clear() { + cache.mutex.Lock() + defer cache.mutex.Unlock() + for _, b := range cache.buckets { + if bkt, ok := b.(*bucketImpl); ok { + bkt.mutex.Lock() + bkt.items = make(map[string]*cacheEntry) + bkt.ttlQueue = heap.NewHeap[int64, *cacheEntry](heap.MinHeapType) + bkt.limitQueue = heap.NewHeap[int64, *cacheEntry](heap.MinHeapType) + bkt.mutex.Unlock() + } + } +} diff --git a/pkg/modules/extractors/extractors.go b/pkg/modules/extractors/extractors.go index 1120400..d598c92 100644 --- a/pkg/modules/extractors/extractors.go +++ b/pkg/modules/extractors/extractors.go @@ -3,6 +3,7 @@ package extractors import ( "context" "errors" + "fmt" "io" "net/http" "net/url" @@ -28,6 +29,15 @@ type Results struct { IsError bool } +func RunAndWait( + rt *goja.Runtime, + fn goja.Callable, + args ...goja.Value, +) error { + _, err := RunAndWaitForResult(rt, fn, args...) + return err +} + // RunAndWaitForResult can execute a goja function and wait for the result // if the result is a promise, it will wait for the promise to resolve func RunAndWaitForResult( @@ -59,6 +69,8 @@ func RunAndWaitForResult( return nil, nil } return results, nil + } else if err != nil { + return nil, err } else { return res, nil } @@ -88,29 +100,29 @@ func ExtractFetchUpstreamFunction( loop *eventloop.EventLoop, ) (fetchUpstream FetchUpstreamUrlFunc, err error) { rt := loop.Runtime() - fetchUpstreamRaw := rt.Get("fetchUpstream") - if call, ok := goja.AssertFunction(fetchUpstreamRaw); ok { + if fn, ok, err := functionExtractor(rt, "fetchUpstream"); ok { fetchUpstream = func(modCtx *types.ModuleContext) (*url.URL, error) { - res, err := RunAndWaitForResult( - rt, call, rt.ToValue(modCtx), - ) - if err != nil { + if res, err := RunAndWaitForResult( + rt, fn, rt.ToValue(modCtx), + ); err != nil { return nil, err - } - upstreamUrlString := res.String() - if goja.IsUndefined(res) || goja.IsNull(res) || upstreamUrlString == "" { + } else if nully(res) || res.String() == "" { return nil, errors.New("fetchUpstream returned an invalid URL") + } else { + upstreamUrlString := res.String() + if !strings.Contains(upstreamUrlString, "://") { + upstreamUrlString += "http://" + } + upstreamUrl, err := url.Parse(upstreamUrlString) + if err != nil { + return nil, err + } + // perhaps add default scheme if not present + return upstreamUrl, err } - if !strings.Contains(upstreamUrlString, "://") { - upstreamUrlString += "http://" - } - upstreamUrl, err := url.Parse(upstreamUrlString) - if err != nil { - return nil, err - } - // perhaps add default scheme if not present - return upstreamUrl, err } + } else if err != nil { + return nil, err } else { fetchUpstream = DefaultFetchUpstreamFunction() } @@ -121,13 +133,14 @@ func ExtractRequestModifierFunction( loop *eventloop.EventLoop, ) (requestModifier RequestModifierFunc, err error) { rt := loop.Runtime() - if call, ok := goja.AssertFunction(rt.Get("requestModifier")); ok { + if fn, ok, err := functionExtractor(rt, "requestModifier"); ok { requestModifier = func(modCtx *types.ModuleContext) error { - _, err := RunAndWaitForResult( - rt, call, rt.ToValue(modCtx), - ) - return err + return RunAndWait(rt, fn, rt.ToValue(modCtx)) } + } else if err != nil { + return nil, err + } else { + return nil, nil } return requestModifier, nil } @@ -136,14 +149,15 @@ func ExtractResponseModifierFunction( loop *eventloop.EventLoop, ) (responseModifier ResponseModifierFunc, err error) { rt := loop.Runtime() - if call, ok := goja.AssertFunction(rt.Get("responseModifier")); ok { + if fn, ok, err := functionExtractor(rt, "responseModifier"); ok { responseModifier = func(modCtx *types.ModuleContext, res *http.Response) error { modCtx = types.ModuleContextWithResponse(modCtx, res) - _, err := RunAndWaitForResult( - rt, call, rt.ToValue(modCtx), - ) - return err + return RunAndWait(rt, fn, rt.ToValue(modCtx)) } + } else if err != nil { + return nil, err + } else { + return nil, nil } return responseModifier, nil } @@ -173,18 +187,16 @@ func ExtractErrorHandlerFunction( loop *eventloop.EventLoop, ) (errorHandler ErrorHandlerFunc, err error) { rt := loop.Runtime() - if call, ok := goja.AssertFunction(rt.Get("errorHandler")); ok { + if fn, ok, err := functionExtractor(rt, "errorHandler"); ok { errorHandler = func(modCtx *types.ModuleContext, upstreamErr error) error { - if modCtx == nil { - return upstreamErr - } modCtx = types.ModuleContextWithError(modCtx, upstreamErr) - _, err := RunAndWaitForResult( - rt, call, rt.ToValue(modCtx), + return RunAndWait( + rt, fn, rt.ToValue(modCtx), rt.ToValue(rt.NewGoError(upstreamErr)), ) - return err } + } else if err != nil { + return nil, err } else { errorHandler = DefaultErrorHandlerFunction() } @@ -195,16 +207,32 @@ func ExtractRequestHandlerFunction( loop *eventloop.EventLoop, ) (requestHandler RequestHandlerFunc, err error) { rt := loop.Runtime() - if call, ok := goja.AssertFunction(rt.Get("requestHandler")); ok { + if fn, ok, err := functionExtractor(rt, "requestHandler"); ok { requestHandler = func(modCtx *types.ModuleContext) error { - if modCtx == nil { - return errors.New("module context is nil") - } - _, err := RunAndWaitForResult( - rt, call, rt.ToValue(modCtx), + return RunAndWait( + rt, fn, rt.ToValue(modCtx), ) - return err } + } else if err != nil { + return nil, err + } else { + return nil, err } return requestHandler, nil } + +func functionExtractor(rt *goja.Runtime, varName string) (goja.Callable, bool, error) { + check := fmt.Sprintf( + "exports?.%s ?? (typeof %s === 'function' ? %s : void 0)", + varName, varName, varName, + ) + if fnRef, err := rt.RunString(check); err != nil { + return nil, false, err + } else if fn, ok := goja.AssertFunction(fnRef); ok { + return fn, true, nil + } else if nully(fnRef) { + return nil, false, nil + } else { + return nil, false, errors.New("extractors: invalid function -> " + varName) + } +} diff --git a/pkg/modules/types/module_context.go b/pkg/modules/types/module_context.go index 413b21c..7a13d75 100644 --- a/pkg/modules/types/module_context.go +++ b/pkg/modules/types/module_context.go @@ -20,7 +20,7 @@ type ModuleContext struct { loop *eventloop.EventLoop req *RequestWrapper rwt *ResponseWriterWrapper - resp *ResponseWrapper + upResp *ResponseWrapper cache map[string]interface{} } @@ -78,7 +78,7 @@ func (modCtx *ModuleContext) Request() *RequestWrapper { } func (modCtx *ModuleContext) Upstream() *ResponseWrapper { - return modCtx.resp + return modCtx.upResp } func (modCtx *ModuleContext) Response() *ResponseWriterWrapper { @@ -89,14 +89,15 @@ func ModuleContextWithResponse( modCtx *ModuleContext, resp *http.Response, ) *ModuleContext { - modCtx.resp = NewResponseWrapper(resp, modCtx.loop) + modCtx.upResp = NewResponseWrapper(resp, modCtx.loop) + modCtx.rwt = nil return modCtx } func ModuleContextWithError( modCtx *ModuleContext, err error, ) *ModuleContext { - modCtx.resp = nil + modCtx.upResp = nil return modCtx } @@ -119,7 +120,7 @@ func GetModuleContextRequest(modCtx *ModuleContext) *RequestWrapper { } func GetModuleContextResponse(modCtx *ModuleContext) *ResponseWrapper { - return modCtx.resp + return modCtx.upResp } func GetModuleContextResponseWriterTracker(modCtx *ModuleContext) spec.ResponseWriterTracker { diff --git a/pkg/modules/types/request.go b/pkg/modules/types/request.go index 87f346e..99d27dd 100644 --- a/pkg/modules/types/request.go +++ b/pkg/modules/types/request.go @@ -1,6 +1,8 @@ package types import ( + "bytes" + "encoding/json" "errors" "io" "net" @@ -8,6 +10,7 @@ import ( "net/url" "github.com/dgate-io/dgate/pkg/eventloop" + "github.com/dgate-io/dgate/pkg/util" "github.com/dop251/goja" ) @@ -15,7 +18,6 @@ type RequestWrapper struct { req *http.Request loop *eventloop.EventLoop - Body io.ReadCloser Method string URL string Headers http.Header @@ -43,22 +45,64 @@ func NewRequestWrapper( Host: req.Host, Proto: req.Proto, Headers: req.Header, - Body: req.Body, Method: req.Method, RemoteAddress: ip, ContentLength: req.ContentLength, } } -func (g *RequestWrapper) GetBody() (*goja.ArrayBuffer, error) { - if g.Body == nil { +func (g *RequestWrapper) clearBody() { + if g.req.Body != nil { + // read all data from body + io.ReadAll(g.req.Body) + g.req.Body.Close() + g.req.Body = nil + } +} + +func (g *RequestWrapper) WriteJson(data any) error { + g.req.Header.Set("Content-Type", "application/json") + buf, err := json.Marshal(data) + if err != nil { + return err + } + return g.WriteBody(buf) +} + +func (g *RequestWrapper) ReadJson() (any, error) { + if ab, err := g.ReadBody(); err != nil { + return nil, err + } else { + var data any + err := json.Unmarshal(ab.Bytes(), &data) + if err != nil { + return nil, err + } + return data, nil + + } +} + +func (g *RequestWrapper) WriteBody(data any) error { + g.clearBody() + buf, err := util.ToBytes(data) + if err != nil { + return err + } + g.req.Body = io.NopCloser(bytes.NewReader(buf)) + g.req.ContentLength = int64(len(buf)) + return nil +} + +func (g *RequestWrapper) ReadBody() (*goja.ArrayBuffer, error) { + if g.req.Body == nil { return nil, errors.New("body is not set") } - buf, err := io.ReadAll(g.Body) + buf, err := io.ReadAll(g.req.Body) if err != nil { return nil, err } - defer g.Body.Close() + defer g.req.Body.Close() arrBuf := g.loop.Runtime().NewArrayBuffer(buf) return &arrBuf, nil } diff --git a/pkg/modules/types/response_writer.go b/pkg/modules/types/response_writer.go index 8e8c5ba..4bda57f 100644 --- a/pkg/modules/types/response_writer.go +++ b/pkg/modules/types/response_writer.go @@ -46,7 +46,7 @@ type CookieOptions struct { SameSite string `json:"sameSite"` } -func (g *ResponseWriterWrapper) SetCookie(name string, value string, opts ...*CookieOptions) (*ResponseWriterWrapper, error) { +func (g *ResponseWriterWrapper) Cookie(name string, value string, opts ...*CookieOptions) (*ResponseWriterWrapper, error) { if len(opts) > 1 { return nil, errors.New("too many auguments") } @@ -130,11 +130,17 @@ func (g *ResponseWriterWrapper) Location(url string) *ResponseWriterWrapper { return g } -func (g *ResponseWriterWrapper) Cookie() []*http.Cookie { +func (g *ResponseWriterWrapper) GetCookies() []*http.Cookie { return g.req.Cookies() } -func (g *ResponseWriterWrapper) Header() http.Header { - return g.rw.Header() +func (g *ResponseWriterWrapper) GetCookie(name string) *http.Cookie { + cookies := g.req.Cookies() + for _, cookie := range cookies { + if cookie.Name == name { + return cookie + } + } + return nil } diff --git a/pkg/modules/types/upstream_response.go b/pkg/modules/types/upstream_response.go index 377bb7a..8de0d3e 100644 --- a/pkg/modules/types/upstream_response.go +++ b/pkg/modules/types/upstream_response.go @@ -18,6 +18,7 @@ type ResponseWrapper struct { response *http.Response loop *eventloop.EventLoop + Headers http.Header `json:"headers"` StatusCode int `json:"statusCode"` StatusText string `json:"statusText"` Trailer http.Header `json:"trailer"` @@ -34,6 +35,7 @@ func NewResponseWrapper( return &ResponseWrapper{ response: resp, loop: loop, + Headers: resp.Header, Protocol: resp.Proto, StatusText: resp.Status, Trailer: resp.Trailer, @@ -44,7 +46,16 @@ func NewResponseWrapper( } } -func (rw *ResponseWrapper) GetBody() *goja.Promise { +func (rw *ResponseWrapper) clearBody() { + if rw.response.Body != nil { + io.ReadAll(rw.response.Body) + rw.response.Body.Close() + rw.response.Body = nil + } + rw.response.ContentLength = 0 +} + +func (rw *ResponseWrapper) ReadBody() *goja.Promise { prom, res, rej := rw.loop.Runtime().NewPromise() rw.loop.RunOnLoop(func(r *goja.Runtime) { buf, err := io.ReadAll(rw.response.Body) @@ -58,7 +69,7 @@ func (rw *ResponseWrapper) GetBody() *goja.Promise { return prom } -func (rw *ResponseWrapper) GetJson() *goja.Promise { +func (rw *ResponseWrapper) ReadJson() *goja.Promise { prom, res, rej := rw.loop.Runtime().NewPromise() rw.loop.RunOnLoop(func(r *goja.Runtime) { var data any @@ -78,16 +89,17 @@ func (rw *ResponseWrapper) GetJson() *goja.Promise { return prom } -func (rw *ResponseWrapper) SetJson(data any) error { - rw.Header().Set("Content-Type", "application/json") +func (rw *ResponseWrapper) WriteJson(data any) error { + rw.Headers.Set("Content-Type", "application/json") b, err := json.Marshal(data) if err != nil { return err } - return rw.SetBody(b) + return rw.WriteBody(b) } -func (rw *ResponseWrapper) SetBody(data any) error { +func (rw *ResponseWrapper) WriteBody(data any) error { + rw.clearBody() if rw.StatusCode <= 0 { rw.StatusCode = http.StatusOK rw.response.Status = rw.StatusText @@ -100,14 +112,13 @@ func (rw *ResponseWrapper) SetBody(data any) error { if err != nil { return err } - rw.response.Body.Close() rw.response.ContentLength = int64(len(buf)) rw.response.Header.Set("Content-Length", strconv.FormatInt(rw.response.ContentLength, 10)) rw.response.Body = io.NopCloser(bytes.NewReader(buf)) return nil } -func (rw *ResponseWrapper) SetStatus(status int) *ResponseWrapper { +func (rw *ResponseWrapper) Status(status int) *ResponseWrapper { rw.response.StatusCode = status rw.StatusCode = rw.response.StatusCode rw.response.Status = http.StatusText(status) @@ -115,17 +126,16 @@ func (rw *ResponseWrapper) SetStatus(status int) *ResponseWrapper { return rw } -func (rw *ResponseWrapper) SetRedirect(url string) { - rw.response.Body = nil - rw.Header().Set("Location", url) - rw.SetStatus(http.StatusTemporaryRedirect) +func (rw *ResponseWrapper) Redirect(url string) { + rw.clearBody() + rw.Headers.Set("Location", url) + rw.Status(http.StatusTemporaryRedirect) } -func (rw *ResponseWrapper) SetRedirectPermanent(url string) { - rw.response.Body.Close() - rw.response.Body = nil - rw.Header().Set("Location", url) - rw.SetStatus(http.StatusMovedPermanently) +func (rw *ResponseWrapper) RedirectPermanent(url string) { + rw.clearBody() + rw.Headers.Set("Location", url) + rw.Status(http.StatusMovedPermanently) } func (rw *ResponseWrapper) Query() url.Values { @@ -135,7 +145,3 @@ func (rw *ResponseWrapper) Query() url.Values { func (rw *ResponseWrapper) Cookie() []*http.Cookie { return rw.response.Cookies() } - -func (rw *ResponseWrapper) Header() http.Header { - return rw.response.Header -} diff --git a/pkg/resources/resource_manager.go b/pkg/resources/resource_manager.go index f20fd0a..a2b8d81 100644 --- a/pkg/resources/resource_manager.go +++ b/pkg/resources/resource_manager.go @@ -3,9 +3,9 @@ package resources import ( "errors" "sort" - "sync" "github.com/dgate-io/dgate/pkg/spec" + "github.com/dgate-io/dgate/pkg/util/keylock" "github.com/dgate-io/dgate/pkg/util/linker" "github.com/dgate-io/dgate/pkg/util/safe" "github.com/dgate-io/dgate/pkg/util/tree/avl" @@ -22,7 +22,7 @@ type ResourceManager struct { routes avlTreeLinker[spec.DGateRoute] secrets avlTreeLinker[spec.DGateSecret] collections avlTreeLinker[spec.DGateCollection] - mutex *sync.RWMutex + mutex *keylock.KeyLock } type Options func(*ResourceManager) @@ -36,7 +36,7 @@ func NewManager(opts ...Options) *ResourceManager { routes: avl.NewTree[string, *linker.Link[string, safe.Ref[spec.DGateRoute]]](), collections: avl.NewTree[string, *linker.Link[string, safe.Ref[spec.DGateCollection]]](), secrets: avl.NewTree[string, *linker.Link[string, safe.Ref[spec.DGateSecret]]](), - mutex: &sync.RWMutex{}, + mutex: keylock.NewKeyLock(), } for _, opt := range opts { if opt != nil { @@ -57,8 +57,7 @@ func WithDefaultNamespace(ns *spec.Namespace) Options { */ func (rm *ResourceManager) GetNamespace(namespace string) (*spec.DGateNamespace, bool) { - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLock(namespace)() return rm.getNamespace(namespace) } @@ -74,8 +73,6 @@ func (rm *ResourceManager) NamespaceCountEquals(target int) bool { if target < 0 { panic("target must be greater than or equal to 0") } - rm.mutex.RLock() - defer rm.mutex.RUnlock() rm.namespaces.Each(func(_ string, lk *linker.Link[string, safe.Ref[spec.DGateNamespace]]) bool { target -= 1 return target > 0 @@ -84,8 +81,6 @@ func (rm *ResourceManager) NamespaceCountEquals(target int) bool { } func (rm *ResourceManager) GetFirstNamespace() *spec.DGateNamespace { - rm.mutex.RLock() - defer rm.mutex.RUnlock() if _, nsLink, ok := rm.namespaces.RootKeyValue(); ok { return nsLink.Item().Read() } @@ -94,8 +89,7 @@ func (rm *ResourceManager) GetFirstNamespace() *spec.DGateNamespace { // GetNamespaces returns a list of all namespaces func (rm *ResourceManager) GetNamespaces() []*spec.DGateNamespace { - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLockMain()() var namespaces []*spec.DGateNamespace rm.namespaces.Each(func(_ string, lk *linker.Link[string, safe.Ref[spec.DGateNamespace]]) bool { namespaces = append(namespaces, lk.Item().Read()) @@ -112,8 +106,7 @@ func (rm *ResourceManager) transformNamespace(ns *spec.Namespace) *spec.DGateNam } func (rm *ResourceManager) AddNamespace(ns *spec.Namespace) *spec.DGateNamespace { - rm.mutex.Lock() - defer rm.mutex.Unlock() + defer rm.mutex.Lock(ns.Name)() namespace := rm.transformNamespace(ns) if nsLk, ok := rm.namespaces.Find(ns.Name); ok { nsLk.Item().Replace(namespace) @@ -130,8 +123,7 @@ func (rm *ResourceManager) AddNamespace(ns *spec.Namespace) *spec.DGateNamespace } func (rm *ResourceManager) RemoveNamespace(namespace string) error { - rm.mutex.Lock() - defer rm.mutex.Unlock() + defer rm.mutex.Lock(namespace)() if nsLk, ok := rm.namespaces.Find(namespace); ok { if nsLk.Len("routes") > 0 { return ErrCannotDeleteNamespace(namespace, "routes still linked") @@ -159,8 +151,7 @@ func (rm *ResourceManager) RemoveNamespace(namespace string) error { /* Route functions */ func (rm *ResourceManager) GetRoute(name, namespace string) (*spec.DGateRoute, bool) { - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLock(namespace)() return rm.getRoute(name, namespace) } @@ -173,8 +164,7 @@ func (rm *ResourceManager) getRoute(name, namespace string) (*spec.DGateRoute, b // GetRoutes returns a list of all routes func (rm *ResourceManager) GetRoutes() []*spec.DGateRoute { - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLockMain()() var routes []*spec.DGateRoute rm.routes.Each(func(_ string, rtlk *linker.Link[string, safe.Ref[spec.DGateRoute]]) bool { routes = append(routes, rtlk.Item().Read()) @@ -185,8 +175,7 @@ func (rm *ResourceManager) GetRoutes() []*spec.DGateRoute { // GetRoutesByNamespace returns a list of all routes in a namespace func (rm *ResourceManager) GetRoutesByNamespace(namespace string) []*spec.DGateRoute { - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLock(namespace)() var routes []*spec.DGateRoute if nsLk, ok := rm.namespaces.Find(namespace); ok { nsLk.Each("routes", func(_ string, lk linker.Linker[string]) { @@ -199,8 +188,7 @@ func (rm *ResourceManager) GetRoutesByNamespace(namespace string) []*spec.DGateR // GetRouteNamespaceMap returns a map of all routes and their namespaces as the key func (rm *ResourceManager) GetRouteNamespaceMap() map[string][]*spec.DGateRoute { - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLockMain()() routeMap := make(map[string][]*spec.DGateRoute) rm.namespaces.Each(func(ns string, lk *linker.Link[string, safe.Ref[spec.DGateNamespace]]) bool { routes := []*spec.DGateRoute{} @@ -217,8 +205,7 @@ func (rm *ResourceManager) GetRouteNamespaceMap() map[string][]*spec.DGateRoute } func (rm *ResourceManager) AddRoute(route *spec.Route) (rt *spec.DGateRoute, err error) { - rm.mutex.Lock() - defer rm.mutex.Unlock() + defer rm.mutex.RLockMain()() if rt, err = rm.transformRoute(route); err != nil { return nil, err } else if nsLk, ok := rm.namespaces.Find(route.NamespaceName); !ok { @@ -277,8 +264,7 @@ func (rm *ResourceManager) transformRoute(route *spec.Route) (*spec.DGateRoute, // RemoveRoute removes a route from the resource manager func (rm *ResourceManager) RemoveRoute(name, namespace string) error { - rm.mutex.Lock() - defer rm.mutex.Unlock() + defer rm.mutex.Lock(namespace)() if nsLk, ok := rm.namespaces.Find(namespace); !ok { return ErrNamespaceNotFound(namespace) } else if lk, ok := rm.routes.Find(name + "/" + namespace); ok { @@ -347,8 +333,7 @@ func (rm *ResourceManager) relinkRoute( /* Service functions */ func (rm *ResourceManager) GetService(name, namespace string) (*spec.DGateService, bool) { - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLock(namespace)() return rm.getService(name, namespace) } @@ -361,8 +346,7 @@ func (rm *ResourceManager) getService(name, namespace string) (*spec.DGateServic // GetServicesByNamespace returns a list of all services in a namespace func (rm *ResourceManager) GetServicesByNamespace(namespace string) []*spec.DGateService { - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLock(namespace)() var services []*spec.DGateService if nsLk, ok := rm.namespaces.Find(namespace); ok { nsLk.Each("services", func(_ string, lk linker.Linker[string]) { @@ -375,8 +359,7 @@ func (rm *ResourceManager) GetServicesByNamespace(namespace string) []*spec.DGat // GetServices returns a list of all services func (rm *ResourceManager) GetServices() []*spec.DGateService { - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLockMain()() var services []*spec.DGateService rm.services.Each(func(_ string, lk *linker.Link[string, safe.Ref[spec.DGateService]]) bool { services = append(services, lk.Item().Read()) @@ -386,8 +369,7 @@ func (rm *ResourceManager) GetServices() []*spec.DGateService { } func (rm *ResourceManager) AddService(service *spec.Service) (*spec.DGateService, error) { - rm.mutex.Lock() - defer rm.mutex.Unlock() + defer rm.mutex.Lock(service.NamespaceName)() svc, err := rm.transformService(service) if err != nil { return nil, err @@ -413,8 +395,7 @@ func (rm *ResourceManager) transformService(service *spec.Service) (*spec.DGateS } func (rm *ResourceManager) RemoveService(name, namespace string) error { - rm.mutex.Lock() - defer rm.mutex.Unlock() + defer rm.mutex.Lock(namespace)() if lk, ok := rm.services.Find(name + "/" + namespace); ok { if nsLk, ok := rm.namespaces.Find(namespace); ok { if rtsLk := lk.Get("routes"); rtsLk != nil { @@ -437,8 +418,7 @@ func (rm *ResourceManager) RemoveService(name, namespace string) error { /* Domain functions */ func (rm *ResourceManager) GetDomain(name, namespace string) (*spec.DGateDomain, bool) { - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLock(namespace)() return rm.getDomain(name, namespace) } @@ -451,8 +431,7 @@ func (rm *ResourceManager) getDomain(name, namespace string) (*spec.DGateDomain, // GetDomains returns a list of all domains func (rm *ResourceManager) GetDomains() []*spec.DGateDomain { - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLockMain()() var domains []*spec.DGateDomain rm.domains.Each(func(_ string, lk *linker.Link[string, safe.Ref[spec.DGateDomain]]) bool { domains = append(domains, lk.Item().Read()) @@ -465,8 +444,7 @@ func (rm *ResourceManager) DomainCountEquals(target int) bool { if target < 0 { panic("target must be greater than or equal to 0") } - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLockMain()() rm.domains.Each(func(_ string, lk *linker.Link[string, safe.Ref[spec.DGateDomain]]) bool { target -= 1 return target > 0 @@ -476,8 +454,7 @@ func (rm *ResourceManager) DomainCountEquals(target int) bool { // GetDomainsByPriority returns a list of all domains sorted by priority and name func (rm *ResourceManager) GetDomainsByPriority() []*spec.DGateDomain { - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLockMain()() var domains []*spec.DGateDomain rm.domains.Each(func(_ string, lk *linker.Link[string, safe.Ref[spec.DGateDomain]]) bool { domains = append(domains, lk.Item().Read()) @@ -497,8 +474,7 @@ func (rm *ResourceManager) GetDomainsByPriority() []*spec.DGateDomain { // GetDomainsByNamespace returns a list of all domains in a namespace func (rm *ResourceManager) GetDomainsByNamespace(namespace string) []*spec.DGateDomain { - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLock(namespace)() var domains []*spec.DGateDomain if nsLk, ok := rm.namespaces.Find(namespace); ok { nsLk.Each("domains", func(_ string, lk linker.Linker[string]) { @@ -512,8 +488,7 @@ func (rm *ResourceManager) GetDomainsByNamespace(namespace string) []*spec.DGate // AddDomain adds a domain to the resource manager func (rm *ResourceManager) AddDomain(domain *spec.Domain) (*spec.DGateDomain, error) { - rm.mutex.Lock() - defer rm.mutex.Unlock() + defer rm.mutex.Lock(domain.NamespaceName)() dm, err := rm.transformDomain(domain) if err != nil { return nil, err @@ -546,8 +521,7 @@ func (rm *ResourceManager) transformDomain(domain *spec.Domain) (*spec.DGateDoma } func (rm *ResourceManager) RemoveDomain(name, namespace string) error { - rm.mutex.Lock() - defer rm.mutex.Unlock() + defer rm.mutex.Lock(namespace)() if dmLk, ok := rm.domains.Find(name + "/" + namespace); ok { if nsLk, ok := rm.namespaces.Find(namespace); ok { nsLk.UnlinkOneMany("domains", name) @@ -567,8 +541,7 @@ func (rm *ResourceManager) RemoveDomain(name, namespace string) error { /* Module functions */ func (rm *ResourceManager) GetModule(name, namespace string) (*spec.DGateModule, bool) { - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLock(namespace)() return rm.getModule(name, namespace) } @@ -581,8 +554,7 @@ func (rm *ResourceManager) getModule(name, namespace string) (*spec.DGateModule, // GetModules returns a list of all modules func (rm *ResourceManager) GetModules() []*spec.DGateModule { - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLockMain()() var modules []*spec.DGateModule rm.modules.Each(func(_ string, lk *linker.Link[string, safe.Ref[spec.DGateModule]]) bool { modules = append(modules, lk.Item().Read()) @@ -593,8 +565,7 @@ func (rm *ResourceManager) GetModules() []*spec.DGateModule { // GetRouteModules returns a list of all modules in a route func (rm *ResourceManager) GetRouteModules(name, namespace string) ([]*spec.DGateModule, bool) { - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLock(namespace)() route, ok := rm.getRoute(name, namespace) if !ok { return nil, false @@ -611,8 +582,7 @@ func (rm *ResourceManager) GetRouteModules(name, namespace string) ([]*spec.DGat // GetModulesByNamespace returns a list of all modules in a namespace func (rm *ResourceManager) GetModulesByNamespace(namespace string) []*spec.DGateModule { - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLock(namespace)() var modules []*spec.DGateModule if nsLk, ok := rm.namespaces.Find(namespace); ok { nsLk.Each("modules", func(_ string, lk linker.Linker[string]) { @@ -624,8 +594,7 @@ func (rm *ResourceManager) GetModulesByNamespace(namespace string) []*spec.DGate } func (rm *ResourceManager) AddModule(module *spec.Module) (*spec.DGateModule, error) { - rm.mutex.Lock() - defer rm.mutex.Unlock() + defer rm.mutex.Lock(module.NamespaceName)() md, err := rm.transformModule(module) if err != nil { return nil, err @@ -651,8 +620,7 @@ func (rm *ResourceManager) transformModule(module *spec.Module) (*spec.DGateModu } func (rm *ResourceManager) RemoveModule(name, namespace string) error { - rm.mutex.Lock() - defer rm.mutex.Unlock() + defer rm.mutex.Lock(namespace)() if modLink, ok := rm.modules.Find(name + "/" + namespace); ok { if modLink.Len("routes") > 0 { return ErrCannotDeleteModule(name, "routes still linked") @@ -675,8 +643,7 @@ func (rm *ResourceManager) RemoveModule(name, namespace string) error { /* Collection functions */ func (rm *ResourceManager) GetCollection(name, namespace string) (*spec.DGateCollection, bool) { - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLock(namespace)() return getCollection(rm, name, namespace) } @@ -688,8 +655,7 @@ func getCollection(rm *ResourceManager, name, namespace string) (*spec.DGateColl } func (rm *ResourceManager) GetCollectionsByNamespace(namespace string) []*spec.DGateCollection { - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLock(namespace)() var collections []*spec.DGateCollection if nsLk, ok := rm.namespaces.Find(namespace); ok { nsLk.Each("collections", func(_ string, lk linker.Linker[string]) { @@ -702,8 +668,7 @@ func (rm *ResourceManager) GetCollectionsByNamespace(namespace string) []*spec.D // GetCollections returns a list of all collections func (rm *ResourceManager) GetCollections() []*spec.DGateCollection { - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLockMain()() var collections []*spec.DGateCollection rm.collections.Each(func(_ string, lk *linker.Link[string, safe.Ref[spec.DGateCollection]]) bool { collections = append(collections, lk.Item().Read()) @@ -713,8 +678,7 @@ func (rm *ResourceManager) GetCollections() []*spec.DGateCollection { } func (rm *ResourceManager) AddCollection(collection *spec.Collection) (*spec.DGateCollection, error) { - rm.mutex.Lock() - defer rm.mutex.Unlock() + defer rm.mutex.Lock(collection.NamespaceName)() cl, err := rm.transformCollection(collection) if err != nil { return nil, err @@ -750,8 +714,7 @@ func (rm *ResourceManager) transformCollection(collection *spec.Collection) (*sp } func (rm *ResourceManager) RemoveCollection(name, namespace string) error { - rm.mutex.Lock() - defer rm.mutex.Unlock() + defer rm.mutex.Lock(namespace)() if colLk, ok := rm.collections.Find(name + "/" + namespace); ok { if nsLk, ok := rm.namespaces.Find(namespace); ok { // unlink namespace to collection @@ -773,8 +736,7 @@ func (rm *ResourceManager) RemoveCollection(name, namespace string) error { /* Secret functions */ func (rm *ResourceManager) GetSecret(name, namespace string) (*spec.DGateSecret, bool) { - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLock(namespace)() return rm.getSecret(name, namespace) } @@ -787,8 +749,7 @@ func (rm *ResourceManager) getSecret(name, namespace string) (*spec.DGateSecret, // GetSecrets returns a list of all secrets func (rm *ResourceManager) GetSecrets() []*spec.DGateSecret { - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLockMain()() var secrets []*spec.DGateSecret rm.secrets.Each(func(_ string, lk *linker.Link[string, safe.Ref[spec.DGateSecret]]) bool { secrets = append(secrets, lk.Item().Read()) @@ -799,8 +760,7 @@ func (rm *ResourceManager) GetSecrets() []*spec.DGateSecret { // GetSecretsByNamespace returns a list of all secrets in a namespace func (rm *ResourceManager) GetSecretsByNamespace(namespace string) []*spec.DGateSecret { - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLock(namespace)() var secrets []*spec.DGateSecret if nsLk, ok := rm.namespaces.Find(namespace); ok { nsLk.Each("secrets", func(_ string, lk linker.Linker[string]) { @@ -812,8 +772,7 @@ func (rm *ResourceManager) GetSecretsByNamespace(namespace string) []*spec.DGate } func (rm *ResourceManager) AddSecret(secret *spec.Secret) (*spec.DGateSecret, error) { - rm.mutex.Lock() - defer rm.mutex.Unlock() + defer rm.mutex.Lock(secret.NamespaceName)() md, err := rm.transformSecret(secret) if err != nil { return nil, err @@ -839,8 +798,7 @@ func (rm *ResourceManager) transformSecret(secret *spec.Secret) (*spec.DGateSecr } func (rm *ResourceManager) RemoveSecret(name, namespace string) error { - rm.mutex.Lock() - defer rm.mutex.Unlock() + defer rm.mutex.Lock(namespace)() if scrtLink, ok := rm.secrets.Find(name + "/" + namespace); ok { if scrtLink.Len("routes") > 0 { return ErrCannotDeleteSecret(name, "routes still linked") @@ -860,9 +818,21 @@ func (rm *ResourceManager) RemoveSecret(name, namespace string) error { } } +// Clear removes all resources from the resource manager +func (rm *ResourceManager) Clear() { + defer rm.mutex.LockMain()() + rm.namespaces.Clear() + rm.services.Clear() + rm.domains.Clear() + rm.modules.Clear() + rm.routes.Clear() + rm.collections.Clear() + rm.secrets.Clear() +} + +// Empty returns true if the resource manager is empty func (rm *ResourceManager) Empty() bool { - rm.mutex.RLock() - defer rm.mutex.RUnlock() + defer rm.mutex.RLockMain()() return rm.namespaces.Empty() && rm.services.Empty() && rm.domains.Empty() && diff --git a/pkg/scheduler/scheduler.go b/pkg/scheduler/scheduler.go index a1d8cb0..189be47 100644 --- a/pkg/scheduler/scheduler.go +++ b/pkg/scheduler/scheduler.go @@ -43,12 +43,13 @@ type Scheduler interface { type scheduler struct { opts Options + ctx context.Context + cancel context.CancelFunc logger *zerolog.Logger tasks map[string]*TaskDefinition pendingJobs priorityQueue mutex *sync.RWMutex running bool - end chan struct{} } type TaskDefinition struct { @@ -83,9 +84,9 @@ func New(opts Options) Scheduler { } return &scheduler{ opts: opts, + ctx: context.TODO(), logger: opts.Logger, mutex: &sync.RWMutex{}, - end: make(chan struct{}, 1), pendingJobs: heap.NewHeap[int64, *TaskDefinition](heap.MinHeapType), tasks: make(map[string]*TaskDefinition), } @@ -103,8 +104,7 @@ func (s *scheduler) Start() error { func (s *scheduler) start() { s.running = true - // replace the end channel to allow for multiple starts - s.end = make(chan struct{}, 1) + s.ctx, s.cancel = context.WithCancel(s.ctx) go func() { ticker := time.NewTicker(s.opts.Interval) defer ticker.Stop() @@ -119,7 +119,7 @@ func (s *scheduler) start() { return } select { - case <-s.end: + case <-s.ctx.Done(): s.running = false done = true return @@ -172,7 +172,10 @@ func (s *scheduler) executeTask(tdt time.Time, taskDef *TaskDefinition) { func (s *scheduler) Stop() { s.mutex.Lock() defer s.mutex.Unlock() - close(s.end) + if !s.running { + return + } + s.cancel() } func (s *scheduler) GetTask(taskId string) (TaskDefinition, bool) { diff --git a/pkg/spec/change_log.go b/pkg/spec/change_log.go index 29b2ee4..664f68b 100644 --- a/pkg/spec/change_log.go +++ b/pkg/spec/change_log.go @@ -83,8 +83,10 @@ var ( DeleteDocumentCommand Command = newCommand(Delete, Documents) DeleteSecretCommand Command = newCommand(Delete, Secrets) - NoopCommand Command = Command("noop") - StopCommand Command = Command("stop") + // internal commands + NoopCommand Command = Command("noop") + ShutdownCommand Command = Command("shutdown") + RestartCommand Command = Command("restart") ) func newCommand(action Action, resource Resource) Command { diff --git a/pkg/spec/response_writer_tracker.go b/pkg/spec/response_writer_tracker.go index 25b6740..26dcb08 100644 --- a/pkg/spec/response_writer_tracker.go +++ b/pkg/spec/response_writer_tracker.go @@ -33,6 +33,7 @@ func (t *rwTracker) Header() http.Header { } func (t *rwTracker) Write(b []byte) (int, error) { + // to write the body, we need to write the headers first if !t.HeadersSent() { t.WriteHeader(http.StatusOK) } diff --git a/pkg/typescript/typescript.go b/pkg/typescript/typescript.go index f549ae1..74f263d 100644 --- a/pkg/typescript/typescript.go +++ b/pkg/typescript/typescript.go @@ -8,6 +8,7 @@ import ( ) // typescript v5.3.3 +// //go:embed typescript.min.js var tscSource string @@ -15,9 +16,17 @@ func Transpile(src string) (string, error) { // transpiles TS into JS with commonjs module and targets es5 return typescript.TranspileString(src, WithCachedTypescriptSource(), + typescript.WithPreventCancellation(), typescript.WithCompileOptions(map[string]any{ - "module": "commonjs", - "target": "es5", + "module": "commonjs", + "target": "es5", + "inlineSourceMap": true, + "inlineSources": true, + "noLib": true, + "noErrorTruncation": true, + "noEmit": true, + "noEmitOnError": true, + "skipLibCheck": true, }), ) } diff --git a/pkg/util/keylock/keylock.go b/pkg/util/keylock/keylock.go new file mode 100644 index 0000000..3714eb6 --- /dev/null +++ b/pkg/util/keylock/keylock.go @@ -0,0 +1,58 @@ +package keylock + +import "sync" + +type KeyLock struct { + locks map[string]*sync.RWMutex + mapLock sync.RWMutex // to make the map safe concurrently +} + +type UnlockFunc func() + +func NewKeyLock() *KeyLock { + return &KeyLock{locks: make(map[string]*sync.RWMutex, 32)} +} + +func (l *KeyLock) getLockBy(key string) *sync.RWMutex { + if mtx, ok := l.findLock(key); ok { + return mtx + } + + l.mapLock.Lock() + defer l.mapLock.Unlock() + ret := &sync.RWMutex{} + l.locks[key] = ret + return ret +} + +func (l *KeyLock) findLock(key string) (*sync.RWMutex, bool) { + l.mapLock.RLock() + defer l.mapLock.RUnlock() + + if ret, found := l.locks[key]; found { + return ret, true + } + return nil, false +} + +func (l *KeyLock) RLock(key string) UnlockFunc { + mtx := l.getLockBy(key) + mtx.RLock() + return mtx.RUnlock +} + +func (l *KeyLock) Lock(key string) UnlockFunc { + mtx := l.getLockBy(key) + mtx.Lock() + return mtx.Unlock +} + +func (l *KeyLock) RLockMain() UnlockFunc { + l.mapLock.RLock() + return l.mapLock.RUnlock +} + +func (l *KeyLock) LockMain() UnlockFunc { + l.mapLock.Lock() + return l.mapLock.Unlock +}