diff --git a/cmd_regexp.go b/cmd_regexp.go index 5e73e4d..a9d5194 100644 --- a/cmd_regexp.go +++ b/cmd_regexp.go @@ -12,6 +12,7 @@ import ( type cmdRegexp struct{} func (c *cmdRegexp) run(ctx context.Context, argv []string, outStream io.Writer, errStream io.Writer) error { + // FIXME: if len(argv) < 3 { return fmt.Errorf("not enough arguments") } @@ -25,7 +26,7 @@ func (c *cmdRegexp) run(ctx context.Context, argv []string, outStream io.Writer, return fmt.Errorf("invalid index: %s", err) } - str, err := getOut(pkgs, total, idx) + str, err := getOut(pkgs, detectTags(argv), total, idx) if err != nil { return err } @@ -33,14 +34,14 @@ func (c *cmdRegexp) run(ctx context.Context, argv []string, outStream io.Writer, return err } -func getOut(pkgs []string, total, idx int) (string, error) { +func getOut(pkgs []string, tags string, total, idx int) (string, error) { if total < 1 { return "", fmt.Errorf("invalid total: %d", total) } if idx >= total { return "", fmt.Errorf("index shoud be between 0 to total-1, but: %d (total:%d)", idx, total) } - testLists, err := getTestListsFromPkgs(pkgs) + testLists, err := getTestListsFromPkgs(pkgs, tags) if err != nil { return "", err } diff --git a/gotesplit.go b/gotesplit.go index b38efdb..766224b 100644 --- a/gotesplit.go +++ b/gotesplit.go @@ -9,6 +9,7 @@ import ( "log" "os" "os/exec" + "regexp" "sort" "strings" ) @@ -59,8 +60,12 @@ Options: return run(ctx, *total, *index, *junitDir, argv, outStream, errStream) } -func getTestListsFromPkgs(pkgs []string) ([]testList, error) { - args := append([]string{"test", "-list", "."}, pkgs...) +func getTestListsFromPkgs(pkgs []string, tags string) ([]testList, error) { + args := []string{"test", "-list"} + if tags != "" { + args = append(args, tags) + } + args = append(append(args, "."), pkgs...) buf := &bytes.Buffer{} c := exec.Command("go", args...) c.Stdout = buf @@ -72,6 +77,24 @@ func getTestListsFromPkgs(pkgs []string) ([]testList, error) { return getTestLists(buf.String()), nil } +var tagsReg = regexp.MustCompile(`^--?tags(=.*)?$`) + +func detectTags(argv []string) string { + l := len(argv) + for i := 0; i < l; i++ { + tags := argv[i] + m := tagsReg.FindStringSubmatch(tags) + if len(m) < 2 { + continue + } + if m[1] == "" && i+1 < l { + tags += "=" + argv[i+1] + } + return tags + } + return "" +} + type testList struct { pkg string list []string diff --git a/gotesplit_test.go b/gotesplit_test.go index d07effc..2524f07 100644 --- a/gotesplit_test.go +++ b/gotesplit_test.go @@ -102,3 +102,24 @@ ok github.com/x-motemen/ghq/logger 0.106s` t.Errorf("expect: %#v\ngot: %#v", expect, got) } } + +func TestDetectTags(t *testing.T) { + testCases := []struct { + input []string + expect string + }{ + {[]string{"aa", "bb"}, ""}, + {[]string{"aa", "-tags", "bb"}, "-tags=bb"}, + {[]string{"aa", "--tags=ccc", "bb"}, "--tags=ccc"}, + {[]string{"aa", "-tags"}, "-tags"}, + } + + for _, tc := range testCases { + t.Run(tc.expect, func(t *testing.T) { + out := detectTags(tc.input) + if out != tc.expect { + t.Errorf("got: %s, expect: %s", out, tc.expect) + } + }) + } +} diff --git a/run.go b/run.go index 70814c2..98aed47 100644 --- a/run.go +++ b/run.go @@ -51,7 +51,7 @@ func run(ctx context.Context, total, idx uint, junitDir string, argv []string, o } } - testLists, err := getTestListsFromPkgs(pkgs) + testLists, err := getTestListsFromPkgs(pkgs, detectTags(argv)) if err != nil { return err }