From 706fc98f9f9a6fa85fd7172a505037753622b7c0 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Thu, 5 Sep 2024 05:12:01 -0700 Subject: [PATCH] refac(git/config): Use an iterator for ListRegexp (#382) TODO from when ListRegexp was originally implemented. Use an iter.Seq now that we're using Go 1.23 to build. --- internal/git/config.go | 44 +++++++++---------- internal/git/config_test.go | 20 ++------- internal/sliceutil/collect.go | 18 ++++++++ internal/sliceutil/collect_test.go | 68 ++++++++++++++++++++++++++++++ internal/spice/config.go | 29 ++++--------- internal/spice/stack_edit.go | 4 +- 6 files changed, 120 insertions(+), 63 deletions(-) create mode 100644 internal/sliceutil/collect.go create mode 100644 internal/sliceutil/collect_test.go diff --git a/internal/git/config.go b/internal/git/config.go index 6909afda..66433cad 100644 --- a/internal/git/config.go +++ b/internal/git/config.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "io" + "iter" "strings" "github.com/charmbracelet/log" @@ -133,10 +134,7 @@ type ConfigEntry struct { // ListRegexp lists all configuration entries that match the given pattern. // If pattern is empty, '.' is used to match all entries. -func (cfg *Config) ListRegexp(ctx context.Context, pattern string) ( - func(yield func(ConfigEntry, error) bool), - error, -) { +func (cfg *Config) ListRegexp(ctx context.Context, pattern string) iter.Seq2[ConfigEntry, error] { if pattern == "" { pattern = "." } @@ -145,26 +143,25 @@ func (cfg *Config) ListRegexp(ctx context.Context, pattern string) ( var _newline = []byte("\n") -func (cfg *Config) list(ctx context.Context, args ...string) ( - func(yield func(ConfigEntry, error) bool), - error, -) { +func (cfg *Config) list(ctx context.Context, args ...string) iter.Seq2[ConfigEntry, error] { + log := cfg.log args = append([]string{"config", "--null"}, args...) - cmd := newGitCmd(ctx, cfg.log, args...). - Dir(cfg.dir). - AppendEnv(cfg.env...) - - stdout, err := cmd.StdoutPipe() - if err != nil { - return nil, fmt.Errorf("stdout pipe: %w", err) - } + return func(yield func(ConfigEntry, error) bool) { + cmd := newGitCmd(ctx, cfg.log, args...). + Dir(cfg.dir). + AppendEnv(cfg.env...) + + stdout, err := cmd.StdoutPipe() + if err != nil { + yield(ConfigEntry{}, fmt.Errorf("stdout pipe: %w", err)) + return + } - if err := cmd.Start(cfg.exec); err != nil { - return nil, fmt.Errorf("start git-config: %w", err) - } + if err := cmd.Start(cfg.exec); err != nil { + yield(ConfigEntry{}, fmt.Errorf("start git-config: %w", err)) + return + } - log := cfg.log - return func(yield func(ConfigEntry, error) bool) { // Always wait for the command to finish when this returns. // Ignore the error because git-config fails if there are no matches. // It's not an error for us if there are no matches. @@ -195,9 +192,10 @@ func (cfg *Config) list(ctx context.Context, args ...string) ( } if err := scan.Err(); err != nil { - _ = yield(ConfigEntry{}, fmt.Errorf("scan git-config output: %w", err)) + yield(ConfigEntry{}, fmt.Errorf("scan git-config output: %w", err)) + return } - }, nil + } } // scanNullDelimited is a bufio.SplitFunc that splits on null bytes. diff --git a/internal/git/config_test.go b/internal/git/config_test.go index db46bbcd..41feecc0 100644 --- a/internal/git/config_test.go +++ b/internal/git/config_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.abhg.dev/gs/internal/logtest" + "go.abhg.dev/gs/internal/sliceutil" "go.uber.org/mock/gomock" ) @@ -146,16 +147,8 @@ func TestConfigListRegexp(t *testing.T) { exec: execer, }) - iter, err := cfg.ListRegexp(context.Background(), ".") + got, err := sliceutil.CollectErr(cfg.ListRegexp(context.Background(), ".")) require.NoError(t, err) - - var got []ConfigEntry - iter(func(entry ConfigEntry, err error) bool { - require.NoError(t, err) - got = append(got, entry) - return true - }) - assert.Equal(t, tt.want, got) }) } @@ -251,15 +244,8 @@ func TestIntegrationConfigListRegexp(t *testing.T) { Log: log, }) - var got []ConfigEntry - iter, err := cfg.ListRegexp(ctx, tt.pattern) + got, err := sliceutil.CollectErr(cfg.ListRegexp(ctx, tt.pattern)) require.NoError(t, err) - iter(func(entry ConfigEntry, err error) bool { - require.NoError(t, err) - got = append(got, entry) - return true - }) - assert.ElementsMatch(t, tt.want, got) }) } diff --git a/internal/sliceutil/collect.go b/internal/sliceutil/collect.go new file mode 100644 index 00000000..dfa7b517 --- /dev/null +++ b/internal/sliceutil/collect.go @@ -0,0 +1,18 @@ +// Package sliceutil contains utility functions for working with slices. +// It's an extension of the std slices package. +package sliceutil + +import "iter" + +// CollectErr collects items from a sequence of items and errors, +// stopping at the first error and returning it. +func CollectErr[T any](ents iter.Seq2[T, error]) ([]T, error) { + var items []T + for item, err := range ents { + if err != nil { + return nil, err + } + items = append(items, item) + } + return items, nil +} diff --git a/internal/sliceutil/collect_test.go b/internal/sliceutil/collect_test.go new file mode 100644 index 00000000..5489484c --- /dev/null +++ b/internal/sliceutil/collect_test.go @@ -0,0 +1,68 @@ +package sliceutil_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.abhg.dev/gs/internal/sliceutil" +) + +func TestCollectErr(t *testing.T) { + type pair struct { + val int + err error + } + + tests := []struct { + name string + give []pair + + want []int + wantErr error + }{ + { + name: "Empty", + give: nil, + want: nil, + }, + { + name: "NoErrors", + give: []pair{ + {val: 1}, + {val: 2}, + {val: 3}, + }, + want: []int{1, 2, 3}, + }, + { + name: "Error", + give: []pair{ + {val: 1}, + {err: assert.AnError}, + {val: 3}, + }, + wantErr: assert.AnError, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := sliceutil.CollectErr(func(yield func(int, error) bool) { + for _, p := range tt.give { + if !yield(p.val, p.err) { + break + } + } + }) + + if tt.wantErr != nil { + require.Error(t, err) + assert.ErrorIs(t, err, tt.wantErr) + } else { + require.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} diff --git a/internal/spice/config.go b/internal/spice/config.go index 574bd010..3182c097 100644 --- a/internal/spice/config.go +++ b/internal/spice/config.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "iter" "sort" "github.com/alecthomas/kong" @@ -21,10 +22,7 @@ const ( // GitConfigLister provides access to git-config output. type GitConfigLister interface { - ListRegexp(context.Context, string) ( - func(yield func(git.ConfigEntry, error) bool), - error, - ) + ListRegexp(context.Context, string) iter.Seq2[git.ConfigEntry, error] } var _ GitConfigLister = (*git.Config)(nil) @@ -79,19 +77,12 @@ func LoadConfig(ctx context.Context, cfg GitConfigLister, opts ConfigOptions) (* opts.Log = log.New(io.Discard) } - entries, err := cfg.ListRegexp(ctx, `^`+_configSection+`\.`) - if err != nil { - return nil, fmt.Errorf("list configuration: %w", err) - } - items := make(map[git.ConfigKey][]string) shorthands := make(map[string][]string) - err = nil // TODO: use a range loop after Go 1.23 - entries(func(entry git.ConfigEntry, iterErr error) bool { - if iterErr != nil { - err = iterErr - return false + for entry, err := range cfg.ListRegexp(ctx, `^`+_configSection+`\.`) { + if err != nil { + return nil, fmt.Errorf("list configuration: %w", err) } key := entry.Key.Canonical() @@ -100,7 +91,7 @@ func LoadConfig(ctx context.Context, cfg GitConfigLister, opts ConfigOptions) (* // Ignore keys that are not in the spice namespace. // This will never happen if git config --get-regexp // behaves correctly, but it's easy to handle. - return true + continue } // Special-case: Everything under "spice.shorthand.*" @@ -114,18 +105,14 @@ func LoadConfig(ctx context.Context, cfg GitConfigLister, opts ConfigOptions) (* "value", entry.Value, "error", err, ) - return true + continue } shorthands[short] = longform - return true + continue } items[key] = append(items[key], entry.Value) - return true - }) - if err != nil { - return nil, fmt.Errorf("read configuration: %w", err) } return &Config{ diff --git a/internal/spice/stack_edit.go b/internal/spice/stack_edit.go index 901ff9bf..9bc52ee5 100644 --- a/internal/spice/stack_edit.go +++ b/internal/spice/stack_edit.go @@ -44,7 +44,7 @@ func (s *Service) StackEdit(ctx context.Context, req *StackEditRequest) (*StackE must.NotContainf(req.Stack, s.store.Trunk(), "cannot edit trunk") must.NotBeBlankf(req.Editor, "editor is required") - branches, err := editStackFile(ctx, req.Editor, req.Stack) + branches, err := editStackFile(req.Editor, req.Stack) if err != nil { return nil, err } @@ -71,7 +71,7 @@ func (s *Service) StackEdit(ctx context.Context, req *StackEditRequest) (*StackE // The response list will be in the same order as the input list. // // Returns ErrStackEditAborted if the user aborts the edit operation. -func editStackFile(ctx context.Context, editor string, branches []string) ([]string, error) { +func editStackFile(editor string, branches []string) ([]string, error) { originals := make(map[string]struct{}, len(branches)) for _, branch := range branches { originals[branch] = struct{}{}