Skip to content

Commit

Permalink
fix: command arg chain parsing
Browse files Browse the repository at this point in the history
Signed-off-by: Marko Kungla <[email protected]>
  • Loading branch information
mkungla committed Mar 8, 2024
1 parent 3468715 commit 8327d66
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 41 deletions.
81 changes: 45 additions & 36 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ import (
)

var (
ErrCommand = errors.New("command error")
ErrCommandFlags = errors.New("command flags error")
ErrCommandAction = errors.New("command action error")
ErrCommand = errors.New("command error")
ErrCommandFlags = errors.New("command flags error")
ErrCommandHasNoParent = errors.New("command has no parent command")
)

type Command struct {
Expand Down Expand Up @@ -373,29 +373,30 @@ func (c *Command) getFlags() varflag.Flags {
return c.flags
}

func (c *Command) getSharedFlags() varflag.Flags {

func (c *Command) getSharedFlags() (varflag.Flags, error) {
if c.parent == nil {
return nil
flags, _ := varflag.NewFlagSet(c.name+"-noparent", 0)
return flags, ErrCommandHasNoParent
}

var flags varflag.Flags
if c.parent.beforeActionShared {
flags = c.parent.getFlags()
}
flags := c.parent.getFlags()
if flags == nil {
flags, _ = varflag.NewFlagSet(c.name, 0)
flags, _ = varflag.NewFlagSet(c.parent.name, 0)
}
parentFlags, err := c.parent.getSharedFlags()
if err != nil {
return nil, err
}

parentFlags := c.parent.getSharedFlags()

if parentFlags != nil {
for _, flag := range parentFlags.Flags() {
_ = flags.Add(flag)
if err := flags.Add(flag); err != nil {
return nil, err
}
}
}

return flags
return flags, nil
}

func (c *Command) getSubCommand(name string) (cmd *Command, exists bool) {
Expand Down Expand Up @@ -439,6 +440,7 @@ func (c *Command) getActiveCommand() (*Command, error) {

func (c *Command) callSharedBeforeAction(sess *Session) error {
if c.parent != nil {
c.parent.isWrapperCommand = true // prevents args from being added to parent command
if err := c.parent.callSharedBeforeAction(sess); err != nil {
return err
}
Expand All @@ -454,38 +456,42 @@ func (c *Command) callBeforeAction(sess *Session) error {
c.mu.Lock()
defer c.mu.Unlock()

if c.parent != nil && !c.sharedCalled {
if err := c.parent.callSharedBeforeAction(sess); err != nil {
return err
}
}
if c.beforeAction == nil {
return nil
pflags, err := c.getSharedFlags()
if err != nil && !errors.Is(err, ErrCommandHasNoParent) {
return err
}

pflags := c.getSharedFlags()
if pflags != nil {
for _, flag := range pflags.Flags() {
if err := c.flags.Add(flag); err != nil {
return fmt.Errorf("%w: %s: %w", ErrCommandAction, c.name, err)
return fmt.Errorf("%w: %s: %w", ErrCommand, c.name, err)
}
}
}
args := sdk.NewArgs(c.flags)

args := sdk.NewArgs(c.flags)
if c.argnmin == 0 && c.argnmax == 0 && args.Argn() > 0 {
return fmt.Errorf("%w: %s: %s", ErrCommandAction, c.name, "command does not accept arguments")
return fmt.Errorf("%w: %s: %s", ErrCommand, c.name, "does not accept arguments")
}

if args.Argn() < c.argnmin {
return fmt.Errorf("%w: %s: command requires min %d arguments, %d provided", ErrCommandAction, c.name, c.argnmin, args.Argn())
return fmt.Errorf("%w: %s: requires min %d arguments, %d provided", ErrCommand, c.name, c.argnmin, args.Argn())
}
if args.Argn() > c.argnmax {
return fmt.Errorf("%w: %s: command accepts max %d arguments, %d provided, extra %v", ErrCommandAction, c.name, c.argnmax, args.Argn(), args.Args()[c.argnmax:args.Argn()])
return fmt.Errorf("%w: %s: accepts max %d arguments, %d provided, extra %v", ErrCommand, c.name, c.argnmax, args.Argn(), args.Args()[c.argnmax:args.Argn()])
}

if c.parent != nil && !c.sharedCalled {
if err := c.parent.callSharedBeforeAction(sess); err != nil {
return err
}
}

if c.beforeAction == nil {
return nil
}

if err := c.beforeAction(sess, args); err != nil {
return fmt.Errorf("%w: %s: %w", ErrCommandAction, c.name, err)
return fmt.Errorf("%w: %s: %w", ErrCommand, c.name, err)
}
return nil
}
Expand All @@ -498,19 +504,22 @@ func (c *Command) callDoAction(session *Session) error {
return nil
}

pflags := c.getSharedFlags()
pflags, err := c.getSharedFlags()
if err != nil && !errors.Is(err, ErrCommandHasNoParent) {
return err
}
if pflags != nil {
for _, flag := range pflags.Flags() {
if err := c.flags.Add(flag); err != nil {
return fmt.Errorf("%w: %s: %w", ErrCommandAction, c.name, err)
return fmt.Errorf("%w: %s: %w", ErrCommand, c.name, err)
}
}
}

args := sdk.NewArgs(c.flags)

if err := c.doAction(session, args); err != nil {
return fmt.Errorf("%w: %s: %w", ErrCommandAction, c.name, err)
return fmt.Errorf("%w: %s: %w", ErrCommand, c.name, err)
}
return nil
}
Expand All @@ -524,7 +533,7 @@ func (c *Command) callAfterFailureAction(session *Session, err error) error {
}

if err := c.afterFailureAction(session, err); err != nil {
return fmt.Errorf("%w: %s: %w", ErrCommandAction, c.name, err)
return fmt.Errorf("%w: %s: %w", ErrCommand, c.name, err)
}
return nil
}
Expand All @@ -538,7 +547,7 @@ func (c *Command) callAfterSuccessAction(session *Session) error {
}

if err := c.afterSuccessAction(session); err != nil {
return fmt.Errorf("%w: %s: %w", ErrCommandAction, c.name, err)
return fmt.Errorf("%w: %s: %w", ErrCommand, c.name, err)
}
return nil
}
Expand All @@ -552,7 +561,7 @@ func (c *Command) callAfterAlwaysAction(session *Session, err error) error {
}

if err := c.afterAlwaysAction(session, err); err != nil {
return fmt.Errorf("%w: %s: %w", ErrCommandAction, c.name, err)
return fmt.Errorf("%w: %s: %w", ErrCommand, c.name, err)
}
return nil
}
Expand Down
11 changes: 8 additions & 3 deletions happy-main.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ func (m *Main) run() {
}

if err := m.executeBeforeActions(); err != nil {
m.sess.Log().Error("failed to call before always", slog.String("err", err.Error()))
m.sess.Log().Error(err.Error(), slog.String("action", "before"))
m.exit(1)
return
}
Expand Down Expand Up @@ -342,6 +342,7 @@ func (m *Main) executeBeforeActions() error {
return err
}
}

return m.cmd.callBeforeAction(m.sess)
}

Expand Down Expand Up @@ -377,7 +378,11 @@ func (m *Main) help() error {

if m.cmd != m.root {
h.AddCommandFlags(m.cmd.getFlags())
h.AddSharedFlags(m.cmd.getSharedFlags())
flags, err := m.cmd.getSharedFlags()
if err != nil {
return err
}
h.AddSharedFlags(flags)
}

h.AddGlobalFlags(m.root.getFlags())
Expand Down Expand Up @@ -449,7 +454,7 @@ func (m *Main) exit(code int) {
}
if !testing.Testing() {
if err := m.save(); err != nil {
m.sess.Log().Error("failed to save state", slog.String("err", err.Error()))
m.sess.Log().Error(err.Error())
}
}

Expand Down
2 changes: 0 additions & 2 deletions sdk/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ import (
)

var (
// ErrCommand = happyx.NewError("command error")
// ErrCommandAction = happyx.NewError("command action error")
ErrCommandInvalid = errors.New("invalid command definition")
ErrCommandArgs = errors.New("command arguments error")
ErrCommandFlags = errors.New("command flags error")
Expand Down

0 comments on commit 8327d66

Please sign in to comment.