From 9c7483fd093b304e304ede67c6a5f5ab24331f6f Mon Sep 17 00:00:00 2001 From: Felipe Martin <812088+fmartingr@users.noreply.github.com> Date: Sat, 30 Mar 2024 08:33:05 +0100 Subject: [PATCH] fix: override configuration from flags only if set (#865) * fix: override configuration from flags only if set * use helper func and test it --- internal/cmd/server.go | 34 ++++++++++++++---- internal/cmd/server_test.go | 72 +++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 7 deletions(-) create mode 100644 internal/cmd/server_test.go diff --git a/internal/cmd/server.go b/internal/cmd/server.go index f8ccc47cf..4b3ad58ff 100644 --- a/internal/cmd/server.go +++ b/internal/cmd/server.go @@ -4,9 +4,11 @@ import ( "context" "strings" + "github.com/go-shiori/shiori/internal/config" "github.com/go-shiori/shiori/internal/http" "github.com/go-shiori/shiori/internal/model" "github.com/spf13/cobra" + "github.com/spf13/pflag" ) func newServerCommand() *cobra.Command { @@ -27,6 +29,12 @@ func newServerCommand() *cobra.Command { return cmd } +func setIfFlagChanged(flagName string, flags *pflag.FlagSet, cfg *config.Config, fn func(cfg *config.Config)) { + if flags.Changed(flagName) { + fn(cfg) + } +} + func newServerCommandHandler() func(cmd *cobra.Command, args []string) { return func(cmd *cobra.Command, args []string) { ctx := context.Background() @@ -54,13 +62,25 @@ func newServerCommandHandler() func(cmd *cobra.Command, args []string) { rootPath += "/" } - // Override configuration from flags - cfg.Http.Port = port - cfg.Http.Address = address + ":" - cfg.Http.RootPath = rootPath - cfg.Http.AccessLog = accessLog - cfg.Http.ServeWebUI = serveWebUI - cfg.Http.SecretKey = secretKey + // Override configuration from flags if needed + setIfFlagChanged("port", cmd.Flags(), cfg, func(cfg *config.Config) { + cfg.Http.Port = port + }) + setIfFlagChanged("address", cmd.Flags(), cfg, func(cfg *config.Config) { + cfg.Http.Address = address + ":" + }) + setIfFlagChanged("webroot", cmd.Flags(), cfg, func(cfg *config.Config) { + cfg.Http.RootPath = rootPath + }) + setIfFlagChanged("access-log", cmd.Flags(), cfg, func(cfg *config.Config) { + cfg.Http.AccessLog = accessLog + }) + setIfFlagChanged("serve-web-ui", cmd.Flags(), cfg, func(cfg *config.Config) { + cfg.Http.ServeWebUI = serveWebUI + }) + setIfFlagChanged("secret-key", cmd.Flags(), cfg, func(cfg *config.Config) { + cfg.Http.SecretKey = secretKey + }) dependencies.Log.Infof("Starting Shiori v%s", model.BuildVersion) diff --git a/internal/cmd/server_test.go b/internal/cmd/server_test.go new file mode 100644 index 000000000..626866689 --- /dev/null +++ b/internal/cmd/server_test.go @@ -0,0 +1,72 @@ +package cmd + +import ( + "testing" + + "github.com/go-shiori/shiori/internal/config" + "github.com/spf13/pflag" + "github.com/stretchr/testify/require" +) + +func Test_setIfFlagChanged(t *testing.T) { + type args struct { + flagName string + flags func() *pflag.FlagSet + cfg *config.Config + fn func(cfg *config.Config) + } + tests := []struct { + name string + args args + assertFn func(t *testing.T, cfg *config.Config) + }{ + { + name: "Flag didn't change", + args: args{ + flagName: "port", + flags: func() *pflag.FlagSet { + return &pflag.FlagSet{} + }, + cfg: &config.Config{ + Http: &config.HttpConfig{ + Port: 8080, + }, + }, + fn: func(cfg *config.Config) { + cfg.Http.Port = 9999 + }, + }, + assertFn: func(t *testing.T, cfg *config.Config) { + require.Equal(t, cfg.Http.Port, 8080) + }, + }, + { + name: "Flag changed", + args: args{ + flagName: "port", + flags: func() *pflag.FlagSet { + pf := &pflag.FlagSet{} + pf.IntP("port", "p", 8080, "Port used by the server") + pf.Set("port", "9999") + return pf + }, + cfg: &config.Config{ + Http: &config.HttpConfig{ + Port: 8080, + }, + }, + fn: func(cfg *config.Config) { + cfg.Http.Port = 9999 + }, + }, + assertFn: func(t *testing.T, cfg *config.Config) { + require.Equal(t, cfg.Http.Port, 9999) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setIfFlagChanged(tt.args.flagName, tt.args.flags(), tt.args.cfg, tt.args.fn) + }) + } +}