From e8892f18ca35e05325bf11351f4afd9173ecfe7a Mon Sep 17 00:00:00 2001 From: Peter Csajtai Date: Sun, 24 Mar 2024 01:24:57 +0100 Subject: [PATCH] Move SDK handling behind a service --- Dockerfile | 2 +- diag/server_test.go | 9 +- diag/status/mware_test.go | 9 +- diag/status/status.go | 94 +++++++++++-------- diag/status/status_test.go | 90 ++++++++++++++---- grpc/flag_service.go | 20 ++-- grpc/flag_service_test.go | 57 +++++------ grpc/mware.go | 8 +- grpc/mware_test.go | 11 ++- grpc/server.go | 4 +- grpc/server_test.go | 31 +++--- internal/testutils/utils.go | 55 ++++++++--- main.go | 25 ++--- .../{usr_attrs_test.go => user_attrs_test.go} | 0 sdk/sdk_registrar.go | 50 ++++++++++ sdk/sdk_registrar_test.go | 51 ++++++++++ sdk/sdk_test.go | 2 +- sdk/store/cache/cache_notify_test.go | 36 +++---- sdk/store/cache/cache_test.go | 2 +- sdk/store/file/file_test.go | 18 ++-- stream/benchmark_test.go | 8 +- stream/load_test.go | 10 +- stream/server.go | 4 +- stream/server_test.go | 4 +- web/api/api.go | 18 ++-- web/api/api_test.go | 63 ++++--------- web/cdnproxy/cdnproxy.go | 18 ++-- web/cdnproxy/cdnproxy_test.go | 38 +++----- web/router.go | 26 ++--- web/router_api_test.go | 4 +- web/router_cdnproxy_test.go | 4 +- web/router_sse_test.go | 4 +- web/router_status_test.go | 25 +---- web/router_webhook_test.go | 4 +- web/sse/sse.go | 4 +- web/sse/sse_test.go | 13 +-- web/webhook/webhook.go | 14 +-- web/webhook/webhook_test.go | 49 ++++------ 38 files changed, 496 insertions(+), 388 deletions(-) rename model/{usr_attrs_test.go => user_attrs_test.go} (100%) create mode 100644 sdk/sdk_registrar.go create mode 100644 sdk/sdk_registrar_test.go diff --git a/Dockerfile b/Dockerfile index 4ba36b5..a565d23 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.21-alpine3.18 AS build +FROM golang:1.22.1-alpine3.19 AS build WORKDIR /go/src/configcat_proxy diff --git a/diag/server_test.go b/diag/server_test.go index 9f68a5b..a267958 100644 --- a/diag/server_test.go +++ b/diag/server_test.go @@ -19,7 +19,10 @@ func TestNewServer(t *testing.T) { Status: config.StatusConfig{Enabled: true}, Metrics: config.MetricsConfig{Enabled: true}, } - srv := NewServer(&conf, status.NewReporter(&config.Config{SDKs: map[string]*config.SDKConfig{"sdk": {Key: "key"}}}), metrics.NewReporter(), log.NewNullLogger(), errChan) + + reporter := status.NewEmptyReporter() + reporter.RegisterSdk("test", &config.SDKConfig{Key: "key"}) + srv := NewServer(&conf, reporter, metrics.NewReporter(), log.NewNullLogger(), errChan) srv.Listen() time.Sleep(500 * time.Millisecond) @@ -47,7 +50,9 @@ func TestNewServer_NotEnabled(t *testing.T) { Metrics: config.MetricsConfig{Enabled: false}, } - srv := NewServer(&conf, status.NewReporter(&config.Config{SDKs: map[string]*config.SDKConfig{"sdk": {Key: "key"}}}), metrics.NewReporter(), log.NewNullLogger(), errChan) + reporter := status.NewEmptyReporter() + reporter.RegisterSdk("test", &config.SDKConfig{Key: "key"}) + srv := NewServer(&conf, reporter, metrics.NewReporter(), log.NewNullLogger(), errChan) srv.Listen() time.Sleep(500 * time.Millisecond) diff --git a/diag/status/mware_test.go b/diag/status/mware_test.go index 6a04200..9a402da 100644 --- a/diag/status/mware_test.go +++ b/diag/status/mware_test.go @@ -10,7 +10,8 @@ import ( func TestInterceptSdk(t *testing.T) { t.Run("ok", func(t *testing.T) { - reporter := NewReporter(&config.Config{SDKs: map[string]*config.SDKConfig{"test": {}}}).(*reporter) + reporter := NewEmptyReporter().(*reporter) + reporter.RegisterSdk("test", &config.SDKConfig{Key: "key"}) repSrv := httptest.NewServer(reporter.HttpHandler()) h := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { writer.WriteHeader(http.StatusOK) @@ -33,7 +34,8 @@ func TestInterceptSdk(t *testing.T) { assert.Equal(t, 0, len(stat.Cache.Records)) }) t.Run("not modified", func(t *testing.T) { - reporter := NewReporter(&config.Config{SDKs: map[string]*config.SDKConfig{"test": {}}}).(*reporter) + reporter := NewEmptyReporter().(*reporter) + reporter.RegisterSdk("test", &config.SDKConfig{Key: "key"}) repSrv := httptest.NewServer(reporter.HttpHandler()) h := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { writer.WriteHeader(http.StatusNotModified) @@ -56,7 +58,8 @@ func TestInterceptSdk(t *testing.T) { assert.Equal(t, 0, len(stat.Cache.Records)) }) t.Run("error", func(t *testing.T) { - reporter := NewReporter(&config.Config{SDKs: map[string]*config.SDKConfig{"test": {}}}).(*reporter) + reporter := NewEmptyReporter().(*reporter) + reporter.RegisterSdk("test", &config.SDKConfig{Key: "key"}) repSrv := httptest.NewServer(reporter.HttpHandler()) h := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { writer.WriteHeader(http.StatusBadRequest) diff --git a/diag/status/status.go b/diag/status/status.go index 176ad4d..3cd6730 100644 --- a/diag/status/status.go +++ b/diag/status/status.go @@ -34,6 +34,7 @@ const maxRecordCount = 5 const maxLastErrorsMeaningDegraded = 2 type Reporter interface { + RegisterSdk(sdkId string, conf *config.SDKConfig) ReportOk(component string, message string) ReportError(component string, message string) GetStatus() Status @@ -74,14 +75,14 @@ type reporter struct { records map[string][]record mu sync.RWMutex status Status - conf *config.Config + conf *config.CacheConfig } -func NewNullReporter() Reporter { - return &reporter{records: make(map[string][]record), conf: &config.Config{SDKs: map[string]*config.SDKConfig{}}} +func NewEmptyReporter() Reporter { + return NewReporter(&config.CacheConfig{}) } -func NewReporter(conf *config.Config) Reporter { +func NewReporter(conf *config.CacheConfig) Reporter { r := &reporter{ conf: conf, records: make(map[string][]record), @@ -90,44 +91,54 @@ func NewReporter(conf *config.Config) Reporter { Cache: CacheStatus{ Status: Initializing, }, + SDKs: map[string]*SdkStatus{}, }, } - r.status.SDKs = make(map[string]*SdkStatus, len(conf.SDKs)) - for key, sdk := range conf.SDKs { - status := &SdkStatus{ - Mode: Online, - SdkKey: utils.Obfuscate(sdk.Key, 5), - Source: SdkSourceStatus{ - Type: RemoteSrc, - Status: Initializing, - }, - } - r.status.SDKs[key] = status - if sdk.Offline.Enabled { - status.Mode = Offline - if sdk.Offline.Local.FilePath != "" { - status.Source.Type = FileSrc - r.status.Cache.Status = NA - } else { - status.Source.Type = CacheSrc - } - } - if !conf.Cache.IsSet() { + return r +} + +func (r *reporter) RegisterSdk(sdkId string, conf *config.SDKConfig) { + r.mu.Lock() + defer r.mu.Unlock() + + status := &SdkStatus{ + Mode: Online, + SdkKey: utils.Obfuscate(conf.Key, 5), + Source: SdkSourceStatus{ + Type: RemoteSrc, + Status: Initializing, + }, + } + r.status.SDKs[sdkId] = status + if conf.Offline.Enabled { + status.Mode = Offline + if conf.Offline.Local.FilePath != "" { + status.Source.Type = FileSrc r.status.Cache.Status = NA - if status.Source.Type == CacheSrc { - r.ReportError(key, "cache offline source enabled without a configured cache") - } + } else { + status.Source.Type = CacheSrc + } + } + if !r.conf.IsSet() { + r.status.Cache.Status = NA + if status.Source.Type == CacheSrc { + r.appendRecord(sdkId, "cache offline source enabled without a configured cache", true) } } - return r } func (r *reporter) ReportOk(component string, message string) { - r.appendRecord(component, "[ok] "+message, false) + r.mu.Lock() + defer r.mu.Unlock() + + r.appendRecord(component, message, false) } func (r *reporter) ReportError(component string, message string) { - r.appendRecord(component, "[error] "+message, true) + r.mu.Lock() + defer r.mu.Unlock() + + r.appendRecord(component, message, true) } func (r *reporter) HttpHandler() http.HandlerFunc { @@ -169,8 +180,11 @@ func (r *reporter) checkStatus(records []record) ([]string, HealthStatus) { } func (r *reporter) appendRecord(component string, message string, isError bool) { - r.mu.Lock() - defer r.mu.Unlock() + if isError { + message = "[error] " + message + } else { + message = "[ok] " + message + } recs, ok := r.records[component] if !ok { @@ -194,14 +208,12 @@ func (r *reporter) appendRecord(component string, message string, isError bool) allSdksDown := true hasDegradedSdk := false - for key := range r.conf.SDKs { - if sdk, ok := r.status.SDKs[key]; ok { - if sdk.Source.Status != Down { - allSdksDown = false - } - if sdk.Source.Status != Healthy { - hasDegradedSdk = true - } + for _, sdk := range r.status.SDKs { + if sdk.Source.Status != Down { + allSdksDown = false + } + if sdk.Source.Status != Healthy { + hasDegradedSdk = true } } if !hasDegradedSdk && !allSdksDown { diff --git a/diag/status/status_test.go b/diag/status/status_test.go index 2da64ac..eb396ee 100644 --- a/diag/status/status_test.go +++ b/diag/status/status_test.go @@ -12,7 +12,8 @@ import ( func TestReporter_Online(t *testing.T) { t.Run("ok", func(t *testing.T) { - reporter := NewReporter(&config.Config{SDKs: map[string]*config.SDKConfig{"t": {}}}) + reporter := NewEmptyReporter() + reporter.RegisterSdk("t", &config.SDKConfig{}) srv := httptest.NewServer(reporter.HttpHandler()) stat := readStatus(srv.URL) @@ -35,9 +36,9 @@ func TestReporter_Online(t *testing.T) { assert.Equal(t, NA, stat.Cache.Status) assert.Equal(t, 0, len(stat.Cache.Records)) }) - t.Run("down after 1 error, then ok, then degraded", func(t *testing.T) { - reporter := NewReporter(&config.Config{SDKs: map[string]*config.SDKConfig{"t": {}}}) + reporter := NewEmptyReporter() + reporter.RegisterSdk("t", &config.SDKConfig{}) srv := httptest.NewServer(reporter.HttpHandler()) reporter.ReportError("t", "") stat := readStatus(srv.URL) @@ -74,7 +75,8 @@ func TestReporter_Online(t *testing.T) { assert.Equal(t, 0, len(stat.Cache.Records)) }) t.Run("max 5 records", func(t *testing.T) { - reporter := NewReporter(&config.Config{SDKs: map[string]*config.SDKConfig{"t": {}}}) + reporter := NewEmptyReporter() + reporter.RegisterSdk("t", &config.SDKConfig{}) srv := httptest.NewServer(reporter.HttpHandler()) reporter.ReportOk("t", "m1") reporter.ReportOk("t", "m2") @@ -93,9 +95,48 @@ func TestReporter_Online(t *testing.T) { }) } +func TestReporter_Report_NonExisting(t *testing.T) { + reporter := NewEmptyReporter() + srv := httptest.NewServer(reporter.HttpHandler()) + + reporter.ReportOk("t1", "") + reporter.ReportError("t1", "") + stat := readStatus(srv.URL) + + assert.Equal(t, Initializing, stat.Status) + assert.Empty(t, stat.SDKs) + assert.Equal(t, Initializing, stat.Cache.Status) + assert.Equal(t, 0, len(stat.Cache.Records)) + + reporter.RegisterSdk("t2", &config.SDKConfig{}) + reporter.ReportOk("t1", "") + reporter.ReportError("t1", "") + reporter.ReportOk("t2", "") + stat = readStatus(srv.URL) + + assert.Equal(t, Healthy, stat.Status) + assert.Equal(t, Healthy, stat.SDKs["t2"].Source.Status) + assert.Equal(t, Online, stat.SDKs["t2"].Mode) + assert.Equal(t, 1, len(stat.SDKs["t2"].Source.Records)) + assert.Equal(t, RemoteSrc, stat.SDKs["t2"].Source.Type) + assert.Equal(t, NA, stat.Cache.Status) + assert.Equal(t, 0, len(stat.Cache.Records)) +} + +func TestReporter_Key_Obfuscation(t *testing.T) { + reporter := NewEmptyReporter() + srv := httptest.NewServer(reporter.HttpHandler()) + + reporter.RegisterSdk("t", &config.SDKConfig{Key: "XxPbCKmzIUGORk4vsufpzw/iC_KABprDEueeQs3yovVnQ"}) + stat := readStatus(srv.URL) + + assert.Equal(t, "****************************************ovVnQ", stat.SDKs["t"].SdkKey) +} + func TestReporter_Offline(t *testing.T) { t.Run("file", func(t *testing.T) { - reporter := NewReporter(&config.Config{SDKs: map[string]*config.SDKConfig{"t": {Offline: config.OfflineConfig{Enabled: true, Local: config.LocalConfig{FilePath: "test"}}}}}) + reporter := NewEmptyReporter() + reporter.RegisterSdk("t", &config.SDKConfig{Offline: config.OfflineConfig{Enabled: true, Local: config.LocalConfig{FilePath: "test"}}}) srv := httptest.NewServer(reporter.HttpHandler()) reporter.ReportOk("t", "") stat := readStatus(srv.URL) @@ -109,7 +150,8 @@ func TestReporter_Offline(t *testing.T) { assert.Equal(t, 0, len(stat.Cache.Records)) }) t.Run("cache invalid", func(t *testing.T) { - reporter := NewReporter(&config.Config{SDKs: map[string]*config.SDKConfig{"t": {Offline: config.OfflineConfig{Enabled: true, UseCache: true}}}}) + reporter := NewEmptyReporter() + reporter.RegisterSdk("t", &config.SDKConfig{Offline: config.OfflineConfig{Enabled: true, UseCache: true}}) srv := httptest.NewServer(reporter.HttpHandler()) stat := readStatus(srv.URL) @@ -122,7 +164,8 @@ func TestReporter_Offline(t *testing.T) { assert.Equal(t, 0, len(stat.Cache.Records)) }) t.Run("cache err", func(t *testing.T) { - reporter := NewReporter(&config.Config{SDKs: map[string]*config.SDKConfig{"t": {Offline: config.OfflineConfig{Enabled: true, UseCache: true}}}, Cache: config.CacheConfig{Redis: config.RedisConfig{Enabled: true}}}) + reporter := NewReporter(&config.CacheConfig{Redis: config.RedisConfig{Enabled: true}}) + reporter.RegisterSdk("t", &config.SDKConfig{Offline: config.OfflineConfig{Enabled: true, UseCache: true}}) srv := httptest.NewServer(reporter.HttpHandler()) reporter.ReportError("t", "") reporter.ReportError("t", "") @@ -135,7 +178,8 @@ func TestReporter_Offline(t *testing.T) { assert.Equal(t, CacheSrc, stat.SDKs["t"].Source.Type) }) t.Run("cache valid", func(t *testing.T) { - reporter := NewReporter(&config.Config{SDKs: map[string]*config.SDKConfig{"t": {Offline: config.OfflineConfig{Enabled: true, UseCache: true}}}, Cache: config.CacheConfig{Redis: config.RedisConfig{Enabled: true}}}) + reporter := NewReporter(&config.CacheConfig{Redis: config.RedisConfig{Enabled: true}}) + reporter.RegisterSdk("t", &config.SDKConfig{Offline: config.OfflineConfig{Enabled: true, UseCache: true}}) srv := httptest.NewServer(reporter.HttpHandler()) reporter.ReportOk("t", "") reporter.ReportOk(Cache, "") @@ -153,7 +197,8 @@ func TestReporter_Offline(t *testing.T) { func TestReporter_Degraded_Calc(t *testing.T) { t.Run("1 record first, 1 error", func(t *testing.T) { - reporter := NewReporter(&config.Config{SDKs: map[string]*config.SDKConfig{"t": {}}}).(*reporter) + reporter := NewEmptyReporter().(*reporter) + reporter.RegisterSdk("t", &config.SDKConfig{}) reporter.ReportError("t", "") stat := reporter.GetStatus() @@ -161,7 +206,8 @@ func TestReporter_Degraded_Calc(t *testing.T) { assert.Equal(t, Down, stat.SDKs["t"].Source.Status) }) t.Run("2 records, 1 error then 1 ok", func(t *testing.T) { - reporter := NewReporter(&config.Config{SDKs: map[string]*config.SDKConfig{"t": {}}}).(*reporter) + reporter := NewEmptyReporter().(*reporter) + reporter.RegisterSdk("t", &config.SDKConfig{}) reporter.ReportError("t", "") reporter.ReportOk("t", "") stat := reporter.GetStatus() @@ -170,7 +216,8 @@ func TestReporter_Degraded_Calc(t *testing.T) { assert.Equal(t, Healthy, stat.SDKs["t"].Source.Status) }) t.Run("2 records, 1 ok then 1 error", func(t *testing.T) { - reporter := NewReporter(&config.Config{SDKs: map[string]*config.SDKConfig{"t": {}}}).(*reporter) + reporter := NewEmptyReporter().(*reporter) + reporter.RegisterSdk("t", &config.SDKConfig{}) reporter.ReportOk("t", "") reporter.ReportError("t", "") stat := reporter.GetStatus() @@ -179,7 +226,8 @@ func TestReporter_Degraded_Calc(t *testing.T) { assert.Equal(t, Healthy, stat.SDKs["t"].Source.Status) }) t.Run("3 records, 1 ok then 2 errors", func(t *testing.T) { - reporter := NewReporter(&config.Config{SDKs: map[string]*config.SDKConfig{"t": {}}}).(*reporter) + reporter := NewEmptyReporter().(*reporter) + reporter.RegisterSdk("t", &config.SDKConfig{}) reporter.ReportOk("t", "") reporter.ReportError("t", "") reporter.ReportError("t", "") @@ -189,7 +237,8 @@ func TestReporter_Degraded_Calc(t *testing.T) { assert.Equal(t, Degraded, stat.SDKs["t"].Source.Status) }) t.Run("3 records, 1 ok then 1 error then 1 ok", func(t *testing.T) { - reporter := NewReporter(&config.Config{SDKs: map[string]*config.SDKConfig{"t": {}}}).(*reporter) + reporter := NewEmptyReporter().(*reporter) + reporter.RegisterSdk("t", &config.SDKConfig{}) reporter.ReportOk("t", "") reporter.ReportError("t", "") reporter.ReportOk("t", "") @@ -199,7 +248,8 @@ func TestReporter_Degraded_Calc(t *testing.T) { assert.Equal(t, Healthy, stat.SDKs["t"].Source.Status) }) t.Run("3 records, 1 error then 1 ok then 1 error", func(t *testing.T) { - reporter := NewReporter(&config.Config{SDKs: map[string]*config.SDKConfig{"t": {}}}).(*reporter) + reporter := NewEmptyReporter().(*reporter) + reporter.RegisterSdk("t", &config.SDKConfig{}) reporter.ReportError("t", "") reporter.ReportOk("t", "") reporter.ReportError("t", "") @@ -209,7 +259,9 @@ func TestReporter_Degraded_Calc(t *testing.T) { assert.Equal(t, Healthy, stat.SDKs["t"].Source.Status) }) t.Run("2 envs 1 down", func(t *testing.T) { - reporter := NewReporter(&config.Config{SDKs: map[string]*config.SDKConfig{"t1": {}, "t2": {}}}).(*reporter) + reporter := NewEmptyReporter().(*reporter) + reporter.RegisterSdk("t1", &config.SDKConfig{}) + reporter.RegisterSdk("t2", &config.SDKConfig{}) reporter.ReportError("t1", "") reporter.ReportOk("t2", "") stat := reporter.GetStatus() @@ -219,7 +271,9 @@ func TestReporter_Degraded_Calc(t *testing.T) { assert.Equal(t, Healthy, stat.SDKs["t2"].Source.Status) }) t.Run("2 envs 1 degraded", func(t *testing.T) { - reporter := NewReporter(&config.Config{SDKs: map[string]*config.SDKConfig{"t1": {}, "t2": {}}}).(*reporter) + reporter := NewEmptyReporter().(*reporter) + reporter.RegisterSdk("t1", &config.SDKConfig{}) + reporter.RegisterSdk("t2", &config.SDKConfig{}) reporter.ReportError("t1", "") reporter.ReportOk("t1", "") reporter.ReportError("t1", "") @@ -234,9 +288,9 @@ func TestReporter_Degraded_Calc(t *testing.T) { } func TestNewNullReporter(t *testing.T) { - rep := NewNullReporter().(*reporter) + rep := NewEmptyReporter().(*reporter) assert.Empty(t, rep.records) - assert.Empty(t, rep.conf.SDKs) + assert.Empty(t, rep.GetStatus().SDKs) } func readStatus(url string) Status { diff --git a/grpc/flag_service.go b/grpc/flag_service.go index acc1486..aae6e03 100644 --- a/grpc/flag_service.go +++ b/grpc/flag_service.go @@ -19,15 +19,15 @@ type flagService struct { proto.UnimplementedFlagServiceServer streamServer stream.Server log log.Logger - sdkClients map[string]sdk.Client + sdkRegistrar sdk.Registrar closed chan struct{} } -func newFlagService(sdkClients map[string]sdk.Client, metrics metrics.Reporter, log log.Logger) *flagService { +func newFlagService(sdkRegistrar sdk.Registrar, metrics metrics.Reporter, log log.Logger) *flagService { return &flagService{ - streamServer: stream.NewServer(sdkClients, metrics, log, "grpc"), + streamServer: stream.NewServer(sdkRegistrar, metrics, log, "grpc"), log: log, - sdkClients: sdkClients, + sdkRegistrar: sdkRegistrar, closed: make(chan struct{}), } } @@ -132,8 +132,8 @@ func (s *flagService) GetKeys(_ context.Context, req *proto.KeysRequest) (*proto return nil, status.Error(codes.InvalidArgument, "sdk id parameter missing") } - sdkClient, ok := s.sdkClients[req.GetSdkId()] - if !ok { + sdkClient := s.sdkRegistrar.GetSdkOrNil(req.GetSdkId()) + if sdkClient == nil { return nil, status.Error(codes.InvalidArgument, "sdk not found for identifier: '"+req.GetSdkId()+"'") } if !sdkClient.IsInValidState() { @@ -149,8 +149,8 @@ func (s *flagService) Refresh(_ context.Context, req *proto.RefreshRequest) (*em return nil, status.Error(codes.InvalidArgument, "sdk id parameter missing") } - sdkClient, ok := s.sdkClients[req.GetSdkId()] - if !ok { + sdkClient := s.sdkRegistrar.GetSdkOrNil(req.GetSdkId()) + if sdkClient == nil { return nil, status.Error(codes.InvalidArgument, "sdk not found for identifier: '"+req.GetSdkId()+"'") } @@ -219,8 +219,8 @@ func (s *flagService) parseEvalRequest(req *proto.EvalRequest, user *model.UserA *user = getUserAttrs(req.GetUser()) } - sdkClient, ok := s.sdkClients[req.GetSdkId()] - if !ok { + sdkClient := s.sdkRegistrar.GetSdkOrNil(req.GetSdkId()) + if sdkClient == nil { return nil, status.Error(codes.InvalidArgument, "sdk not found for identifier: '"+req.GetSdkId()+"'") } if !sdkClient.IsInValidState() { diff --git a/grpc/flag_service_test.go b/grpc/flag_service_test.go index 842aa6a..e0d761a 100644 --- a/grpc/flag_service_test.go +++ b/grpc/flag_service_test.go @@ -7,7 +7,6 @@ import ( "github.com/configcat/configcat-proxy/internal/testutils" "github.com/configcat/configcat-proxy/internal/utils" "github.com/configcat/configcat-proxy/log" - "github.com/configcat/configcat-proxy/sdk" "github.com/configcat/go-sdk/v9/configcattest" "github.com/stretchr/testify/assert" "google.golang.org/grpc" @@ -30,10 +29,9 @@ func TestGrpc_EvalFlagStream(t *testing.T) { sdkSrv := httptest.NewServer(&h) defer sdkSrv.Close() - ctx := testutils.NewTestSdkContext(&config.SDKConfig{BaseUrl: sdkSrv.URL, Key: key, PollInterval: 1}, nil) - sdkClient := sdk.NewClient(ctx, log.NewNullLogger()) - defer sdkClient.Close() - flagSrv := newFlagService(map[string]sdk.Client{"test": sdkClient}, nil, log.NewNullLogger()) + reg := testutils.NewTestRegistrar(&config.SDKConfig{BaseUrl: sdkSrv.URL, Key: key, PollInterval: 1}, nil) + defer reg.Close() + flagSrv := newFlagService(reg, nil, log.NewNullLogger()) lis := bufconn.Listen(1024 * 1024) @@ -98,10 +96,9 @@ func TestGrpc_EvalAllFlagsStream(t *testing.T) { sdkSrv := httptest.NewServer(&h) defer sdkSrv.Close() - ctx := testutils.NewTestSdkContext(&config.SDKConfig{BaseUrl: sdkSrv.URL, Key: key, PollInterval: 1}, nil) - sdkClient := sdk.NewClient(ctx, log.NewNullLogger()) - defer sdkClient.Close() - flagSrv := newFlagService(map[string]sdk.Client{"test": sdkClient}, nil, log.NewNullLogger()) + reg := testutils.NewTestRegistrar(&config.SDKConfig{BaseUrl: sdkSrv.URL, Key: key, PollInterval: 1}, nil) + defer reg.Close() + flagSrv := newFlagService(reg, nil, log.NewNullLogger()) lis := bufconn.Listen(1024 * 1024) @@ -174,10 +171,9 @@ func TestGrpc_EvalFlag(t *testing.T) { sdkSrv := httptest.NewServer(&h) defer sdkSrv.Close() - ctx := testutils.NewTestSdkContext(&config.SDKConfig{BaseUrl: sdkSrv.URL, Key: key, PollInterval: 1}, nil) - sdkClient := sdk.NewClient(ctx, log.NewNullLogger()) - defer sdkClient.Close() - flagSrv := newFlagService(map[string]sdk.Client{"test": sdkClient}, nil, log.NewNullLogger()) + reg := testutils.NewTestRegistrar(&config.SDKConfig{BaseUrl: sdkSrv.URL, Key: key, PollInterval: 1}, nil) + defer reg.Close() + flagSrv := newFlagService(reg, nil, log.NewNullLogger()) lis := bufconn.Listen(1024 * 1024) @@ -226,10 +222,9 @@ func TestGrpc_EvalFlag(t *testing.T) { } func TestGrpc_SDK_InvalidState(t *testing.T) { - ctx := testutils.NewTestSdkContext(&config.SDKConfig{BaseUrl: "http://localhost", Key: configcattest.RandomSDKKey()}, nil) - sdkClient := sdk.NewClient(ctx, log.NewNullLogger()) - defer sdkClient.Close() - flagSrv := newFlagService(map[string]sdk.Client{"test": sdkClient}, nil, log.NewNullLogger()) + reg := testutils.NewTestRegistrar(&config.SDKConfig{BaseUrl: "http://localhost", Key: configcattest.RandomSDKKey()}, nil) + defer reg.Close() + flagSrv := newFlagService(reg, nil, log.NewNullLogger()) lis := bufconn.Listen(1024 * 1024) @@ -287,10 +282,9 @@ func TestGrpc_Invalid_SdkKey(t *testing.T) { }) sdkSrv := httptest.NewServer(&h) defer sdkSrv.Close() - ctx := testutils.NewTestSdkContext(&config.SDKConfig{BaseUrl: sdkSrv.URL, Key: key}, nil) - sdkClient := sdk.NewClient(ctx, log.NewNullLogger()) - defer sdkClient.Close() - flagSrv := newFlagService(map[string]sdk.Client{"test": sdkClient}, nil, log.NewNullLogger()) + reg := testutils.NewTestRegistrar(&config.SDKConfig{BaseUrl: sdkSrv.URL, Key: key}, nil) + defer reg.Close() + flagSrv := newFlagService(reg, nil, log.NewNullLogger()) lis := bufconn.Listen(1024 * 1024) srv := grpc.NewServer() defer srv.GracefulStop() @@ -348,10 +342,9 @@ func TestGrpc_Invalid_FlagKey(t *testing.T) { }) sdkSrv := httptest.NewServer(&h) defer sdkSrv.Close() - ctx := testutils.NewTestSdkContext(&config.SDKConfig{BaseUrl: sdkSrv.URL, Key: key}, nil) - sdkClient := sdk.NewClient(ctx, log.NewNullLogger()) - defer sdkClient.Close() - flagSrv := newFlagService(map[string]sdk.Client{"test": sdkClient}, nil, log.NewNullLogger()) + reg := testutils.NewTestRegistrar(&config.SDKConfig{BaseUrl: sdkSrv.URL, Key: key}, nil) + defer reg.Close() + flagSrv := newFlagService(reg, nil, log.NewNullLogger()) lis := bufconn.Listen(1024 * 1024) srv := grpc.NewServer() defer srv.GracefulStop() @@ -398,10 +391,9 @@ func TestGrpc_EvalAllFlags(t *testing.T) { sdkSrv := httptest.NewServer(&h) defer sdkSrv.Close() - ctx := testutils.NewTestSdkContext(&config.SDKConfig{BaseUrl: sdkSrv.URL, Key: key, PollInterval: 1}, nil) - sdkClient := sdk.NewClient(ctx, log.NewNullLogger()) - defer sdkClient.Close() - flagSrv := newFlagService(map[string]sdk.Client{"test": sdkClient}, nil, log.NewNullLogger()) + reg := testutils.NewTestRegistrar(&config.SDKConfig{BaseUrl: sdkSrv.URL, Key: key, PollInterval: 1}, nil) + defer reg.Close() + flagSrv := newFlagService(reg, nil, log.NewNullLogger()) lis := bufconn.Listen(1024 * 1024) @@ -471,10 +463,9 @@ func TestGrpc_GetKeys(t *testing.T) { sdkSrv := httptest.NewServer(&h) defer sdkSrv.Close() - ctx := testutils.NewTestSdkContext(&config.SDKConfig{BaseUrl: sdkSrv.URL, Key: key, PollInterval: 1}, nil) - sdkClient := sdk.NewClient(ctx, log.NewNullLogger()) - defer sdkClient.Close() - flagSrv := newFlagService(map[string]sdk.Client{"test": sdkClient}, nil, log.NewNullLogger()) + reg := testutils.NewTestRegistrar(&config.SDKConfig{BaseUrl: sdkSrv.URL, Key: key, PollInterval: 1}, nil) + defer reg.Close() + flagSrv := newFlagService(reg, nil, log.NewNullLogger()) lis := bufconn.Listen(1024 * 1024) diff --git a/grpc/mware.go b/grpc/mware.go index 0a69e90..f07e013 100644 --- a/grpc/mware.go +++ b/grpc/mware.go @@ -14,7 +14,7 @@ import ( func DebugLogUnaryInterceptor(log log.Logger) grpc.UnaryServerInterceptor { return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { - if isHealthCheck(info.FullMethod) { + if shouldIgnore(info.FullMethod) { return handler(ctx, req) } @@ -44,7 +44,7 @@ func DebugLogUnaryInterceptor(log log.Logger) grpc.UnaryServerInterceptor { func DebugLogStreamInterceptor(log log.Logger) grpc.StreamServerInterceptor { return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - if isHealthCheck(info.FullMethod) { + if shouldIgnore(info.FullMethod) { return handler(srv, ss) } @@ -72,8 +72,8 @@ func DebugLogStreamInterceptor(log log.Logger) grpc.StreamServerInterceptor { } } -func isHealthCheck(method string) bool { - if strings.Contains(method, "grpc.health") { +func shouldIgnore(method string) bool { + if strings.Contains(method, "grpc.health") || strings.Contains(method, "grpc.reflection") { return true } return false diff --git a/grpc/mware_test.go b/grpc/mware_test.go index a92d616..b28bcb7 100644 --- a/grpc/mware_test.go +++ b/grpc/mware_test.go @@ -32,7 +32,7 @@ func TestDebug_UnaryInterceptor(t *testing.T) { outLog := out.String() assert.Contains(t, outLog, "[debug] rpc starting test-method [peer: 127.0.0.1/32]") - assert.Contains(t, outLog, "[debug] request finished test-method [peer: 127.0.0.1/32] [test-agent] [code: OK] [duration: 0ms]") + assert.Contains(t, outLog, "[debug] request finished test-method [peer: 127.0.0.1/32] [test-agent] [code: OK] [duration: ") } func TestDebug_StreamInterceptor(t *testing.T) { @@ -55,12 +55,13 @@ func TestDebug_StreamInterceptor(t *testing.T) { outLog := out.String() assert.Contains(t, outLog, "[debug] rpc starting test-method [peer: 127.0.0.1/32] [test-agent]") - assert.Contains(t, outLog, "[debug] request finished test-method [peer: 127.0.0.1/32] [test-agent] [code: OK] [duration: 0ms]") + assert.Contains(t, outLog, "[debug] request finished test-method [peer: 127.0.0.1/32] [test-agent] [code: OK] [duration: ") } -func TestIsHealthCheck(t *testing.T) { - assert.False(t, isHealthCheck("/configcat.FlagService/EvalFlag")) - assert.True(t, isHealthCheck("/grpc.health.v1.Health/Check")) +func TestIgnoreServiceNames(t *testing.T) { + assert.False(t, shouldIgnore("/configcat.FlagService/EvalFlag")) + assert.True(t, shouldIgnore("/grpc.health.v1.Health/Check")) + assert.True(t, shouldIgnore("/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo")) } type MockStreamServer struct { diff --git a/grpc/server.go b/grpc/server.go index 8d2eeda..b0f919b 100644 --- a/grpc/server.go +++ b/grpc/server.go @@ -32,7 +32,7 @@ type Server struct { errorChannel chan error } -func NewServer(sdkClients map[string]sdk.Client, metricsReporter metrics.Reporter, statusReporter status.Reporter, conf *config.Config, logger log.Logger, errorChan chan error) (*Server, error) { +func NewServer(sdkRegistrar sdk.Registrar, metricsReporter metrics.Reporter, statusReporter status.Reporter, conf *config.Config, logger log.Logger, errorChan chan error) (*Server, error) { grpcLog := logger.WithLevel(conf.Grpc.Log.GetLevel()).WithPrefix("grpc") opts := make([]grpc.ServerOption, 0) if conf.Tls.Enabled { @@ -64,7 +64,7 @@ func NewServer(sdkClients map[string]sdk.Client, metricsReporter metrics.Reporte opts = append(opts, grpc.KeepaliveParams(params)) } - flagService := newFlagService(sdkClients, metricsReporter, grpcLog) + flagService := newFlagService(sdkRegistrar, metricsReporter, grpcLog) grpcServer := grpc.NewServer(opts...) proto.RegisterFlagServiceServer(grpcServer, flagService) diff --git a/grpc/server_test.go b/grpc/server_test.go index ad81d07..6863c65 100644 --- a/grpc/server_test.go +++ b/grpc/server_test.go @@ -7,7 +7,6 @@ import ( "github.com/configcat/configcat-proxy/internal/testutils" "github.com/configcat/configcat-proxy/internal/utils" "github.com/configcat/configcat-proxy/log" - "github.com/configcat/configcat-proxy/sdk" "github.com/configcat/go-sdk/v9/configcattest" "github.com/stretchr/testify/assert" "net/http/httptest" @@ -30,12 +29,12 @@ func TestNewServer(t *testing.T) { sdkSrv := httptest.NewServer(&h) defer sdkSrv.Close() - ctx := testutils.NewTestSdkContext(&config.SDKConfig{BaseUrl: sdkSrv.URL, Key: key, PollInterval: 1}, nil) - sdkClient := sdk.NewClient(ctx, log.NewNullLogger()) - conf := config.Config{Grpc: config.GrpcConfig{Port: 5061, HealthCheckEnabled: true, ServerReflectionEnabled: true, KeepAlive: config.KeepAliveConfig{Timeout: 10}}, SDKs: map[string]*config.SDKConfig{key: ctx.SDKConf}} - defer sdkClient.Close() + sdkConf := config.SDKConfig{BaseUrl: sdkSrv.URL, Key: key, PollInterval: 1} + reg := testutils.NewTestRegistrar(&sdkConf, nil) + conf := config.Config{Grpc: config.GrpcConfig{Port: 5061, HealthCheckEnabled: true, ServerReflectionEnabled: true, KeepAlive: config.KeepAliveConfig{Timeout: 10}}, SDKs: map[string]*config.SDKConfig{key: &sdkConf}} + defer reg.Close() - srv, _ := NewServer(map[string]sdk.Client{"test": sdkClient}, metrics.NewReporter(), status.NewReporter(&conf), &conf, log.NewDebugLogger(), errChan) + srv, _ := NewServer(reg, metrics.NewReporter(), status.NewReporter(&conf.Cache), &conf, log.NewDebugLogger(), errChan) wg := sync.WaitGroup{} wg.Add(1) @@ -116,12 +115,12 @@ MK4Li/LGWcksyoF+hbPNXMFCIA== sdkSrv := httptest.NewServer(&h) defer sdkSrv.Close() - ctx := testutils.NewTestSdkContext(&config.SDKConfig{BaseUrl: sdkSrv.URL, Key: key, PollInterval: 1}, nil) - sdkClient := sdk.NewClient(ctx, log.NewNullLogger()) - conf := config.Config{Grpc: config.GrpcConfig{Port: 5062}, Tls: tlsConf, SDKs: map[string]*config.SDKConfig{key: ctx.SDKConf}} - defer sdkClient.Close() + sdkConf := config.SDKConfig{BaseUrl: sdkSrv.URL, Key: key, PollInterval: 1} + reg := testutils.NewTestRegistrar(&sdkConf, nil) + conf := config.Config{Grpc: config.GrpcConfig{Port: 5062}, Tls: tlsConf, SDKs: map[string]*config.SDKConfig{key: &sdkConf}} + defer reg.Close() - srv, _ := NewServer(map[string]sdk.Client{"test": sdkClient}, nil, status.NewReporter(&conf), &conf, log.NewNullLogger(), errChan) + srv, _ := NewServer(reg, nil, status.NewReporter(&conf.Cache), &conf, log.NewNullLogger(), errChan) wg := sync.WaitGroup{} wg.Add(1) @@ -157,12 +156,12 @@ func TestNewServer_TLS_Missing_Cert(t *testing.T) { sdkSrv := httptest.NewServer(&h) defer sdkSrv.Close() - ctx := testutils.NewTestSdkContext(&config.SDKConfig{BaseUrl: sdkSrv.URL, Key: key, PollInterval: 1}, nil) - sdkClient := sdk.NewClient(ctx, log.NewNullLogger()) - conf := config.Config{Grpc: config.GrpcConfig{Port: 5063}, Tls: tlsConf, SDKs: map[string]*config.SDKConfig{key: ctx.SDKConf}} - defer sdkClient.Close() + sdkConf := config.SDKConfig{BaseUrl: sdkSrv.URL, Key: key, PollInterval: 1} + reg := testutils.NewTestRegistrar(&sdkConf, nil) + conf := config.Config{Grpc: config.GrpcConfig{Port: 5063}, Tls: tlsConf, SDKs: map[string]*config.SDKConfig{key: &sdkConf}} + defer reg.Close() - _, err := NewServer(map[string]sdk.Client{"test": sdkClient}, nil, status.NewReporter(&conf), &conf, log.NewNullLogger(), errChan) + _, err := NewServer(reg, nil, status.NewReporter(&conf.Cache), &conf, log.NewNullLogger(), errChan) assert.Error(t, err) } diff --git a/internal/testutils/utils.go b/internal/testutils/utils.go index f24ccbc..84d9792 100644 --- a/internal/testutils/utils.go +++ b/internal/testutils/utils.go @@ -14,7 +14,23 @@ import ( "testing" ) -func NewTestSdkClient(t *testing.T) (map[string]sdk.Client, *configcattest.Handler, string) { +func NewTestRegistrar(conf *config.SDKConfig, cache configcat.ConfigCache) sdk.Registrar { + return NewTestRegistrarWithStatusReporter(conf, cache, status.NewEmptyReporter()) +} + +func NewTestRegistrarWithStatusReporter(conf *config.SDKConfig, cache configcat.ConfigCache, reporter status.Reporter) sdk.Registrar { + ctx := NewTestSdkContext(conf, cache) + reg := sdk.NewRegistrar(&config.Config{ + SDKs: map[string]*config.SDKConfig{"test": conf}, + }, ctx.MetricsReporter, reporter, cache, log.NewNullLogger()) + return reg +} + +func NewTestRegistrarT(t *testing.T) (sdk.Registrar, *configcattest.Handler, string) { + return NewTestRegistrarTWithStatusReporter(t, status.NewEmptyReporter()) +} + +func NewTestRegistrarTWithStatusReporter(t *testing.T, reporter status.Reporter) (sdk.Registrar, *configcattest.Handler, string) { key := configcattest.RandomSDKKey() var h configcattest.Handler _ = h.SetFlags(key, map[string]*configcattest.Flag{ @@ -31,25 +47,38 @@ func NewTestSdkClient(t *testing.T) (map[string]sdk.Client, *configcattest.Handl }, }) srv := httptest.NewServer(&h) - opts := config.SDKConfig{BaseUrl: srv.URL, Key: key} - ctx := NewTestSdkContext(&opts, nil) - client := sdk.NewClient(ctx, log.NewNullLogger()) + reg := NewTestRegistrarWithStatusReporter(&config.SDKConfig{BaseUrl: srv.URL, Key: key}, nil, reporter) + t.Cleanup(func() { + srv.Close() + reg.Close() + }) + return reg, &h, key +} + +func NewTestRegistrarTWithErrorServer(t *testing.T) sdk.Registrar { + key := configcattest.RandomSDKKey() + var h configcattest.Handler + srv := httptest.NewServer(&h) + reg := NewTestRegistrarWithStatusReporter(&config.SDKConfig{BaseUrl: srv.URL, Key: key}, nil, status.NewEmptyReporter()) t.Cleanup(func() { srv.Close() - client.Close() + reg.Close() }) - return map[string]sdk.Client{"test": client}, &h, key + return reg +} + +func NewTestSdkClient(t *testing.T) (map[string]sdk.Client, *configcattest.Handler, string) { + reg, h, k := NewTestRegistrarT(t) + return reg.GetAll(), h, k } func NewTestSdkContext(conf *config.SDKConfig, cache configcat.ConfigCache) *sdk.Context { return &sdk.Context{ - SDKConf: conf, - ProxyConf: &config.HttpProxyConfig{}, - StatusReporter: status.NewNullReporter(), - MetricsReporter: nil, - EvalReporter: nil, - SdkId: "test", - ExternalCache: cache, + SDKConf: conf, + ProxyConf: &config.HttpProxyConfig{}, + StatusReporter: status.NewEmptyReporter(), + SdkId: "test", + ExternalCache: cache, } } diff --git a/main.go b/main.go index f136acd..09c5466 100644 --- a/main.go +++ b/main.go @@ -56,7 +56,7 @@ func run(closeSignal chan os.Signal) int { // in the future we might implement an evaluation statistics reporter // var evalReporter statistics.Reporter - statusReporter := status.NewReporter(&conf) + statusReporter := status.NewReporter(&conf.Cache) var metricsReporter metrics.Reporter if conf.Diag.Metrics.Enabled { @@ -80,24 +80,12 @@ func run(closeSignal chan os.Signal) int { } } - sdkClients := make(map[string]sdk.Client) - for key, sdkConf := range conf.SDKs { - sdkClients[key] = sdk.NewClient(&sdk.Context{ - SDKConf: sdkConf, - EvalReporter: nil, - MetricsReporter: metricsReporter, - StatusReporter: statusReporter, - ProxyConf: &conf.HttpProxy, - GlobalDefaultAttrs: conf.DefaultAttrs, - SdkId: key, - ExternalCache: externalCache, - }, logger) - } + sdkRegistrar := sdk.NewRegistrar(&conf, metricsReporter, statusReporter, externalCache, logger) var httpServer *web.Server var router *web.HttpRouter if conf.Http.Enabled { - router = web.NewRouter(sdkClients, metricsReporter, statusReporter, &conf.Http, logger) + router = web.NewRouter(sdkRegistrar, metricsReporter, statusReporter, &conf.Http, logger) httpServer, err = web.NewServer(router.Handler(), logger, &conf, errorChan) if err != nil { return exitFailure @@ -107,7 +95,7 @@ func run(closeSignal chan os.Signal) int { var grpcServer *grpc.Server if conf.Grpc.Enabled { - grpcServer, err = grpc.NewServer(sdkClients, metricsReporter, statusReporter, &conf, logger, errorChan) + grpcServer, err = grpc.NewServer(sdkRegistrar, metricsReporter, statusReporter, &conf, logger, errorChan) if err != nil { return exitFailure } @@ -117,9 +105,8 @@ func run(closeSignal chan os.Signal) int { for { select { case <-closeSignal: - for _, sdkClient := range sdkClients { - sdkClient.Close() - } + sdkRegistrar.Close() + if router != nil { router.Close() } diff --git a/model/usr_attrs_test.go b/model/user_attrs_test.go similarity index 100% rename from model/usr_attrs_test.go rename to model/user_attrs_test.go diff --git a/sdk/sdk_registrar.go b/sdk/sdk_registrar.go new file mode 100644 index 0000000..486a7d4 --- /dev/null +++ b/sdk/sdk_registrar.go @@ -0,0 +1,50 @@ +package sdk + +import ( + "github.com/configcat/configcat-proxy/config" + "github.com/configcat/configcat-proxy/diag/metrics" + "github.com/configcat/configcat-proxy/diag/status" + "github.com/configcat/configcat-proxy/log" + configcat "github.com/configcat/go-sdk/v9" +) + +type Registrar interface { + GetSdkOrNil(sdkId string) Client + GetAll() map[string]Client + Close() +} + +type registrar struct { + sdkClients map[string]Client +} + +func NewRegistrar(conf *config.Config, metricsReporter metrics.Reporter, statusReporter status.Reporter, externalCache configcat.ConfigCache, log log.Logger) Registrar { + sdkClients := make(map[string]Client, len(conf.SDKs)) + for key, sdkConf := range conf.SDKs { + statusReporter.RegisterSdk(key, sdkConf) + sdkClients[key] = NewClient(&Context{ + SDKConf: sdkConf, + MetricsReporter: metricsReporter, + StatusReporter: statusReporter, + ProxyConf: &conf.HttpProxy, + GlobalDefaultAttrs: conf.DefaultAttrs, + SdkId: key, + ExternalCache: externalCache, + }, log) + } + return ®istrar{sdkClients: sdkClients} +} + +func (r *registrar) GetSdkOrNil(sdkId string) Client { + return r.sdkClients[sdkId] +} + +func (r *registrar) GetAll() map[string]Client { + return r.sdkClients +} + +func (r *registrar) Close() { + for _, sdkClient := range r.sdkClients { + sdkClient.Close() + } +} diff --git a/sdk/sdk_registrar_test.go b/sdk/sdk_registrar_test.go new file mode 100644 index 0000000..e3898f3 --- /dev/null +++ b/sdk/sdk_registrar_test.go @@ -0,0 +1,51 @@ +package sdk + +import ( + "github.com/configcat/configcat-proxy/config" + "github.com/configcat/configcat-proxy/diag/status" + "github.com/configcat/configcat-proxy/internal/utils" + "github.com/configcat/configcat-proxy/log" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestRegistrar_GetSdkOrNil(t *testing.T) { + reg := NewRegistrar(&config.Config{ + SDKs: map[string]*config.SDKConfig{"test": {Key: "key"}}, + }, nil, status.NewEmptyReporter(), nil, log.NewNullLogger()) + defer reg.Close() + + assert.NotNil(t, reg.GetSdkOrNil("test")) +} + +func TestRegistrar_All(t *testing.T) { + reg := NewRegistrar(&config.Config{ + SDKs: map[string]*config.SDKConfig{"test1": {Key: "key1"}, "test2": {Key: "key2"}}, + }, nil, status.NewEmptyReporter(), nil, log.NewNullLogger()) + defer reg.Close() + + assert.Equal(t, 2, len(reg.GetAll())) +} + +func TestClient_Close(t *testing.T) { + reg := NewRegistrar(&config.Config{ + SDKs: map[string]*config.SDKConfig{"test": {Key: "key"}}, + }, nil, status.NewEmptyReporter(), nil, log.NewNullLogger()) + + c := reg.GetSdkOrNil("test").(*client) + reg.Close() + utils.WithTimeout(1*time.Second, func() { + <-c.ctx.Done() + }) +} + +func TestRegistrar_Reporter(t *testing.T) { + reporter := status.NewEmptyReporter() + reg := NewRegistrar(&config.Config{ + SDKs: map[string]*config.SDKConfig{"test": {Key: "key"}}, + }, nil, reporter, nil, log.NewNullLogger()) + defer reg.Close() + + assert.NotEmpty(t, reporter.GetStatus().SDKs) +} diff --git a/sdk/sdk_test.go b/sdk/sdk_test.go index 0328511..5ac3a20 100644 --- a/sdk/sdk_test.go +++ b/sdk/sdk_test.go @@ -431,7 +431,7 @@ func newTestSdkContext(conf *config.SDKConfig, externalCache configcat.ConfigCac return &Context{ SDKConf: conf, ProxyConf: &config.HttpProxyConfig{}, - StatusReporter: status.NewNullReporter(), + StatusReporter: status.NewEmptyReporter(), MetricsReporter: nil, EvalReporter: nil, SdkId: "test", diff --git a/sdk/store/cache/cache_notify_test.go b/sdk/store/cache/cache_notify_test.go index 8946c4d..8559d60 100644 --- a/sdk/store/cache/cache_notify_test.go +++ b/sdk/store/cache/cache_notify_test.go @@ -21,8 +21,8 @@ func TestRedisNotify(t *testing.T) { s := miniredis.RunT(t) red, err := newRedis(&config.RedisConfig{Addresses: []string{s.Addr()}}, log.NewNullLogger()) assert.NoError(t, err) - r := NewCacheStore(red, status.NewNullReporter()) - srv := NewNotifyingCacheStore("test", cacheKey, r, &config.OfflineConfig{CachePollInterval: 1}, status.NewNullReporter(), log.NewNullLogger()).(*notifyingCacheStore) + r := NewCacheStore(red, status.NewEmptyReporter()) + srv := NewNotifyingCacheStore("test", cacheKey, r, &config.OfflineConfig{CachePollInterval: 1}, status.NewEmptyReporter(), log.NewNullLogger()).(*notifyingCacheStore) cacheEntry := configcatcache.CacheSegmentsToBytes(time.Now(), "etag", []byte(`{"f":{"flag":{"v":{"b":true}}},"p":null}`)) err = s.Set(cacheKey, string(cacheEntry)) assert.NoError(t, err) @@ -48,8 +48,8 @@ func TestRedisNotify_Initial(t *testing.T) { assert.NoError(t, err) red, err := newRedis(&config.RedisConfig{Addresses: []string{s.Addr()}}, log.NewNullLogger()) assert.NoError(t, err) - r := NewCacheStore(red, status.NewNullReporter()) - srv := NewNotifyingCacheStore("test", cacheKey, r, &config.OfflineConfig{CachePollInterval: 1}, status.NewNullReporter(), log.NewNullLogger()) + r := NewCacheStore(red, status.NewEmptyReporter()) + srv := NewNotifyingCacheStore("test", cacheKey, r, &config.OfflineConfig{CachePollInterval: 1}, status.NewEmptyReporter(), log.NewNullLogger()) s.CheckGet(t, cacheKey, string(cacheEntry)) res, err := srv.Get(context.Background(), "") _, _, j, _ := configcatcache.CacheSegmentsFromBytes(res) @@ -67,8 +67,8 @@ func TestRedisNotify_Notify(t *testing.T) { assert.NoError(t, err) red, err := newRedis(&config.RedisConfig{Addresses: []string{s.Addr()}}, log.NewNullLogger()) assert.NoError(t, err) - r := NewCacheStore(red, status.NewNullReporter()) - srv := NewNotifyingCacheStore("test", cacheKey, r, &config.OfflineConfig{CachePollInterval: 1}, status.NewNullReporter(), log.NewNullLogger()).(*notifyingCacheStore) + r := NewCacheStore(red, status.NewEmptyReporter()) + srv := NewNotifyingCacheStore("test", cacheKey, r, &config.OfflineConfig{CachePollInterval: 1}, status.NewEmptyReporter(), log.NewNullLogger()).(*notifyingCacheStore) s.CheckGet(t, cacheKey, string(cacheEntry)) res, err := srv.Get(context.Background(), "") _, _, j, _ := configcatcache.CacheSegmentsFromBytes(res) @@ -96,8 +96,8 @@ func TestRedisNotify_BadJson(t *testing.T) { assert.NoError(t, err) red, err := newRedis(&config.RedisConfig{Addresses: []string{s.Addr()}}, log.NewNullLogger()) assert.NoError(t, err) - r := NewCacheStore(red, status.NewNullReporter()) - srv := NewNotifyingCacheStore("test", cacheKey, r, &config.OfflineConfig{CachePollInterval: 1}, status.NewNullReporter(), log.NewNullLogger()) + r := NewCacheStore(red, status.NewEmptyReporter()) + srv := NewNotifyingCacheStore("test", cacheKey, r, &config.OfflineConfig{CachePollInterval: 1}, status.NewEmptyReporter(), log.NewNullLogger()) res, err := srv.Get(context.Background(), "") _, _, j, _ := configcatcache.CacheSegmentsFromBytes(res) assert.NoError(t, err) @@ -113,8 +113,8 @@ func TestRedisNotify_MalformedCacheEntry(t *testing.T) { assert.NoError(t, err) red, err := newRedis(&config.RedisConfig{Addresses: []string{s.Addr()}}, log.NewNullLogger()) assert.NoError(t, err) - r := NewCacheStore(red, status.NewNullReporter()) - srv := NewNotifyingCacheStore("test", cacheKey, r, &config.OfflineConfig{CachePollInterval: 1}, status.NewNullReporter(), log.NewNullLogger()) + r := NewCacheStore(red, status.NewEmptyReporter()) + srv := NewNotifyingCacheStore("test", cacheKey, r, &config.OfflineConfig{CachePollInterval: 1}, status.NewEmptyReporter(), log.NewNullLogger()) res, err := srv.Get(context.Background(), "") _, _, j, _ := configcatcache.CacheSegmentsFromBytes(res) assert.NoError(t, err) @@ -131,8 +131,8 @@ func TestRedisNotify_MalformedJson(t *testing.T) { assert.NoError(t, err) red, err := newRedis(&config.RedisConfig{Addresses: []string{s.Addr()}}, log.NewNullLogger()) assert.NoError(t, err) - r := NewCacheStore(red, status.NewNullReporter()) - srv := NewNotifyingCacheStore("test", cacheKey, r, &config.OfflineConfig{CachePollInterval: 1}, status.NewNullReporter(), log.NewNullLogger()) + r := NewCacheStore(red, status.NewEmptyReporter()) + srv := NewNotifyingCacheStore("test", cacheKey, r, &config.OfflineConfig{CachePollInterval: 1}, status.NewEmptyReporter(), log.NewNullLogger()) res, err := srv.Get(context.Background(), "") _, _, j, _ := configcatcache.CacheSegmentsFromBytes(res) assert.NoError(t, err) @@ -176,8 +176,8 @@ func TestRedisNotify_Reporter(t *testing.T) { func TestRedisNotify_Unavailable(t *testing.T) { red, err := newRedis(&config.RedisConfig{Addresses: []string{"nonexisting"}}, log.NewNullLogger()) assert.NoError(t, err) - r := NewCacheStore(red, status.NewNullReporter()) - srv := NewNotifyingCacheStore("test", "", r, &config.OfflineConfig{CachePollInterval: 1}, status.NewNullReporter(), log.NewNullLogger()) + r := NewCacheStore(red, status.NewEmptyReporter()) + srv := NewNotifyingCacheStore("test", "", r, &config.OfflineConfig{CachePollInterval: 1}, status.NewEmptyReporter(), log.NewNullLogger()) res, err := srv.Get(context.Background(), "") _, _, j, _ := configcatcache.CacheSegmentsFromBytes(res) assert.NoError(t, err) @@ -189,8 +189,8 @@ func TestRedisNotify_Close(t *testing.T) { s := miniredis.RunT(t) red, err := newRedis(&config.RedisConfig{Addresses: []string{s.Addr()}}, log.NewNullLogger()) assert.NoError(t, err) - r := NewCacheStore(red, status.NewNullReporter()) - srv := NewNotifyingCacheStore("test", "", r, &config.OfflineConfig{CachePollInterval: 1}, status.NewNullReporter(), log.NewNullLogger()).(*notifyingCacheStore) + r := NewCacheStore(red, status.NewEmptyReporter()) + srv := NewNotifyingCacheStore("test", "", r, &config.OfflineConfig{CachePollInterval: 1}, status.NewEmptyReporter(), log.NewNullLogger()).(*notifyingCacheStore) go func() { srv.Close() }() @@ -208,6 +208,10 @@ type testReporter struct { mu sync.RWMutex } +func (r *testReporter) RegisterSdk(_ string, _ *config.SDKConfig) { + // do nothing +} + func (r *testReporter) ReportOk(component string, message string) { r.mu.Lock() defer r.mu.Unlock() diff --git a/sdk/store/cache/cache_test.go b/sdk/store/cache/cache_test.go index bb6186e..85d9ff7 100644 --- a/sdk/store/cache/cache_test.go +++ b/sdk/store/cache/cache_test.go @@ -13,7 +13,7 @@ import ( ) func TestCacheStore(t *testing.T) { - store := NewCacheStore(&testCache{}, status.NewNullReporter()).(*cacheStore) + store := NewCacheStore(&testCache{}, status.NewEmptyReporter()).(*cacheStore) err := store.Set(context.Background(), "key", configcatcache.CacheSegmentsToBytes(time.Now(), "etag", []byte(`test`))) assert.NoError(t, err) diff --git a/sdk/store/file/file_test.go b/sdk/store/file/file_test.go index 667d351..d55bc01 100644 --- a/sdk/store/file/file_test.go +++ b/sdk/store/file/file_test.go @@ -14,7 +14,7 @@ import ( func TestFileStore_Existing(t *testing.T) { utils.UseTempFile("", func(path string) { - str := NewFileStore("test", &config.LocalConfig{FilePath: path}, status.NewNullReporter(), log.NewNullLogger()).(*fileStore) + str := NewFileStore("test", &config.LocalConfig{FilePath: path}, status.NewEmptyReporter(), log.NewNullLogger()).(*fileStore) utils.WriteIntoFile(path, `{"f":{"flag":{"v":{"b":true}}},"p":null}`) utils.WithTimeout(2*time.Second, func() { <-str.Modified() @@ -31,7 +31,7 @@ func TestFileStore_Existing(t *testing.T) { func TestFileStore_Existing_Initial(t *testing.T) { utils.UseTempFile(`{"f":{"flag":{"v":{"b":false}}},"p":null}`, func(path string) { - str := NewFileStore("test", &config.LocalConfig{FilePath: path}, status.NewNullReporter(), log.NewNullLogger()).(*fileStore) + str := NewFileStore("test", &config.LocalConfig{FilePath: path}, status.NewEmptyReporter(), log.NewNullLogger()).(*fileStore) res, err := str.Get(context.Background(), "") assert.NoError(t, err) _, _, j, _ := configcatcache.CacheSegmentsFromBytes(res) @@ -42,7 +42,7 @@ func TestFileStore_Existing_Initial(t *testing.T) { func TestFileStore_Existing_Initial_Gets_MalformedJson(t *testing.T) { utils.UseTempFile(`{"f":{"flag":{"v":{"b":false}}},"p":null}`, func(path string) { - str := NewFileStore("test", &config.LocalConfig{FilePath: path}, status.NewNullReporter(), log.NewNullLogger()).(*fileStore) + str := NewFileStore("test", &config.LocalConfig{FilePath: path}, status.NewEmptyReporter(), log.NewNullLogger()).(*fileStore) res, err := str.Get(context.Background(), "") assert.NoError(t, err) _, _, j, _ := configcatcache.CacheSegmentsFromBytes(res) @@ -60,7 +60,7 @@ func TestFileStore_Existing_Initial_Gets_MalformedJson(t *testing.T) { func TestFileStore_Existing_Initial_Notify(t *testing.T) { utils.UseTempFile(`{"f":{"flag":{"v":{"b":false}}},"p":null}`, func(path string) { - str := NewFileStore("test", &config.LocalConfig{FilePath: path}, status.NewNullReporter(), log.NewNullLogger()).(*fileStore) + str := NewFileStore("test", &config.LocalConfig{FilePath: path}, status.NewEmptyReporter(), log.NewNullLogger()).(*fileStore) res, err := str.Get(context.Background(), "") assert.NoError(t, err) _, _, j, _ := configcatcache.CacheSegmentsFromBytes(res) @@ -80,7 +80,7 @@ func TestFileStore_Existing_Initial_Notify(t *testing.T) { func TestFileStore_Existing_Initial_Gets_BadJson(t *testing.T) { utils.UseTempFile(`{"f":{"flag":{"v":{"b":false}}},"p":null}`, func(path string) { - str := NewFileStore("test", &config.LocalConfig{FilePath: path}, status.NewNullReporter(), log.NewNullLogger()).(*fileStore) + str := NewFileStore("test", &config.LocalConfig{FilePath: path}, status.NewEmptyReporter(), log.NewNullLogger()).(*fileStore) res, err := str.Get(context.Background(), "") assert.NoError(t, err) _, _, j, _ := configcatcache.CacheSegmentsFromBytes(res) @@ -98,7 +98,7 @@ func TestFileStore_Existing_Initial_Gets_BadJson(t *testing.T) { func TestFileStore_Existing_Initial_BadJson(t *testing.T) { utils.UseTempFile(`{"k":{"flag":{"v":{"b":false}}},"p":null}`, func(path string) { - str := NewFileStore("test", &config.LocalConfig{FilePath: path}, status.NewNullReporter(), log.NewNullLogger()).(*fileStore) + str := NewFileStore("test", &config.LocalConfig{FilePath: path}, status.NewEmptyReporter(), log.NewNullLogger()).(*fileStore) res, err := str.Get(context.Background(), "") assert.NoError(t, err) _, _, j, _ := configcatcache.CacheSegmentsFromBytes(res) @@ -109,7 +109,7 @@ func TestFileStore_Existing_Initial_BadJson(t *testing.T) { func TestFileStore_Existing_Initial_MalformedJson(t *testing.T) { utils.UseTempFile(`{"k":{"flag`, func(path string) { - str := NewFileStore("test", &config.LocalConfig{FilePath: path}, status.NewNullReporter(), log.NewNullLogger()).(*fileStore) + str := NewFileStore("test", &config.LocalConfig{FilePath: path}, status.NewEmptyReporter(), log.NewNullLogger()).(*fileStore) res, err := str.Get(context.Background(), "") assert.NoError(t, err) _, _, j, _ := configcatcache.CacheSegmentsFromBytes(res) @@ -120,7 +120,7 @@ func TestFileStore_Existing_Initial_MalformedJson(t *testing.T) { func TestFileStore_Stop(t *testing.T) { utils.UseTempFile("", func(path string) { - str := NewFileStore("test", &config.LocalConfig{FilePath: path}, status.NewNullReporter(), log.NewNullLogger()).(*fileStore) + str := NewFileStore("test", &config.LocalConfig{FilePath: path}, status.NewEmptyReporter(), log.NewNullLogger()).(*fileStore) go func() { str.Close() }() @@ -139,7 +139,7 @@ func TestFileStore_Stop(t *testing.T) { } func TestFileStore_NonExisting(t *testing.T) { - str := NewFileStore("test", &config.LocalConfig{FilePath: "nonexisting"}, status.NewNullReporter(), log.NewNullLogger()).(*fileStore) + str := NewFileStore("test", &config.LocalConfig{FilePath: "nonexisting"}, status.NewEmptyReporter(), log.NewNullLogger()).(*fileStore) defer str.Close() res, err := str.Get(context.Background(), "") diff --git a/stream/benchmark_test.go b/stream/benchmark_test.go index bbb6ee5..0ccc623 100644 --- a/stream/benchmark_test.go +++ b/stream/benchmark_test.go @@ -6,7 +6,6 @@ import ( "github.com/configcat/configcat-proxy/internal/testutils" "github.com/configcat/configcat-proxy/log" "github.com/configcat/configcat-proxy/model" - "github.com/configcat/configcat-proxy/sdk" "github.com/configcat/go-sdk/v9/configcattest" "net/http/httptest" "strconv" @@ -24,11 +23,10 @@ func BenchmarkStream(b *testing.B) { srv := httptest.NewServer(&h) defer srv.Close() - ctx := testutils.NewTestSdkContext(&config.SDKConfig{BaseUrl: srv.URL, Key: key}, nil) - client := sdk.NewClient(ctx, log.NewNullLogger()) - defer client.Close() + reg := testutils.NewTestRegistrar(&config.SDKConfig{BaseUrl: srv.URL, Key: key}, nil) + defer reg.Close() - strServer := NewServer(map[string]sdk.Client{"test": client}, nil, log.NewNullLogger(), "test").(*server) + strServer := NewServer(reg, nil, log.NewNullLogger(), "test").(*server) defer strServer.Close() b.ResetTimer() diff --git a/stream/load_test.go b/stream/load_test.go index e4fa105..eef9920 100644 --- a/stream/load_test.go +++ b/stream/load_test.go @@ -6,7 +6,6 @@ import ( "github.com/configcat/configcat-proxy/internal/utils" "github.com/configcat/configcat-proxy/log" "github.com/configcat/configcat-proxy/model" - "github.com/configcat/configcat-proxy/sdk" "github.com/configcat/go-sdk/v9/configcattest" "github.com/stretchr/testify/assert" "net/http/httptest" @@ -28,11 +27,10 @@ func TestStreamServer_Load(t *testing.T) { srv := httptest.NewServer(&h) defer srv.Close() - ctx := testutils.NewTestSdkContext(&config.SDKConfig{BaseUrl: srv.URL, Key: key}, nil) - client := sdk.NewClient(ctx, log.NewNullLogger()) - defer client.Close() + reg := testutils.NewTestRegistrar(&config.SDKConfig{BaseUrl: srv.URL, Key: key}, nil) + defer reg.Close() - strServer := NewServer(map[string]sdk.Client{"test": client}, nil, log.NewNullLogger(), "test").(*server) + strServer := NewServer(reg, nil, log.NewNullLogger(), "test").(*server) defer strServer.Close() t.Run("init", func(t *testing.T) { @@ -49,7 +47,7 @@ func TestStreamServer_Load(t *testing.T) { flags["flag"+strconv.Itoa(i)] = &configcattest.Flag{Default: true} } _ = h.SetFlags(key, flags) - _ = client.Refresh() + _ = reg.GetSdkOrNil("test").Refresh() assert.Equal(t, connCount, len(strServer.GetStreamOrNil("test").(*stream).channels[AllFlagsDiscriminator][0].(*allFlagsChannel).connections)) t.Run("check refresh", func(t *testing.T) { checkConnections(t, strServer) diff --git a/stream/server.go b/stream/server.go index fca0ae5..ceece1b 100644 --- a/stream/server.go +++ b/stream/server.go @@ -16,10 +16,10 @@ type server struct { log log.Logger } -func NewServer(sdkClients map[string]sdk.Client, metrics metrics.Reporter, log log.Logger, serverType string) Server { +func NewServer(sdkRegistrar sdk.Registrar, metrics metrics.Reporter, log log.Logger, serverType string) Server { strLog := log.WithPrefix("stream-server") streams := make(map[string]Stream) - for id, sdkClient := range sdkClients { + for id, sdkClient := range sdkRegistrar.GetAll() { streams[id] = NewStream(id, sdkClient, metrics, strLog, serverType) } return &server{ diff --git a/stream/server_test.go b/stream/server_test.go index 7a7e038..c86b93f 100644 --- a/stream/server_test.go +++ b/stream/server_test.go @@ -8,8 +8,8 @@ import ( ) func TestServer_GetStreamOrNil(t *testing.T) { - clients, _, _ := testutils.NewTestSdkClient(t) - srv := NewServer(clients, nil, log.NewNullLogger(), "test").(*server) + reg, _, _ := testutils.NewTestRegistrarT(t) + srv := NewServer(reg, nil, log.NewNullLogger(), "test").(*server) str := srv.GetStreamOrNil("test") assert.NotNil(t, str) diff --git a/web/api/api.go b/web/api/api.go index 90d102e..be89477 100644 --- a/web/api/api.go +++ b/web/api/api.go @@ -20,17 +20,17 @@ type keysResponse struct { } type Server struct { - sdkClients map[string]sdk.Client - config *config.ApiConfig - logger log.Logger + sdkRegistrar sdk.Registrar + config *config.ApiConfig + logger log.Logger } -func NewServer(sdkClients map[string]sdk.Client, config *config.ApiConfig, log log.Logger) *Server { +func NewServer(sdkRegistrar sdk.Registrar, config *config.ApiConfig, log log.Logger) *Server { cdnLogger := log.WithPrefix("api") return &Server{ - sdkClients: sdkClients, - config: config, - logger: cdnLogger, + sdkRegistrar: sdkRegistrar, + config: config, + logger: cdnLogger, } } @@ -137,8 +137,8 @@ func (s *Server) getSDKClient(ctx context.Context) (sdk.Client, error, int) { if sdkId == "" { return nil, fmt.Errorf("'sdkId' path parameter must be set"), http.StatusNotFound } - sdkClient, ok := s.sdkClients[sdkId] - if !ok { + sdkClient := s.sdkRegistrar.GetSdkOrNil(sdkId) + if sdkClient == nil { return nil, fmt.Errorf("invalid SDK identifier: '%s'", sdkId), http.StatusNotFound } if !sdkClient.IsInValidState() { diff --git a/web/api/api_test.go b/web/api/api_test.go index 5e268cd..749219b 100644 --- a/web/api/api_test.go +++ b/web/api/api_test.go @@ -5,7 +5,6 @@ import ( "github.com/configcat/configcat-proxy/internal/testutils" "github.com/configcat/configcat-proxy/internal/utils" "github.com/configcat/configcat-proxy/log" - "github.com/configcat/configcat-proxy/sdk" "github.com/configcat/go-sdk/v9/configcattest" "github.com/stretchr/testify/assert" "net/http" @@ -237,15 +236,7 @@ func TestAPI_Keys(t *testing.T) { } func TestAPI_Refresh(t *testing.T) { - key := configcattest.RandomSDKKey() - var h configcattest.Handler - _ = h.SetFlags(key, map[string]*configcattest.Flag{ - "flag": { - Default: true, - }, - }) - - srv := newServerWithHandler(t, &h, key, config.ApiConfig{Enabled: true}) + srv, h, key := newServerWithHandler(t, config.ApiConfig{Enabled: true}) res := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"key":"flag"}`)) @@ -322,16 +313,14 @@ func TestAPI_WrongSdkId(t *testing.T) { } func TestAPI_WrongSDKState(t *testing.T) { - opts := config.SDKConfig{BaseUrl: "http://localhost", Key: configcattest.RandomSDKKey()} - ctx := testutils.NewTestSdkContext(&opts, nil) - client := sdk.NewClient(ctx, log.NewNullLogger()) - defer client.Close() + reg := testutils.NewTestRegistrar(&config.SDKConfig{BaseUrl: "http://localhost", Key: configcattest.RandomSDKKey()}, nil) + defer reg.Close() t.Run("Eval", func(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"key":"flag"}`)) - srv := NewServer(map[string]sdk.Client{"test": client}, &config.ApiConfig{Enabled: true}, log.NewNullLogger()) + srv := NewServer(reg, &config.ApiConfig{Enabled: true}, log.NewNullLogger()) testutils.AddSdkIdContextParam(req) srv.Eval(res, req) @@ -342,7 +331,7 @@ func TestAPI_WrongSDKState(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodPost, "/", strings.NewReader(`{"key":"flag"}`)) - srv := NewServer(map[string]sdk.Client{"test": client}, &config.ApiConfig{Enabled: true}, log.NewNullLogger()) + srv := NewServer(reg, &config.ApiConfig{Enabled: true}, log.NewNullLogger()) testutils.AddSdkIdContextParam(req) srv.EvalAll(res, req) @@ -353,7 +342,7 @@ func TestAPI_WrongSDKState(t *testing.T) { res := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/", http.NoBody) - srv := NewServer(map[string]sdk.Client{"test": client}, &config.ApiConfig{Enabled: true}, log.NewNullLogger()) + srv := NewServer(reg, &config.ApiConfig{Enabled: true}, log.NewNullLogger()) testutils.AddSdkIdContextParam(req) srv.Keys(res, req) @@ -363,44 +352,24 @@ func TestAPI_WrongSDKState(t *testing.T) { } func newServer(t *testing.T, conf config.ApiConfig) *Server { - client, _, _ := testutils.NewTestSdkClient(t) - return NewServer(client, &conf, log.NewNullLogger()) + reg, _, _ := testutils.NewTestRegistrarT(t) + return NewServer(reg, &conf, log.NewNullLogger()) } -func newServerWithHandler(t *testing.T, h *configcattest.Handler, key string, conf config.ApiConfig) *Server { - _ = h.SetFlags(key, map[string]*configcattest.Flag{ - "flag": { - Default: true, - }, - }) - srv := httptest.NewServer(h) - ctx := testutils.NewTestSdkContext(&config.SDKConfig{BaseUrl: srv.URL, Key: key}, nil) - client := sdk.NewClient(ctx, log.NewNullLogger()) - t.Cleanup(func() { - srv.Close() - client.Close() - }) - return NewServer(map[string]sdk.Client{"test": client}, &conf, log.NewNullLogger()) +func newServerWithHandler(t *testing.T, conf config.ApiConfig) (*Server, *configcattest.Handler, string) { + reg, h, k := testutils.NewTestRegistrarT(t) + return NewServer(reg, &conf, log.NewNullLogger()), h, k } func newErrorServer(t *testing.T, conf config.ApiConfig) *Server { - key := configcattest.RandomSDKKey() - var h configcattest.Handler - srv := httptest.NewServer(&h) - ctx := testutils.NewTestSdkContext(&config.SDKConfig{BaseUrl: srv.URL, Key: key}, nil) - client := sdk.NewClient(ctx, log.NewNullLogger()) - t.Cleanup(func() { - srv.Close() - client.Close() - }) - return NewServer(map[string]sdk.Client{"test": client}, &conf, log.NewNullLogger()) + reg := testutils.NewTestRegistrarTWithErrorServer(t) + return NewServer(reg, &conf, log.NewNullLogger()) } func newOfflineServer(t *testing.T, path string, conf config.ApiConfig) *Server { - ctx := testutils.NewTestSdkContext(&config.SDKConfig{Key: "local", Offline: config.OfflineConfig{Enabled: true, Local: config.LocalConfig{FilePath: path, Polling: true, PollInterval: 30}}}, nil) - client := sdk.NewClient(ctx, log.NewNullLogger()) + reg := testutils.NewTestRegistrar(&config.SDKConfig{Key: "local", Offline: config.OfflineConfig{Enabled: true, Local: config.LocalConfig{FilePath: path, Polling: true, PollInterval: 30}}}, nil) t.Cleanup(func() { - client.Close() + reg.Close() }) - return NewServer(map[string]sdk.Client{"test": client}, &conf, log.NewNullLogger()) + return NewServer(reg, &conf, log.NewNullLogger()) } diff --git a/web/cdnproxy/cdnproxy.go b/web/cdnproxy/cdnproxy.go index 0d421e7..5acdb9f 100644 --- a/web/cdnproxy/cdnproxy.go +++ b/web/cdnproxy/cdnproxy.go @@ -11,17 +11,17 @@ import ( ) type Server struct { - sdkClients map[string]sdk.Client - config *config.CdnProxyConfig - logger log.Logger + sdkRegistrar sdk.Registrar + config *config.CdnProxyConfig + logger log.Logger } -func NewServer(sdkClients map[string]sdk.Client, config *config.CdnProxyConfig, log log.Logger) *Server { +func NewServer(sdkRegistrar sdk.Registrar, config *config.CdnProxyConfig, log log.Logger) *Server { cdnLogger := log.WithPrefix("cdn-proxy") return &Server{ - sdkClients: sdkClients, - config: config, - logger: cdnLogger, + sdkRegistrar: sdkRegistrar, + config: config, + logger: cdnLogger, } } @@ -51,8 +51,8 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) getSDKClient(ctx context.Context) (sdk.Client, error, int) { vars := httprouter.ParamsFromContext(ctx) sdkId := vars.ByName("sdkId") - sdkClient, ok := s.sdkClients[sdkId] - if !ok { + sdkClient := s.sdkRegistrar.GetSdkOrNil(sdkId) + if sdkClient == nil { return nil, fmt.Errorf("invalid SDK identifier: '%s'", sdkId), http.StatusNotFound } if !sdkClient.IsInValidState() { diff --git a/web/cdnproxy/cdnproxy_test.go b/web/cdnproxy/cdnproxy_test.go index 33b439a..ba29075 100644 --- a/web/cdnproxy/cdnproxy_test.go +++ b/web/cdnproxy/cdnproxy_test.go @@ -5,7 +5,6 @@ import ( "github.com/configcat/configcat-proxy/internal/testutils" "github.com/configcat/configcat-proxy/internal/utils" "github.com/configcat/configcat-proxy/log" - "github.com/configcat/configcat-proxy/sdk" "github.com/configcat/go-sdk/v9/configcattest" "github.com/stretchr/testify/assert" "net/http" @@ -133,7 +132,7 @@ func TestProxy_Get(t *testing.T) { Default: false, }, }) - _ = srv.sdkClients["test"].Refresh() + _ = srv.sdkRegistrar.GetSdkOrNil("test").Refresh() res = httptest.NewRecorder() req = &http.Request{Method: http.MethodGet, Header: map[string][]string{}} @@ -166,15 +165,13 @@ func TestProxy_Get(t *testing.T) { assert.Equal(t, http.StatusNotFound, res.Code) }) t.Run("SDK invalid state", func(t *testing.T) { - opts := config.SDKConfig{BaseUrl: "http://localhost", Key: configcattest.RandomSDKKey()} - ctx := testutils.NewTestSdkContext(&opts, nil) - client := sdk.NewClient(ctx, log.NewNullLogger()) - defer client.Close() + reg := testutils.NewTestRegistrar(&config.SDKConfig{BaseUrl: "http://localhost", Key: configcattest.RandomSDKKey()}, nil) + defer reg.Close() res := httptest.NewRecorder() req := &http.Request{Method: http.MethodGet} - srv := NewServer(map[string]sdk.Client{"test": client}, &config.CdnProxyConfig{Enabled: true}, log.NewNullLogger()) + srv := NewServer(reg, &config.CdnProxyConfig{Enabled: true}, log.NewNullLogger()) testutils.AddSdkIdContextParam(req) srv.ServeHTTP(res, req) @@ -184,33 +181,24 @@ func TestProxy_Get(t *testing.T) { } func newServer(t *testing.T, proxyConfig config.CdnProxyConfig) *Server { - client, _, _ := testutils.NewTestSdkClient(t) - return NewServer(client, &proxyConfig, log.NewNullLogger()) + reg, _, _ := testutils.NewTestRegistrarT(t) + return NewServer(reg, &proxyConfig, log.NewNullLogger()) } func newServerWithHandler(t *testing.T, proxyConfig config.CdnProxyConfig) (*Server, *configcattest.Handler, string) { - client, h, k := testutils.NewTestSdkClient(t) - return NewServer(client, &proxyConfig, log.NewNullLogger()), h, k + reg, h, k := testutils.NewTestRegistrarT(t) + return NewServer(reg, &proxyConfig, log.NewNullLogger()), h, k } func newErrorServer(t *testing.T, proxyConfig config.CdnProxyConfig) *Server { - key := configcattest.RandomSDKKey() - var h configcattest.Handler - srv := httptest.NewServer(&h) - ctx := testutils.NewTestSdkContext(&config.SDKConfig{BaseUrl: srv.URL, Key: key}, nil) - client := sdk.NewClient(ctx, log.NewNullLogger()) - t.Cleanup(func() { - srv.Close() - client.Close() - }) - return NewServer(map[string]sdk.Client{"test": client}, &proxyConfig, log.NewNullLogger()) + reg := testutils.NewTestRegistrarTWithErrorServer(t) + return NewServer(reg, &proxyConfig, log.NewNullLogger()) } func newOfflineServer(t *testing.T, path string, proxyConfig config.CdnProxyConfig) *Server { - ctx := testutils.NewTestSdkContext(&config.SDKConfig{Key: "local", Offline: config.OfflineConfig{Enabled: true, Local: config.LocalConfig{FilePath: path}}}, nil) - client := sdk.NewClient(ctx, log.NewNullLogger()) + reg := testutils.NewTestRegistrar(&config.SDKConfig{Key: "local", Offline: config.OfflineConfig{Enabled: true, Local: config.LocalConfig{FilePath: path}}}, nil) t.Cleanup(func() { - client.Close() + reg.Close() }) - return NewServer(map[string]sdk.Client{"test": client}, &proxyConfig, log.NewNullLogger()) + return NewServer(reg, &proxyConfig, log.NewNullLogger()) } diff --git a/web/router.go b/web/router.go index eeca6c3..ab50a0b 100644 --- a/web/router.go +++ b/web/router.go @@ -25,7 +25,7 @@ type HttpRouter struct { metrics metrics.Reporter } -func NewRouter(sdkClients map[string]sdk.Client, metrics metrics.Reporter, reporter status.Reporter, conf *config.HttpConfig, log log.Logger) *HttpRouter { +func NewRouter(sdkRegistrar sdk.Registrar, metrics metrics.Reporter, reporter status.Reporter, conf *config.HttpConfig, log log.Logger) *HttpRouter { httpLog := log.WithLevel(conf.Log.GetLevel()).WithPrefix("http") r := &HttpRouter{ @@ -37,16 +37,16 @@ func NewRouter(sdkClients map[string]sdk.Client, metrics metrics.Reporter, repor metrics: metrics, } if conf.Sse.Enabled { - r.setupSSERoutes(&conf.Sse, sdkClients, httpLog) + r.setupSSERoutes(&conf.Sse, sdkRegistrar, httpLog) } if conf.Webhook.Enabled { - r.setupWebhookRoutes(&conf.Webhook, sdkClients, httpLog) + r.setupWebhookRoutes(&conf.Webhook, sdkRegistrar, httpLog) } if conf.CdnProxy.Enabled { - r.setupCDNProxyRoutes(&conf.CdnProxy, sdkClients, httpLog) + r.setupCDNProxyRoutes(&conf.CdnProxy, sdkRegistrar, httpLog) } if conf.Api.Enabled { - r.setupAPIRoutes(&conf.Api, sdkClients, httpLog) + r.setupAPIRoutes(&conf.Api, sdkRegistrar, httpLog) } if conf.Status.Enabled { r.setupStatusRoutes(reporter, httpLog) @@ -64,8 +64,8 @@ func (s *HttpRouter) Close() { } } -func (s *HttpRouter) setupSSERoutes(conf *config.SseConfig, sdkClients map[string]sdk.Client, l log.Logger) { - s.sseServer = sse.NewServer(sdkClients, s.metrics, conf, l) +func (s *HttpRouter) setupSSERoutes(conf *config.SseConfig, sdkRegistrar sdk.Registrar, l log.Logger) { + s.sseServer = sse.NewServer(sdkRegistrar, s.metrics, conf, l) endpoints := []endpoint{ {path: "/sse/:sdkId/eval/:data", handler: http.HandlerFunc(s.sseServer.SingleFlag), method: http.MethodGet}, {path: "/sse/:sdkId/eval-all/:data", handler: http.HandlerFunc(s.sseServer.AllFlags), method: http.MethodGet}, @@ -88,8 +88,8 @@ func (s *HttpRouter) setupSSERoutes(conf *config.SseConfig, sdkClients map[strin l.Reportf("SSE enabled, accepting requests on path: /sse/:sdkId/*") } -func (s *HttpRouter) setupWebhookRoutes(conf *config.WebhookConfig, sdkClients map[string]sdk.Client, l log.Logger) { - s.webhookServer = webhook.NewServer(sdkClients, l) +func (s *HttpRouter) setupWebhookRoutes(conf *config.WebhookConfig, sdkRegistrar sdk.Registrar, l log.Logger) { + s.webhookServer = webhook.NewServer(sdkRegistrar, l) path := "/hook/:sdkId" handler := http.HandlerFunc(s.webhookServer.ServeHTTP) if conf.Auth.User != "" && conf.Auth.Password != "" { @@ -109,8 +109,8 @@ func (s *HttpRouter) setupWebhookRoutes(conf *config.WebhookConfig, sdkClients m l.Reportf("webhook enabled, accepting requests on path: %s", path) } -func (s *HttpRouter) setupCDNProxyRoutes(conf *config.CdnProxyConfig, sdkClients map[string]sdk.Client, l log.Logger) { - s.cdnProxyServer = cdnproxy.NewServer(sdkClients, conf, l) +func (s *HttpRouter) setupCDNProxyRoutes(conf *config.CdnProxyConfig, sdkRegistrar sdk.Registrar, l log.Logger) { + s.cdnProxyServer = cdnproxy.NewServer(sdkRegistrar, conf, l) path := "/configuration-files/configcat-proxy/:sdkId/config_v6.json" handler := mware.AutoOptions(mware.GZip(s.cdnProxyServer.ServeHTTP)) if len(conf.Headers) > 0 { @@ -147,8 +147,8 @@ type endpoint struct { path string } -func (s *HttpRouter) setupAPIRoutes(conf *config.ApiConfig, sdkClients map[string]sdk.Client, l log.Logger) { - s.apiServer = api.NewServer(sdkClients, conf, l) +func (s *HttpRouter) setupAPIRoutes(conf *config.ApiConfig, sdkRegistrar sdk.Registrar, l log.Logger) { + s.apiServer = api.NewServer(sdkRegistrar, conf, l) endpoints := []endpoint{ {path: "/api/:sdkId/eval", handler: mware.GZip(s.apiServer.Eval), method: http.MethodPost}, {path: "/api/:sdkId/eval-all", handler: mware.GZip(s.apiServer.EvalAll), method: http.MethodPost}, diff --git a/web/router_api_test.go b/web/router_api_test.go index 5fcfbe3..7622aee 100644 --- a/web/router_api_test.go +++ b/web/router_api_test.go @@ -425,6 +425,6 @@ func TestAPI_Refresh_Headers(t *testing.T) { } func newAPIRouter(t *testing.T, conf config.ApiConfig) *HttpRouter { - client, _, _ := testutils.NewTestSdkClient(t) - return NewRouter(client, nil, status.NewNullReporter(), &config.HttpConfig{Api: conf}, log.NewNullLogger()) + reg, _, _ := testutils.NewTestRegistrarT(t) + return NewRouter(reg, nil, status.NewEmptyReporter(), &config.HttpConfig{Api: conf}, log.NewNullLogger()) } diff --git a/web/router_cdnproxy_test.go b/web/router_cdnproxy_test.go index 6cf96c7..0168fe1 100644 --- a/web/router_cdnproxy_test.go +++ b/web/router_cdnproxy_test.go @@ -132,6 +132,6 @@ func TestCDNProxy_Get_Body_GZip(t *testing.T) { } func newCDNProxyRouter(t *testing.T, conf config.CdnProxyConfig) *HttpRouter { - client, _, _ := testutils.NewTestSdkClient(t) - return NewRouter(client, nil, status.NewNullReporter(), &config.HttpConfig{CdnProxy: conf}, log.NewNullLogger()) + reg, _, _ := testutils.NewTestRegistrarT(t) + return NewRouter(reg, nil, status.NewEmptyReporter(), &config.HttpConfig{CdnProxy: conf}, log.NewNullLogger()) } diff --git a/web/router_sse_test.go b/web/router_sse_test.go index a2fd6f5..ba58a6c 100644 --- a/web/router_sse_test.go +++ b/web/router_sse_test.go @@ -184,6 +184,6 @@ func TestSSE_EvalAllFlags_Not_Allowed_Methods(t *testing.T) { } func newSSERouter(t *testing.T, conf config.SseConfig) *HttpRouter { - client, _, _ := testutils.NewTestSdkClient(t) - return NewRouter(client, nil, status.NewNullReporter(), &config.HttpConfig{Sse: conf}, log.NewNullLogger()) + reg, _, _ := testutils.NewTestRegistrarT(t) + return NewRouter(reg, nil, status.NewEmptyReporter(), &config.HttpConfig{Sse: conf}, log.NewNullLogger()) } diff --git a/web/router_status_test.go b/web/router_status_test.go index ae05389..e4276bb 100644 --- a/web/router_status_test.go +++ b/web/router_status_test.go @@ -8,8 +8,6 @@ import ( "github.com/configcat/configcat-proxy/internal/testutils" "github.com/configcat/configcat-proxy/internal/utils" "github.com/configcat/configcat-proxy/log" - "github.com/configcat/configcat-proxy/sdk" - "github.com/configcat/go-sdk/v9/configcattest" "github.com/stretchr/testify/assert" "io" "net/http" @@ -74,26 +72,11 @@ func TestStatus_Not_Allowed_Methods(t *testing.T) { } func newStatusRouter(t *testing.T) *HttpRouter { - key := configcattest.RandomSDKKey() - var h configcattest.Handler - _ = h.SetFlags(key, map[string]*configcattest.Flag{ - "flag": { - Default: true, - }, - }) - srv := httptest.NewServer(&h) - opts := config.SDKConfig{BaseUrl: srv.URL, Key: key} - ctx := testutils.NewTestSdkContext(&opts, nil) - conf := config.Config{SDKs: map[string]*config.SDKConfig{"test": &opts}} - reporter := status.NewReporter(&conf) - ctx.StatusReporter = reporter - client := sdk.NewClient(ctx, log.NewNullLogger()) + reporter := status.NewEmptyReporter() + reg, _, _ := testutils.NewTestRegistrarTWithStatusReporter(t, reporter) + client := reg.GetSdkOrNil("test") utils.WithTimeout(2*time.Second, func() { <-client.Ready() }) - t.Cleanup(func() { - srv.Close() - client.Close() - }) - return NewRouter(map[string]sdk.Client{"test": client}, nil, reporter, &config.HttpConfig{Status: config.StatusConfig{Enabled: true}}, log.NewNullLogger()) + return NewRouter(reg, nil, reporter, &config.HttpConfig{Status: config.StatusConfig{Enabled: true}}, log.NewNullLogger()) } diff --git a/web/router_webhook_test.go b/web/router_webhook_test.go index e97e201..3a5fd5b 100644 --- a/web/router_webhook_test.go +++ b/web/router_webhook_test.go @@ -99,6 +99,6 @@ func TestWebhook_NotAllowed(t *testing.T) { } func newWebhookRouter(t *testing.T, conf config.WebhookConfig) *HttpRouter { - clients, _, _ := testutils.NewTestSdkClient(t) - return NewRouter(clients, nil, status.NewNullReporter(), &config.HttpConfig{Webhook: conf}, log.NewNullLogger()) + reg, _, _ := testutils.NewTestRegistrarT(t) + return NewRouter(reg, nil, status.NewEmptyReporter(), &config.HttpConfig{Webhook: conf}, log.NewNullLogger()) } diff --git a/web/sse/sse.go b/web/sse/sse.go index b53950f..ea45286 100644 --- a/web/sse/sse.go +++ b/web/sse/sse.go @@ -22,10 +22,10 @@ type Server struct { stop chan struct{} } -func NewServer(sdkClients map[string]sdk.Client, metrics metrics.Reporter, conf *config.SseConfig, logger log.Logger) *Server { +func NewServer(sdkRegistrar sdk.Registrar, metrics metrics.Reporter, conf *config.SseConfig, logger log.Logger) *Server { sseLog := logger.WithLevel(conf.Log.GetLevel()).WithPrefix("sse") return &Server{ - streamServer: stream.NewServer(sdkClients, metrics, sseLog, "sse"), + streamServer: stream.NewServer(sdkRegistrar, metrics, sseLog, "sse"), logger: sseLog, config: conf, stop: make(chan struct{}), diff --git a/web/sse/sse_test.go b/web/sse/sse_test.go index 11d9aa2..70faa3e 100644 --- a/web/sse/sse_test.go +++ b/web/sse/sse_test.go @@ -6,7 +6,6 @@ import ( "github.com/configcat/configcat-proxy/config" "github.com/configcat/configcat-proxy/internal/testutils" "github.com/configcat/configcat-proxy/log" - "github.com/configcat/configcat-proxy/sdk" "github.com/configcat/go-sdk/v9/configcattest" "github.com/julienschmidt/httprouter" "github.com/stretchr/testify/assert" @@ -80,14 +79,12 @@ func TestSSE_NonExisting_Flag(t *testing.T) { } func TestSSE_SDK_InvalidState(t *testing.T) { - opts := config.SDKConfig{BaseUrl: "http://localhost", Key: configcattest.RandomSDKKey()} - sdkCtx := testutils.NewTestSdkContext(&opts, nil) - client := sdk.NewClient(sdkCtx, log.NewNullLogger()) - defer client.Close() + reg := testutils.NewTestRegistrar(&config.SDKConfig{BaseUrl: "http://localhost", Key: configcattest.RandomSDKKey()}, nil) + defer reg.Close() req := httptest.NewRequest(http.MethodGet, "/", nil) - srv := NewServer(map[string]sdk.Client{"test": client}, nil, &config.SseConfig{Enabled: true}, log.NewNullLogger()) + srv := NewServer(reg, nil, &config.SseConfig{Enabled: true}, log.NewNullLogger()) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() @@ -216,6 +213,6 @@ func TestSSE_Get_All_User_Invalid(t *testing.T) { } func newServer(t *testing.T, conf *config.SseConfig) *Server { - client, _, _ := testutils.NewTestSdkClient(t) - return NewServer(client, nil, conf, log.NewNullLogger()) + reg, _, _ := testutils.NewTestRegistrarT(t) + return NewServer(reg, nil, conf, log.NewNullLogger()) } diff --git a/web/webhook/webhook.go b/web/webhook/webhook.go index 8bf2bc6..bef7c24 100644 --- a/web/webhook/webhook.go +++ b/web/webhook/webhook.go @@ -19,15 +19,15 @@ const idHeader = "X-ConfigCat-Webhook-ID" const timestampHeader = "X-ConfigCat-Webhook-Timestamp" type Server struct { - sdkClients map[string]sdk.Client - logger log.Logger + sdkRegistrar sdk.Registrar + logger log.Logger } -func NewServer(sdkClients map[string]sdk.Client, log log.Logger) *Server { +func NewServer(sdkRegistrar sdk.Registrar, log log.Logger) *Server { whLogger := log.WithPrefix("webhook") return &Server{ - sdkClients: sdkClients, - logger: whLogger, + sdkRegistrar: sdkRegistrar, + logger: whLogger, } } @@ -38,8 +38,8 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, "'sdkId' path parameter must be set", http.StatusBadRequest) return } - sdkClient, ok := s.sdkClients[sdkId] - if !ok { + sdkClient := s.sdkRegistrar.GetSdkOrNil(sdkId) + if sdkClient == nil { http.Error(w, "SDK not found for identifier: '"+sdkId+"'", http.StatusNotFound) return } diff --git a/web/webhook/webhook_test.go b/web/webhook/webhook_test.go index 213afe0..73e9061 100644 --- a/web/webhook/webhook_test.go +++ b/web/webhook/webhook_test.go @@ -21,11 +21,8 @@ import ( ) func TestWebhook_Signature_Bad(t *testing.T) { - key := configcattest.RandomSDKKey() - var h = &configcattest.Handler{} - _ = h.SetFlags(key, map[string]*configcattest.Flag{"flag": {Default: true}}) - clients := newClient(t, h, key, "test-key", 300) - srv := NewServer(clients, log.NewNullLogger()) + reg, _, _ := newRegistrar(t, "test-key", 300) + srv := NewServer(reg, log.NewNullLogger()) t.Run("headers missing", func(t *testing.T) { res := httptest.NewRecorder() @@ -58,11 +55,8 @@ func TestWebhook_Signature_Bad(t *testing.T) { func TestWebhook_Signature_Ok(t *testing.T) { t.Run("signature OK GET", func(t *testing.T) { - key := configcattest.RandomSDKKey() - var h = &configcattest.Handler{} - _ = h.SetFlags(key, map[string]*configcattest.Flag{"flag": {Default: true}}) - clients := newClient(t, h, key, "test-key", 300) - srv := NewServer(clients, log.NewNullLogger()) + reg, h, key := newRegistrar(t, "test-key", 300) + srv := NewServer(reg, log.NewNullLogger()) res := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -76,9 +70,9 @@ func TestWebhook_Signature_Ok(t *testing.T) { req.Header.Set("X-ConfigCat-Webhook-ID", id) req.Header.Set("X-ConfigCat-Webhook-Timestamp", timestamp) testutils.AddSdkIdContextParam(req) - sub := clients["test"].SubConfigChanged("hook1") + sub := reg.GetSdkOrNil("test").SubConfigChanged("hook1") utils.WithTimeout(2*time.Second, func() { - <-clients["test"].Ready() + <-reg.GetSdkOrNil("test").Ready() }) // wait for the SDK to do the initialization _ = h.SetFlags(key, map[string]*configcattest.Flag{"flag": {Default: false}}) srv.ServeHTTP(res, req) @@ -88,11 +82,8 @@ func TestWebhook_Signature_Ok(t *testing.T) { assert.Equal(t, http.StatusOK, res.Code) }) t.Run("signature OK POST", func(t *testing.T) { - key := configcattest.RandomSDKKey() - var h = &configcattest.Handler{} - _ = h.SetFlags(key, map[string]*configcattest.Flag{"flag": {Default: true}}) - clients := newClient(t, h, key, "test-key", 300) - srv := NewServer(clients, log.NewNullLogger()) + reg, h, key := newRegistrar(t, "test-key", 300) + srv := NewServer(reg, log.NewNullLogger()) id := "1" timestamp := strconv.FormatInt(time.Now().Unix(), 10) @@ -107,9 +98,9 @@ func TestWebhook_Signature_Ok(t *testing.T) { req.Header.Set("X-ConfigCat-Webhook-ID", id) req.Header.Set("X-ConfigCat-Webhook-Timestamp", timestamp) testutils.AddSdkIdContextParam(req) - sub := clients["test"].SubConfigChanged("hook1") + sub := reg.GetSdkOrNil("test").SubConfigChanged("hook1") utils.WithTimeout(2*time.Second, func() { - <-clients["test"].Ready() + <-reg.GetSdkOrNil("test").Ready() }) // wait for the SDK to do the initialization _ = h.SetFlags(key, map[string]*configcattest.Flag{"flag": {Default: false}}) srv.ServeHTTP(res, req) @@ -121,11 +112,8 @@ func TestWebhook_Signature_Ok(t *testing.T) { } func TestWebhook_Signature_Replay_Reject(t *testing.T) { - key := configcattest.RandomSDKKey() - var h = &configcattest.Handler{} - _ = h.SetFlags(key, map[string]*configcattest.Flag{"flag": {Default: true}}) - clients := newClient(t, h, key, "test-key", 1) - srv := NewServer(clients, log.NewNullLogger()) + reg, _, _ := newRegistrar(t, "test-key", 1) + srv := NewServer(reg, log.NewNullLogger()) id := "1" timestamp := strconv.FormatInt(time.Now().Unix(), 10) @@ -145,14 +133,15 @@ func TestWebhook_Signature_Replay_Reject(t *testing.T) { assert.Equal(t, http.StatusBadRequest, res.Code) } -func newClient(t *testing.T, h *configcattest.Handler, key string, signingKey string, validFor int) map[string]sdk.Client { +func newRegistrar(t *testing.T, signingKey string, validFor int) (sdk.Registrar, *configcattest.Handler, string) { + key := configcattest.RandomSDKKey() + var h = &configcattest.Handler{} + _ = h.SetFlags(key, map[string]*configcattest.Flag{"flag": {Default: true}}) srv := httptest.NewServer(h) - sdkConf := &config.SDKConfig{BaseUrl: srv.URL, Key: key, WebhookSigningKey: signingKey, WebhookSignatureValidFor: validFor} - ctx := testutils.NewTestSdkContext(sdkConf, nil) - client := sdk.NewClient(ctx, log.NewNullLogger()) + reg := testutils.NewTestRegistrar(&config.SDKConfig{BaseUrl: srv.URL, Key: key, WebhookSigningKey: signingKey, WebhookSignatureValidFor: validFor}, nil) t.Cleanup(func() { srv.Close() - client.Close() + reg.Close() }) - return map[string]sdk.Client{"test": client} + return reg, h, key }