Skip to content

Commit

Permalink
shell: convert the Each method into a range function
Browse files Browse the repository at this point in the history
- Move (most of) the tests into a separate test package.
- Update the examples.
  • Loading branch information
creachadair committed Sep 12, 2024
1 parent 6a3c5af commit 80cfc50
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 57 deletions.
42 changes: 42 additions & 0 deletions shell/internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) 2015, Michael J. Fromberger

package shell

import (
"fmt"
"testing"
)

func TestQuote(t *testing.T) {
type testCase struct{ in, want string }
tests := []testCase{
{"", "''"}, // empty is special
{"abc", "abc"}, // nothing to quote
{"--flag", "--flag"}, // "
{"'abc", `\'abc`}, // single quote only
{"abc'", `abc\'`}, // "
{`shan't`, `shan\'t`}, // "
{"--flag=value", `'--flag=value'`},
{"a b\tc", "'a b\tc'"},
{`a"b"c`, `'a"b"c'`},
{`'''`, `\'\'\'`},
{`\`, `'\'`},
{`'a=b`, `\''a=b'`}, // quotes and other stuff
{`a='b`, `'a='\''b'`}, // "
{`a=b'`, `'a=b'\'`}, // "
}
// Verify that all the designated special characters get quoted.
for _, c := range shouldQuote + mustQuote {
tests = append(tests, testCase{
in: string(c),
want: fmt.Sprintf(`'%c'`, c),
})
}

for _, test := range tests {
got := Quote(test.in)
if got != test.want {
t.Errorf("Quote %q: got %q, want %q", test.in, got, test.want)
}
}
}
13 changes: 5 additions & 8 deletions shell/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,18 +231,15 @@ func (s *Scanner) Rest() io.Reader {
return s.buf
}

// Each calls f for each token in the scanner until the input is exhausted, f
// returns false, or an error occurs.
func (s *Scanner) Each(f func(tok string) bool) error {
// Each is a range function that calls f for each token in the scanner. It
// continues until the input is exhausted, f returns false, or an error other
// than [io.EOF] occurs. Use the Err method to check for an error.
func (s *Scanner) Each(f func(tok string) bool) {
for s.Next() {
if !f(s.Text()) {
return nil
return
}
}
if err := s.Err(); err != io.EOF {
return err
}
return nil
}

// Split returns the remaining tokens in s, not including the current token if
Expand Down
69 changes: 20 additions & 49 deletions shell/shell_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) 2015, Michael J. Fromberger

package shell
package shell_test

import (
"fmt"
Expand All @@ -9,43 +9,10 @@ import (
"strings"
"testing"

"github.com/creachadair/mds/shell"
"github.com/google/go-cmp/cmp"
)

func TestQuote(t *testing.T) {
type testCase struct{ in, want string }
tests := []testCase{
{"", "''"}, // empty is special
{"abc", "abc"}, // nothing to quote
{"--flag", "--flag"}, // "
{"'abc", `\'abc`}, // single quote only
{"abc'", `abc\'`}, // "
{`shan't`, `shan\'t`}, // "
{"--flag=value", `'--flag=value'`},
{"a b\tc", "'a b\tc'"},
{`a"b"c`, `'a"b"c'`},
{`'''`, `\'\'\'`},
{`\`, `'\'`},
{`'a=b`, `\''a=b'`}, // quotes and other stuff
{`a='b`, `'a='\''b'`}, // "
{`a=b'`, `'a=b'\'`}, // "
}
// Verify that all the designated special characters get quoted.
for _, c := range shouldQuote + mustQuote {
tests = append(tests, testCase{
in: string(c),
want: fmt.Sprintf(`'%c'`, c),
})
}

for _, test := range tests {
got := Quote(test.in)
if got != test.want {
t.Errorf("Quote %q: got %q, want %q", test.in, got, test.want)
}
}
}

func TestSplit(t *testing.T) {
tests := []struct {
in string
Expand Down Expand Up @@ -98,7 +65,7 @@ func TestSplit(t *testing.T) {
{`a "b \"`, []string{"a", `b "`}, false},
}
for _, test := range tests {
got, ok := Split(test.in)
got, ok := shell.Split(test.in)
if ok != test.ok {
t.Errorf("Split %#q: got valid=%v, want %v", test.in, ok, test.ok)
}
Expand All @@ -125,7 +92,7 @@ func TestScannerSplit(t *testing.T) {
for _, test := range tests {
t.Logf("Scanner split input: %q", test.in)

s := NewScanner(strings.NewReader(test.in))
s := shell.NewScanner(strings.NewReader(test.in))
var got, rest []string
for s.Next() {
if s.Text() == "--" {
Expand Down Expand Up @@ -163,9 +130,9 @@ func TestRoundTrip(t *testing.T) {
{"cat", "a${b}.txt", "|", "tee", "capture", "2>", "/dev/null"},
}
for _, test := range tests {
s := Join(test)
s := shell.Join(test)
t.Logf("Join %#q = %v", test, s)
got, ok := Split(s)
got, ok := shell.Split(s)
if !ok {
t.Errorf("Split %+q: should be valid, but is not", s)
}
Expand All @@ -177,21 +144,21 @@ func TestRoundTrip(t *testing.T) {

func ExampleScanner() {
const input = `a "free range" exploration of soi\ disant novelties`
s := NewScanner(strings.NewReader(input))
s := shell.NewScanner(strings.NewReader(input))
sum, count := 0, 0
for s.Next() {
for tok := range s.Each {
count++
sum += len(s.Text())
sum += len(tok)
}
fmt.Println(len(input), count, sum, s.Complete(), s.Err())
// Output: 51 6 43 true EOF
}

func ExampleScanner_Rest() {
const input = `things 'and stuff' %end% all the remaining stuff`
s := NewScanner(strings.NewReader(input))
for s.Next() {
if s.Text() == "%end%" {
s := shell.NewScanner(strings.NewReader(input))
for tok := range s.Each {
if tok == "%end%" {
fmt.Print("found marker; ")
break
}
Expand All @@ -206,10 +173,14 @@ func ExampleScanner_Rest() {

func ExampleScanner_Each() {
const input = `a\ b 'c d' "e f's g" stop "go directly to jail"`
if err := NewScanner(strings.NewReader(input)).Each(func(tok string) bool {
s := shell.NewScanner(strings.NewReader(input))
for tok := range s.Each {
fmt.Println(tok)
return tok != "stop"
}); err != nil {
if tok == "stop" {
break
}
}
if err := s.Err(); err != nil {
log.Fatal(err)
}
// Output:
Expand All @@ -222,7 +193,7 @@ func ExampleScanner_Each() {
func ExampleScanner_Split() {
const input = `cmd -flag=t -- foo bar baz`

s := NewScanner(strings.NewReader(input))
s := shell.NewScanner(strings.NewReader(input))
for s.Next() {
if s.Text() == "--" {
fmt.Println("** Args:", strings.Join(s.Split(), ", "))
Expand Down

0 comments on commit 80cfc50

Please sign in to comment.