From c2c01b8671418a428f63f120b8c60c35dc11cf3f Mon Sep 17 00:00:00 2001 From: bubbajoe Date: Fri, 21 Jun 2024 02:10:21 +0900 Subject: [PATCH] refactor admin, raft, and tests: fix deadlocks, slow startup, and --- .github/workflows/e2e.yml | 2 +- cmd/dgate-server/main.go | 12 +- functional-tests/admin_tests/admin_test.sh | 3 +- go.mod | 4 +- internal/admin/admin_fsm.go | 10 +- internal/admin/admin_raft.go | 76 ++++--- internal/admin/admin_routes_test.go | 4 +- internal/admin/changestate/change_state.go | 3 +- .../changestate/testutil/change_state.go | 3 +- internal/admin/routes/collection_routes.go | 9 +- internal/admin/routes/domain_routes.go | 3 +- internal/admin/routes/misc_routes.go | 5 - internal/admin/routes/module_routes.go | 3 +- internal/admin/routes/module_routes_test.go | 54 +++-- internal/admin/routes/namespace_routes.go | 6 +- .../admin/routes/namespace_routes_test.go | 52 +++-- internal/admin/routes/route_routes.go | 6 +- internal/admin/routes/route_routes_test.go | 11 +- internal/admin/routes/service_routes.go | 7 +- internal/admin/routes/service_routes_test.go | 11 +- internal/config/configtest/dgate_configs.go | 6 +- internal/proxy/change_log.go | 104 +++++---- internal/proxy/dynamic_proxy.go | 145 ++++++++----- internal/proxy/proxy_replication.go | 13 +- internal/proxy/proxy_state.go | 201 +++++++++--------- internal/proxy/proxy_state_test.go | 6 +- .../{raftadmin_client.go => client.go} | 103 +++++---- pkg/raftadmin/{raftadmin.go => server.go} | 55 +++-- .../{raftadmin_test.go => server_test.go} | 7 +- pkg/rafthttp/rafthttp.go | 18 +- pkg/rafthttp/rafthttp_test.go | 4 +- pkg/util/queue/queue.go | 9 +- pkg/util/sliceutil/slice.go | 13 +- pkg/util/sliceutil/slice_test.go | 114 ++++++++++ 34 files changed, 639 insertions(+), 443 deletions(-) rename pkg/raftadmin/{raftadmin_client.go => client.go} (72%) rename pkg/raftadmin/{raftadmin.go => server.go} (82%) rename pkg/raftadmin/{raftadmin_test.go => server_test.go} (97%) create mode 100644 pkg/util/sliceutil/slice_test.go diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index fa11015..88f6c4b 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -52,7 +52,7 @@ jobs: - run: cd functional-tests/raft_tests && goreman start & - name: Wait for server to start - run: sleep 5 + run: sleep 10 - name: Functional Standalone Tests run: | diff --git a/cmd/dgate-server/main.go b/cmd/dgate-server/main.go index 0f1a8ed..95dc01e 100644 --- a/cmd/dgate-server/main.go +++ b/cmd/dgate-server/main.go @@ -58,19 +58,23 @@ func main() { } if dgateConfig, err := config.LoadConfig(*configPath); err != nil { fmt.Printf("Error loading config: %s\n", err) - os.Exit(1) + panic(err) } else { logger, err := dgateConfig.GetLogger() if err != nil { fmt.Printf("Error setting up logger: %s\n", err) - os.Exit(1) + panic(err) } defer logger.Sync() proxyState := proxy.NewProxyState(logger.Named("proxy"), dgateConfig) - admin.StartAdminAPI(version, dgateConfig, logger.Named("admin"), proxyState) + err = admin.StartAdminAPI(version, dgateConfig, logger.Named("admin"), proxyState) + if err != nil { + fmt.Printf("Error starting admin api: %s\n", err) + panic(err) + } if err := proxyState.Start(); err != nil { fmt.Printf("Error loading config: %s\n", err) - os.Exit(1) + panic(err) } sigchan := make(chan os.Signal, 1) diff --git a/functional-tests/admin_tests/admin_test.sh b/functional-tests/admin_tests/admin_test.sh index 1b3e91a..e1c56a3 100755 --- a/functional-tests/admin_tests/admin_test.sh +++ b/functional-tests/admin_tests/admin_test.sh @@ -8,7 +8,8 @@ TEST_URL=${TEST_URL:-"http://localhost:8888"} DIR="$( cd "$( dirname "$0" )" && pwd )" -# domain setup +export DGATE_ADMIN_API=$ADMIN_URL + # check if uuid is available if ! command -v uuid > /dev/null; then id=X$RANDOM-$RANDOM-$RANDOM diff --git a/go.mod b/go.mod index b77eff7..a9cda3c 100644 --- a/go.mod +++ b/go.mod @@ -26,12 +26,14 @@ require ( github.com/stoewer/go-strcase v1.3.0 github.com/stretchr/testify v1.9.0 github.com/urfave/cli/v2 v2.27.1 + go.etcd.io/bbolt v1.3.10 go.opentelemetry.io/otel v1.26.0 go.opentelemetry.io/otel/exporters/prometheus v0.48.0 go.opentelemetry.io/otel/metric v1.26.0 go.opentelemetry.io/otel/sdk/metric v1.26.0 go.uber.org/zap v1.27.0 golang.org/x/net v0.21.0 + golang.org/x/sync v0.7.0 golang.org/x/term v0.19.0 ) @@ -80,12 +82,10 @@ require ( github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect - go.etcd.io/bbolt v1.3.10 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/otel/sdk v1.26.0 // indirect go.opentelemetry.io/otel/trace v1.26.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.21.0 // indirect golang.org/x/text v0.14.0 // indirect google.golang.org/protobuf v1.33.0 // indirect diff --git a/internal/admin/admin_fsm.go b/internal/admin/admin_fsm.go index a243b10..daf7480 100644 --- a/internal/admin/admin_fsm.go +++ b/internal/admin/admin_fsm.go @@ -39,6 +39,7 @@ func newAdminFSM( logger.Warn("corrupted state detected", zap.ByteString("prev_state", stateBytes)) } else { logger.Info("found state in store", zap.Any("prev_state", fsm.localState)) + return fsm } } return fsm @@ -76,11 +77,10 @@ func (fsm *AdminFSM) applyLog(log *raft.Log, reload bool) (*spec.ChangeLog, erro } func (fsm *AdminFSM) Apply(log *raft.Log) any { - resps := fsm.ApplyBatch([]*raft.Log{log}) - if len(resps) != 1 { - panic("apply batch not returning the correct number of responses") + if resps := fsm.ApplyBatch([]*raft.Log{log}); len(resps) == 1 { + return resps[0] } - return resps[0] + panic("apply batch not returning the correct number of responses") } func (fsm *AdminFSM) ApplyBatch(logs []*raft.Log) []any { @@ -115,7 +115,7 @@ func (fsm *AdminFSM) ApplyBatch(logs []*raft.Log) []any { zap.Uint64("applied_index", lastLogIndex), ) } - fsm.cs.SetReady(true) + // defer fsm.cs.SetReady(true) } return results diff --git a/internal/admin/admin_raft.go b/internal/admin/admin_raft.go index 2393a63..8c90bf3 100644 --- a/internal/admin/admin_raft.go +++ b/internal/admin/admin_raft.go @@ -92,13 +92,9 @@ func setupRaft( address := raft.ServerAddress(advertAddr) raftHttpLogger := logger.Named("http") - if adminConfig.Replication.AdvertScheme != "http" && adminConfig.Replication.AdvertScheme != "https" { - panic(fmt.Errorf("invalid scheme: %s", adminConfig.Replication.AdvertScheme)) - } - transport := rafthttp.NewHTTPTransport( address, http.DefaultClient, raftHttpLogger, - adminConfig.Replication.AdvertScheme+"://(address)/raft", + adminConfig.Replication.AdvertScheme, ) fsmLogger := logger.Named("fsm") adminFSM := newAdminFSM(fsmLogger, configStore, cs) @@ -110,14 +106,13 @@ func setupRaft( panic(err) } - cs.SetupRaft(raftNode) - // Setup raft handler server.Handle("/raft/*", transport) raftAdminLogger := logger.Named("admin") - raftAdmin := raftadmin.NewRaftAdminHTTPServer( - raftNode, raftAdminLogger, []raft.ServerAddress{address}, + raftAdmin := raftadmin.NewServer( + raftNode, raftAdminLogger, + []raft.ServerAddress{address}, ) // Setup handler for raft admin @@ -138,6 +133,39 @@ func setupRaft( util.JsonResponse(w, http.StatusOK, raftNode.Stats()) })) + // Setup handler for readys + server.Handle("/raftadmin/readyz", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Raft-State", raftNode.State().String()) + if err := cs.WaitForChanges(nil); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + leaderId, leaderAddr := raftNode.LeaderWithID() + util.JsonResponse(w, http.StatusOK, map[string]any{ + "status": "ok", + "proxy_ready": cs.Ready(), + "state": raftNode.State().String(), + "leader": leaderId, + "leader_addr": leaderAddr, + }) + })) + + doer := func(req *http.Request) (*http.Response, error) { + req.Header.Set("User-Agent", "dgate") + if adminConfig.Replication.SharedKey != "" { + req.Header.Set("X-DGate-Shared-Key", adminConfig.Replication.SharedKey) + } + client := *http.DefaultClient + client.Timeout = time.Second * 10 + return client.Do(req) + } + adminClient := raftadmin.NewClient( + doer, logger.Named("raft-admin-client"), + adminConfig.Replication.AdvertScheme, + ) + + cs.SetupRaft(raftNode, adminClient) + configFuture := raftNode.GetConfiguration() if err = configFuture.Error(); err != nil { panic(err) @@ -203,33 +231,25 @@ func setupRaft( if len(addresses) > 0 { addresses = append(addresses, adminConfig.Replication.ClusterAddrs...) retries := 0 - doer := func(req *http.Request) (*http.Response, error) { - req.Header.Set("User-Agent", "dgate") - if adminConfig.Replication.SharedKey != "" { - req.Header.Set("X-DGate-Shared-Key", adminConfig.Replication.SharedKey) - } - return http.DefaultClient.Do(req) - } - adminClient := raftadmin.NewHTTPAdminClient(doer, - adminConfig.Replication.AdvertScheme+"://(address)/raftadmin", - logger.Named("raft-admin-client"), - ) RETRY: - for _, url := range addresses { - err = adminClient.VerifyLeader(context.Background(), raft.ServerAddress(url)) + for _, addr := range addresses { + err = adminClient.VerifyLeader( + context.Background(), + raft.ServerAddress(addr), + ) if err != nil { if err == raftadmin.ErrNotLeader { continue } if retries > 15 { logger.Error("Skipping verifying leader", - zap.String("url", url), zap.Error(err), + zap.String("url", addr), zap.Error(err), ) continue } retries += 1 logger.Debug("Retrying verifying leader", - zap.String("url", url), zap.Error(err)) + zap.String("url", addr), zap.Error(err)) <-time.After(3 * time.Second) goto RETRY } @@ -238,10 +258,10 @@ func setupRaft( logger.Info("Adding non-voter", zap.String("id", raftId), zap.String("leader", adminConfig.Replication.AdvertAddr), - zap.String("url", url), + zap.String("url", addr), ) resp, err := adminClient.AddNonvoter( - context.Background(), raft.ServerAddress(url), + context.Background(), raft.ServerAddress(addr), &raftadmin.AddNonvoterRequest{ ID: raftId, Address: adminConfig.Replication.AdvertAddr, @@ -257,9 +277,9 @@ func setupRaft( logger.Info("Adding voter: %s - leader: %s", zap.String("id", raftId), zap.String("leader", adminConfig.Replication.AdvertAddr), - zap.String("url", url), + zap.String("url", addr), ) - resp, err := adminClient.AddVoter(context.Background(), raft.ServerAddress(url), &raftadmin.AddVoterRequest{ + resp, err := adminClient.AddVoter(context.Background(), raft.ServerAddress(addr), &raftadmin.AddVoterRequest{ ID: raftId, Address: adminConfig.Replication.AdvertAddr, }) diff --git a/internal/admin/admin_routes_test.go b/internal/admin/admin_routes_test.go index 84923e0..272f997 100644 --- a/internal/admin/admin_routes_test.go +++ b/internal/admin/admin_routes_test.go @@ -17,7 +17,9 @@ func TestAdminRoutes_configureRoutes(t *testing.T) { cs.On("DocumentManager").Return(nil) conf := configtest.NewTestAdminConfig() if err := configureRoutes( - mux, "test", zap.NewNop(), cs, conf, + mux, "test", + zap.NewNop(), + cs, conf, ); err != nil { t.Fatal(err) } diff --git a/internal/admin/changestate/change_state.go b/internal/admin/changestate/change_state.go index dc2fdd4..2e95c52 100644 --- a/internal/admin/changestate/change_state.go +++ b/internal/admin/changestate/change_state.go @@ -2,6 +2,7 @@ package changestate import ( "github.com/dgate-io/dgate/internal/proxy" + "github.com/dgate-io/dgate/pkg/raftadmin" "github.com/dgate-io/dgate/pkg/resources" "github.com/dgate-io/dgate/pkg/spec" "github.com/hashicorp/raft" @@ -21,7 +22,7 @@ type ChangeState interface { SetReady(bool) // Replication - SetupRaft(*raft.Raft) + SetupRaft(*raft.Raft, *raftadmin.Client) Raft() *raft.Raft // Resources diff --git a/internal/admin/changestate/testutil/change_state.go b/internal/admin/changestate/testutil/change_state.go index f0cd74d..f0e5a82 100644 --- a/internal/admin/changestate/testutil/change_state.go +++ b/internal/admin/changestate/testutil/change_state.go @@ -7,6 +7,7 @@ import ( "github.com/dgate-io/dgate/internal/admin/changestate" "github.com/dgate-io/dgate/pkg/resources" "github.com/dgate-io/dgate/pkg/spec" + "github.com/dgate-io/dgate/pkg/raftadmin" "github.com/hashicorp/raft" "github.com/stretchr/testify/mock" ) @@ -71,7 +72,7 @@ func (m *MockChangeState) ReloadState(a bool, cls ...*spec.ChangeLog) error { } // SetupRaft implements changestate.ChangeState. -func (m *MockChangeState) SetupRaft(*raft.Raft) { +func (m *MockChangeState) SetupRaft(*raft.Raft, *raftadmin.Client) { m.Called().Error(0) } diff --git a/internal/admin/routes/collection_routes.go b/internal/admin/routes/collection_routes.go index 879bfa5..52340f8 100644 --- a/internal/admin/routes/collection_routes.go +++ b/internal/admin/routes/collection_routes.go @@ -63,8 +63,7 @@ func ConfigureCollectionAPI(server chi.Router, logger *zap.Logger, cs changestat } cl := spec.NewChangeLog(&collection, collection.NamespaceName, spec.AddCollectionCommand) - err = cs.ApplyChangeLog(cl) - if err != nil { + if err = cs.ApplyChangeLog(cl); err != nil { util.JsonError(w, http.StatusInternalServerError, err.Error()) return } @@ -268,8 +267,7 @@ func ConfigureCollectionAPI(server chi.Router, logger *zap.Logger, cs changestat } cl := spec.NewChangeLog(&doc, doc.NamespaceName, spec.AddDocumentCommand) - err = cs.ApplyChangeLog(cl) - if err != nil { + if err = cs.ApplyChangeLog(cl); err != nil { util.JsonError(w, http.StatusInternalServerError, err.Error()) return } @@ -350,8 +348,7 @@ func ConfigureCollectionAPI(server chi.Router, logger *zap.Logger, cs changestat return } cl := spec.NewChangeLog(document, namespaceName, spec.DeleteDocumentCommand) - err = cs.ApplyChangeLog(cl) - if err != nil { + if err = cs.ApplyChangeLog(cl); err != nil { util.JsonError(w, http.StatusInternalServerError, err.Error()) return } diff --git a/internal/admin/routes/domain_routes.go b/internal/admin/routes/domain_routes.go index a151090..096fd77 100644 --- a/internal/admin/routes/domain_routes.go +++ b/internal/admin/routes/domain_routes.go @@ -41,8 +41,7 @@ func ConfigureDomainAPI(server chi.Router, logger *zap.Logger, cs changestate.Ch domain.NamespaceName = spec.DefaultNamespace.Name } cl := spec.NewChangeLog(&domain, domain.NamespaceName, spec.AddDomainCommand) - err = cs.ApplyChangeLog(cl) - if err != nil { + if err = cs.ApplyChangeLog(cl); err != nil { util.JsonError(w, http.StatusBadRequest, err.Error()) return } diff --git a/internal/admin/routes/misc_routes.go b/internal/admin/routes/misc_routes.go index a281f79..4f169de 100644 --- a/internal/admin/routes/misc_routes.go +++ b/internal/admin/routes/misc_routes.go @@ -45,11 +45,6 @@ func ConfigureHealthAPI(server chi.Router, version string, cs changestate.Change w.Header().Set("Content-Type", "application/json") if cs.Ready() { if r := cs.Raft(); r != nil { - if err := cs.WaitForChanges(nil); err != nil { - w.WriteHeader(http.StatusServiceUnavailable) - w.Write([]byte(`{"status":"not ready"}`)) - return - } w.Header().Set("X-Raft-State", r.State().String()) if leaderAddr := r.Leader(); leaderAddr == "" { w.WriteHeader(http.StatusServiceUnavailable) diff --git a/internal/admin/routes/module_routes.go b/internal/admin/routes/module_routes.go index 3e5a4df..027da3d 100644 --- a/internal/admin/routes/module_routes.go +++ b/internal/admin/routes/module_routes.go @@ -78,8 +78,7 @@ func ConfigureModuleAPI(server chi.Router, logger *zap.Logger, cs changestate.Ch mod.NamespaceName = spec.DefaultNamespace.Name } cl := spec.NewChangeLog(&mod, mod.NamespaceName, spec.DeleteModuleCommand) - err = cs.ApplyChangeLog(cl) - if err != nil { + if err = cs.ApplyChangeLog(cl); err != nil { util.JsonError(w, http.StatusBadRequest, err.Error()) return } diff --git a/internal/admin/routes/module_routes_test.go b/internal/admin/routes/module_routes_test.go index 74db074..cfabb25 100644 --- a/internal/admin/routes/module_routes_test.go +++ b/internal/admin/routes/module_routes_test.go @@ -21,20 +21,21 @@ import ( func TestAdminRoutes_Module(t *testing.T) { namespaces := []string{"default", "test"} - for _, ns := range namespaces { - config := configtest.NewTest4DGateConfig() - ps := proxy.NewProxyState(zap.NewNop(), config) - mux := chi.NewMux() - mux.Route("/api/v1", func(r chi.Router) { - routes.ConfigureModuleAPI(r, zap.NewNop(), ps, config) - }) - server := httptest.NewServer(mux) - defer server.Close() + config := configtest.NewTest4DGateConfig() + ps := proxy.NewProxyState(zap.NewNop(), config) + if err := ps.Start(); err != nil { + t.Fatal(err) + } + mux := chi.NewMux() + mux.Route("/api/v1", func(r chi.Router) { + routes.ConfigureModuleAPI(r, zap.NewNop(), ps, config) + }) + server := httptest.NewServer(mux) + defer server.Close() + for _, ns := range namespaces { client := dgclient.NewDGateClient() - if err := client.Init(server.URL, server.Client(), - dgclient.WithVerboseLogging(true), - ); err != nil { + if err := client.Init(server.URL, server.Client()); err != nil { t.Fatal(err) } @@ -77,25 +78,22 @@ func TestAdminRoutes_Module(t *testing.T) { } func TestAdminRoutes_ModuleError(t *testing.T) { + config := configtest.NewTest3DGateConfig() + cs := testutil.NewMockChangeState() + rm := resources.NewManager() + cs.On("ApplyChangeLog", mock.Anything). + Return(errors.New("test error")) + cs.On("ResourceManager").Return(rm) + mux := chi.NewMux() + mux.Route("/api/v1", func(r chi.Router) { + routes.ConfigureModuleAPI(r, zap.NewNop(), cs, config) + }) + server := httptest.NewServer(mux) + defer server.Close() namespaces := []string{"default", "test", ""} for _, ns := range namespaces { - config := configtest.NewTest3DGateConfig() - rm := resources.NewManager() - cs := testutil.NewMockChangeState() - cs.On("ApplyChangeLog", mock.Anything). - Return(errors.New("test error")) - cs.On("ResourceManager").Return(rm) - mux := chi.NewMux() - mux.Route("/api/v1", func(r chi.Router) { - routes.ConfigureModuleAPI(r, zap.NewNop(), cs, config) - }) - server := httptest.NewServer(mux) - defer server.Close() - client := dgclient.NewDGateClient() - if err := client.Init(server.URL, server.Client(), - dgclient.WithVerboseLogging(true), - ); err != nil { + if err := client.Init(server.URL, server.Client()); err != nil { t.Fatal(err) } diff --git a/internal/admin/routes/namespace_routes.go b/internal/admin/routes/namespace_routes.go index daed9ab..f6b175c 100644 --- a/internal/admin/routes/namespace_routes.go +++ b/internal/admin/routes/namespace_routes.go @@ -35,8 +35,7 @@ func ConfigureNamespaceAPI(server chi.Router, logger *zap.Logger, cs changestate } cl := spec.NewChangeLog(&namespace, namespace.Name, spec.AddNamespaceCommand) - err = cs.ApplyChangeLog(cl) - if err != nil { + if err = cs.ApplyChangeLog(cl); err != nil { util.JsonError(w, http.StatusBadRequest, err.Error()) return } @@ -69,8 +68,7 @@ func ConfigureNamespaceAPI(server chi.Router, logger *zap.Logger, cs changestate } cl := spec.NewChangeLog(&namespace, namespace.Name, spec.DeleteNamespaceCommand) - err = cs.ApplyChangeLog(cl) - if err != nil { + if err = cs.ApplyChangeLog(cl); err != nil { util.JsonError(w, http.StatusBadRequest, err.Error()) return } diff --git a/internal/admin/routes/namespace_routes_test.go b/internal/admin/routes/namespace_routes_test.go index 00cc227..a03ae77 100644 --- a/internal/admin/routes/namespace_routes_test.go +++ b/internal/admin/routes/namespace_routes_test.go @@ -19,21 +19,22 @@ import ( ) func TestAdminRoutes_Namespace(t *testing.T) { + config := configtest.NewTest3DGateConfig() + ps := proxy.NewProxyState(zap.NewNop(), config) + if err := ps.Start(); err != nil { + t.Fatal(err) + } + mux := chi.NewMux() + mux.Route("/api/v1", func(r chi.Router) { + routes.ConfigureNamespaceAPI(r, zap.NewNop(), ps, config) + }) + server := httptest.NewServer(mux) + defer server.Close() namespaces := []string{"_test", "default"} for _, ns := range namespaces { - config := configtest.NewTest3DGateConfig() - ps := proxy.NewProxyState(zap.NewNop(), config) - mux := chi.NewMux() - mux.Route("/api/v1", func(r chi.Router) { - routes.ConfigureNamespaceAPI(r, zap.NewNop(), ps, config) - }) - server := httptest.NewServer(mux) - defer server.Close() client := dgclient.NewDGateClient() - if err := client.Init(server.URL, server.Client(), - dgclient.WithVerboseLogging(true), - ); err != nil { + if err := client.Init(server.URL, server.Client()); err != nil { t.Fatal(err) } @@ -71,25 +72,22 @@ func TestAdminRoutes_Namespace(t *testing.T) { } func TestAdminRoutes_NamespaceError(t *testing.T) { + config := configtest.NewTest3DGateConfig() + rm := resources.NewManager() + cs := testutil.NewMockChangeState() + cs.On("ApplyChangeLog", mock.Anything). + Return(errors.New("test error")) + cs.On("ResourceManager").Return(rm) + mux := chi.NewMux() + mux.Route("/api/v1", func(r chi.Router) { + routes.ConfigureNamespaceAPI(r, zap.NewNop(), cs, config) + }) + server := httptest.NewServer(mux) + defer server.Close() namespaces := []string{"default", "test", ""} for _, ns := range namespaces { - config := configtest.NewTest3DGateConfig() - rm := resources.NewManager() - cs := testutil.NewMockChangeState() - cs.On("ApplyChangeLog", mock.Anything). - Return(errors.New("test error")) - cs.On("ResourceManager").Return(rm) - mux := chi.NewMux() - mux.Route("/api/v1", func(r chi.Router) { - routes.ConfigureNamespaceAPI(r, zap.NewNop(), cs, config) - }) - server := httptest.NewServer(mux) - defer server.Close() - client := dgclient.NewDGateClient() - if err := client.Init(server.URL, server.Client(), - dgclient.WithVerboseLogging(true), - ); err != nil { + if err := client.Init(server.URL, server.Client()); err != nil { t.Fatal(err) } diff --git a/internal/admin/routes/route_routes.go b/internal/admin/routes/route_routes.go index a1c886e..23f064e 100644 --- a/internal/admin/routes/route_routes.go +++ b/internal/admin/routes/route_routes.go @@ -43,8 +43,7 @@ func ConfigureRouteAPI(server chi.Router, logger *zap.Logger, cs changestate.Cha } cl := spec.NewChangeLog(&route, route.NamespaceName, spec.AddRouteCommand) - err = cs.ApplyChangeLog(cl) - if err != nil { + if err = cs.ApplyChangeLog(cl); err != nil { util.JsonError(w, http.StatusBadRequest, err.Error()) return } @@ -86,8 +85,7 @@ func ConfigureRouteAPI(server chi.Router, logger *zap.Logger, cs changestate.Cha } cl := spec.NewChangeLog(&route, route.NamespaceName, spec.DeleteRouteCommand) - err = cs.ApplyChangeLog(cl) - if err != nil { + if err = cs.ApplyChangeLog(cl); err != nil { util.JsonError(w, http.StatusBadRequest, err.Error()) return } diff --git a/internal/admin/routes/route_routes_test.go b/internal/admin/routes/route_routes_test.go index e8efc07..42956bf 100644 --- a/internal/admin/routes/route_routes_test.go +++ b/internal/admin/routes/route_routes_test.go @@ -23,6 +23,9 @@ func TestAdminRoutes_Route(t *testing.T) { for _, ns := range namespaces { config := configtest.NewTest3DGateConfig() ps := proxy.NewProxyState(zap.NewNop(), config) + if err := ps.Start(); err != nil { + t.Fatal(err) + } mux := chi.NewMux() mux.Route("/api/v1", func(r chi.Router) { routes.ConfigureRouteAPI(r, zap.NewNop(), ps, config) @@ -31,9 +34,7 @@ func TestAdminRoutes_Route(t *testing.T) { defer server.Close() client := dgclient.NewDGateClient() - if err := client.Init(server.URL, server.Client(), - dgclient.WithVerboseLogging(true), - ); err != nil { + if err := client.Init(server.URL, server.Client()); err != nil { t.Fatal(err) } @@ -90,9 +91,7 @@ func TestAdminRoutes_RouteError(t *testing.T) { defer server.Close() client := dgclient.NewDGateClient() - if err := client.Init(server.URL, server.Client(), - dgclient.WithVerboseLogging(true), - ); err != nil { + if err := client.Init(server.URL, server.Client()); err != nil { t.Fatal(err) } diff --git a/internal/admin/routes/service_routes.go b/internal/admin/routes/service_routes.go index 3a41a1d..fb2c631 100644 --- a/internal/admin/routes/service_routes.go +++ b/internal/admin/routes/service_routes.go @@ -80,13 +80,11 @@ func ConfigureServiceAPI(server chi.Router, logger *zap.Logger, cs changestate.C } cl := spec.NewChangeLog(&svc, svc.NamespaceName, spec.AddServiceCommand) - err = cs.ApplyChangeLog(cl) - if err != nil { + if err = cs.ApplyChangeLog(cl); err != nil { util.JsonError(w, http.StatusBadRequest, err.Error()) return } - logger.Debug("Waiting for raft barrier") if err := cs.WaitForChanges(cl); err != nil { util.JsonError(w, http.StatusInternalServerError, err.Error()) return @@ -118,8 +116,7 @@ func ConfigureServiceAPI(server chi.Router, logger *zap.Logger, cs changestate.C svc.NamespaceName = spec.DefaultNamespace.Name } cl := spec.NewChangeLog(&svc, svc.NamespaceName, spec.DeleteServiceCommand) - err = cs.ApplyChangeLog(cl) - if err != nil { + if err = cs.ApplyChangeLog(cl); err != nil { util.JsonError(w, http.StatusBadRequest, err.Error()) return } diff --git a/internal/admin/routes/service_routes_test.go b/internal/admin/routes/service_routes_test.go index 555959f..81be839 100644 --- a/internal/admin/routes/service_routes_test.go +++ b/internal/admin/routes/service_routes_test.go @@ -23,6 +23,9 @@ func TestAdminRoutes_Service(t *testing.T) { for _, ns := range namespaces { config := configtest.NewTest4DGateConfig() ps := proxy.NewProxyState(zap.NewNop(), config) + if err := ps.Start(); err != nil { + t.Fatal(err) + } mux := chi.NewMux() mux.Route("/api/v1", func(r chi.Router) { routes.ConfigureServiceAPI(r, zap.NewNop(), ps, config) @@ -31,9 +34,7 @@ func TestAdminRoutes_Service(t *testing.T) { defer server.Close() client := dgclient.NewDGateClient() - if err := client.Init(server.URL, server.Client(), - dgclient.WithVerboseLogging(true), - ); err != nil { + if err := client.Init(server.URL, server.Client()); err != nil { t.Fatal(err) } @@ -90,9 +91,7 @@ func TestAdminRoutes_ServiceError(t *testing.T) { defer server.Close() client := dgclient.NewDGateClient() - if err := client.Init(server.URL, server.Client(), - dgclient.WithVerboseLogging(true), - ); err != nil { + if err := client.Init(server.URL, server.Client()); err != nil { t.Fatal(err) } diff --git a/internal/config/configtest/dgate_configs.go b/internal/config/configtest/dgate_configs.go index db1e93c..a45e8e6 100644 --- a/internal/config/configtest/dgate_configs.go +++ b/internal/config/configtest/dgate_configs.go @@ -23,7 +23,7 @@ func NewTestDGateConfig() *config.DGateConfig { ProxyConfig: config.DGateProxyConfig{ AllowedDomains: []string{"*test.com", "localhost"}, Host: "localhost", - Port: 8080, + Port: 0, InitResources: &config.DGateResources{ Namespaces: []spec.Namespace{ { @@ -69,7 +69,7 @@ func NewTest2DGateConfig() *config.DGateConfig { conf := NewTestDGateConfig() conf.ProxyConfig = config.DGateProxyConfig{ Host: "localhost", - Port: 16436, + Port: 0, InitResources: &config.DGateResources{ Namespaces: []spec.Namespace{ { @@ -112,7 +112,7 @@ func NewTest4DGateConfig() *config.DGateConfig { conf.DisableDefaultNamespace = false conf.ProxyConfig = config.DGateProxyConfig{ Host: "localhost", - Port: 16436, + Port: 0, InitResources: &config.DGateResources{ Namespaces: []spec.Namespace{ { diff --git a/internal/proxy/change_log.go b/internal/proxy/change_log.go index 57d8fc8..69ef075 100644 --- a/internal/proxy/change_log.go +++ b/internal/proxy/change_log.go @@ -30,7 +30,7 @@ func (ps *ProxyState) processChangeLog(cl *spec.ChangeLog, reload, store bool) ( if store && !cl.Cmd.IsNoop() { defer func() { if err == nil { - if !ps.replicationEnabled { + if !ps.raftEnabled { // dont store change logs if err = ps.store.StoreChangeLog(cl); err != nil { ps.logger.Error("Error storing change log, restarting state", zap.Error(err)) @@ -67,21 +67,29 @@ func (ps *ProxyState) processChangeLog(cl *spec.ChangeLog, reload, store bool) ( goto hash_retry } } else { - ps.restartState(func(err error) { + go ps.restartState(func(err error) { if err != nil { - go ps.Stop() + ps.Stop() } }) } }() if cl.Cmd.Resource() == spec.Documents { - if err = ps.processDocument(cl.Item.(*spec.Document), cl, store); err != nil { + var item *spec.Document + if item, err = decode[*spec.Document](cl.Item); err != nil { + return + } + if err = ps.processDocument(item, cl, store); err != nil { ps.logger.Error("error processing document change log", zap.Error(err)) return } } else { if err = ps.processResource(cl); err != nil { - ps.logger.Error("error processing change log", zap.Error(err)) + ps.logger.Error("error processing change log", + zap.String("id", cl.ID), + zap.Stringer("cmd", cl.Cmd), + zap.Error(err), + ) return } } @@ -91,7 +99,6 @@ func (ps *ProxyState) processChangeLog(cl *spec.ChangeLog, reload, store bool) ( if reload { overrideReload := cl.Cmd.IsNoop() || ps.pendingChanges if overrideReload || cl.Cmd.Resource().IsRelatedTo(spec.Routes) { - ps.logger.Debug("Storing cached documents", zap.String("id", cl.ID)) if err := ps.storeCachedDocuments(); err != nil { ps.logger.Error("error storing cached documents", zap.Error(err)) return err @@ -101,7 +108,6 @@ func (ps *ProxyState) processChangeLog(cl *spec.ChangeLog, reload, store bool) ( ps.logger.Error("Error registering change log", zap.Error(err)) return } - ps.ready.CompareAndSwap(false, true) ps.pendingChanges = false } } else if !cl.Cmd.IsNoop() { @@ -263,6 +269,10 @@ func (ps *ProxyState) processCollection(col *spec.Collection, cl *spec.ChangeLog var docCache = []*spec.Document{} func (ps *ProxyState) storeCachedDocuments() error { + if len(docCache) == 0 { + return nil + } + ps.logger.Debug("Storing cached documents", zap.Int("count", len(docCache))) err := ps.store.StoreDocuments(docCache) if err != nil { return err @@ -289,8 +299,14 @@ func (ps *ProxyState) processDocument(doc *spec.Document, cl *spec.ChangeLog, st case spec.Add: docCache = append(docCache, doc) case spec.Delete: - deletedIndex := sliceutil.BinarySearch(docCache, doc, func(doc1 *spec.Document, doc2 *spec.Document) bool { - return doc1.ID < doc2.ID + deletedIndex := sliceutil.BinarySearch(docCache, doc, func(doc1 *spec.Document, doc2 *spec.Document) int { + if doc1.ID == doc2.ID { + return 0 + } + if doc1.ID < doc2.ID { + return -1 + } + return 1 }) if deletedIndex >= 0 { docCache = append(docCache[:deletedIndex], docCache[deletedIndex+1:]...) @@ -319,45 +335,51 @@ func (ps *ProxyState) processSecret(scrt *spec.Secret, cl *spec.ChangeLog) (err // restoreFromChangeLogs - restores the proxy state from change logs; directApply is used to avoid locking the proxy state func (ps *ProxyState) restoreFromChangeLogs(directApply bool) error { - if logs, err := ps.store.FetchChangeLogs(); err != nil { + var logs []*spec.ChangeLog + var err error + if ps.raftEnabled { + if logs = ps.changeLogs; len(logs) == 0 { + return nil + } + } else if logs, err = ps.store.FetchChangeLogs(); err != nil { return errors.New("failed to get state change logs from storage: " + err.Error()) - } else { - ps.logger.Info("restoring state change logs from storage", zap.Int("count", len(logs))) - // we might need to sort the change logs by timestamp - for _, cl := range logs { - // skip documents as they are persisted in the store - if cl.Cmd.Resource() == spec.Documents { - continue - } - if err = ps.processChangeLog(cl, false, false); err != nil { - return err - } else { - ps.changeLogs = append(ps.changeLogs, cl) - } + } + ps.logger.Info("restoring state change logs from storage", zap.Int("count", len(logs))) + // we might need to sort the change logs by timestamp + for _, cl := range logs { + // skip documents as they are persisted in the store + if cl.Cmd.Resource() == spec.Documents { + continue } - if cl := spec.NewNoopChangeLog(); !directApply { - if err = ps.reconfigureState(cl); err != nil { - return err - } - } else if err = ps.processChangeLog(cl, true, false); err != nil { + if err = ps.processChangeLog(cl, false, false); err != nil { + return err + } else { + ps.changeLogs = append(ps.changeLogs, cl) + } + } + if cl := spec.NewNoopChangeLog(); !directApply { + if err = ps.reconfigureState(cl); err != nil { return err } + } else if err = ps.processChangeLog(cl, true, false); err != nil { + return err + } - // TODO: optionally compact change logs through a flag in config? - if len(logs) > 1 { - removed, err := ps.compactChangeLogs(logs) - if err != nil { - ps.logger.Error("failed to compact state change logs", zap.Error(err)) - return err - } - if removed > 0 { - ps.logger.Info("compacted change logs", - zap.Int("removed", removed), - zap.Int("total", len(logs)), - ) - } + // DISABLED: compaction of change logs needs to have better testing + if len(logs) < 0 { + removed, err := ps.compactChangeLogs(logs) + if err != nil { + ps.logger.Error("failed to compact state change logs", zap.Error(err)) + return err + } + if removed > 0 { + ps.logger.Info("compacted change logs", + zap.Int("removed", removed), + zap.Int("total", len(logs)), + ) } } + return nil } diff --git a/internal/proxy/dynamic_proxy.go b/internal/proxy/dynamic_proxy.go index 9094169..de74200 100644 --- a/internal/proxy/dynamic_proxy.go +++ b/internal/proxy/dynamic_proxy.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "math" "net/http" "os" "time" @@ -23,21 +24,25 @@ import ( func (ps *ProxyState) reconfigureState(log *spec.ChangeLog) (err error) { defer func() { if err != nil { - ps.restartState(func(err error) { + ps.logger.Error("error occurred reloading state, restarting...", zap.Error(err)) + go ps.restartState(func(err error) { if err != nil { ps.logger.Error("Error restarting state", zap.Error(err)) - go ps.Stop() + ps.Stop() } }) } }() + ctx, cancel := context.WithTimeout(context.TODO(), 30*time.Second) + defer cancel() + start := time.Now() - if err = ps.setupModules(log); err != nil { + if err = ps.setupModules(ctx, log); err != nil { ps.logger.Error("Error setting up modules", zap.Error(err)) return } - if err = ps.setupRoutes(log); err != nil { + if err = ps.setupRoutes(ctx, log); err != nil { ps.logger.Error("Error setting up routes", zap.Error(err)) return } @@ -48,7 +53,18 @@ func (ps *ProxyState) reconfigureState(log *spec.ChangeLog) (err error) { return nil } -func (ps *ProxyState) setupModules(log *spec.ChangeLog) error { +func customErrGroup(ctx context.Context, count int) (*errgroup.Group, context.Context) { + grp, ctx := errgroup.WithContext(ctx) + limit := int(math.Log2(float64(count))) + limit = min(1, max(16, limit)) + grp.SetLimit(limit) + return grp, ctx +} + +func (ps *ProxyState) setupModules( + ctx context.Context, + log *spec.ChangeLog, +) error { var routes = []*spec.DGateRoute{} if log.Namespace == "" || ps.pendingChanges { routes = ps.rm.GetRoutes() @@ -56,8 +72,8 @@ func (ps *ProxyState) setupModules(log *spec.ChangeLog) error { routes = ps.rm.GetRoutesByNamespace(log.Namespace) } programs := avl.NewTree[string, *goja.Program]() - grp, ctx := errgroup.WithContext(context.TODO()) - grp.SetLimit(16) + grp, ctx := customErrGroup(ctx, len(routes)) + start := time.Now() for _, rt := range routes { if len(rt.Modules) > 0 { route := rt @@ -69,11 +85,24 @@ func (ps *ProxyState) setupModules(log *spec.ChangeLog) error { modPayload string = mod.Payload ) if mod.Type == spec.ModuleTypeTypescript { + tsBucket := ps.sharedCache.Bucket("typescript") + // hash the typescript module payload + tsHash, err := HashString(1337, modPayload) + if err != nil { + ps.logger.Error("Error hashing module: " + mod.Name) + } else if cacheData, ok := tsBucket.Get(tsHash); ok { + if modPayload, ok = cacheData.(string); ok { + goto compile + } + } if modPayload, err = typescript.Transpile(ctx, modPayload); err != nil { ps.logger.Error("Error transpiling module: " + mod.Name) return err + } else { + tsBucket.SetWithTTL(tsHash, modPayload, 5*time.Minute) } } + compile: if mod.Type == spec.ModuleTypeJavascript || mod.Type == spec.ModuleTypeTypescript { if program, err = goja.Compile(mod.Name, modPayload, true); err != nil { ps.logger.Error("Error compiling module: " + mod.Name) @@ -104,40 +133,55 @@ func (ps *ProxyState) setupModules(log *spec.ChangeLog) error { ps.modPrograms.Insert(s, p) return true }) - + ps.logger.Debug("Modules setup", + zap.Duration("elapsed", time.Since(start)), + ) return nil } -func (ps *ProxyState) setupRoutes(log *spec.ChangeLog) (err error) { +func (ps *ProxyState) setupRoutes( + ctx context.Context, + log *spec.ChangeLog, +) error { var rtMap map[string][]*spec.DGateRoute if log.Namespace == "" || ps.pendingChanges { rtMap = ps.rm.GetRouteNamespaceMap() + ps.providers.Clear() } else { rtMap = make(map[string][]*spec.DGateRoute) - rtMap[log.Namespace] = ps.rm.GetRoutesByNamespace(log.Namespace) + routes := ps.rm.GetRoutesByNamespace(log.Namespace) + if len(routes) > 0 { + rtMap[log.Namespace] = routes + } else { + // if namespace has no routes, delete the router + ps.routers.Delete(log.Namespace) + } } + start := time.Now() + grp, _ := customErrGroup(ctx, len(rtMap)) for namespaceName, routes := range rtMap { - mux := router.NewMux() - for _, rt := range routes { - reqCtxProvider := NewRequestContextProvider(rt, ps) - if len(rt.Modules) > 0 { - modExtFunc := ps.createModuleExtractorFunc(rt) - if modPool, err := NewModulePool( - 256, 1024, reqCtxProvider, modExtFunc, - ); err != nil { - ps.logger.Error("Error creating module buffer", zap.Error(err)) - return err - } else { - reqCtxProvider.SetModulePool(modPool) + namespaceName, routes := namespaceName, routes + grp.Go(func() (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%v", r) } - } - ps.providers.Insert(rt.Namespace.Name+"/"+rt.Name, reqCtxProvider) - err = func(rt *spec.DGateRoute) (err error) { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("%v", r) + }() + mux := router.NewMux() + for _, rt := range routes { + reqCtxProvider := NewRequestContextProvider(rt, ps) + if len(rt.Modules) > 0 { + modExtFunc := ps.createModuleExtractorFunc(rt) + if modPool, err := NewModulePool( + 256, 1024, reqCtxProvider, modExtFunc, + ); err != nil { + ps.logger.Error("Error creating module buffer", zap.Error(err)) + return err + } else { + reqCtxProvider.SetModulePool(modPool) } - }() + } + ps.providers.Insert(rt.Namespace.Name+"/"+rt.Name, reqCtxProvider) for _, path := range rt.Paths { if len(rt.Methods) > 0 && rt.Methods[0] == "*" { if len(rt.Methods) > 1 { @@ -155,18 +199,23 @@ func (ps *ProxyState) setupRoutes(log *spec.ChangeLog) (err error) { } } } - return nil - }(rt) - } - - if dr, ok := ps.routers.Find(namespaceName); ok { - dr.ReplaceMux(mux) - } else { - dr := router.NewRouterWithMux(mux) - ps.routers.Insert(namespaceName, dr) - } + } + if dr, ok := ps.routers.Find(namespaceName); ok { + dr.ReplaceMux(mux) + } else { + dr := router.NewRouterWithMux(mux) + ps.routers.Insert(namespaceName, dr) + } + return nil + }) } - return + if err := grp.Wait(); err != nil { + return err + } + ps.logger.Debug("Routes setup", + zap.Duration("elapsed", time.Since(start)), + ) + return nil } func (ps *ProxyState) createModuleExtractorFunc(rt *spec.DGateRoute) ModuleExtractorFunc { @@ -245,8 +294,8 @@ func (ps *ProxyState) startProxyServer() { } } if err := server.ListenAndServe(); err != nil { - ps.logger.Error("Error starting proxy server", zap.Error(err)) - os.Exit(1) + ps.logger.Error("error starting proxy server", zap.Error(err)) + panic(err) } } @@ -285,14 +334,14 @@ func (ps *ProxyState) startProxyServerTLS() { } if err := secureServer.ListenAndServeTLS("", ""); err != nil { ps.logger.Error("Error starting secure proxy server", zap.Error(err)) - os.Exit(1) + panic(err) } } func (ps *ProxyState) Start() (err error) { defer func() { if err != nil { - ps.logger.Error("Error starting proxy server", zap.Error(err)) + ps.logger.Error("error starting proxy", zap.Error(err)) ps.Stop() } }() @@ -305,21 +354,20 @@ func (ps *ProxyState) Start() (err error) { go ps.startProxyServer() go ps.startProxyServerTLS() - if !ps.replicationEnabled { + if !ps.raftEnabled { if err = ps.restoreFromChangeLogs(false); err != nil { return err } else { - ps.ready.Store(true) + ps.SetReady(true) } } - return nil } func (ps *ProxyState) Stop() { go func() { defer os.Exit(3) - <-time.After(5 * time.Second) + <-time.After(7 * time.Second) ps.logger.Error("Failed to stop proxy server") }() @@ -329,6 +377,7 @@ func (ps *ProxyState) Stop() { ps.proxyLock.Lock() defer ps.proxyLock.Unlock() + ps.logger.Info("Shutting down raft") if raftNode := ps.Raft(); raftNode != nil { ps.logger.Info("Stopping Raft node") diff --git a/internal/proxy/proxy_replication.go b/internal/proxy/proxy_replication.go index 76f96e6..db1d46c 100644 --- a/internal/proxy/proxy_replication.go +++ b/internal/proxy/proxy_replication.go @@ -1,13 +1,18 @@ package proxy -import "github.com/hashicorp/raft" +import ( + "github.com/dgate-io/dgate/pkg/raftadmin" + "github.com/hashicorp/raft" +) type ProxyReplication struct { - raft *raft.Raft + raft *raft.Raft + client *raftadmin.Client } -func NewProxyReplication(raft *raft.Raft) *ProxyReplication { +func NewProxyReplication(raft *raft.Raft, client *raftadmin.Client) *ProxyReplication { return &ProxyReplication{ - raft: raft, + raft: raft, + client: client, } } diff --git a/internal/proxy/proxy_state.go b/internal/proxy/proxy_state.go index 355788d..834c6a4 100644 --- a/internal/proxy/proxy_state.go +++ b/internal/proxy/proxy_state.go @@ -1,6 +1,7 @@ package proxy import ( + "context" "crypto/tls" "encoding/base64" "encoding/json" @@ -21,6 +22,7 @@ import ( "github.com/dgate-io/dgate/internal/router" "github.com/dgate-io/dgate/pkg/cache" "github.com/dgate-io/dgate/pkg/modules/extractors" + "github.com/dgate-io/dgate/pkg/raftadmin" "github.com/dgate-io/dgate/pkg/resources" "github.com/dgate-io/dgate/pkg/scheduler" "github.com/dgate-io/dgate/pkg/spec" @@ -41,22 +43,22 @@ type ProxyState struct { printer console.Printer config *config.DGateConfig store *proxystore.ProxyStore - changeLogs []*spec.ChangeLog - metrics *ProxyMetrics sharedCache cache.TCache proxyLock *sync.RWMutex + ready *atomic.Bool pendingChanges bool + metrics *ProxyMetrics - rm *resources.ResourceManager - skdr scheduler.Scheduler - + rm *resources.ResourceManager + skdr scheduler.Scheduler + changeLogs []*spec.ChangeLog providers avl.Tree[string, *RequestContextProvider] modPrograms avl.Tree[string, *goja.Program] + routers avl.Tree[string, *router.DynamicRouter] - ready *atomic.Bool - replicationSettings *ProxyReplication - replicationEnabled bool - routers avl.Tree[string, *router.DynamicRouter] + raft *raft.Raft + raftClient *raftadmin.Client + raftEnabled bool ReverseProxyBuilder reverse_proxy.Builder ProxyTransportBuilder proxy_transport.Builder @@ -101,9 +103,9 @@ func NewProxyState(logger *zap.Logger, conf *config.DGateConfig) *ProxyState { storeLogger := logger.Named("store") schedulerLogger := logger.Named("scheduler") - replicationEnabled := false + raftEnabled := false if conf.AdminConfig != nil && conf.AdminConfig.Replication != nil { - replicationEnabled = true + raftEnabled = true } state := &ProxyState{ startTime: time.Now(), @@ -119,12 +121,12 @@ func NewProxyState(logger *zap.Logger, conf *config.DGateConfig) *ProxyState { skdr: scheduler.New(scheduler.Options{ Logger: schedulerLogger, }), - providers: avl.NewTree[string, *RequestContextProvider](), - modPrograms: avl.NewTree[string, *goja.Program](), - proxyLock: new(sync.RWMutex), - sharedCache: cache.New(), - store: proxystore.New(dataStore, storeLogger), - replicationEnabled: replicationEnabled, + providers: avl.NewTree[string, *RequestContextProvider](), + modPrograms: avl.NewTree[string, *goja.Program](), + proxyLock: new(sync.RWMutex), + sharedCache: cache.New(), + store: proxystore.New(dataStore, storeLogger), + raftEnabled: raftEnabled, ReverseProxyBuilder: reverse_proxy.NewBuilder(). FlushInterval(-1). ErrorLogger(zap.NewStdLog(rpLogger)). @@ -174,27 +176,32 @@ func (ps *ProxyState) Ready() bool { } func (ps *ProxyState) SetReady(ready bool) { - ps.ready.CompareAndSwap(false, true) + if !ps.Ready() && ready { + ps.logger.Info("Proxy state is ready", + zap.Duration("uptime", time.Since(ps.startTime)), + ) + } + ps.ready.Store(ready) } func (ps *ProxyState) Raft() *raft.Raft { - if ps.replicationEnabled { - return ps.replicationSettings.raft + if ps.raftEnabled { + return ps.raft } return nil } -func (ps *ProxyState) SetupRaft(r *raft.Raft) { +func (ps *ProxyState) SetupRaft(r *raft.Raft, client *raftadmin.Client) { ps.proxyLock.Lock() defer ps.proxyLock.Unlock() + ps.raft = r + ps.raftClient = client + oc := make(chan raft.Observation, 32) r.RegisterObserver(raft.NewObserver(oc, false, func(o *raft.Observation) bool { switch o.Data.(type) { - case - raft.FailedHeartbeatObservation, - raft.LeaderObservation, - raft.PeerObservation: + case raft.LeaderObservation, raft.PeerObservation: return true } return false @@ -218,6 +225,7 @@ func (ps *ProxyState) SetupRaft(r *raft.Raft) { ) } case raft.LeaderObservation: + ps.SetReady(true) logger.Info("leader observation", zap.String("leader_addr", string(ro.LeaderAddr)), zap.String("leader_id", string(ro.LeaderID)), @@ -226,52 +234,60 @@ func (ps *ProxyState) SetupRaft(r *raft.Raft) { } panic("raft observer channel closed") }() - ps.replicationSettings = NewProxyReplication(r) } func (ps *ProxyState) WaitForChanges(log *spec.ChangeLog) error { - if r := ps.Raft(); r != nil && log != nil { - waitTime := time.Second * 5 - if r.State() != raft.Leader { - return r.Barrier(waitTime).Error() + if r := ps.Raft(); r != nil { + waitTime := time.Second * 10 + if r.State() == raft.Leader { + err := r.Barrier(waitTime).Error() + if err != nil && log != nil { + ps.logger.Error("error waiting for changes", + zap.String("id", log.ID), + zap.Stringer("command", log.Cmd), + zap.Error(err), + ) + } + return err } else { - hasChanges := func() bool { - ps.proxyLock.RLock() - defer ps.proxyLock.RUnlock() - if ps.changeLogs != nil { - lastLog := ps.changeLogs[len(ps.changeLogs)-1] - if lastLog.ID >= log.ID { - return true - } + if leaderAddr := r.Leader(); leaderAddr != "" { + ctx, cancel := context.WithTimeout( + context.Background(), waitTime) + defer cancel() + retries := 0 + RETRY: + await, err := ps.raftClient.Barrier(ctx, r.Leader()) + if err == nil && await.Error != "" { + err = errors.New(await.Error) } - return false - } - timeout := time.After(waitTime) - backoff := time.Millisecond * 10 - multiplier := 1.8 - for { - select { - case <-timeout: - if hasChanges() { - return nil - } - return errors.New("timeout waiting for changes") - case <-time.After(backoff): - if hasChanges() { + if err != nil && log != nil { + ps.logger.Error("error waiting for changes", + zap.String("id", log.ID), + zap.Stringer("command", log.Cmd), + zap.Error(err), + ) + } + if len(ps.changeLogs) > 0 && retries < 5 { + if log.ID >= ps.changeLogs[len(ps.changeLogs)-1].ID { return nil } - backoff = min(time.Duration(float64(backoff)*multiplier), time.Millisecond*500) + retries++ + goto RETRY } + return err + } else { + return errors.New("no leader found") } } - } else { - ps.proxyLock.RLock() - defer ps.proxyLock.RUnlock() } return nil } +// ApplyChangeLog - apply change log to the proxy state func (ps *ProxyState) ApplyChangeLog(log *spec.ChangeLog) error { + if !ps.Ready() { + return errors.New("proxy state not ready") + } if r := ps.Raft(); r != nil { if r.State() != raft.Leader { return raft.ErrNotLeader @@ -287,15 +303,17 @@ func (ps *ProxyState) ApplyChangeLog(log *spec.ChangeLog) error { now := time.Now() future := r.ApplyLog(raftLog, time.Second*15) err = future.Error() - ps.logger.With(). - Debug("waiting for reply from raft", - zap.String("id", log.ID), - zap.Stringer("command", log.Cmd), - zap.Stringer("command", time.Since(now)), - zap.Uint64("index", future.Index()), - zap.Any("response", future.Response()), - zap.Error(err), - ) + if err != nil { + ps.logger.With(). + Error("error at ApplyLog", + zap.String("id", log.ID), + zap.Stringer("command", log.Cmd), + zap.Stringer("command", time.Since(now)), + zap.Uint64("index", future.Index()), + zap.Any("response", future.Response()), + zap.Error(err), + ) + } return err } else { return ps.processChangeLog(log, true, true) @@ -317,36 +335,27 @@ func (ps *ProxyState) SharedCache() cache.TCache { // restartState - restart state clears the state and reloads the configuration // this is useful for rollbacks when broken changes are made. func (ps *ProxyState) restartState(fn func(error)) { + ps.logger.Info("Attempting to restart state...") ps.proxyLock.Lock() defer ps.proxyLock.Unlock() - - ps.logger.Info("Attempting to restart state...") - + ps.changeHash.Store(0) + ps.pendingChanges = false ps.rm.Empty() ps.modPrograms.Clear() ps.providers.Clear() ps.routers.Clear() ps.sharedCache.Clear() - ps.Scheduler().Stop() + ps.skdr.Stop() if err := ps.initConfigResources(ps.config.ProxyConfig.InitResources); err != nil { - fn(err) + go fn(err) return } - if ps.replicationEnabled { - raft := ps.Raft() - err := raft.ReloadConfig(raft.ReloadableConfig()) - if err != nil { - fn(err) - return - } - } else { - if err := ps.restoreFromChangeLogs(true); err != nil { - fn(err) - return - } + if err := ps.restoreFromChangeLogs(true); err != nil { + go fn(err) + return } ps.logger.Info("State successfully restarted") - fn(nil) + go fn(nil) } // ReloadState - reload state checks the change logs to see if a reload is required, @@ -464,6 +473,9 @@ func (ps *ProxyState) getDomainCertificate(domain string) (*tls.Certificate, err } func (ps *ProxyState) initConfigResources(resources *config.DGateResources) error { + processCL := func(cl *spec.ChangeLog) error { + return ps.processChangeLog(cl, false, false) + } if resources != nil { numChanges, err := resources.Validate() if err != nil { @@ -472,15 +484,14 @@ func (ps *ProxyState) initConfigResources(resources *config.DGateResources) erro if numChanges > 0 { defer func() { if err != nil { - err = ps.processChangeLog(nil, false, false) + err = processCL(nil) } }() } ps.logger.Info("Initializing resources") for _, ns := range resources.Namespaces { cl := spec.NewChangeLog(&ns, ns.Name, spec.AddNamespaceCommand) - err := ps.processChangeLog(cl, false, false) - if err != nil { + if err := processCL(cl); err != nil { return err } } @@ -498,22 +509,19 @@ func (ps *ProxyState) initConfigResources(resources *config.DGateResources) erro ) } cl := spec.NewChangeLog(&mod.Module, mod.NamespaceName, spec.AddModuleCommand) - err := ps.processChangeLog(cl, false, false) - if err != nil { + if err := processCL(cl); err != nil { return err } } for _, svc := range resources.Services { cl := spec.NewChangeLog(&svc, svc.NamespaceName, spec.AddServiceCommand) - err := ps.processChangeLog(cl, false, false) - if err != nil { + if err := processCL(cl); err != nil { return err } } for _, rt := range resources.Routes { cl := spec.NewChangeLog(&rt, rt.NamespaceName, spec.AddRouteCommand) - err := ps.processChangeLog(cl, false, false) - if err != nil { + if err := processCL(cl); err != nil { return err } } @@ -533,22 +541,19 @@ func (ps *ProxyState) initConfigResources(resources *config.DGateResources) erro dom.Key = string(key) } cl := spec.NewChangeLog(&dom.Domain, dom.NamespaceName, spec.AddDomainCommand) - err := ps.processChangeLog(cl, false, false) - if err != nil { + if err := processCL(cl); err != nil { return err } } for _, col := range resources.Collections { cl := spec.NewChangeLog(&col, col.NamespaceName, spec.AddCollectionCommand) - err := ps.processChangeLog(cl, false, false) - if err != nil { + if err := processCL(cl); err != nil { return err } } for _, doc := range resources.Documents { cl := spec.NewChangeLog(&doc, doc.NamespaceName, spec.AddDocumentCommand) - err := ps.processChangeLog(cl, false, false) - if err != nil { + if err := processCL(cl); err != nil { return err } } diff --git a/internal/proxy/proxy_state_test.go b/internal/proxy/proxy_state_test.go index 7cff142..f76cc2e 100644 --- a/internal/proxy/proxy_state_test.go +++ b/internal/proxy/proxy_state_test.go @@ -10,6 +10,7 @@ import ( "github.com/dgate-io/dgate/internal/proxy" "github.com/dgate-io/dgate/pkg/spec" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.uber.org/zap" ) @@ -416,7 +417,6 @@ func TestProcessChangeLog_Document(t *testing.T) { if err := ps.Store().InitStore(); err != nil { t.Fatal(err) } - c := &spec.Collection{ Name: "test123", NamespaceName: "test", @@ -439,7 +439,7 @@ func TestProcessChangeLog_Document(t *testing.T) { } cl = spec.NewChangeLog(d, d.NamespaceName, spec.AddDocumentCommand) - err = ps.ProcessChangeLog(cl, false) + err = ps.ProcessChangeLog(cl, true) if !assert.Nil(t, err, "error should be nil") { return } @@ -449,7 +449,7 @@ func TestProcessChangeLog_Document(t *testing.T) { if !assert.Nil(t, err, "error should be nil") { return } - assert.Equal(t, 1, len(documents), "should have 1 item") + require.Equal(t, 1, len(documents), "should have 1 item") assert.Equal(t, d.ID, documents[0].ID, "should have the same id") assert.Equal(t, d.NamespaceName, documents[0].NamespaceName, "should have the same namespace") assert.Equal(t, d.CollectionName, documents[0].CollectionName, "should have the same collection") diff --git a/pkg/raftadmin/raftadmin_client.go b/pkg/raftadmin/client.go similarity index 72% rename from pkg/raftadmin/raftadmin_client.go rename to pkg/raftadmin/client.go index 905dfd2..e92041e 100644 --- a/pkg/raftadmin/raftadmin_client.go +++ b/pkg/raftadmin/client.go @@ -7,8 +7,6 @@ import ( "errors" "fmt" "net/http" - "strings" - "time" "github.com/hashicorp/raft" "go.uber.org/zap" @@ -16,39 +14,32 @@ import ( type Doer func(*http.Request) (*http.Response, error) -type HTTPAdminClient struct { +type Client struct { do Doer - urlFmt string + scheme string logger *zap.Logger } -func NewHTTPAdminClient(doer Doer, urlFmt string, logger *zap.Logger) *HTTPAdminClient { +func NewClient(doer Doer, logger *zap.Logger, scheme string) *Client { if doer == nil { doer = http.DefaultClient.Do } - if urlFmt == "" { - urlFmt = "http://(address)/raftadmin/" - } else { - if !strings.Contains(urlFmt, "(address)") { - panic("urlFmt must contain the string '(address)'") - } - if !strings.HasSuffix(urlFmt, "/") { - urlFmt += "/" - } + if scheme == "" { + scheme = "http" } - return &HTTPAdminClient{ + return &Client{ do: doer, - urlFmt: urlFmt, + scheme: scheme, logger: logger, } } -func (c *HTTPAdminClient) generateUrl(target raft.ServerAddress, action string) string { - return strings.ReplaceAll(c.urlFmt+action, - "(address)", string(target)) +func (c *Client) generateUrl(target raft.ServerAddress, action string) string { + uri := fmt.Sprintf("%s://%s/raftadmin/%s", c.scheme, target, action) + // c.logger.Debug("raftadmin: generated url", zap.String("url", uri)) + return uri } - -func (c *HTTPAdminClient) AddNonvoter(ctx context.Context, target raft.ServerAddress, req *AddNonvoterRequest) (*AwaitResponse, error) { +func (c *Client) AddNonvoter(ctx context.Context, target raft.ServerAddress, req *AddNonvoterRequest) (*AwaitResponse, error) { url := c.generateUrl(target, "AddNonvoter") buf, err := json.Marshal(req) if err != nil { @@ -58,7 +49,7 @@ func (c *HTTPAdminClient) AddNonvoter(ctx context.Context, target raft.ServerAdd if err != nil { return nil, err } - res, err := c.clientRetry(r) + res, err := c.clientRetry(ctx, r) if err != nil { return nil, err } @@ -74,7 +65,7 @@ func (c *HTTPAdminClient) AddNonvoter(ctx context.Context, target raft.ServerAdd return &out, nil } -func (c *HTTPAdminClient) AddVoter(ctx context.Context, target raft.ServerAddress, req *AddVoterRequest) (*AwaitResponse, error) { +func (c *Client) AddVoter(ctx context.Context, target raft.ServerAddress, req *AddVoterRequest) (*AwaitResponse, error) { url := c.generateUrl(target, "AddVoter") buf, err := json.Marshal(req) if err != nil { @@ -84,7 +75,7 @@ func (c *HTTPAdminClient) AddVoter(ctx context.Context, target raft.ServerAddres if err != nil { return nil, err } - res, err := c.clientRetry(r) + res, err := c.clientRetry(ctx, r) if err != nil { return nil, err } @@ -100,13 +91,13 @@ func (c *HTTPAdminClient) AddVoter(ctx context.Context, target raft.ServerAddres return &out, nil } -func (c *HTTPAdminClient) AppliedIndex(ctx context.Context, target raft.ServerAddress) (*AppliedIndexResponse, error) { +func (c *Client) AppliedIndex(ctx context.Context, target raft.ServerAddress) (*AppliedIndexResponse, error) { url := c.generateUrl(target, "AppliedIndex") r, err := http.NewRequest("GET", url, nil) if err != nil { return nil, err } - res, err := c.clientRetry(r) + res, err := c.clientRetry(ctx, r) if err != nil { return nil, err } @@ -122,7 +113,7 @@ func (c *HTTPAdminClient) AppliedIndex(ctx context.Context, target raft.ServerAd return &out, nil } -func (c *HTTPAdminClient) ApplyLog(ctx context.Context, target raft.ServerAddress, req *ApplyLogRequest) (*AwaitResponse, error) { +func (c *Client) ApplyLog(ctx context.Context, target raft.ServerAddress, req *ApplyLogRequest) (*AwaitResponse, error) { url := c.generateUrl(target, "ApplyLog") buf, err := json.Marshal(req) if err != nil { @@ -132,7 +123,7 @@ func (c *HTTPAdminClient) ApplyLog(ctx context.Context, target raft.ServerAddres if err != nil { return nil, err } - res, err := c.clientRetry(r) + res, err := c.clientRetry(ctx, r) if err != nil { return nil, err } @@ -148,13 +139,13 @@ func (c *HTTPAdminClient) ApplyLog(ctx context.Context, target raft.ServerAddres return &out, nil } -func (c *HTTPAdminClient) Barrier(ctx context.Context, target raft.ServerAddress) (*AwaitResponse, error) { +func (c *Client) Barrier(ctx context.Context, target raft.ServerAddress) (*AwaitResponse, error) { url := c.generateUrl(target, "Barrier") r, err := http.NewRequest("POST", url, nil) if err != nil { return nil, err } - res, err := c.clientRetry(r) + res, err := c.clientRetry(ctx, r) if err != nil { return nil, err } @@ -170,7 +161,7 @@ func (c *HTTPAdminClient) Barrier(ctx context.Context, target raft.ServerAddress return &out, nil } -func (c *HTTPAdminClient) DemoteVoter(ctx context.Context, target raft.ServerAddress, req *DemoteVoterRequest) (*AwaitResponse, error) { +func (c *Client) DemoteVoter(ctx context.Context, target raft.ServerAddress, req *DemoteVoterRequest) (*AwaitResponse, error) { url := c.generateUrl(target, "DemoteVoter") buf, err := json.Marshal(req) if err != nil { @@ -180,7 +171,7 @@ func (c *HTTPAdminClient) DemoteVoter(ctx context.Context, target raft.ServerAdd if err != nil { return nil, err } - res, err := c.clientRetry(r) + res, err := c.clientRetry(ctx, r) if err != nil { return nil, err } @@ -196,13 +187,13 @@ func (c *HTTPAdminClient) DemoteVoter(ctx context.Context, target raft.ServerAdd return &out, nil } -func (c *HTTPAdminClient) GetConfiguration(ctx context.Context, target raft.ServerAddress) (*GetConfigurationResponse, error) { +func (c *Client) GetConfiguration(ctx context.Context, target raft.ServerAddress) (*GetConfigurationResponse, error) { url := c.generateUrl(target, "GetConfiguration") r, err := http.NewRequest("GET", url, nil) if err != nil { return nil, err } - res, err := c.clientRetry(r) + res, err := c.clientRetry(ctx, r) if err != nil { return nil, err } @@ -218,13 +209,13 @@ func (c *HTTPAdminClient) GetConfiguration(ctx context.Context, target raft.Serv return &out, nil } -func (c *HTTPAdminClient) LastContact(ctx context.Context, target raft.ServerAddress) (*LastContactResponse, error) { +func (c *Client) LastContact(ctx context.Context, target raft.ServerAddress) (*LastContactResponse, error) { url := c.generateUrl(target, "LastContact") r, err := http.NewRequest("GET", url, nil) if err != nil { return nil, err } - res, err := c.clientRetry(r) + res, err := c.clientRetry(ctx, r) if err != nil { return nil, err } @@ -240,13 +231,13 @@ func (c *HTTPAdminClient) LastContact(ctx context.Context, target raft.ServerAdd return &out, nil } -func (c *HTTPAdminClient) LastIndex(ctx context.Context, target raft.ServerAddress) (*LastIndexResponse, error) { +func (c *Client) LastIndex(ctx context.Context, target raft.ServerAddress) (*LastIndexResponse, error) { url := c.generateUrl(target, "LastIndex") r, err := http.NewRequest("GET", url, nil) if err != nil { return nil, err } - res, err := c.clientRetry(r) + res, err := c.clientRetry(ctx, r) if err != nil { return nil, err } @@ -264,13 +255,13 @@ func (c *HTTPAdminClient) LastIndex(ctx context.Context, target raft.ServerAddre var ErrNotLeader = errors.New("not leader") -func (c *HTTPAdminClient) Leader(ctx context.Context, target raft.ServerAddress) (*LeaderResponse, error) { +func (c *Client) Leader(ctx context.Context, target raft.ServerAddress) (*LeaderResponse, error) { url := c.generateUrl(target, "Leader") r, err := http.NewRequest("GET", url, nil) if err != nil { return nil, err } - res, err := c.clientRetry(r) + res, err := c.clientRetry(ctx, r) if err != nil { return nil, err } @@ -290,7 +281,7 @@ func (c *HTTPAdminClient) Leader(ctx context.Context, target raft.ServerAddress) } } -func (c *HTTPAdminClient) LeadershipTransfer(ctx context.Context, target raft.ServerAddress, req *LeadershipTransferToServerRequest) (*AwaitResponse, error) { +func (c *Client) LeadershipTransfer(ctx context.Context, target raft.ServerAddress, req *LeadershipTransferToServerRequest) (*AwaitResponse, error) { url := c.generateUrl(target, "LeadershipTransfer") buf, err := json.Marshal(req) if err != nil { @@ -300,7 +291,7 @@ func (c *HTTPAdminClient) LeadershipTransfer(ctx context.Context, target raft.Se if err != nil { return nil, err } - res, err := c.clientRetry(r) + res, err := c.clientRetry(ctx, r) if err != nil { return nil, err } @@ -320,7 +311,7 @@ func (c *HTTPAdminClient) LeadershipTransfer(ctx context.Context, target raft.Se } } -func (c *HTTPAdminClient) RemoveServer(ctx context.Context, target raft.ServerAddress, req *RemoveServerRequest) (*AwaitResponse, error) { +func (c *Client) RemoveServer(ctx context.Context, target raft.ServerAddress, req *RemoveServerRequest) (*AwaitResponse, error) { url := c.generateUrl(target, "RemoveServer") buf, err := json.Marshal(req) if err != nil { @@ -330,7 +321,7 @@ func (c *HTTPAdminClient) RemoveServer(ctx context.Context, target raft.ServerAd if err != nil { return nil, err } - res, err := c.clientRetry(r) + res, err := c.clientRetry(ctx, r) if err != nil { return nil, err } @@ -350,13 +341,13 @@ func (c *HTTPAdminClient) RemoveServer(ctx context.Context, target raft.ServerAd } } -func (c *HTTPAdminClient) Shutdown(ctx context.Context, target raft.ServerAddress) (*AwaitResponse, error) { +func (c *Client) Shutdown(ctx context.Context, target raft.ServerAddress) (*AwaitResponse, error) { url := c.generateUrl(target, "Shutdown") r, err := http.NewRequest("POST", url, nil) if err != nil { return nil, err } - res, err := c.clientRetry(r) + res, err := c.clientRetry(ctx, r) if err != nil { return nil, err } @@ -376,13 +367,13 @@ func (c *HTTPAdminClient) Shutdown(ctx context.Context, target raft.ServerAddres } } -func (c *HTTPAdminClient) State(ctx context.Context, target raft.ServerAddress) (*StateResponse, error) { +func (c *Client) State(ctx context.Context, target raft.ServerAddress) (*StateResponse, error) { url := c.generateUrl(target, "State") r, err := http.NewRequest("GET", url, nil) if err != nil { return nil, err } - res, err := c.clientRetry(r) + res, err := c.clientRetry(ctx, r) if err != nil { return nil, err } @@ -402,13 +393,13 @@ func (c *HTTPAdminClient) State(ctx context.Context, target raft.ServerAddress) } } -func (c *HTTPAdminClient) Stats(ctx context.Context, target raft.ServerAddress) (*StatsResponse, error) { +func (c *Client) Stats(ctx context.Context, target raft.ServerAddress) (*StatsResponse, error) { url := c.generateUrl(target, "Stats") r, err := http.NewRequest("POST", url, nil) if err != nil { return nil, err } - res, err := c.clientRetry(r) + res, err := c.clientRetry(ctx, r) if err != nil { return nil, err } @@ -428,13 +419,13 @@ func (c *HTTPAdminClient) Stats(ctx context.Context, target raft.ServerAddress) } } -func (c *HTTPAdminClient) VerifyLeader(ctx context.Context, target raft.ServerAddress) error { +func (c *Client) VerifyLeader(ctx context.Context, target raft.ServerAddress) error { url := c.generateUrl(target, "VerifyLeader") r, err := http.NewRequest("POST", url, nil) if err != nil { return err } - res, err := c.clientRetry(r) + res, err := c.clientRetry(ctx, r) if err != nil { return err } @@ -449,15 +440,17 @@ func (c *HTTPAdminClient) VerifyLeader(ctx context.Context, target raft.ServerAd } } -func (c *HTTPAdminClient) clientRetry(r *http.Request) (*http.Response, error) { +func (c *Client) clientRetry(ctx context.Context, r *http.Request) (*http.Response, error) { retries := 0 + r = r.WithContext(ctx) RETRY: res, err := c.do(r) if err != nil { - if retries > 5 { + if retries > 3 { return nil, err + } else if ctx.Err() != nil { + return nil, ctx.Err() } - <-time.After(1 * time.Second) retries++ goto RETRY } diff --git a/pkg/raftadmin/raftadmin.go b/pkg/raftadmin/server.go similarity index 82% rename from pkg/raftadmin/raftadmin.go rename to pkg/raftadmin/server.go index 168424b..5d43ecb 100644 --- a/pkg/raftadmin/raftadmin.go +++ b/pkg/raftadmin/server.go @@ -16,20 +16,19 @@ import ( "go.uber.org/zap" ) -// RaftAdminHTTPServer provides a HTTP-based transport that can be used to +// Server provides a HTTP-based transport that can be used to // communicate with Raft on remote machines. It is convenient to use if your // application is an HTTP server already and you do not want to use multiple // different transports (if not, you can use raft.NetworkTransport). -type RaftAdminHTTPServer struct { +type Server struct { logger *zap.Logger r *raft.Raft - // addrs map[raft.ServerID]raft.ServerAddress addrs []raft.ServerAddress } -// NewRaftAdminHTTPServer creates a new HTTP transport on the given addr. -func NewRaftAdminHTTPServer(r *raft.Raft, logger *zap.Logger, addrs []raft.ServerAddress) *RaftAdminHTTPServer { - return &RaftAdminHTTPServer{ +// NewServer creates a new HTTP transport on the given addr. +func NewServer(r *raft.Raft, logger *zap.Logger, addrs []raft.ServerAddress) *Server { + return &Server{ logger: logger, r: r, addrs: addrs, @@ -71,7 +70,7 @@ func toFuture(f raft.Future) (*Future, error) { }, nil } -func (a *RaftAdminHTTPServer) Await(ctx context.Context, req *Future) (*AwaitResponse, error) { +func (a *Server) Await(ctx context.Context, req *Future) (*AwaitResponse, error) { mtx.Lock() f, ok := operations[req.OperationToken] defer func() { @@ -107,7 +106,7 @@ func (a *RaftAdminHTTPServer) Await(ctx context.Context, req *Future) (*AwaitRes return r, nil } -func (a *RaftAdminHTTPServer) Forget(ctx context.Context, req *Future) (*ForgetResponse, error) { +func (a *Server) Forget(ctx context.Context, req *Future) (*ForgetResponse, error) { mtx.Lock() delete(operations, req.OperationToken) mtx.Unlock() @@ -116,29 +115,29 @@ func (a *RaftAdminHTTPServer) Forget(ctx context.Context, req *Future) (*ForgetR }, nil } -func (a *RaftAdminHTTPServer) AddNonvoter(ctx context.Context, req *AddNonvoterRequest) (*Future, error) { +func (a *Server) AddNonvoter(ctx context.Context, req *AddNonvoterRequest) (*Future, error) { return toFuture(a.r.AddNonvoter(raft.ServerID(req.ID), raft.ServerAddress(req.Address), uint64(req.PrevIndex), timeout(ctx))) } -func (a *RaftAdminHTTPServer) AddVoter(ctx context.Context, req *AddVoterRequest) (*Future, error) { +func (a *Server) AddVoter(ctx context.Context, req *AddVoterRequest) (*Future, error) { return toFuture(a.r.AddVoter(raft.ServerID(req.ID), raft.ServerAddress(req.Address), uint64(req.PrevIndex), timeout(ctx))) } -func (a *RaftAdminHTTPServer) AppliedIndex(ctx context.Context) (*AppliedIndexResponse, error) { +func (a *Server) AppliedIndex(ctx context.Context) (*AppliedIndexResponse, error) { return &AppliedIndexResponse{ Index: a.r.AppliedIndex(), }, nil } -func (a *RaftAdminHTTPServer) Barrier(ctx context.Context) (*Future, error) { +func (a *Server) Barrier(ctx context.Context) (*Future, error) { return toFuture(a.r.Barrier(timeout(ctx))) } -func (a *RaftAdminHTTPServer) DemoteVoter(ctx context.Context, req *DemoteVoterRequest) (*Future, error) { +func (a *Server) DemoteVoter(ctx context.Context, req *DemoteVoterRequest) (*Future, error) { return toFuture(a.r.DemoteVoter(raft.ServerID(req.ID), req.PrevIndex, timeout(ctx))) } -func (a *RaftAdminHTTPServer) GetConfiguration(ctx context.Context) (*GetConfigurationResponse, error) { +func (a *Server) GetConfiguration(ctx context.Context) (*GetConfigurationResponse, error) { f := a.r.GetConfiguration() if err := f.Error(); err != nil { return nil, err @@ -162,24 +161,24 @@ func (a *RaftAdminHTTPServer) GetConfiguration(ctx context.Context) (*GetConfigu return resp, nil } -func (a *RaftAdminHTTPServer) LastContact(ctx context.Context) (*LastContactResponse, error) { +func (a *Server) LastContact(ctx context.Context) (*LastContactResponse, error) { t := a.r.LastContact() return &LastContactResponse{ UnixNano: t.UnixNano(), }, nil } -func (a *RaftAdminHTTPServer) LastIndex(ctx context.Context) (*LastIndexResponse, error) { +func (a *Server) LastIndex(ctx context.Context) (*LastIndexResponse, error) { return &LastIndexResponse{ Index: a.r.LastIndex(), }, nil } -func (a *RaftAdminHTTPServer) CurrentNodeIsLeader(ctx context.Context) bool { +func (a *Server) CurrentNodeIsLeader(ctx context.Context) bool { return a.r.State() == raft.Leader } -func (a *RaftAdminHTTPServer) Leader(ctx context.Context) (*LeaderResponse, error) { +func (a *Server) Leader(ctx context.Context) (*LeaderResponse, error) { for _, s := range a.r.GetConfiguration().Configuration().Servers { if s.Suffrage == raft.Voter && s.Address == a.r.Leader() { return &LeaderResponse{ @@ -193,27 +192,27 @@ func (a *RaftAdminHTTPServer) Leader(ctx context.Context) (*LeaderResponse, erro }, nil } -func (a *RaftAdminHTTPServer) LeadershipTransfer(ctx context.Context) (*Future, error) { +func (a *Server) LeadershipTransfer(ctx context.Context) (*Future, error) { return toFuture(a.r.LeadershipTransfer()) } -func (a *RaftAdminHTTPServer) LeadershipTransferToServer(ctx context.Context, req *LeadershipTransferToServerRequest) (*Future, error) { +func (a *Server) LeadershipTransferToServer(ctx context.Context, req *LeadershipTransferToServerRequest) (*Future, error) { return toFuture(a.r.LeadershipTransferToServer(raft.ServerID(req.ID), raft.ServerAddress(req.Address))) } -func (a *RaftAdminHTTPServer) RemoveServer(ctx context.Context, req *RemoveServerRequest) (*Future, error) { +func (a *Server) RemoveServer(ctx context.Context, req *RemoveServerRequest) (*Future, error) { return toFuture(a.r.RemoveServer(raft.ServerID(req.ID), req.PrevIndex, timeout(ctx))) } -func (a *RaftAdminHTTPServer) Shutdown(ctx context.Context) (*Future, error) { +func (a *Server) Shutdown(ctx context.Context) (*Future, error) { return toFuture(a.r.Shutdown()) } -func (a *RaftAdminHTTPServer) Snapshot(ctx context.Context) (*Future, error) { +func (a *Server) Snapshot(ctx context.Context) (*Future, error) { return toFuture(a.r.Snapshot()) } -func (a *RaftAdminHTTPServer) State(ctx context.Context) (*StateResponse, error) { +func (a *Server) State(ctx context.Context) (*StateResponse, error) { switch s := a.r.State(); s { case raft.Follower: return &StateResponse{State: RaftStateFollower}, nil @@ -228,7 +227,7 @@ func (a *RaftAdminHTTPServer) State(ctx context.Context) (*StateResponse, error) } } -func (a *RaftAdminHTTPServer) Stats(ctx context.Context) (*StatsResponse, error) { +func (a *Server) Stats(ctx context.Context) (*StatsResponse, error) { ret := &StatsResponse{} ret.Stats = map[string]string{} for k, v := range a.r.Stats() { @@ -237,12 +236,12 @@ func (a *RaftAdminHTTPServer) Stats(ctx context.Context) (*StatsResponse, error) return ret, nil } -func (a *RaftAdminHTTPServer) VerifyLeader(ctx context.Context) (*Future, error) { +func (a *Server) VerifyLeader(ctx context.Context) (*Future, error) { return toFuture(a.r.VerifyLeader()) } // ServeHTTP implements the net/http.Handler interface, so that you can use -func (t *RaftAdminHTTPServer) ServeHTTP(res http.ResponseWriter, req *http.Request) { +func (t *Server) ServeHTTP(res http.ResponseWriter, req *http.Request) { cmd := path.Base(req.URL.Path) if cmdRequiresLeader(cmd) && t.r.State() != raft.Leader { @@ -467,7 +466,7 @@ func cmdRequiresLeader(cmd string) bool { } } -func (t *RaftAdminHTTPServer) genericResponse(req *http.Request, res http.ResponseWriter, f *Future, cmd string) { +func (t *Server) genericResponse(req *http.Request, res http.ResponseWriter, f *Future, cmd string) { resp, err := t.Await(req.Context(), f) if err != nil { http.Error(res, err.Error(), http.StatusInternalServerError) diff --git a/pkg/raftadmin/raftadmin_test.go b/pkg/raftadmin/server_test.go similarity index 97% rename from pkg/raftadmin/raftadmin_test.go rename to pkg/raftadmin/server_test.go index b4eca87..44621e6 100644 --- a/pkg/raftadmin/raftadmin_test.go +++ b/pkg/raftadmin/server_test.go @@ -129,7 +129,7 @@ func setupRaftAdmin(t *testing.T) *httptest.Server { } <-time.After(time.Second * 5) - raftAdmin := NewRaftAdminHTTPServer( + raftAdmin := NewServer( raftNode, zap.NewNop(), []raft.ServerAddress{ "localhost:9090", @@ -169,10 +169,9 @@ func TestRaft(t *testing.T) { Return(mockClient.res, nil) ctx := context.Background() - client := NewHTTPAdminClient( + client := NewClient( server.Client().Do, - "http://(address)/raftadmin", - zap.NewNop(), + zap.NewNop(), "http", ) serverAddr := raft.ServerAddress(server.Listener.Addr().String()) leader, err := client.Leader(ctx, serverAddr) diff --git a/pkg/rafthttp/rafthttp.go b/pkg/rafthttp/rafthttp.go index af28bdc..17e9d9b 100644 --- a/pkg/rafthttp/rafthttp.go +++ b/pkg/rafthttp/rafthttp.go @@ -32,28 +32,25 @@ type HTTPTransport struct { consumer chan raft.RPC addr raft.ServerAddress client Doer - urlFmt string + scheme string } var _ raft.Transport = (*HTTPTransport)(nil) var _ raft.WithPreVote = (*HTTPTransport)(nil) -func NewHTTPTransport(addr raft.ServerAddress, client Doer, logger *zap.Logger, urlFmt string) *HTTPTransport { +func NewHTTPTransport(addr raft.ServerAddress, client Doer, logger *zap.Logger, scheme string) *HTTPTransport { if client == nil { client = http.DefaultClient } - if !strings.Contains(urlFmt, "(address)") { - panic("urlFmt must contain the string '(address)'") - } - if !strings.HasSuffix(urlFmt, "/") { - urlFmt += "/" + if scheme == "" { + scheme = "http" } return &HTTPTransport{ logger: logger, consumer: make(chan raft.RPC), addr: addr, client: client, - urlFmt: urlFmt, + scheme: scheme, } } @@ -100,8 +97,9 @@ RETRY: } func (t *HTTPTransport) generateUrl(target raft.ServerAddress, action string) string { - return strings.ReplaceAll(t.urlFmt+action, - "(address)", string(target)) + uri := fmt.Sprintf("%s://%s/raft/%s", t.scheme, target, action) + // t.logger.Debug("rafthttp: generated url", zap.String("url", uri)) + return uri } // Consumer implements the raft.Transport interface. diff --git a/pkg/rafthttp/rafthttp_test.go b/pkg/rafthttp/rafthttp_test.go index 6656096..b1cfe40 100644 --- a/pkg/rafthttp/rafthttp_test.go +++ b/pkg/rafthttp/rafthttp_test.go @@ -44,8 +44,8 @@ func TestExample(t *testing.T) { log.Printf("Listening on %s", ln.Addr().String()) srvAddr := raft.ServerAddress(ln.Addr().String()) transport := rafthttp.NewHTTPTransport( - srvAddr, http.DefaultClient, zap.NewNop(), - "http://(address)/raft", + srvAddr, http.DefaultClient, + zap.NewNop(), "http", ) srv := &http.Server{ Handler: transport, diff --git a/pkg/util/queue/queue.go b/pkg/util/queue/queue.go index 348c74d..a8988e7 100644 --- a/pkg/util/queue/queue.go +++ b/pkg/util/queue/queue.go @@ -17,7 +17,14 @@ type queueImpl[V any] struct { } // New returns a new queue. -func New[V any]() Queue[V] { +func New[V any](vs ...V) Queue[V] { + if len(vs) > 0 { + q := newQueue[V](len(vs)) + for _, v := range vs { + q.Push(v) + } + return q + } return newQueue[V](128) } diff --git a/pkg/util/sliceutil/slice.go b/pkg/util/sliceutil/slice.go index 4f7eae9..1775a4a 100644 --- a/pkg/util/sliceutil/slice.go +++ b/pkg/util/sliceutil/slice.go @@ -68,20 +68,19 @@ func SliceCopy[T any](arr []T) []T { return append([]T(nil), arr...) } - // BinarySearch searches for a value in a sorted slice and returns the index of the value. // If the value is not found, it returns -1 -func BinarySearch[T any](slice []T, val T, less func(T, T) bool) int { +func BinarySearch[T any](slice []T, val T, compare func(T, T) int) int { low, high := 0, len(slice)-1 for low <= high { mid := low + (high-low)/2 - if less(slice[mid], val) { - low = mid + 1 - } else if less(val, slice[mid]) { + if i := compare(slice[mid], val); i == 0 { + return mid + } else if i > 0 { high = mid - 1 } else { - return mid + low = mid + 1 } } return -1 -} \ No newline at end of file +} diff --git a/pkg/util/sliceutil/slice_test.go b/pkg/util/sliceutil/slice_test.go new file mode 100644 index 0000000..91b7858 --- /dev/null +++ b/pkg/util/sliceutil/slice_test.go @@ -0,0 +1,114 @@ +package sliceutil_test + +import ( + "testing" + + "github.com/dgate-io/dgate/pkg/util/sliceutil" +) + +func TestBinarySearch(t *testing.T) { + tests := []struct { + name string + items []int + search int + expected int + iterations int + }{ + { + name: "empty", + items: []int{}, + search: 1, + expected: -1, + iterations: 0, + }, + { + name: "not found/1", + items: []int{1, 3, 5, 7, 9}, + search: 6, + expected: -1, + iterations: 2, + }, + { + name: "not found/2", + items: []int{1, 3, 5, 7, 9}, + search: 10, + expected: -1, + iterations: 3, + }, + { + name: "not found/3", + search: 6, + expected: -1, + iterations: 4, + items: []int{ + 1, 2, 3, 4, 5, + 7, 7, 8, 9, 10, + 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, + }, + }, + { + name: "found/1", + items: []int{1, 2, 3, 4, 5}, + search: 4, + expected: 3, + iterations: 2, + }, + { + name: "found/2", + search: 13, + expected: 12, + iterations: 4, + items: []int{ + 1, 2, 3, 4, 5, + 7, 7, 8, 9, 10, + 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + iters := 0 + actual := sliceutil.BinarySearch(tt.items, tt.search, func(a, b int) int { + iters++ + return a - b + }) + if actual != tt.expected { + t.Errorf("expected %d, got %d", tt.expected, actual) + } + if iters != tt.iterations { + t.Errorf("expected %d iterations, got %d", tt.iterations, iters) + } + }) + } +} + +func BenchmarkCompareLinearAndBinarySearch(b *testing.B) { + items := make([]int, 1000000) + for i := range items { + items[i] = i + } + + b.Run("linear", func(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, item := range items { + if item == 999999 { + break + } + } + } + }) + + b.Run("binary", func(b *testing.B) { + for i := 0; i < b.N; i++ { + opts := 0 + sliceutil.BinarySearch(items, 999999, func(a, b int) int { + opts++ + return a - b + }) + b.ReportMetric(float64(opts), "opts/op") + } + }) +}