From 109609b3e55d772ca0366b3e01e247aa03ebc7dd Mon Sep 17 00:00:00 2001 From: bubbajoe Date: Thu, 9 May 2024 08:54:04 +0900 Subject: [PATCH] fix: test improvments --- internal/config/loader.go | 45 +++++++------ internal/config/loader_test.go | 83 ++++++++++++++++++++++++ internal/config/testdata/env.config.yaml | 21 ++++++ internal/config/utils.go | 16 ++++- internal/proxy/dynamic_proxy.go | 8 +-- internal/proxy/module_executor.go | 17 +++-- internal/proxy/module_mock_test.go | 24 +++---- internal/proxy/proxy_handler.go | 6 +- internal/proxy/proxy_handler_test.go | 31 +++++---- internal/proxy/request_context.go | 18 +++-- internal/router/router.go | 16 +---- performance-tests/long-perf-test.js | 22 ++++++- pkg/util/tree/avl/avl.go | 14 ---- pkg/util/tree/avl/avl_test.go | 38 ++++++++++- 14 files changed, 250 insertions(+), 109 deletions(-) create mode 100644 internal/config/loader_test.go create mode 100644 internal/config/testdata/env.config.yaml diff --git a/internal/config/loader.go b/internal/config/loader.go index ba2a777..f7928e8 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -10,7 +10,6 @@ import ( "path" "regexp" "strings" - "time" "github.com/dgate-io/dgate/pkg/util" "github.com/hashicorp/raft" @@ -24,13 +23,13 @@ import ( "github.com/mitchellh/mapstructure" ) -func LoadConfig(dgateConfigPath string) (*DGateConfig, error) { - ctx, cancel := context.WithTimeout( - context.Background(), - time.Second*10, - ) - defer cancel() +var ( + EnvVarRegex = regexp.MustCompile(`\${(?P[a-zA-Z0-9_]{1,})(:-(?P.*?)?)?}`) + CommandRegex = regexp.MustCompile(`\$\((?P.*?)\)`) +) +func LoadConfig(dgateConfigPath string) (*DGateConfig, error) { + ctx := context.Background() var dgateConfigData string if dgateConfigPath == "" { dgateConfigPath = os.Getenv("DG_CONFIG_PATH") @@ -91,33 +90,34 @@ func LoadConfig(dgateConfigPath string) (*DGateConfig, error) { } panicVars := []string{} - data := k.All() if !util.EnvVarCheckBool("DG_DISABLE_SHELL_PARSER") { - commandRegex := regexp.MustCompile(`\$\((?P.*?)\)`) + data := k.All() shell := "/bin/sh" - if shellEnv, exists := os.LookupEnv("SHELL"); exists { + if shellEnv := os.Getenv("SHELL"); shellEnv != "" { shell = shellEnv - } - resolveConfigStringPattern(data, commandRegex, func(value string, results map[string]string) (string, error) { - cmdResult, err := exec.CommandContext(ctx, shell, "-c", results["cmd"]).Output() + } + resolveConfigStringPattern(data, CommandRegex, func(value string, results map[string]string) (string, error) { + cmdResult, err := exec.CommandContext( + ctx, shell, "-c", results["cmd"]).Output() if err != nil { + panicVars = append(panicVars, results["cmd"]) return "", err } return strings.TrimSpace(string(cmdResult)), nil }, func(results map[string]string, err error) { panic("error on command - `" + results["cmd"] + "`: " + err.Error()) }) + k.Load(confmap.Provider(data, "."), nil) } if !util.EnvVarCheckBool("DG_DISABLE_ENV_PARSER") { - envVarRe := regexp.MustCompile(`\${(?P[a-zA-Z0-9_]{1,})(:-(?P.*?)?)?}`) - resolveConfigStringPattern(data, envVarRe, func(value string, results map[string]string) (string, error) { + data := k.All() + resolveConfigStringPattern(data, EnvVarRegex, func(value string, results map[string]string) (string, error) { if envVar := os.Getenv(results["var_name"]); envVar != "" { return envVar, nil } else if strings.Contains(value, results["var_name"]+":-") { return results["default"], nil } - panicVars = append(panicVars, results["var_name"]) return "", nil }, func(results map[string]string, err error) { panicVars = append(panicVars, results["var_name"]) @@ -126,10 +126,9 @@ func LoadConfig(dgateConfigPath string) (*DGateConfig, error) { if len(panicVars) > 0 { panic("required env vars not set: " + strings.Join(panicVars, ", ")) } + k.Load(confmap.Provider(data, "."), nil) } - k.Load(confmap.Provider(data, "."), nil) - // validate configuration var err error kDefault(k, "log_level", "info") @@ -168,11 +167,11 @@ func LoadConfig(dgateConfigPath string) (*DGateConfig, error) { return nil, err } - kDefault(k, "proxy.transport.max_idle_conns", 100) - kDefault(k, "proxy.transport.force_attempt_http2", true) - kDefault(k, "proxy.transport.idle_conn_timeout", "90s") - kDefault(k, "proxy.transport.tls_handshake_timeout", "10s") - kDefault(k, "proxy.transport.expect_continue_timeout", "1s") + // kDefault(k, "proxy.transport.max_idle_conns", 100) + // kDefault(k, "proxy.transport.force_attempt_http2", true) + // kDefault(k, "proxy.transport.idle_conn_timeout", "90s") + // kDefault(k, "proxy.transport.tls_handshake_timeout", "10s") + // kDefault(k, "proxy.transport.expect_continue_timeout", "1s") if k.Exists("test_server") { kDefault(k, "test_server.enable_h2c", true) kDefault(k, "test_server.enable_http2", true) diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go new file mode 100644 index 0000000..4c28a3d --- /dev/null +++ b/internal/config/loader_test.go @@ -0,0 +1,83 @@ +package config_test + +import ( + "os" + "testing" + + "github.com/dgate-io/dgate/internal/config" + "github.com/stretchr/testify/assert" +) + +func TestConfig_EnvVarRegex(t *testing.T) { + re := config.EnvVarRegex + inputs := []string{ + "${var_name1}", + "${var_name2:-default}", + "${var_name3:-}", + "${var_name4:default}", + "${var_name5:}", + } + results := make(map[string]string) + for _, input := range inputs { + matches := re.FindAllStringSubmatch(input, -1) + for _, match := range matches { + results[input] = match[1] + "//" + match[3] + } + } + assert.Equal(t, results, map[string]string{ + "${var_name1}": "var_name1//", + "${var_name2:-default}": "var_name2//default", + "${var_name3:-}": "var_name3//", + }) +} + +func TestConfig_CommandRegex(t *testing.T) { + re := config.CommandRegex + inputs := []string{ + "$(cmd1)", + "$(cmd2 arg1 arg2)", + "$(cmd3 \"arg1\" 'arg2')", + } + results := make(map[string]string) + for _, input := range inputs { + matches := re.FindAllStringSubmatch(input, -1) + for _, match := range matches { + results[input] = match[1] + } + } + assert.Equal(t, results, map[string]string{ + "$(cmd1)": "cmd1", + "$(cmd2 arg1 arg2)": "cmd2 arg1 arg2", + "$(cmd3 \"arg1\" 'arg2')": "cmd3 \"arg1\" 'arg2'", + }) +} + +func TestConfig_LoaderVariables(t *testing.T) { + os.Setenv("ENV1", "test1") + os.Setenv("ENV2", "test2") + os.Setenv("ENV3", "test3") + os.Setenv("ENV4", "") + os.Setenv("ENV5", "test5") + os.Setenv("ADMIN_PORT", "8080") + conf, err := config.LoadConfig("testdata/env.config.yaml") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, []string{ + "test1", + "$ENV2", + "test3", + "test4", + "testing", + "test test5", + }, conf.Tags) + assert.Equal(t, "v1", conf.Version) + assert.Equal(t, true, conf.TestServerConfig.EnableH2C) + assert.Equal(t, true, conf.TestServerConfig.EnableHTTP2) + assert.Equal(t, false, conf.TestServerConfig.EnableEnvVars) + assert.Equal(t, 80, conf.ProxyConfig.Port) + assert.Equal(t, 8080, conf.AdminConfig.Port) + assert.Equal(t, 1, len(conf.Storage.Config)) + assert.Equal(t, "test1-test2-testing", + conf.Storage.Config["testing"]) +} diff --git a/internal/config/testdata/env.config.yaml b/internal/config/testdata/env.config.yaml new file mode 100644 index 0000000..e52144c --- /dev/null +++ b/internal/config/testdata/env.config.yaml @@ -0,0 +1,21 @@ +version: v1 +debug: true +tags: + - ${ENV1} + - $ENV2 + - ${ENV3:-abc123} + - ${ENV4:-test4} + - $(echo "testing") + - $(echo "test $ENV5") +test_server: + port: 8080 +storage: + type: debug + testing: ${ENV1}-${ENV2}-$(echo "testing") +proxy: + port: ${PROXY_PORT:-80} + host: 0.0.0.0 +admin: + port: ${ADMIN_PORT:-9080} + host: 0.0.0.0 + diff --git a/internal/config/utils.go b/internal/config/utils.go index e1fc92a..f4f386a 100644 --- a/internal/config/utils.go +++ b/internal/config/utils.go @@ -12,15 +12,21 @@ import ( ) type FoundFunc func(string, map[string]string) (string, error) - type ErrorFunc func(map[string]string, error) -func resolveConfigStringPattern(data map[string]any, re *regexp.Regexp, foundFunc FoundFunc, errorFunc ErrorFunc) { +func resolveConfigStringPattern( + data map[string]any, + re *regexp.Regexp, + foundFunc FoundFunc, + errorFunc ErrorFunc, +) { for k, v := range data { var values []string switch vt := v.(type) { case string: values = []string{vt} + case []string: + values = vt case []any: values = sliceMap(vt, func(val any) string { if s, ok := val.(string); ok { @@ -31,6 +37,12 @@ func resolveConfigStringPattern(data map[string]any, re *regexp.Regexp, foundFun if len(values) == 0 { continue } + case any: + if vv, ok := vt.(string); ok { + values = []string{vv} + } else { + continue + } default: continue } diff --git a/internal/proxy/dynamic_proxy.go b/internal/proxy/dynamic_proxy.go index 35a780f..2b6badf 100644 --- a/internal/proxy/dynamic_proxy.go +++ b/internal/proxy/dynamic_proxy.go @@ -110,7 +110,7 @@ func (ps *ProxyState) setupRoutes() (err error) { reqCtxProvider := NewRequestContextProvider(r, ps) reqCtxProviders.Insert(r.Namespace.Name+"/"+r.Name, reqCtxProvider) if len(r.Modules) > 0 { - modBuf, err := NewModuleBuffer( + modPool, err := NewModulePool( 256, 1024, reqCtxProvider, ps.createModuleExtractorFunc(r), ) @@ -118,7 +118,7 @@ func (ps *ProxyState) setupRoutes() (err error) { ps.logger.Err(err).Msg("Error creating module buffer") return err } - reqCtxProvider.SetModuleBuffer(modBuf) + reqCtxProvider.SetModulePool(modPool) } err = func() (err error) { defer func() { @@ -229,7 +229,7 @@ func (ps *ProxyState) startChangeLoop() { func() { ps.proxyLock.Lock() defer ps.proxyLock.Unlock() - + err := ps.reconfigureState(false, log) if log.PushError(err); err != nil { ps.logger.Err(err). @@ -352,7 +352,7 @@ func (ps *ProxyState) Stop() { func (ps *ProxyState) HandleRoute(requestCtxProvider *RequestContextProvider, pattern string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - // ctx, cancel := context.WithCancel(requestCtxProvider.ctx) + // ctx, cancel := context.WithCancel(requestCtxPrdovider.ctx) // defer cancel() ps.ProxyHandlerFunc(ps, requestCtxProvider. CreateRequestContext(requestCtxProvider.ctx, w, r, pattern)) diff --git a/internal/proxy/module_executor.go b/internal/proxy/module_executor.go index 94ccd0f..c6ce81f 100644 --- a/internal/proxy/module_executor.go +++ b/internal/proxy/module_executor.go @@ -5,14 +5,13 @@ import ( "errors" ) -type ModuleBuffer interface { - // Load(cb func()) +type ModulePool interface { Borrow() ModuleExtractor Return(me ModuleExtractor) Close() } -type moduleBuffer struct { +type modulePool struct { modExtBuffer chan ModuleExtractor min, max int @@ -22,11 +21,11 @@ type moduleBuffer struct { createModuleExtract func() ModuleExtractor } -func NewModuleBuffer( +func NewModulePool( minBuffers, maxBuffers int, reqCtxProvider *RequestContextProvider, createModExts func(*RequestContextProvider) ModuleExtractor, -) (ModuleBuffer, error) { +) (ModulePool, error) { if minBuffers < 1 { panic("module concurrency must be greater than 0") } @@ -38,7 +37,7 @@ func NewModuleBuffer( if me == nil { return nil, errors.New("could not load moduleExtract") } - mb := &moduleBuffer{ + mb := &modulePool{ min: minBuffers, max: maxBuffers, modExtBuffer: make(chan ModuleExtractor, maxBuffers), @@ -50,7 +49,7 @@ func NewModuleBuffer( return mb, nil } -func (mb *moduleBuffer) Borrow() ModuleExtractor { +func (mb *modulePool) Borrow() ModuleExtractor { if mb == nil || mb.ctx == nil || mb.ctx.Err() != nil { return nil } @@ -65,7 +64,7 @@ func (mb *moduleBuffer) Borrow() ModuleExtractor { return me } -func (mb *moduleBuffer) Return(me ModuleExtractor) { +func (mb *modulePool) Return(me ModuleExtractor) { // if context is canceled, do not return module extract if mb.ctx != nil && mb.ctx.Err() == nil { select { @@ -78,7 +77,7 @@ func (mb *moduleBuffer) Return(me ModuleExtractor) { me.Stop(true) } -func (mb *moduleBuffer) Close() { +func (mb *modulePool) Close() { if mb.ctxCancel != nil { mb.ctxCancel() } diff --git a/internal/proxy/module_mock_test.go b/internal/proxy/module_mock_test.go index b006bff..9ef01e6 100644 --- a/internal/proxy/module_mock_test.go +++ b/internal/proxy/module_mock_test.go @@ -11,34 +11,34 @@ import ( "github.com/stretchr/testify/mock" ) -type mockModuleBuffer struct { +type mockModulePool struct { mock.Mock } -var _ proxy.ModuleBuffer = &mockModuleBuffer{} +var _ proxy.ModulePool = &mockModulePool{} -func NewMockModuleBuffer() *mockModuleBuffer { - return &mockModuleBuffer{} +func NewMockModulePool() *mockModulePool { + return &mockModulePool{} } -// Borrow implements proxy.ModuleBuffer. -func (mb *mockModuleBuffer) Borrow() proxy.ModuleExtractor { +// Borrow implements proxy.ModulePool. +func (mb *mockModulePool) Borrow() proxy.ModuleExtractor { args := mb.Called() return args.Get(0).(proxy.ModuleExtractor) } -// Close implements proxy.ModuleBuffer. -func (mb *mockModuleBuffer) Close() { +// Close implements proxy.ModulePool. +func (mb *mockModulePool) Close() { mb.Called() } -// Load implements proxy.ModuleBuffer. -func (mb *mockModuleBuffer) Load(cb func()) { +// Load implements proxy.ModulePool. +func (mb *mockModulePool) Load(cb func()) { mb.Called(cb) } -// Return implements proxy.ModuleBuffer. -func (mb *mockModuleBuffer) Return(me proxy.ModuleExtractor) { +// Return implements proxy.ModulePool. +func (mb *mockModulePool) Return(me proxy.ModuleExtractor) { mb.Called(me) } diff --git a/internal/proxy/proxy_handler.go b/internal/proxy/proxy_handler.go index f2239c8..34334e8 100644 --- a/internal/proxy/proxy_handler.go +++ b/internal/proxy/proxy_handler.go @@ -36,12 +36,12 @@ func proxyHandler(ps *ProxyState, reqCtx *RequestContext) { var modExt ModuleExtractor if len(reqCtx.route.Modules) != 0 { runtimeStart := time.Now() - if reqCtx.provider.modBuf == nil { + if reqCtx.provider.modPool == nil { ps.logger.Error().Msg("Error getting module buffer: invalid state") util.WriteStatusCodeError(reqCtx.rw, http.StatusInternalServerError) return } - if modExt = reqCtx.provider.modBuf.Borrow(); modExt == nil { + if modExt = reqCtx.provider.modPool.Borrow(); modExt == nil { ps.metrics.MeasureModuleDuration( reqCtx, "module_extract", runtimeStart, errors.New("error borrowing module"), @@ -50,7 +50,7 @@ func proxyHandler(ps *ProxyState, reqCtx *RequestContext) { util.WriteStatusCodeError(reqCtx.rw, http.StatusInternalServerError) return } - defer reqCtx.provider.modBuf.Return(modExt) + defer reqCtx.provider.modPool.Return(modExt) modExt.Start(reqCtx) defer modExt.Stop(true) diff --git a/internal/proxy/proxy_handler_test.go b/internal/proxy/proxy_handler_test.go index 977db17..e934a8a 100644 --- a/internal/proxy/proxy_handler_test.go +++ b/internal/proxy/proxy_handler_test.go @@ -61,14 +61,19 @@ func TestProxyHandler_ReverseProxy(t *testing.T) { modExt := NewMockModuleExtractor() modExt.ConfigureDefaultMock(req, wr, ps, rt) - modBuf := NewMockModuleBuffer() + modBuf := NewMockModulePool() modBuf.On("Borrow").Return(modExt).Once() modBuf.On("Return", modExt).Return().Once() - reqCtxProvider.SetModuleBuffer(modBuf) + reqCtxProvider.SetModulePool(modBuf) + + modPool := NewMockModulePool() + modPool.On("Borrow").Return(modExt).Once() + modPool.On("Return", modExt).Return().Once() + reqCtxProvider.SetModulePool(modPool) ps.ProxyHandlerFunc(ps, reqCtx) wr.AssertExpectations(t) - modBuf.AssertExpectations(t) + modPool.AssertExpectations(t) modExt.AssertExpectations(t) rpBuilder.AssertExpectations(t) // rpe.AssertExpectations(t) @@ -117,17 +122,17 @@ func TestProxyHandler_ProxyHandler(t *testing.T) { reqCtxProvider := proxy.NewRequestContextProvider(rt, ps) modExt := NewMockModuleExtractor() modExt.ConfigureDefaultMock(req, wr, ps, rt) - modBuf := NewMockModuleBuffer() - modBuf.On("Borrow").Return(modExt).Once() - modBuf.On("Return", modExt).Return().Once() - reqCtxProvider.SetModuleBuffer(modBuf) + modPool := NewMockModulePool() + modPool.On("Borrow").Return(modExt).Once() + modPool.On("Return", modExt).Return().Once() + reqCtxProvider.SetModulePool(modPool) reqCtx := reqCtxProvider.CreateRequestContext( context.Background(), wr, req, "/") ps.ProxyHandlerFunc(ps, reqCtx) wr.AssertExpectations(t) - modBuf.AssertExpectations(t) + modPool.AssertExpectations(t) modExt.AssertExpectations(t) } } @@ -169,18 +174,18 @@ func TestProxyHandler_ProxyHandlerError(t *testing.T) { modExt := NewMockModuleExtractor() modExt.ConfigureDefaultMock(req, wr, ps, rt) - modBuf := NewMockModuleBuffer() - modBuf.On("Borrow").Return(modExt).Once() - modBuf.On("Return", modExt).Return().Once() + modPool := NewMockModulePool() + modPool.On("Borrow").Return(modExt).Once() + modPool.On("Return", modExt).Return().Once() reqCtxProvider := proxy.NewRequestContextProvider(rt, ps) - reqCtxProvider.SetModuleBuffer(modBuf) + reqCtxProvider.SetModulePool(modPool) reqCtx := reqCtxProvider.CreateRequestContext( context.Background(), wr, req, "/") ps.ProxyHandlerFunc(ps, reqCtx) wr.AssertExpectations(t) - modBuf.AssertExpectations(t) + modPool.AssertExpectations(t) modExt.AssertExpectations(t) } } diff --git a/internal/proxy/request_context.go b/internal/proxy/request_context.go index 08a7faf..7e2b911 100644 --- a/internal/proxy/request_context.go +++ b/internal/proxy/request_context.go @@ -14,10 +14,10 @@ import ( type S string type RequestContextProvider struct { - ctx context.Context - route *spec.DGateRoute - rpb reverse_proxy.Builder - modBuf ModuleBuffer + ctx context.Context + route *spec.DGateRoute + rpb reverse_proxy.Builder + modPool ModulePool } type RequestContext struct { @@ -76,15 +76,13 @@ func NewRequestContextProvider(route *spec.DGateRoute, ps *ProxyState) *RequestC } } -func (reqCtxProvider *RequestContextProvider) SetModuleBuffer(mb ModuleBuffer) { - reqCtxProvider.modBuf = mb +func (reqCtxProvider *RequestContextProvider) SetModulePool(mb ModulePool) { + reqCtxProvider.modPool = mb } func (reqCtxProvider *RequestContextProvider) CreateRequestContext( - ctx context.Context, - rw http.ResponseWriter, - req *http.Request, - pattern string, + ctx context.Context, rw http.ResponseWriter, + req *http.Request, pattern string, ) *RequestContext { pathParams := make(map[string]string) if chiCtx := chi.RouteContext(req.Context()); chiCtx != nil { diff --git a/internal/router/router.go b/internal/router/router.go index adf679a..ff3582f 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -14,12 +14,6 @@ type DynamicRouter struct { lock sync.RWMutex } -// NewRouter creates a new router -func NewRouter() *DynamicRouter { - chiMux := chi.NewRouter() - return NewRouterWithMux(chiMux) -} - // NewRouter creates a new router func NewRouterWithMux(mux *chi.Mux) *DynamicRouter { return &DynamicRouter{ @@ -40,12 +34,6 @@ func (r *DynamicRouter) ModifyMux(fn func(*chi.Mux)) { fn(r.router) } -func (r *DynamicRouter) Match(method, path string) bool { - r.lock.Lock() - defer r.lock.Unlock() - return r.router.Match(r.routeCtx, method, path) -} - // ReplaceRouter replaces the router func (r *DynamicRouter) ReplaceMux(router *chi.Mux) { r.lock.Lock() @@ -56,7 +44,7 @@ func (r *DynamicRouter) ReplaceMux(router *chi.Mux) { // ServeHTTP is a wrapper around chi.Router.ServeHTTP func (r *DynamicRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { - r.lock.RLock() - defer r.lock.RUnlock() + // r.lock.RLock() + // defer r.lock.RUnlock() r.router.ServeHTTP(w, req) } diff --git a/performance-tests/long-perf-test.js b/performance-tests/long-perf-test.js index 1ecb17c..c0412f1 100644 --- a/performance-tests/long-perf-test.js +++ b/performance-tests/long-perf-test.js @@ -17,7 +17,7 @@ export let options = { }, modtest_wait: { executor: 'constant-vus', - vus: n*5, + vus: n*3, duration: inc + 'm', startTime: (curWait += inc) + 'm', exec: 'dgatePath', @@ -35,7 +35,7 @@ export let options = { }, svctest_wait: { executor: 'constant-vus', - vus: n*5, + vus: n*3, duration: inc + 'm', startTime: (curWait += inc) + 'm', exec: 'dgatePath', @@ -61,6 +61,24 @@ export let options = { gracefulStop: '5s', }, }, + // test_server_direct: { + // executor: 'constant-vus', + // vus: n, + // duration: inc + 'm', + // startTime: (curWait += inc) + 'm', + // exec: 'dgatePath', + // env: { DGATE_PATH: ":8888/direct" }, + // gracefulStop: '5s', + // }, + // test_server_direct_wait: { + // executor: 'constant-vus', + // vus: n*3, + // duration: inc + 'm', + // startTime: (curWait += inc) + 'm', + // exec: 'dgatePath', + // env: { DGATE_PATH: ":8888/svctest?wait=30ms" }, + // gracefulStop: '5s', + // }, discardResponseBodies: true, }; diff --git a/pkg/util/tree/avl/avl.go b/pkg/util/tree/avl/avl.go index c75291b..d517a42 100644 --- a/pkg/util/tree/avl/avl.go +++ b/pkg/util/tree/avl/avl.go @@ -405,17 +405,3 @@ func clone[K cmp.Ordered, V any](root *node[K, V], fn func(K, V) V) *node[K, V] right: clone(root.right, fn), } } - -// MarshalJSON returns the JSON encoding of the AVL tree. -// func (t *tree[K, V]) MarshalJSON() ([]byte, error) { -// t.mtx.RLock() -// defer t.mtx.RUnlock() -// return json.Marshal(t.root) -// } - -// UnmarshalJSON decodes the JSON encoding of the AVL tree. -// func (t *tree[K, V]) UnmarshalJSON(data []byte) error { -// t.mtx.Lock() -// defer t.mtx.Unlock() -// return json.Unmarshal(data, &t.root) -// } diff --git a/pkg/util/tree/avl/avl_test.go b/pkg/util/tree/avl/avl_test.go index d887f19..dc98a98 100644 --- a/pkg/util/tree/avl/avl_test.go +++ b/pkg/util/tree/avl/avl_test.go @@ -242,8 +242,8 @@ func treeLength(t *testing.T, tree avl.Tree[int, int]) int { return totalTreeLength } -// Benchmark AVL Tree Insertion -func BenchmarkAVLTreeInsert(b *testing.B) { +// Benchmark AVL Tree Insertion in ascending order +func BenchmarkAVLTreeInsertAsc(b *testing.B) { tree := avl.NewTree[int, int]() // Example with string keys and int values // Run the insertion operation b.N times @@ -252,7 +252,17 @@ func BenchmarkAVLTreeInsert(b *testing.B) { } } -// Benchmark AVL Tree Insertion +// Benchmark AVL Tree Insertion in descending order +func BenchmarkAVLTreeInsertDesc(b *testing.B) { + tree := avl.NewTree[int, int]() // Example with string keys and int values + + // Run the insertion operation b.N times + for i := 0; i < b.N; i++ { + tree.Insert(b.N-i, i) + } +} + +// Benchmark AVL Tree Find operation func BenchmarkAVLTreeFind(b *testing.B) { tree := avl.NewTree[int, int]() @@ -272,6 +282,28 @@ func BenchmarkAVLTreeFind(b *testing.B) { } } +// Benchmark AVL Tree Each operation +func BenchmarkAVLTreeEach(b *testing.B) { + tree := avl.NewTree[int, int]() + + // Insert k nodes into the tree + k := 10_000 + { + b.StopTimer() + for i := 0; i < k; i++ { + tree.Insert(i, i) + } + b.StartTimer() + } + + // Run the each operation b.N times + for i := 0; i < b.N; i++ { + tree.Each(func(k int, v int) bool { + return true + }) + } +} + func BenchmarkAVLTreeInsertAndFindParallel(b *testing.B) { tree := avl.NewTree[int, int]() b.RunParallel(func(pb *testing.PB) {