From df98c09e6f63849f77b2a916dfb1acd71691ff9a Mon Sep 17 00:00:00 2001 From: Seth Vargo Date: Thu, 23 Mar 2023 10:34:33 -0400 Subject: [PATCH] Add package for creating CLIs (#100) * Add package for creating CLIs This adds the `cli` package which creates well-structured command line interfaces and flag parsing. * Document prompt * Note users can run -help * Add more detailed examples * Interpolate {{ COMMAND }} * Make flag handling easier * Add an example for persistent flags * Faster trim --- cli/cli.go | 81 +++ cli/command.go | 310 ++++++++++ cli/command_doc_test.go | 137 +++++ cli/command_group_doc_test.go | 128 +++++ cli/command_persistent_flags_doc_test.go | 180 ++++++ cli/command_test.go | 321 +++++++++++ cli/flags.go | 697 +++++++++++++++++++++++ cli/flags_test.go | 96 ++++ go.mod | 2 + go.sum | 3 + 10 files changed, 1955 insertions(+) create mode 100644 cli/cli.go create mode 100644 cli/command.go create mode 100644 cli/command_doc_test.go create mode 100644 cli/command_group_doc_test.go create mode 100644 cli/command_persistent_flags_doc_test.go create mode 100644 cli/command_test.go create mode 100644 cli/flags.go create mode 100644 cli/flags_test.go diff --git a/cli/cli.go b/cli/cli.go new file mode 100644 index 00000000..cabbc49b --- /dev/null +++ b/cli/cli.go @@ -0,0 +1,81 @@ +// Copyright 2023 The Authors (see AUTHORS file) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cli defines an SDK for building performant and consistent CLIs. All +// commands start with a [RootCommand] which can then accept one or more nested +// subcommands. Subcommands can also be [RootCommand], which creates nested CLIs +// (e.g. "my-tool do the-thing"). +// +// The CLI provides opinionated, formatted help output including flag structure. +// It also provides a more integrated experience for defining CLI flags, hiding +// flags, and generating aliases. +// +// To minimize startup times, things are as lazy-loaded as possible. This means +// commands are instantiated only when needed. Most applications will create a +// private global variable that returns the root command: +// +// var rootCmd = func() cli.Command { +// return &cli.RootCommand{ +// Name: "my-tool", +// Version: "1.2.3", +// Commands: map[string]cli.CommandFactory{ +// "eat": func() cli.Command { +// return &EatCommand{} +// }, +// "sleep": func() cli.Command { +// return &SleepCommand{} +// }, +// }, +// } +// } +// +// This CLI could be invoked via: +// +// $ my-tool eat +// $ my-tool sleep +// +// Deeply-nested [RootCommand] behave like nested CLIs: +// +// var rootCmd = func() cli.Command { +// return &cli.RootCommand{ +// Name: "my-tool", +// Version: "1.2.3", +// Commands: map[string]cli.CommandFactory{ +// "transport": func() cli.Command { +// return &cli.RootCommand{ +// Name: "transport", +// Description: "Subcommands for transportation", +// Commands: map[string]cli.CommandFactory{ +// "bus": func() cli.Command { +// return &BusCommand{} +// }, +// "car": func() cli.Command { +// return &CarCommand{} +// }, +// "train": func() cli.Command { +// return &TrainCommand{} +// }, +// }, +// } +// }, +// }, +// } +// } +// +// This CLI could be invoked via: +// +// $ my-tool transport bus +// $ my-tool transport car +// $ my-tool transport train +package cli diff --git a/cli/command.go b/cli/command.go new file mode 100644 index 00000000..7d027b7f --- /dev/null +++ b/cli/command.go @@ -0,0 +1,310 @@ +// Copyright 2023 The Authors (see AUTHORS file) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cli + +import ( + "bufio" + "bytes" + "context" + "errors" + "flag" + "fmt" + "io" + "os" + "sort" + "strings" + + "github.com/mattn/go-isatty" +) + +// Command is the interface for a command or subcommand. Most of these functions +// have default implementations on [BaseCommand]. +type Command interface { + // Desc provides a short, one-line description of the command. It should be + // shorter than 50 characters. + Desc() string + + // Help is the long-form help output. It should include usage instructions and + // flag information. + // + // Callers can insert the literal string "{{ COMMAND }}" which will be + // replaced with the actual subcommand structure. + Help() string + + // Flags returns the list of flags that are defined on the command. + Flags() *FlagSet + + // Hidden indicates whether the command is hidden from help output. + Hidden() bool + + // Run executes the command. + Run(ctx context.Context, args []string) error + + // Prompt provides a mechanism for asking for user input. It reads from + // [Stdin]. If there's an input stream (e.g. a pipe), it will read the pipe. + // If the terminal is a TTY, it will prompt. Otherwise it will fail if there's + // no pipe and the terminal is not a tty. + Prompt(msg string) (string, error) + + // Stdout returns the stdout stream. SetStdout sets the stdout stream. + Stdout() io.Writer + SetStdout(w io.Writer) + + // Stderr returns the stderr stream. SetStderr sets the stderr stream. + Stderr() io.Writer + SetStderr(w io.Writer) + + // Stdin returns the stdin stream. SetStdin sets the stdin stream. + Stdin() io.Reader + SetStdin(r io.Reader) + + // Pipe creates new unqiue stdin, stdout, and stderr buffers, sets them on the + // command, and returns them. This is most useful for testing where callers + // want to simulate inputs or assert certain command outputs. + Pipe() (stdin, stdout, stderr *bytes.Buffer) +} + +// CommandFactory returns a new instance of a command. This returns a function +// instead of allocations because we want the CLI to load as fast as possible, +// so we lazy load as much as possible. +type CommandFactory func() Command + +// Ensure [RootCommand] implements [Command]. +var _ Command = (*RootCommand)(nil) + +// RootCommand represents a command root for a parent or collection of +// subcommands. +type RootCommand struct { + BaseCommand + + // Name is the name of the command or subcommand. For top-level commands, this + // should be the binary name. For subcommands, this should be the name of the + // subcommand. + Name string + + // Description is the human-friendly description of the command. + Description string + + // Hide marks the entire subcommand as hidden. It will not be shown in help + // output. + Hide bool + + // Version defines the version information for the command. This can be + // omitted for subcommands as it will be inherited from the parent. + Version string + + // Commands is the list of sub commands. + Commands map[string]CommandFactory +} + +// Desc is the root command description. It is used to satisfy the [Command] +// interface. +func (r *RootCommand) Desc() string { + return r.Description +} + +// Hidden determines whether the command group is hidden. It is used to satisfy +// the [Command] interface. +func (r *RootCommand) Hidden() bool { + return r.Hide +} + +// Help compiles structured help information. It is used to satisfy the +// [Command] interface. +func (r *RootCommand) Help() string { + var b strings.Builder + + longest := 0 + names := make([]string, 0, len(r.Commands)) + for name := range r.Commands { + names = append(names, name) + if l := len(name); l > longest { + longest = l + } + } + sort.Strings(names) + + fmt.Fprintf(&b, "Usage: %s COMMAND\n\n", r.Name) + for _, name := range names { + cmd := r.Commands[name]() + if cmd == nil { + continue + } + + if !cmd.Hidden() { + fmt.Fprintf(&b, " %-*s%s\n", longest+4, name, cmd.Desc()) + } + } + + return strings.TrimRight(b.String(), "\n") +} + +// Run executes the command and prints help output or delegates to a subcommand. +func (r *RootCommand) Run(ctx context.Context, args []string) error { + name, args := extractCommandAndArgs(args) + + // Short-circuit top-level help. + if name == "" || name == "-h" || name == "-help" || name == "--help" { + fmt.Fprintln(r.Stderr(), formatHelp(r.Help(), r.Name, r.Flags())) + return nil + } + + // Short-circuit version. + if name == "-v" || name == "-version" || name == "--version" { + fmt.Fprintln(r.Stderr(), r.Version) + return nil + } + + cmd, ok := r.Commands[name] + if !ok { + return fmt.Errorf("unknown command %q: run \"%s -help\" for a list of "+ + "commands", name, r.Name) + } + instance := cmd() + + // Ensure the child inherits the streams from the root. + instance.SetStdin(r.stdin) + instance.SetStdout(r.stdout) + instance.SetStderr(r.stderr) + + // If this is a subcommand, prefix the name with the parent and inherit some + // values. + if typ, ok := instance.(*RootCommand); ok { + typ.Name = r.Name + " " + typ.Name + typ.Version = r.Version + return typ.Run(ctx, args) + } + + if err := instance.Run(ctx, args); err != nil { + // Special case requesting help. + if errors.Is(err, flag.ErrHelp) { + fmt.Fprintln(instance.Stderr(), formatHelp(instance.Help(), r.Name+" "+name, instance.Flags())) + return nil + } + //nolint:wrapcheck // We want to bubble this error exactly as-is. + return err + } + return nil +} + +// extractCommandAndArgs is a helper that pulls the subcommand and arguments. +func extractCommandAndArgs(args []string) (string, []string) { + switch len(args) { + case 0: + return "", nil + case 1: + return args[0], nil + default: + return args[0], args[1:] + } +} + +// formatHelp is a helper function that does variable replacement from the help +// string. +func formatHelp(help, name string, flags *FlagSet) string { + h := strings.Trim(help, "\n") + if flags != nil { + if v := strings.Trim(flags.Help(), "\n"); v != "" { + h = h + "\n\n" + v + } + } + return strings.ReplaceAll(h, "{{ COMMAND }}", name) +} + +// BaseCommand is the default command structure. All commands should embed this +// structure. +type BaseCommand struct { + stdout, stderr io.Writer + stdin io.Reader +} + +// Flags returns the base command flags, which is always nil. +func (c *BaseCommand) Flags() *FlagSet { + return nil +} + +// Hidden indicates whether the command is hidden. The default is unhidden. +func (c *BaseCommand) Hidden() bool { + return false +} + +// Prompt prompts the user for a value. If stdin is a tty, it prompts. Otherwise +// it reads from the reader. +func (c *BaseCommand) Prompt(msg string) (string, error) { + scanner := bufio.NewScanner(io.LimitReader(c.Stdin(), 64*1_000)) + + if c.Stdin() == os.Stdin && isatty.IsTerminal(os.Stdin.Fd()) { + fmt.Fprint(c.Stdout(), msg) + } + + scanner.Scan() + + if err := scanner.Err(); err != nil { + return "", fmt.Errorf("failed to read stdin: %w", err) + } + return scanner.Text(), nil +} + +// Stdout returns the stdout stream. +func (c *BaseCommand) Stdout() io.Writer { + if v := c.stdout; v != nil { + return v + } + return os.Stdout +} + +// SetStdout sets the standard out. +func (c *BaseCommand) SetStdout(w io.Writer) { + c.stdout = w +} + +// Stderr returns the stderr stream. +func (c *BaseCommand) Stderr() io.Writer { + if v := c.stderr; v != nil { + return v + } + return os.Stderr +} + +// SetStdout sets the standard error. +func (c *BaseCommand) SetStderr(w io.Writer) { + c.stderr = w +} + +// Stdin returns the stdin stream. +func (c *BaseCommand) Stdin() io.Reader { + if v := c.stdin; v != nil { + return v + } + return os.Stdin +} + +// SetStdout sets the standard input. +func (c *BaseCommand) SetStdin(r io.Reader) { + c.stdin = r +} + +// Pipe creates new unqiue stdin, stdout, and stderr buffers, sets them on the +// command, and returns them. This is most useful for testing where callers want +// to simulate inputs or assert certain command outputs. +func (c *BaseCommand) Pipe() (stdin, stdout, stderr *bytes.Buffer) { + stdin = bytes.NewBuffer(nil) + stdout = bytes.NewBuffer(nil) + stderr = bytes.NewBuffer(nil) + c.stdin = stdin + c.stdout = stdout + c.stderr = stderr + return +} diff --git a/cli/command_doc_test.go b/cli/command_doc_test.go new file mode 100644 index 00000000..e0a14190 --- /dev/null +++ b/cli/command_doc_test.go @@ -0,0 +1,137 @@ +// Copyright 2023 The Authors (see AUTHORS file) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cli_test + +import ( + "context" + "fmt" + "os" + "strconv" + + "github.com/abcxyz/pkg/cli" +) + +type CountCommand struct { + cli.BaseCommand + + flagStep int64 +} + +func (c *CountCommand) Desc() string { + return "Counts from 0 up to a number" +} + +func (c *CountCommand) Help() string { + return ` +Usage: {{ COMMAND }} [options] MAX + + The count command prints out a list of numbers starting from 0 up to and + including the provided MAX. + + $ {{ COMMAND }} 50 + + The value for MAX must be a positive integer. +` +} + +func (c *CountCommand) Flags() *cli.FlagSet { + set := cli.NewFlagSet() + + f := set.NewSection("Number options") + + f.Int64Var(&cli.Int64Var{ + Name: "step", + Aliases: []string{"s"}, + Example: "1", + Default: 1, + Target: &c.flagStep, + Usage: "Numeric value by which to increment between each number.", + }) + + return set +} + +func (c *CountCommand) Run(ctx context.Context, args []string) error { + f := c.Flags() + if err := f.Parse(args); err != nil { + return fmt.Errorf("failed to parse flags: %w", err) + } + + args = f.Args() + if len(args) != 1 { + return fmt.Errorf("expected 1 argument, got %q", args) + } + + maxStr := args[0] + max, err := strconv.ParseInt(maxStr, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse max: %w", err) + } + + for i := int64(0); i <= max; i += c.flagStep { + fmt.Fprintln(c.Stdout(), i) + } + + return nil +} + +func Example_commandWithFlags() { + ctx := context.Background() + + // Create the command. + rootCmd := func() cli.Command { + return &cli.RootCommand{ + Name: "my-tool", + Version: "1.2.3", + Commands: map[string]cli.CommandFactory{ + "count": func() cli.Command { + return &CountCommand{} + }, + }, + } + } + + cmd := rootCmd() + + // Help output is written to stderr by default. Redirect to stdout so the + // "Output" assertion works. + cmd.SetStderr(os.Stdout) + + fmt.Fprintln(cmd.Stdout(), "\nUp to 3:") + if err := cmd.Run(ctx, []string{"count", "3"}); err != nil { + panic(err) + } + + fmt.Fprintln(cmd.Stdout(), "\nUp to 10, stepping 2") + if err := cmd.Run(ctx, []string{"count", "-step=2", "10"}); err != nil { + panic(err) + } + + // Output: + // + // Up to 3: + // 0 + // 1 + // 2 + // 3 + // + // Up to 10, stepping 2 + // 0 + // 2 + // 4 + // 6 + // 8 + // 10 +} diff --git a/cli/command_group_doc_test.go b/cli/command_group_doc_test.go new file mode 100644 index 00000000..5e0f5d29 --- /dev/null +++ b/cli/command_group_doc_test.go @@ -0,0 +1,128 @@ +// Copyright 2023 The Authors (see AUTHORS file) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cli_test + +import ( + "context" + "fmt" + "os" + + "github.com/abcxyz/pkg/cli" +) + +type EatCommand struct { + cli.BaseCommand +} + +func (c *EatCommand) Desc() string { + return "Eat some food" +} + +func (c *EatCommand) Help() string { + return ` +Usage: {{ COMMAND }} [options] + + The eat command eats food. +` +} + +func (c *EatCommand) Flags() *cli.FlagSet { + return cli.NewFlagSet() +} + +func (c *EatCommand) Run(ctx context.Context, args []string) error { + if err := c.Flags().Parse(args); err != nil { + return fmt.Errorf("failed to parse flags: %w", err) + } + + // TODO: implement + return nil +} + +type DrinkCommand struct { + cli.BaseCommand +} + +func (c *DrinkCommand) Desc() string { + return "Drink some water" +} + +func (c *DrinkCommand) Help() string { + return ` +Usage: {{ COMMAND }} [options] + + The drink command drinks water. +` +} + +func (c *DrinkCommand) Flags() *cli.FlagSet { + return cli.NewFlagSet() +} + +func (c *DrinkCommand) Run(ctx context.Context, args []string) error { + if err := c.Flags().Parse(args); err != nil { + return fmt.Errorf("failed to parse flags: %w", err) + } + + // TODO: implement + return nil +} + +func Example_commandGroup() { + ctx := context.Background() + + rootCmd := func() cli.Command { + return &cli.RootCommand{ + Name: "my-tool", + Version: "1.2.3", + Commands: map[string]cli.CommandFactory{ + "eat": func() cli.Command { + return &EatCommand{} + }, + "drink": func() cli.Command { + return &DrinkCommand{} + }, + }, + } + } + + cmd := rootCmd() + + // Help output is written to stderr by default. Redirect to stdout so the + // "Output" assertion works. + cmd.SetStderr(os.Stdout) + + fmt.Fprintln(cmd.Stdout(), "\nTop-level help:") + if err := cmd.Run(ctx, []string{"-h"}); err != nil { + panic(err) + } + + fmt.Fprintln(cmd.Stdout(), "\nCommand-level help:") + if err := cmd.Run(ctx, []string{"eat", "-h"}); err != nil { + panic(err) + } + + // Output: + // Top-level help: + // Usage: my-tool COMMAND + // + // drink Drink some water + // eat Eat some food + // + // Command-level help: + // Usage: my-tool eat [options] + // + // The eat command eats food. +} diff --git a/cli/command_persistent_flags_doc_test.go b/cli/command_persistent_flags_doc_test.go new file mode 100644 index 00000000..a7d0b18e --- /dev/null +++ b/cli/command_persistent_flags_doc_test.go @@ -0,0 +1,180 @@ +// Copyright 2023 The Authors (see AUTHORS file) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cli_test + +import ( + "context" + "fmt" + "os" + + "github.com/abcxyz/pkg/cli" +) + +// serverFlags represent the shared flags among all server commands. Embed this +// struct into any commands that interact with a server. +type serverFlags struct { + flagAddress string + flagTLSSkipVerify bool +} + +func (s *serverFlags) addServerFlags(set *cli.FlagSet) { + f := set.NewSection("Server options") + + f.StringVar(&cli.StringVar{ + Name: "server-address", + Example: "https://my.corp.server:8145", + Default: "http://localhost:8145", + EnvVar: "CLI_SERVER_ADDRESS", + Target: &s.flagAddress, + Usage: "Endpoint, including protocol and port, the server.", + }) + + f.BoolVar(&cli.BoolVar{ + Name: "insecure", + Default: false, + EnvVar: "CLI_SERVER_TLS_SKIP_VERIFY", + Target: &s.flagTLSSkipVerify, + Usage: "Skip TLS verification. This is bad, please don't do it.", + }) +} + +type UploadCommand struct { + cli.BaseCommand + serverFlags +} + +func (c *UploadCommand) Desc() string { + return "Upload a file" +} + +func (c *UploadCommand) Help() string { + return ` +Usage: {{ COMMAND }} [options] + + Upload a file to the server. +` +} + +func (c *UploadCommand) Flags() *cli.FlagSet { + set := cli.NewFlagSet() + c.serverFlags.addServerFlags(set) + return set +} + +func (c *UploadCommand) Run(ctx context.Context, args []string) error { + if err := c.Flags().Parse(args); err != nil { + return fmt.Errorf("failed to parse flags: %w", err) + } + + _ = c.flagAddress // or c.serverFlags.flagAddress + _ = c.flagTLSSkipVerify + + // TODO: implement + return nil +} + +type DownloadCommand struct { + cli.BaseCommand + serverFlags +} + +func (c *DownloadCommand) Desc() string { + return "Download a file" +} + +func (c *DownloadCommand) Help() string { + return ` +Usage: {{ COMMAND }} [options] + + Download a file from the server. +` +} + +func (c *DownloadCommand) Flags() *cli.FlagSet { + set := cli.NewFlagSet() + c.serverFlags.addServerFlags(set) + return set +} + +func (c *DownloadCommand) Run(ctx context.Context, args []string) error { + if err := c.Flags().Parse(args); err != nil { + return fmt.Errorf("failed to parse flags: %w", err) + } + + _ = c.flagAddress // or c.serverFlags.flagAddress + _ = c.flagTLSSkipVerify + + // TODO: implement + return nil +} + +func Example_persistentFlags() { + ctx := context.Background() + + rootCmd := func() cli.Command { + return &cli.RootCommand{ + Name: "my-tool", + Version: "1.2.3", + Commands: map[string]cli.CommandFactory{ + "download": func() cli.Command { + return &DownloadCommand{} + }, + "upload": func() cli.Command { + return &UploadCommand{} + }, + }, + } + } + + cmd := rootCmd() + + // Help output is written to stderr by default. Redirect to stdout so the + // "Output" assertion works. + cmd.SetStderr(os.Stdout) + + fmt.Fprintln(cmd.Stdout(), "\nTop-level help:") + if err := cmd.Run(ctx, []string{"-h"}); err != nil { + panic(err) + } + + fmt.Fprintln(cmd.Stdout(), "\nCommand-level help:") + if err := cmd.Run(ctx, []string{"download", "-h"}); err != nil { + panic(err) + } + + // Output: + // Top-level help: + // Usage: my-tool COMMAND + // + // download Download a file + // upload Upload a file + // + // Command-level help: + // Usage: my-tool download [options] + // + // Download a file from the server. + // + // Server options + // + // -insecure + // Skip TLS verification. This is bad, please don't do it. The default + // value is "false". This option can also be specified with the + // CLI_SERVER_TLS_SKIP_VERIFY environment variable. + // + // -server-address="https://my.corp.server:8145" + // Endpoint, including protocol and port, the server. The default value + // is "http://localhost:8145". This option can also be specified with the + // CLI_SERVER_ADDRESS environment variable. +} diff --git a/cli/command_test.go b/cli/command_test.go new file mode 100644 index 00000000..d45621bd --- /dev/null +++ b/cli/command_test.go @@ -0,0 +1,321 @@ +// Copyright 2023 The Authors (see AUTHORS file) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cli + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/abcxyz/pkg/logging" + "github.com/abcxyz/pkg/testutil" +) + +func TestRootCommand_Help(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + cmd Command + exp string + }{ + { + name: "no_commands", + cmd: &RootCommand{ + Name: "test", + }, + exp: `Usage: test COMMAND`, + }, + { + name: "nil_command", + cmd: &RootCommand{ + Name: "test", + Commands: map[string]CommandFactory{ + "nil": func() Command { + return nil + }, + }, + }, + exp: `Usage: test COMMAND`, + }, + { + name: "single", + cmd: &RootCommand{ + Name: "test", + Commands: map[string]CommandFactory{ + "one": func() Command { return &TestCommand{} }, + }, + }, + exp: ` +Usage: test COMMAND + + one Test command +`, + }, + { + name: "multiple", + cmd: &RootCommand{ + Name: "test", + Commands: map[string]CommandFactory{ + "one": func() Command { return &TestCommand{} }, + "two": func() Command { return &TestCommand{} }, + "three": func() Command { return &TestCommand{} }, + }, + }, + exp: ` +Usage: test COMMAND + + one Test command + three Test command + two Test command +`, + }, + { + name: "hidden", + cmd: &RootCommand{ + Name: "test", + Commands: map[string]CommandFactory{ + "one": func() Command { return &TestCommand{} }, + "two": func() Command { + return &TestCommand{ + Hide: true, + } + }, + }, + }, + exp: ` +Usage: test COMMAND + + one Test command +`, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + if got, want := strings.TrimSpace(tc.cmd.Help()), strings.TrimSpace(tc.exp); got != want { + t.Errorf("expected\n\n%s\n\nto be\n\n%s\n\n", got, want) + } + }) + } +} + +func TestRootCommand_Run(t *testing.T) { + t.Parallel() + + ctx := logging.WithLogger(context.Background(), logging.TestLogger(t)) + + rootCmd := func() Command { + return &RootCommand{ + Name: "test", + Version: "1.2.3", + Commands: map[string]CommandFactory{ + "default": func() Command { + return &TestCommand{ + Output: "output from default command", + } + }, + "error": func() Command { + return &TestCommand{ + Error: fmt.Errorf("a bad thing happened"), + } + }, + "hidden": func() Command { + return &TestCommand{ + Hide: true, + Output: "you found me", + } + }, + "child": func() Command { + return &RootCommand{ + Name: "child", + Description: "This is a child command", + Commands: map[string]CommandFactory{ + "default": func() Command { + return &TestCommand{ + Output: "output from child", + } + }, + }, + } + }, + }, + } + } + + cases := []struct { + name string + cmd Command + args []string + err string + expStdout string + expStderr string + }{ + { + name: "nothing", + args: nil, + expStderr: `Usage: test COMMAND`, + }, + { + name: "-h", + args: []string{"-h"}, + expStderr: `Usage: test COMMAND`, + }, + { + name: "-help", + args: []string{"-help"}, + expStderr: `Usage: test COMMAND`, + }, + { + name: "--help", + args: []string{"-help"}, + expStderr: `Usage: test COMMAND`, + }, + { + name: "-v", + args: []string{"-v"}, + expStderr: `1.2.3`, + }, + { + name: "-version", + args: []string{"-version"}, + expStderr: `1.2.3`, + }, + { + name: "--version", + args: []string{"--version"}, + expStderr: `1.2.3`, + }, + { + name: "unknown_command", + args: []string{"nope"}, + err: `unknown command "nope": run "test -help" for a list of commands`, + }, + { + name: "runs_parent_command", + args: []string{"default"}, + expStdout: `output from default command`, + }, + { + name: "handles_error", + args: []string{"error"}, + err: `a bad thing happened`, + }, + { + name: "runs_hidden", + args: []string{"hidden"}, + expStdout: `you found me`, + }, + { + name: "runs_child", + args: []string{"child", "default"}, + expStdout: `output from child`, + }, + { + name: "child_version", + args: []string{"child", "-v"}, + expStderr: `1.2.3`, + }, + { + name: "child_help", + args: []string{"child", "-h"}, + expStderr: `Usage: test child COMMAND`, + }, + { + name: "child_help_flags", + args: []string{"child", "default", "-h"}, + expStderr: `-string="my-string"`, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + cmd := rootCmd() + _, stdout, stderr := cmd.Pipe() + + err := cmd.Run(ctx, tc.args) + if diff := testutil.DiffErrString(err, tc.err); diff != "" { + t.Errorf("Unexpected err: %s", diff) + } + + if got, want := strings.TrimSpace(stdout.String()), strings.TrimSpace(tc.expStdout); !strings.Contains(got, want) { + t.Errorf("expected\n\n%s\n\nto contain\n\n%s\n\n", got, want) + } + if got, want := strings.TrimSpace(stderr.String()), strings.TrimSpace(tc.expStderr); !strings.Contains(got, want) { + t.Errorf("expected\n\n%s\n\nto contain\n\n%s\n\n", got, want) + } + }) + } +} + +type TestCommand struct { + BaseCommand + + Hide bool + Output string + Error error + + flagString string +} + +func (c *TestCommand) Desc() string { + return "Test command" +} + +func (c *TestCommand) Help() string { + return "Usage: {{ COMMAND }}" +} + +func (c *TestCommand) Flags() *FlagSet { + set := NewFlagSet() + + f := set.NewSection("Options") + + f.StringVar(&StringVar{ + Name: "string", + Example: "my-string", + Target: &c.flagString, + Usage: "A literal string.", + }) + + return set +} + +func (c *TestCommand) Hidden() bool { return c.Hide } + +func (c *TestCommand) Run(ctx context.Context, args []string) error { + if err := c.Flags().Parse(args); err != nil { + return fmt.Errorf("failed to parse flags: %w", err) + } + + if err := c.Error; err != nil { + return err + } + + if v := c.Output; v != "" { + fmt.Fprint(c.Stdout(), v) + } + + return nil +} diff --git a/cli/flags.go b/cli/flags.go new file mode 100644 index 00000000..bfbd93cc --- /dev/null +++ b/cli/flags.go @@ -0,0 +1,697 @@ +// Copyright 2023 The Authors (see AUTHORS file) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//nolint:wrapcheck // These functions intentionally just wrap flag.Flag. +package cli + +import ( + "flag" + "fmt" + "io" + "os" + "sort" + "strconv" + "strings" + "time" + + "github.com/abcxyz/pkg/timeutil" + "github.com/kr/text" +) + +const maxLineLength = 80 + +// LookupEnvFunc is the signature of a function for looking up environment +// variables. It makes that of [os.LookupEnv]. +type LookupEnvFunc = func(string) (string, bool) + +// MapLookuper returns a LookupEnvFunc that reads from a map instead of the +// environment. This is mostly used for testing. +func MapLookuper(m map[string]string) LookupEnvFunc { + return func(s string) (string, bool) { + v, ok := m[s] + return v, ok + } +} + +// FlagSet is the root flag set for creating and managing flag sections. +type FlagSet struct { + flagSet *flag.FlagSet + sections []*FlagSection + lookupEnv LookupEnvFunc +} + +// Option is an option to the flagset. +type Option func(fs *FlagSet) *FlagSet + +// WithLookupEnv defines a custom function for looking up environment variables. +// This is mostly useful for testing. +func WithLookupEnv(fn LookupEnvFunc) Option { + if fn == nil { + panic("lookup cannot be nil") + } + + return func(fs *FlagSet) *FlagSet { + fs.lookupEnv = fn + return fs + } +} + +// NewFlagSet creates a new root flag set. +func NewFlagSet(opts ...Option) *FlagSet { + f := flag.NewFlagSet("", flag.ContinueOnError) + + // Errors and usage are controlled by the writer. + f.Usage = func() {} + f.SetOutput(io.Discard) + + fs := &FlagSet{ + flagSet: f, + lookupEnv: os.LookupEnv, + } + + for _, opt := range opts { + fs = opt(fs) + } + + return fs +} + +// FlagSection represents a group section of flags. The flags are actually +// "flat" in memory, but maintain a structure for better help output and alias +// matching. +type FlagSection struct { + name string + flagNames []string + + // fields inherited from the parent + flagSet *flag.FlagSet + lookupEnv LookupEnvFunc +} + +// NewSection creates a new flag section. +func (f *FlagSet) NewSection(name string) *FlagSection { + fs := &FlagSection{ + name: name, + flagSet: f.flagSet, + lookupEnv: f.lookupEnv, + } + f.sections = append(f.sections, fs) + return fs +} + +// Args implements flag.FlagSet#Args. +func (f *FlagSet) Args() []string { + return f.flagSet.Args() +} + +// Args implements flag.FlagSet#Parse. +func (f *FlagSet) Parse(args []string) error { + return f.flagSet.Parse(args) +} + +// Args implements flag.FlagSet#Parsed. +func (f *FlagSet) Parsed() bool { + return f.flagSet.Parsed() +} + +// Args implements flag.FlagSet#Visit. +func (f *FlagSet) Visit(fn func(*flag.Flag)) { + f.flagSet.Visit(fn) +} + +// Args implements flag.FlagSet#VisitAll. +func (f *FlagSet) VisitAll(fn func(*flag.Flag)) { + f.flagSet.VisitAll(fn) +} + +// Help returns formatted help output. +func (f *FlagSet) Help() string { + var b strings.Builder + + for _, set := range f.sections { + sort.Strings(set.flagNames) + + fmt.Fprint(&b, set.name) + fmt.Fprint(&b, "\n\n") + + for _, name := range set.flagNames { + sub := set.flagSet.Lookup(name) + if sub == nil { + panic("inconsistency between flag structure and help") + } + + typ, ok := sub.Value.(Value) + if !ok { + panic(fmt.Sprintf("flag is incorrect type %T", sub.Value)) + } + + // Do not process hidden flags. + if typ.Hidden() { + continue + } + + // Incorporate aliases. + aliases := typ.Aliases() + sort.Slice(aliases, func(i, j int) bool { + return len(aliases[i]) < len(aliases[j]) + }) + all := make([]string, 0, len(aliases)+1) + for _, v := range aliases { + all = append(all, "-"+v) + } + all = append(all, "-"+sub.Name) + + // Handle boolean flags + if typ.IsBoolFlag() { + fmt.Fprintf(&b, " %s\n", strings.Join(all, ", ")) + } else { + fmt.Fprintf(&b, " %s=%q\n", strings.Join(all, ", "), typ.Example()) + } + + indented := wrapAtLengthWithPadding(sub.Usage, 8) + fmt.Fprint(&b, indented) + fmt.Fprint(&b, "\n\n") + } + } + + return strings.TrimRight(b.String(), "\n") +} + +// Value is an extension of [flag.Value] which adds additional fields for +// setting examples and defining aliases. All flags with this package must +// statisfy this interface. +type Value interface { + flag.Value + + // Get returns the value. Even though we know the concrete type with generics, + // this returns [any] to match the standard library. + Get() any + + // Aliases returns any defined aliases of the flag. + Aliases() []string + + // Example returns an example input for the flag. For example, if the flag was + // accepting a URL, this could be "https://example.com". This is largely meant + // as a hint to the CLI user and only affects help output. + Example() string + + // Hidden returns true if the flag is hidden, false otherwise. + Hidden() bool + + // IsBoolFlag returns true if the flag accepts no arguments, false otherwise. + IsBoolFlag() bool +} + +// ParserFunc is a function that parses a value into T, or returns an error. +type ParserFunc[T any] func(val string) (T, error) + +// PrinterFunc is a function that pretty-prints T. +type PrinterFunc[T any] func(cur T) string + +// SetterFunc is a function that sets *T to T. +type SetterFunc[T any] func(cur *T, val T) + +type Var[T any] struct { + Name string + Aliases []string + Usage string + Example string + Default T + Hidden bool + IsBool bool + EnvVar string + Target *T + + Parser ParserFunc[T] + Printer PrinterFunc[T] + + // Setter defines the function that sets the variable into the target. If nil, + // it uses a default setter which overwrites the entire value of the Target. + // Implementations that do special processing (such as appending to a slice), + // may override this to customize the behavior. + Setter SetterFunc[T] +} + +// Flag is a lower-level API for creating a flag on a flag section. Callers +// should use this for defining new flags as it sets defaults and provides more +// granular usage details. +// +// It panics if any of the target, parser, or printer are nil. +func Flag[T any](f *FlagSection, i *Var[T]) { + if i.Target == nil { + panic("missing target") + } + + parser := i.Parser + if parser == nil { + panic("missing parser func") + } + + printer := i.Printer + if printer == nil { + panic("missing printer func") + } + + setter := i.Setter + if setter == nil { + setter = func(cur *T, val T) { *cur = val } + } + + initial := i.Default + if v, ok := f.lookupEnv(i.EnvVar); ok { + if t, err := parser(v); err == nil { + initial = t + } + } + + // Set a default value. + *i.Target = initial + + // Compute a sane default if one was not given. + example := i.Example + if example == "" { + example = fmt.Sprintf("%T", *new(T)) + } + + // Pre-compute full usage. + usage := i.Usage + + if v := printer(i.Default); v != "" { + usage += fmt.Sprintf(" The default value is %q.", v) + } + + if v := i.EnvVar; v != "" { + usage += fmt.Sprintf(" This option can also be specified with the %s "+ + "environment variable.", v) + } + + fv := &flagValue[T]{ + target: i.Target, + hidden: i.Hidden, + isBool: i.IsBool, + example: example, + parser: parser, + printer: printer, + setter: setter, + aliases: i.Aliases, + } + f.flagNames = append(f.flagNames, i.Name) + f.flagSet.Var(fv, i.Name, usage) + + // Since aliases are not added as a flag name, we can safely add them to the + // main flag set. Our custom help will skip them. + for _, alias := range i.Aliases { + f.flagSet.Var(fv, alias, "") + } +} + +var _ Value = (*flagValue[any])(nil) + +type flagValue[T any] struct { + target *T + hidden bool + isBool bool + example string + + parser ParserFunc[T] + printer PrinterFunc[T] + setter SetterFunc[T] + aliases []string +} + +func (f *flagValue[T]) Set(s string) error { + v, err := f.parser(s) + if err != nil { + return err + } + f.setter(f.target, v) + return nil +} + +func (f *flagValue[T]) Get() any { return *f.target } +func (f *flagValue[T]) Aliases() []string { return f.aliases } +func (f *flagValue[T]) String() string { return f.printer(*f.target) } +func (f *flagValue[T]) Example() string { return f.example } +func (f *flagValue[T]) Hidden() bool { return f.hidden } +func (f *flagValue[T]) IsBoolFlag() bool { return f.isBool } + +type BoolVar struct { + Name string + Aliases []string + Usage string + Example string + Default bool + Hidden bool + EnvVar string + Target *bool +} + +func (f *FlagSection) BoolVar(i *BoolVar) { + Flag(f, &Var[bool]{ + Name: i.Name, + Aliases: i.Aliases, + Usage: i.Usage, + Example: i.Example, + IsBool: true, + Default: i.Default, + Hidden: i.Hidden, + EnvVar: i.EnvVar, + Target: i.Target, + Parser: strconv.ParseBool, + Printer: strconv.FormatBool, + }) +} + +type DurationVar struct { + Name string + Aliases []string + Usage string + Example string + Default time.Duration + Hidden bool + EnvVar string + Target *time.Duration +} + +func (f *FlagSection) DurationVar(i *DurationVar) { + Flag(f, &Var[time.Duration]{ + Name: i.Name, + Aliases: i.Aliases, + Usage: i.Usage, + Example: i.Example, + Default: i.Default, + Hidden: i.Hidden, + EnvVar: i.EnvVar, + Target: i.Target, + Parser: time.ParseDuration, + Printer: timeutil.HumanDuration, + }) +} + +type Float64Var struct { + Name string + Aliases []string + Usage string + Example string + Default float64 + Hidden bool + EnvVar string + Target *float64 +} + +func (f *FlagSection) Float64Var(i *Float64Var) { + parser := func(s string) (float64, error) { + return strconv.ParseFloat(s, 64) + } + printer := func(v float64) string { + return strconv.FormatFloat(v, 'e', -1, 64) + } + Flag(f, &Var[float64]{ + Name: i.Name, + Aliases: i.Aliases, + Usage: i.Usage, + Example: i.Example, + Default: i.Default, + Hidden: i.Hidden, + EnvVar: i.EnvVar, + Target: i.Target, + Parser: parser, + Printer: printer, + }) +} + +type IntVar struct { + Name string + Aliases []string + Usage string + Example string + Default int + Hidden bool + EnvVar string + Target *int +} + +func (f *FlagSection) IntVar(i *IntVar) { + parser := func(s string) (int, error) { + v, err := strconv.ParseInt(s, 10, 64) + return int(v), err + } + printer := func(v int) string { return strconv.FormatInt(int64(v), 10) } + + Flag(f, &Var[int]{ + Name: i.Name, + Aliases: i.Aliases, + Usage: i.Usage, + Example: i.Example, + Default: i.Default, + Hidden: i.Hidden, + EnvVar: i.EnvVar, + Target: i.Target, + Parser: parser, + Printer: printer, + }) +} + +type Int64Var struct { + Name string + Aliases []string + Usage string + Example string + Default int64 + Hidden bool + EnvVar string + Target *int64 +} + +func (f *FlagSection) Int64Var(i *Int64Var) { + parser := func(s string) (int64, error) { return strconv.ParseInt(s, 10, 64) } + printer := func(v int64) string { return strconv.FormatInt(v, 10) } + + Flag(f, &Var[int64]{ + Name: i.Name, + Aliases: i.Aliases, + Usage: i.Usage, + Example: i.Example, + Default: i.Default, + Hidden: i.Hidden, + EnvVar: i.EnvVar, + Target: i.Target, + Parser: parser, + Printer: printer, + }) +} + +type StringVar struct { + Name string + Aliases []string + Usage string + Example string + Default string + Hidden bool + EnvVar string + Target *string +} + +func (f *FlagSection) StringVar(i *StringVar) { + parser := func(s string) (string, error) { return s, nil } + printer := func(v string) string { return v } + + Flag(f, &Var[string]{ + Name: i.Name, + Aliases: i.Aliases, + Usage: i.Usage, + Example: i.Example, + Default: i.Default, + Hidden: i.Hidden, + EnvVar: i.EnvVar, + Target: i.Target, + Parser: parser, + Printer: printer, + }) +} + +type StringMapVar struct { + Name string + Aliases []string + Usage string + Example string + Default map[string]string + Hidden bool + EnvVar string + Target *map[string]string +} + +func (f *FlagSection) StringMapVar(i *StringMapVar) { + parser := func(s string) (map[string]string, error) { + idx := strings.Index(s, "=") + if idx == -1 { + return nil, fmt.Errorf("missing = in KV pair %q", s) + } + + m := make(map[string]string, 1) + m[s[0:idx]] = s[idx+1:] + return m, nil + } + + printer := func(m map[string]string) string { + list := make([]string, 0, len(m)) + for k, v := range m { + list = append(list, k+"="+v) + } + sort.Strings(list) + return strings.Join(list, ",") + } + + setter := func(cur *map[string]string, val map[string]string) { + if *cur == nil { + *cur = make(map[string]string) + } + for k, v := range val { + (*cur)[k] = v + } + } + + Flag(f, &Var[map[string]string]{ + Name: i.Name, + Aliases: i.Aliases, + Usage: i.Usage, + Example: i.Example, + Default: i.Default, + Hidden: i.Hidden, + EnvVar: i.EnvVar, + Target: i.Target, + Parser: parser, + Printer: printer, + Setter: setter, + }) +} + +type StringSliceVar struct { + Name string + Aliases []string + Usage string + Example string + Default []string + Hidden bool + EnvVar string + Target *[]string +} + +func (f *FlagSection) StringSliceVar(i *StringSliceVar) { + parser := func(s string) ([]string, error) { + final := make([]string, 0) + parts := strings.Split(s, ",") + for _, part := range parts { + trimmed := strings.TrimSpace(part) + if trimmed != "" { + final = append(final, trimmed) + } + } + return final, nil + } + + printer := func(v []string) string { + return strings.Join(v, ",") + } + + setter := func(cur *[]string, val []string) { + *cur = append(*cur, val...) + } + + Flag(f, &Var[[]string]{ + Name: i.Name, + Aliases: i.Aliases, + Usage: i.Usage, + Example: i.Example, + Default: i.Default, + Hidden: i.Hidden, + EnvVar: i.EnvVar, + Target: i.Target, + Parser: parser, + Printer: printer, + Setter: setter, + }) +} + +type UintVar struct { + Name string + Aliases []string + Usage string + Example string + Default uint + Hidden bool + EnvVar string + Target *uint +} + +func (f *FlagSection) UintVar(i *UintVar) { + parser := func(s string) (uint, error) { + v, err := strconv.ParseUint(s, 10, 64) + return uint(v), err + } + printer := func(v uint) string { return strconv.FormatUint(uint64(v), 10) } + + Flag(f, &Var[uint]{ + Name: i.Name, + Aliases: i.Aliases, + Usage: i.Usage, + Example: i.Example, + Default: i.Default, + Hidden: i.Hidden, + EnvVar: i.EnvVar, + Target: i.Target, + Parser: parser, + Printer: printer, + }) +} + +type Uint64Var struct { + Name string + Aliases []string + Usage string + Example string + Default uint64 + Hidden bool + EnvVar string + Target *uint64 +} + +func (f *FlagSection) Uint64Var(i *Uint64Var) { + parser := func(s string) (uint64, error) { return strconv.ParseUint(s, 10, 64) } + printer := func(v uint64) string { return strconv.FormatUint(v, 10) } + + Flag(f, &Var[uint64]{ + Name: i.Name, + Aliases: i.Aliases, + Usage: i.Usage, + Example: i.Example, + Default: i.Default, + Hidden: i.Hidden, + EnvVar: i.EnvVar, + Target: i.Target, + Parser: parser, + Printer: printer, + }) +} + +// wrapAtLengthWithPadding wraps the given text at the maxLineLength, taking +// into account any provided left padding. +func wrapAtLengthWithPadding(s string, pad int) string { + wrapped := text.Wrap(s, maxLineLength-pad) + lines := strings.Split(wrapped, "\n") + for i, line := range lines { + lines[i] = strings.Repeat(" ", pad) + line + } + return strings.Join(lines, "\n") +} diff --git a/cli/flags_test.go b/cli/flags_test.go new file mode 100644 index 00000000..730b6ab1 --- /dev/null +++ b/cli/flags_test.go @@ -0,0 +1,96 @@ +// Copyright 2023 The Authors (see AUTHORS file) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cli + +import ( + "flag" + "io" + "reflect" + "strings" + "testing" +) + +func TestNewFlagSet(t *testing.T) { + t.Parallel() + + fs := NewFlagSet() + + if got, want := fs.flagSet.ErrorHandling(), flag.ContinueOnError; got != want { + t.Errorf("expected %q to be %q", got, want) + } + if got, want := fs.flagSet.Output(), io.Discard; got != want { + t.Errorf("expected %q to be %q", got, want) + } +} + +func TestFlagSet_NewSection(t *testing.T) { + t.Parallel() + + fs := NewFlagSet() + sec := fs.NewSection("child") + + if got, want := sec.name, "child"; got != want { + t.Errorf("expected %q to be %q", got, want) + } + // object equality check + if got, want := sec.flagSet, fs.flagSet; got != want { + t.Errorf("expected %v to be %v", got, want) + } + if got, want := fs.sections, []*FlagSection{sec}; !reflect.DeepEqual(got, want) { + t.Errorf("expected %v to be %v", got, want) + } +} + +func TestFlagSet_Help(t *testing.T) { + t.Parallel() + + fs := NewFlagSet() + + sec1 := fs.NewSection("child1") + sec1.BoolVar(&BoolVar{ + Name: "my-bool", + Usage: "One usage.", + Target: ptrTo(true), + }) + sec1.Int64Var(&Int64Var{ + Name: "my-int", + Usage: "One usage.", + Hidden: true, + Target: ptrTo(int64(0)), + }) + + sec2 := fs.NewSection("child2") + sec2.StringVar(&StringVar{ + Name: "two", + Usage: "Two usage.", + Aliases: []string{"t", "at"}, + Example: "example", + Target: ptrTo(""), + }) + + if got, want := fs.Help(), "One usage. The default value is"; !strings.Contains(got, want) { + t.Errorf("expected\n\n%s\n\nto include %q", got, want) + } + if got, want := fs.Help(), `-t, -at, -two="example"`; !strings.Contains(got, want) { + t.Errorf("expected\n\n%s\n\nto include %q", got, want) + } + if got, want := fs.Help(), "my-int"; strings.Contains(got, want) { + t.Errorf("expected\n\n%s\n\nto not include %q", got, want) + } +} + +func ptrTo[T any](v T) *T { + return &v +} diff --git a/go.mod b/go.mod index df35e9c2..71e45d75 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,9 @@ require ( github.com/go-sql-driver/mysql v1.7.0 github.com/google/go-cmp v0.5.9 github.com/hashicorp/hcl/v2 v2.16.2 + github.com/kr/text v0.1.0 github.com/lestrrat-go/jwx/v2 v2.0.8 + github.com/mattn/go-isatty v0.0.17 github.com/ory/dockertest/v3 v3.9.1 github.com/sethvargo/go-envconfig v0.9.0 go.uber.org/zap v1.24.0 diff --git a/go.sum b/go.sum index 3a48b02e..c4649309 100644 --- a/go.sum +++ b/go.sum @@ -87,6 +87,8 @@ github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmt github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/lib/pq v0.0.0-20180327071824-d34b9ff171c2 h1:hRGSmZu7j271trc9sneMrpOW7GN5ngLm8YUZIPzf394= +github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= +github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0= github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= @@ -188,6 +190,7 @@ golang.org/x/sys v0.0.0-20210906170528-6f6e22806c34/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211116061358-0a5406a5449c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=