diff --git a/cmd/container.go b/cmd/container.go index c05681d8d..086b778b5 100644 --- a/cmd/container.go +++ b/cmd/container.go @@ -1,7 +1,6 @@ package cmd import ( - "context" "crypto/tls" "fmt" "io" @@ -11,7 +10,6 @@ import ( "os" "strings" - "github.com/Masterminds/semver/v3" "github.com/Shopify/sarama" "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" @@ -25,7 +23,6 @@ import ( "github.com/numary/go-libs/sharedpublish/sharedpublishhttp" "github.com/numary/go-libs/sharedpublish/sharedpublishkafka" "github.com/numary/ledger/cmd/internal" - "github.com/numary/ledger/pkg/analytics" "github.com/numary/ledger/pkg/api" "github.com/numary/ledger/pkg/api/middlewares" "github.com/numary/ledger/pkg/api/routes" @@ -208,45 +205,7 @@ func NewContainer(v *viper.Viper, userOptions ...fx.Option) *fx.App { }(), })) - if v.GetBool(telemetryEnabledFlag) || v.GetBool(segmentEnabledFlag) { - applicationId := viper.GetString(telemetryApplicationIdFlag) - if applicationId == "" { - applicationId = viper.GetString(segmentApplicationIdFlag) - } - var appIdProviderModule fx.Option - if applicationId == "" { - appIdProviderModule = fx.Provide(analytics.FromStorageAppIdProvider) - } else { - appIdProviderModule = fx.Provide(func() analytics.AppIdProvider { - return analytics.AppIdProviderFn(func(ctx context.Context) (string, error) { - return applicationId, nil - }) - }) - } - writeKey := viper.GetString(telemetryWriteKeyFlag) - if writeKey == "" { - writeKey = viper.GetString(segmentWriteKeyFlag) - } - interval := viper.GetDuration(telemetryHeartbeatIntervalFlag) - if interval == 0 { - interval = viper.GetDuration(segmentHeartbeatIntervalFlag) - } - if writeKey == "" { - sharedlogging.GetLogger(context.Background()).Infof("telemetry enabled but no write key provided") - } else if interval == 0 { - sharedlogging.GetLogger(context.Background()).Error("telemetry heartbeat interval is 0") - } else { - _, err := semver.NewVersion(Version) - if err != nil { - sharedlogging.GetLogger(context.Background()).Infof("telemetry enabled but version '%s' is not semver, skip", Version) - } else { - options = append(options, - appIdProviderModule, - analytics.NewHeartbeatModule(Version, writeKey, interval), - ) - } - } - } + options = append(options, internal.NewAnalyticsModule(v, Version)) options = append(options, fx.Provide( fx.Annotate(func() []ledger.LedgerOption { diff --git a/cmd/internal/analytics.go b/cmd/internal/analytics.go new file mode 100644 index 000000000..c06ce2e6e --- /dev/null +++ b/cmd/internal/analytics.go @@ -0,0 +1,83 @@ +package internal + +import ( + "context" + "time" + + "github.com/Masterminds/semver/v3" + "github.com/numary/go-libs/sharedlogging" + "github.com/numary/ledger/pkg/analytics" + "github.com/spf13/cobra" + "github.com/spf13/viper" + "go.uber.org/fx" +) + +const ( + // deprecated + segmentEnabledFlag = "segment-enabled" + // deprecated + segmentWriteKeyFlag = "segment-write-key" + // deprecated + segmentApplicationIdFlag = "segment-application-id" + // deprecated + segmentHeartbeatIntervalFlag = "segment-heartbeat-interval" + + telemetryEnabledFlag = "telemetry-enabled" + telemetryWriteKeyFlag = "telemetry-write-key" + telemetryApplicationIdFlag = "telemetry-application-id" + telemetryHeartbeatIntervalFlag = "telemetry-heartbeat-interval" +) + +func InitAnalyticsFlags(cmd *cobra.Command, defaultWriteKey string) { + cmd.PersistentFlags().Bool(segmentEnabledFlag, true, "Is segment enabled") + cmd.PersistentFlags().String(segmentApplicationIdFlag, "", "Segment application id") + cmd.PersistentFlags().String(segmentWriteKeyFlag, defaultWriteKey, "Segment write key") + cmd.PersistentFlags().Duration(segmentHeartbeatIntervalFlag, 4*time.Hour, "Segment heartbeat interval") + cmd.PersistentFlags().Bool(telemetryEnabledFlag, true, "Is telemetry enabled") + cmd.PersistentFlags().String(telemetryApplicationIdFlag, "", "telemetry application id") + cmd.PersistentFlags().String(telemetryWriteKeyFlag, defaultWriteKey, "telemetry write key") + cmd.PersistentFlags().Duration(telemetryHeartbeatIntervalFlag, 4*time.Hour, "telemetry heartbeat interval") +} + +func NewAnalyticsModule(v *viper.Viper, version string) fx.Option { + if v.GetBool(telemetryEnabledFlag) || v.GetBool(segmentEnabledFlag) { + applicationId := viper.GetString(telemetryApplicationIdFlag) + if applicationId == "" { + applicationId = viper.GetString(segmentApplicationIdFlag) + } + var appIdProviderModule fx.Option + if applicationId == "" { + appIdProviderModule = fx.Provide(analytics.FromStorageAppIdProvider) + } else { + appIdProviderModule = fx.Provide(func() analytics.AppIdProvider { + return analytics.AppIdProviderFn(func(ctx context.Context) (string, error) { + return applicationId, nil + }) + }) + } + writeKey := viper.GetString(telemetryWriteKeyFlag) + if writeKey == "" { + writeKey = viper.GetString(segmentWriteKeyFlag) + } + interval := viper.GetDuration(telemetryHeartbeatIntervalFlag) + if interval == 0 { + interval = viper.GetDuration(segmentHeartbeatIntervalFlag) + } + if writeKey == "" { + sharedlogging.GetLogger(context.Background()).Infof("telemetry enabled but no write key provided") + } else if interval == 0 { + sharedlogging.GetLogger(context.Background()).Error("telemetry heartbeat interval is 0") + } else { + _, err := semver.NewVersion(version) + if err != nil { + sharedlogging.GetLogger(context.Background()).Infof("telemetry enabled but version '%s' is not semver, skip", version) + } else { + return fx.Options( + appIdProviderModule, + analytics.NewHeartbeatModule(version, writeKey, interval), + ) + } + } + } + return fx.Options() +} diff --git a/cmd/internal/analytics_test.go b/cmd/internal/analytics_test.go new file mode 100644 index 000000000..c83971bf4 --- /dev/null +++ b/cmd/internal/analytics_test.go @@ -0,0 +1,169 @@ +package internal + +import ( + "context" + "net/http" + "os" + "reflect" + "testing" + "time" + + "github.com/numary/ledger/pkg/storage" + "github.com/numary/ledger/pkg/storage/sqlstorage" + "github.com/pborman/uuid" + "github.com/spf13/cobra" + "github.com/spf13/viper" + "github.com/stretchr/testify/require" + "go.uber.org/fx" + "gopkg.in/segmentio/analytics-go.v3" +) + +func TestAnalyticsFlags(t *testing.T) { + type testCase struct { + name string + key string + envValue string + viperMethod interface{} + expectedValue interface{} + } + + for _, testCase := range []testCase{ + { + name: "using deprecated segment enabled flag", + key: segmentEnabledFlag, + envValue: "true", + viperMethod: (*viper.Viper).GetBool, + expectedValue: true, + }, + { + name: "using deprecated segment write key flag", + key: segmentWriteKeyFlag, + envValue: "foo:bar", + viperMethod: (*viper.Viper).GetString, + expectedValue: "foo:bar", + }, + { + name: "using deprecated segment heartbeat interval flag", + key: segmentHeartbeatIntervalFlag, + envValue: "10s", + viperMethod: (*viper.Viper).GetDuration, + expectedValue: 10 * time.Second, + }, + { + name: "using deprecated segment application id flag", + key: segmentApplicationIdFlag, + envValue: "foo:bar", + viperMethod: (*viper.Viper).GetString, + expectedValue: "foo:bar", + }, + { + name: "using telemetry enabled flag", + key: telemetryEnabledFlag, + envValue: "true", + viperMethod: (*viper.Viper).GetBool, + expectedValue: true, + }, + { + name: "using telemetry write key flag", + key: telemetryWriteKeyFlag, + envValue: "foo:bar", + viperMethod: (*viper.Viper).GetString, + expectedValue: "foo:bar", + }, + { + name: "using telemetry heartbeat interval flag", + key: telemetryHeartbeatIntervalFlag, + envValue: "10s", + viperMethod: (*viper.Viper).GetDuration, + expectedValue: 10 * time.Second, + }, + { + name: "using telemetry application id flag", + key: telemetryApplicationIdFlag, + envValue: "foo:bar", + viperMethod: (*viper.Viper).GetString, + expectedValue: "foo:bar", + }, + } { + t.Run(testCase.name, func(t *testing.T) { + v := viper.GetViper() + cmd := &cobra.Command{ + Run: func(cmd *cobra.Command, args []string) { + ret := reflect.ValueOf(testCase.viperMethod).Call([]reflect.Value{ + reflect.ValueOf(v), + reflect.ValueOf(testCase.key), + }) + require.Len(t, ret, 1) + + rValue := ret[0].Interface() + require.Equal(t, testCase.expectedValue, rValue) + }, + } + InitHTTPBasicFlags(cmd) + BindEnv(v) + + restoreEnvVar := setEnvVar(testCase.key, testCase.envValue) + defer restoreEnvVar() + + require.NoError(t, v.BindPFlags(cmd.PersistentFlags())) + require.NoError(t, cmd.Execute()) + }) + } +} + +func TestAnalyticsModule(t *testing.T) { + v := viper.GetViper() + v.Set(telemetryEnabledFlag, true) + v.Set(telemetryWriteKeyFlag, "XXX") + v.Set(telemetryApplicationIdFlag, "appId") + v.Set(telemetryHeartbeatIntervalFlag, 10*time.Second) + + handled := make(chan struct{}) + + module := NewAnalyticsModule(v, "1.0.0") + app := fx.New( + module, + fx.Provide(func(lc fx.Lifecycle) (storage.Driver, error) { + id := uuid.New() + driver := sqlstorage.NewDriver("sqlite", sqlstorage.NewSQLiteDB(os.TempDir(), id)) + lc.Append(fx.Hook{ + OnStart: driver.Initialize, + }) + return driver, nil + }), + fx.Replace(analytics.Config{ + BatchSize: 1, + Transport: roundTripperFn(func(req *http.Request) (*http.Response, error) { + select { + case <-handled: + // Nothing to do, the chan has already been closed + default: + close(handled) + } + return &http.Response{ + StatusCode: http.StatusOK, + }, nil + }), + })) + require.NoError(t, app.Start(context.Background())) + defer func() { + require.NoError(t, app.Stop(context.Background())) + }() + + select { + case <-time.After(time.Second): + require.Fail(t, "Timeout waiting first stats from analytics module") + case <-handled: + } + +} + +func TestAnalyticsModuleDisabled(t *testing.T) { + v := viper.GetViper() + v.Set(telemetryEnabledFlag, false) + + module := NewAnalyticsModule(v, "1.0.0") + app := fx.New(module) + require.NoError(t, app.Start(context.Background())) + require.NoError(t, app.Stop(context.Background())) +} diff --git a/cmd/internal/http_basic_test.go b/cmd/internal/http_basic_test.go index 48053b569..0b164e75e 100644 --- a/cmd/internal/http_basic_test.go +++ b/cmd/internal/http_basic_test.go @@ -2,9 +2,7 @@ package internal import ( "fmt" - "os" "reflect" - "strings" "testing" "github.com/numary/go-libs/sharedauth" @@ -13,19 +11,6 @@ import ( "github.com/stretchr/testify/require" ) -func withPrefix(flag string) string { - return strings.ToUpper(fmt.Sprintf("%s_%s", envPrefix, EnvVarReplacer.Replace(flag))) -} - -func setEnvVar(key, value string) func() { - prefixedFlag := withPrefix(key) - oldEnv := os.Getenv(prefixedFlag) - os.Setenv(prefixedFlag, value) - return func() { - os.Setenv(prefixedFlag, oldEnv) - } -} - func TestViperEnvBinding(t *testing.T) { type testCase struct { diff --git a/cmd/internal/utils.go b/cmd/internal/utils.go new file mode 100644 index 000000000..52940bae5 --- /dev/null +++ b/cmd/internal/utils.go @@ -0,0 +1,27 @@ +package internal + +import ( + "fmt" + "net/http" + "os" + "strings" +) + +func withPrefix(flag string) string { + return strings.ToUpper(fmt.Sprintf("%s_%s", envPrefix, EnvVarReplacer.Replace(flag))) +} + +func setEnvVar(key, value string) func() { + prefixedFlag := withPrefix(key) + oldEnv := os.Getenv(prefixedFlag) + os.Setenv(prefixedFlag, value) + return func() { + os.Setenv(prefixedFlag, oldEnv) + } +} + +type roundTripperFn func(req *http.Request) (*http.Response, error) + +func (fn roundTripperFn) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} diff --git a/cmd/root.go b/cmd/root.go index 1f4f2416e..f30343365 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -4,7 +4,6 @@ import ( "fmt" "os" "path" - "time" "github.com/numary/ledger/cmd/internal" "github.com/numary/ledger/pkg/redis" @@ -62,20 +61,6 @@ const ( authBearerAudiencesWildcardFlag = "auth-bearer-audiences-wildcard" authBearerUseScopesFlag = "auth-bearer-use-scopes" - // deprecated - segmentEnabledFlag = "segment-enabled" - // deprecated - segmentWriteKeyFlag = "segment-write-key" - // deprecated - segmentApplicationIdFlag = "segment-application-id" - // deprecated - segmentHeartbeatIntervalFlag = "segment-heartbeat-interval" - - telemetryEnabledFlag = "telemetry-enabled" - telemetryWriteKeyFlag = "telemetry-write-key" - telemetryApplicationIdFlag = "telemetry-application-id" - telemetryHeartbeatIntervalFlag = "telemetry-heartbeat-interval" - commitPolicyFlag = "commit-policy" ) @@ -178,17 +163,10 @@ func NewRootCommand() *cobra.Command { root.PersistentFlags().StringSlice(authBearerAudienceFlag, []string{}, "Allowed audiences") root.PersistentFlags().Bool(authBearerAudiencesWildcardFlag, false, "Don't check audience") root.PersistentFlags().Bool(authBearerUseScopesFlag, false, "Use scopes as defined by rfc https://datatracker.ietf.org/doc/html/rfc8693") - root.PersistentFlags().Bool(segmentEnabledFlag, true, "Is segment enabled") - root.PersistentFlags().String(segmentApplicationIdFlag, "", "Segment application id") - root.PersistentFlags().String(segmentWriteKeyFlag, DefaultSegmentWriteKey, "Segment write key") - root.PersistentFlags().Duration(segmentHeartbeatIntervalFlag, 4*time.Hour, "Segment heartbeat interval") - root.PersistentFlags().Bool(telemetryEnabledFlag, true, "Is telemetry enabled") - root.PersistentFlags().String(telemetryApplicationIdFlag, "", "telemetry application id") - root.PersistentFlags().String(telemetryWriteKeyFlag, DefaultSegmentWriteKey, "telemetry write key") - root.PersistentFlags().Duration(telemetryHeartbeatIntervalFlag, 4*time.Hour, "telemetry heartbeat interval") root.PersistentFlags().String(commitPolicyFlag, "", "Transaction commit policy (default or allow-past-timestamps)") internal.InitHTTPBasicFlags(root) + internal.InitAnalyticsFlags(root, DefaultSegmentWriteKey) if err = viper.BindPFlags(root.PersistentFlags()); err != nil { panic(err)