Skip to content

Commit

Permalink
refac(git/config): Use an iterator for ListRegexp (#382)
Browse files Browse the repository at this point in the history
TODO from when ListRegexp was originally implemented.
Use an iter.Seq now that we're using Go 1.23 to build.
  • Loading branch information
abhinav authored Sep 5, 2024
1 parent 6d8d795 commit 706fc98
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 63 deletions.
44 changes: 21 additions & 23 deletions internal/git/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"fmt"
"io"
"iter"
"strings"

"github.com/charmbracelet/log"
Expand Down Expand Up @@ -133,10 +134,7 @@ type ConfigEntry struct {

// ListRegexp lists all configuration entries that match the given pattern.
// If pattern is empty, '.' is used to match all entries.
func (cfg *Config) ListRegexp(ctx context.Context, pattern string) (
func(yield func(ConfigEntry, error) bool),
error,
) {
func (cfg *Config) ListRegexp(ctx context.Context, pattern string) iter.Seq2[ConfigEntry, error] {
if pattern == "" {
pattern = "."
}
Expand All @@ -145,26 +143,25 @@ func (cfg *Config) ListRegexp(ctx context.Context, pattern string) (

var _newline = []byte("\n")

func (cfg *Config) list(ctx context.Context, args ...string) (
func(yield func(ConfigEntry, error) bool),
error,
) {
func (cfg *Config) list(ctx context.Context, args ...string) iter.Seq2[ConfigEntry, error] {
log := cfg.log
args = append([]string{"config", "--null"}, args...)
cmd := newGitCmd(ctx, cfg.log, args...).
Dir(cfg.dir).
AppendEnv(cfg.env...)

stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("stdout pipe: %w", err)
}
return func(yield func(ConfigEntry, error) bool) {
cmd := newGitCmd(ctx, cfg.log, args...).
Dir(cfg.dir).
AppendEnv(cfg.env...)

stdout, err := cmd.StdoutPipe()
if err != nil {
yield(ConfigEntry{}, fmt.Errorf("stdout pipe: %w", err))
return
}

if err := cmd.Start(cfg.exec); err != nil {
return nil, fmt.Errorf("start git-config: %w", err)
}
if err := cmd.Start(cfg.exec); err != nil {
yield(ConfigEntry{}, fmt.Errorf("start git-config: %w", err))
return
}

log := cfg.log
return func(yield func(ConfigEntry, error) bool) {
// Always wait for the command to finish when this returns.
// Ignore the error because git-config fails if there are no matches.
// It's not an error for us if there are no matches.
Expand Down Expand Up @@ -195,9 +192,10 @@ func (cfg *Config) list(ctx context.Context, args ...string) (
}

if err := scan.Err(); err != nil {
_ = yield(ConfigEntry{}, fmt.Errorf("scan git-config output: %w", err))
yield(ConfigEntry{}, fmt.Errorf("scan git-config output: %w", err))
return
}
}, nil
}
}

// scanNullDelimited is a bufio.SplitFunc that splits on null bytes.
Expand Down
20 changes: 3 additions & 17 deletions internal/git/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.abhg.dev/gs/internal/logtest"
"go.abhg.dev/gs/internal/sliceutil"
"go.uber.org/mock/gomock"
)

Expand Down Expand Up @@ -146,16 +147,8 @@ func TestConfigListRegexp(t *testing.T) {
exec: execer,
})

iter, err := cfg.ListRegexp(context.Background(), ".")
got, err := sliceutil.CollectErr(cfg.ListRegexp(context.Background(), "."))
require.NoError(t, err)

var got []ConfigEntry
iter(func(entry ConfigEntry, err error) bool {
require.NoError(t, err)
got = append(got, entry)
return true
})

assert.Equal(t, tt.want, got)
})
}
Expand Down Expand Up @@ -251,15 +244,8 @@ func TestIntegrationConfigListRegexp(t *testing.T) {
Log: log,
})

var got []ConfigEntry
iter, err := cfg.ListRegexp(ctx, tt.pattern)
got, err := sliceutil.CollectErr(cfg.ListRegexp(ctx, tt.pattern))
require.NoError(t, err)
iter(func(entry ConfigEntry, err error) bool {
require.NoError(t, err)
got = append(got, entry)
return true
})

assert.ElementsMatch(t, tt.want, got)
})
}
Expand Down
18 changes: 18 additions & 0 deletions internal/sliceutil/collect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Package sliceutil contains utility functions for working with slices.
// It's an extension of the std slices package.
package sliceutil

import "iter"

