diff --git a/examples/burger/main.go b/examples/burger/main.go index e8b88c57..2d3b2dd2 100644 --- a/examples/burger/main.go +++ b/examples/burger/main.go @@ -51,7 +51,7 @@ type Burger struct { func main() { var burger Burger - var order = Order{Burger: burger} + order := Order{Burger: burger} // Should we run in accessible mode? accessible, _ := strconv.ParseBool(os.Getenv("ACCESSIBLE")) @@ -152,7 +152,6 @@ func main() { ).WithAccessible(accessible) err := form.Run() - if err != nil { fmt.Println("Uh oh:", err) os.Exit(1) diff --git a/go.mod b/go.mod index 89a0bbca..409d0ea3 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/charmbracelet/huh -go 1.21 +go 1.22 require ( github.com/catppuccin/go v0.2.0 diff --git a/spinner/examples/context-and-action-and-error/main.go b/spinner/examples/context-and-action-and-error/main.go new file mode 100644 index 00000000..a213bf51 --- /dev/null +++ b/spinner/examples/context-and-action-and-error/main.go @@ -0,0 +1,28 @@ +package main + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/charmbracelet/huh/spinner" +) + +func main() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + err := spinner.New(). + Context(ctx). + ActionWithErr(func(context.Context) error { + time.Sleep(5 * time.Second) + return nil + }). + Accessible(false). + Run() + if err != nil { + log.Fatalln(err) + } + fmt.Println("Done!") +} diff --git a/spinner/examples/context-and-action/main.go b/spinner/examples/context-and-action/main.go new file mode 100644 index 00000000..20ade5fa --- /dev/null +++ b/spinner/examples/context-and-action/main.go @@ -0,0 +1,28 @@ +package main + +import ( + "context" + "fmt" + "log" + "math/rand" + "time" + + "github.com/charmbracelet/huh/spinner" +) + +func main() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + err := spinner.New(). + Context(ctx). + Action(func() { + time.Sleep(time.Minute) + }). + Accessible(rand.Int()%2 == 0). + Run() + if err != nil { + log.Fatalln(err) + } + fmt.Println("Done!") +} diff --git a/spinner/examples/loading/main.go b/spinner/examples/loading/main.go index 6128d961..60683f5b 100644 --- a/spinner/examples/loading/main.go +++ b/spinner/examples/loading/main.go @@ -2,7 +2,6 @@ package main import ( "fmt" - "os" "time" "github.com/charmbracelet/huh/spinner" @@ -10,11 +9,12 @@ import ( func main() { action := func() { - time.Sleep(2 * time.Second) + time.Sleep(1 * time.Second) } if err := spinner.New().Title("Preparing your burger...").Action(action).Run(); err != nil { - fmt.Println(err) - os.Exit(1) + fmt.Println("Failed:", err) + return } fmt.Println("Order up!") } + diff --git a/spinner/go.mod b/spinner/go.mod index 2bd4fef7..8201378a 100644 --- a/spinner/go.mod +++ b/spinner/go.mod @@ -1,6 +1,6 @@ module github.com/charmbracelet/huh/spinner -go 1.19 +go 1.22 require ( github.com/charmbracelet/bubbles v0.20.0 diff --git a/spinner/spinner.go b/spinner/spinner.go index 9b5fd876..d0fd1616 100644 --- a/spinner/spinner.go +++ b/spinner/spinner.go @@ -1,12 +1,12 @@ package spinner import ( + "cmp" "context" - "errors" "fmt" + "io" "os" "strings" - "time" "github.com/charmbracelet/bubbles/spinner" tea "github.com/charmbracelet/bubbletea" @@ -23,12 +23,13 @@ import ( // ⣾ Loading... type Spinner struct { spinner spinner.Model - action func() + action func(ctx context.Context) error ctx context.Context accessible bool - output *termenv.Output title string titleStyle lipgloss.Style + output io.Writer + err error } type Type spinner.Spinner @@ -60,8 +61,27 @@ func (s *Spinner) Title(title string) *Spinner { return s } +// Output set the output for the spinner. +// Default is STDOUT when [Spinner.Accessible], STDERR otherwise. +func (s *Spinner) Output(w io.Writer) *Spinner { + s.output = w + return s +} + // Action sets the action of the spinner. func (s *Spinner) Action(action func()) *Spinner { + s.action = func(context.Context) error { + action() + return nil + } + return s +} + +// ActionWithErr sets the action of the spinner. +// +// This is just like [Spinner.Action], but allows the action to use a `context.Context` +// and to return an error. +func (s *Spinner) ActionWithErr(action func(context.Context) error) *Spinner { s.action = action return s } @@ -98,24 +118,29 @@ func New() *Spinner { s.Style = lipgloss.NewStyle().Foreground(lipgloss.Color("#F780E2")) return &Spinner{ - action: func() { time.Sleep(time.Second) }, spinner: s, title: "Loading...", titleStyle: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#00020A", Dark: "#FFFDF5"}), - output: termenv.NewOutput(os.Stdout), - ctx: nil, } } // Init initializes the spinner. func (s *Spinner) Init() tea.Cmd { - return s.spinner.Tick + return tea.Batch(s.spinner.Tick, func() tea.Msg { + if s.action != nil { + err := s.action(s.ctx) + return doneMsg{err} + } + return nil + }) } // Update updates the spinner. func (s *Spinner) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { - case spinner.TickMsg: + case doneMsg: + s.err = msg.err + return s, tea.Quit case tea.KeyMsg: switch msg.String() { case "ctrl+c": @@ -132,74 +157,71 @@ func (s *Spinner) Update(msg tea.Msg) (tea.Model, tea.Cmd) { func (s *Spinner) View() string { var title string if s.title != "" { - title = s.titleStyle.Render(s.title) + " " + title = s.titleStyle.Render(s.title) } return s.spinner.View() + title } // Run runs the spinner. func (s *Spinner) Run() error { - if s.accessible { - return s.runAccessible() + if s.ctx == nil && s.action == nil { + return nil } - - hasCtx := s.ctx != nil - hasCtxErr := hasCtx && s.ctx.Err() != nil - - if hasCtxErr { - if errors.Is(s.ctx.Err(), context.Canceled) { - return nil - } - return s.ctx.Err() + if s.ctx == nil { + s.ctx = context.Background() + } + if err := s.ctx.Err(); err != nil { + return err } - p := tea.NewProgram(s, tea.WithContext(s.ctx), tea.WithOutput(os.Stderr)) - if s.ctx == nil { - go func() { - s.action() - p.Quit() - }() + if s.accessible { + return s.runAccessible() } - _, err := p.Run() - if errors.Is(err, tea.ErrProgramKilled) { - return nil - } else { - return err + m, err := tea.NewProgram( + s, + tea.WithContext(s.ctx), + tea.WithOutput(s.output), + tea.WithInput(nil), + ).Run() + mm := m.(*Spinner) + if mm.err != nil { + return mm.err } + return err } // runAccessible runs the spinner in an accessible mode (statically). func (s *Spinner) runAccessible() error { - s.output.HideCursor() + tty := cmp.Or[io.Writer](s.output, os.Stdout) + output := termenv.NewOutput(tty) + output.HideCursor() frame := s.spinner.Style.Render("...") title := s.titleStyle.Render(strings.TrimSuffix(s.title, "...")) fmt.Println(title + frame) - if s.ctx == nil { - s.action() - s.output.ShowCursor() - s.output.CursorBack(len(frame) + len(title)) - return nil - } - - actionDone := make(chan struct{}) - - go func() { - s.action() - actionDone <- struct{}{} + defer func() { + output.ShowCursor() + output.CursorBack(len(frame) + len(title)) }() + actionDone := make(chan error) + if s.action != nil { + go func() { + actionDone <- s.action(s.ctx) + }() + } + for { select { case <-s.ctx.Done(): - s.output.ShowCursor() - s.output.CursorBack(len(frame) + len(title)) return s.ctx.Err() - case <-actionDone: - s.output.ShowCursor() - s.output.CursorBack(len(frame) + len(title)) - return nil + case err := <-actionDone: + return err } } } + +type doneMsg struct { + err error +} diff --git a/spinner/spinner_test.go b/spinner/spinner_test.go index 2e62f8d1..2f04d9dd 100644 --- a/spinner/spinner_test.go +++ b/spinner/spinner_test.go @@ -2,9 +2,12 @@ package spinner import ( "context" + "errors" + "io" "reflect" "strings" "testing" + "time" "github.com/charmbracelet/bubbles/spinner" tea "github.com/charmbracelet/bubbletea" @@ -45,15 +48,23 @@ func TestSpinnerView(t *testing.T) { } func TestSpinnerContextCancellation(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - - s := New().Context(ctx) - cancel() // Cancel before running + exercise(t, func() *Spinner { + ctx, cancel := context.WithCancel(context.Background()) + s := New().Context(ctx) + cancel() // Cancel before running + return s + }, requireContextCanceled) +} - err := s.Run() - if err != nil { - t.Errorf("Run() returned an error after context cancellation: %v", err) - } +func TestSpinnerContextCancellationWhileRunning(t *testing.T) { + exercise(t, func() *Spinner { + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(250 * time.Millisecond) + cancel() + }() + return New().Context(ctx) + }, requireContextCanceled) } func TestSpinnerStyleMethods(t *testing.T) { @@ -105,10 +116,66 @@ func TestSpinnerUpdate(t *testing.T) { } } -func TestAccessibleSpinner(t *testing.T) { - s := New().Accessible(true) - err := s.Run() +func TestSpinnerSimple(t *testing.T) { + exercise(t, func() *Spinner { + return New().Action(func() {}) + }, requireNoError) +} + +func TestSpinnerWithContextAndAction(t *testing.T) { + exercise(t, func() *Spinner { + ctx := context.Background() + return New().Context(ctx).Action(func() {}) + }, requireNoError) +} + +func TestSpinnerWithActionError(t *testing.T) { + fake := errors.New("fake") + exercise(t, func() *Spinner { + return New().ActionWithErr(func(context.Context) error { return fake }) + }, requireErrorIs(fake)) +} + +func exercise(t *testing.T, factory func() *Spinner, checker func(tb testing.TB, err error)) { + t.Helper() + t.Run("accessible", func(t *testing.T) { + err := factory(). + Accessible(true). + Output(io.Discard). + Run() + checker(t, err) + }) + t.Run("regular", func(t *testing.T) { + err := factory(). + Accessible(false). + Output(io.Discard). + Run() + checker(t, err) + }) +} + +func requireNoError(tb testing.TB, err error) { + tb.Helper() if err != nil { - t.Errorf("Run() in accessible mode returned an error: %v", err) + tb.Errorf("expected no error, got %v", err) + } +} + +func requireErrorIs(target error) func(tb testing.TB, err error) { + return func(tb testing.TB, err error) { + tb.Helper() + if !errors.Is(err, target) { + tb.Errorf("expected error to be %v, got %v", target, err) + } + } +} + +func requireContextCanceled(tb testing.TB, err error) { + tb.Helper() + switch { + case errors.Is(err, context.Canceled): + case errors.Is(err, tea.ErrProgramKilled): + default: + tb.Errorf("expected to get a context canceled error, got %v", err) } }