// CollectErr collects items from a sequence of items and errors,
// stopping at the first error and returning it.
func CollectErr[T any](ents iter.Seq2[T, error]) ([]T, error) {
var items []T
for item, err := range ents {
if err != nil {
return nil, err
}
items = append(items, item)
}
return items, nil
}
68 changes: 68 additions & 0 deletions internal/sliceutil/collect_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package sliceutil_test

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.abhg.dev/gs/internal/sliceutil"
)

func TestCollectErr(t *testing.T) {
type pair struct {
val int
err error
}

tests := []struct {
name string
give []pair

want []int
wantErr error
}{
{
name: "Empty",
give: nil,
want: nil,
},
{
name: "NoErrors",
give: []pair{
{val: 1},
{val: 2},
{val: 3},
},
want: []int{1, 2, 3},
},
{
name: "Error",
give: []pair{
{val: 1},
{err: assert.AnError},
{val: 3},
},
wantErr: assert.AnError,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := sliceutil.CollectErr(func(yield func(int, error) bool) {
for _, p := range tt.give {
if !yield(p.val, p.err) {
break
}
}
})

if tt.wantErr != nil {
require.Error(t, err)
assert.ErrorIs(t, err, tt.wantErr)
} else {
require.NoError(t, err)
assert.Equal(t, tt.want, got)
}
})
}
}
29 changes: 8 additions & 21 deletions internal/spice/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"iter"
"sort"

"github.com/alecthomas/kong"
Expand All @@ -21,10 +22,7 @@ const (

// GitConfigLister provides access to git-config output.
type GitConfigLister interface {
ListRegexp(context.Context, string) (
func(yield func(git.ConfigEntry, error) bool),
error,
)
ListRegexp(context.Context, string) iter.Seq2[git.ConfigEntry, error]
}

var _ GitConfigLister = (*git.Config)(nil)
Expand Down Expand Up @@ -79,19 +77,12 @@ func LoadConfig(ctx context.Context, cfg GitConfigLister, opts ConfigOptions) (*
opts.Log = log.New(io.Discard)
}

entries, err := cfg.ListRegexp(ctx, `^`+_configSection+`\.`)
if err != nil {
return nil, fmt.Errorf("list configuration: %w", err)
}

items := make(map[git.ConfigKey][]string)
shorthands := make(map[string][]string)

err = nil // TODO: use a range loop after Go 1.23
entries(func(entry git.ConfigEntry, iterErr error) bool {
if iterErr != nil {
err = iterErr
return false
for entry, err := range cfg.ListRegexp(ctx, `^`+_configSection+`\.`) {
if err != nil {
return nil, fmt.Errorf("list configuration: %w", err)
}

key := entry.Key.Canonical()
Expand All @@ -100,7 +91,7 @@ func LoadConfig(ctx context.Context, cfg GitConfigLister, opts ConfigOptions) (*
// Ignore keys that are not in the spice namespace.
// This will never happen if git config --get-regexp
// behaves correctly, but it's easy to handle.
return true
continue
}

// Special-case: Everything under "spice.shorthand.*"
Expand All @@ -114,18 +105,14 @@ func LoadConfig(ctx context.Context, cfg GitConfigLister, opts ConfigOptions) (*
"value", entry.Value,
"error", err,
)
return true
continue
}

shorthands[short] = longform
return true
continue
}

items[key] = append(items[key], entry.Value)
return true
})
if err != nil {
return nil, fmt.Errorf("read configuration: %w", err)
}

return &Config{
Expand Down
4 changes: 2 additions & 2 deletions internal/spice/stack_edit.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (s *Service) StackEdit(ctx context.Context, req *StackEditRequest) (*StackE
must.NotContainf(req.Stack, s.store.Trunk(), "cannot edit trunk")
must.NotBeBlankf(req.Editor, "editor is required")

branches, err := editStackFile(ctx, req.Editor, req.Stack)
branches, err := editStackFile(req.Editor, req.Stack)
if err != nil {
return nil, err
}
Expand All @@ -71,7 +71,7 @@ func (s *Service) StackEdit(ctx context.Context, req *StackEditRequest) (*StackE
// The response list will be in the same order as the input list.
//
// Returns ErrStackEditAborted if the user aborts the edit operation.
func editStackFile(ctx context.Context, editor string, branches []string) ([]string, error) {
func editStackFile(editor string, branches []string) ([]string, error) {
originals := make(map[string]struct{}, len(branches))
for _, branch := range branches {
originals[branch] = struct{}{}
Expand Down

0 comments on commit 706fc98

Please sign in to comment.