diff --git a/cmd_regexp.go b/cmd_regexp.go index a9d5194..217d825 100644 --- a/cmd_regexp.go +++ b/cmd_regexp.go @@ -26,7 +26,7 @@ func (c *cmdRegexp) run(ctx context.Context, argv []string, outStream io.Writer, return fmt.Errorf("invalid index: %s", err) } - str, err := getOut(pkgs, detectTags(argv), total, idx) + str, err := getOut(pkgs, detectTags(argv), detectRace(argv), total, idx) if err != nil { return err } @@ -34,14 +34,14 @@ func (c *cmdRegexp) run(ctx context.Context, argv []string, outStream io.Writer, return err } -func getOut(pkgs []string, tags string, total, idx int) (string, error) { +func getOut(pkgs []string, tags string, withRace bool, total, idx int) (string, error) { if total < 1 { return "", fmt.Errorf("invalid total: %d", total) } if idx >= total { return "", fmt.Errorf("index shoud be between 0 to total-1, but: %d (total:%d)", idx, total) } - testLists, err := getTestListsFromPkgs(pkgs, tags) + testLists, err := getTestListsFromPkgs(pkgs, tags, withRace) if err != nil { return "", err } diff --git a/gotesplit.go b/gotesplit.go index 828910e..f9ad530 100644 --- a/gotesplit.go +++ b/gotesplit.go @@ -60,11 +60,16 @@ Options: return run(ctx, *total, *index, *junitDir, argv, outStream, errStream) } -func getTestListsFromPkgs(pkgs []string, tags string) ([]testList, error) { +func getTestListsFromPkgs(pkgs []string, tags string, withRace bool) ([]testList, error) { args := []string{"test", "-list", "."} if tags != "" { args = append(args, tags) } + if withRace { + // If -race is specified for test options, add -race to list + // to prevent compilation from being executed twice. + args = append(args, "-race") + } args = append(args, pkgs...) buf := &bytes.Buffer{} c := exec.Command("go", args...) @@ -95,6 +100,16 @@ func detectTags(argv []string) string { return "" } +func detectRace(argv []string) bool { + l := len(argv) + for i := 0; i < l; i++ { + if argv[i] == "-race" || argv[i] == "--race" { + return true + } + } + return false +} + type testList struct { pkg string list []string diff --git a/gotesplit_test.go b/gotesplit_test.go index 890c7a8..58267fd 100644 --- a/gotesplit_test.go +++ b/gotesplit_test.go @@ -125,6 +125,29 @@ func TestDetectTags(t *testing.T) { } } +func TestDetectRace(t *testing.T) { + testCases := []struct { + input []string + expect bool + desc string + }{ + {[]string{"-race"}, true, "-race only"}, + {[]string{"-tags", "aaa", "-race", "-bench"}, true, "-race with other flags"}, + {[]string{"--race", "-p", "1"}, true, "--race with other flags"}, + {[]string{}, false, "no flags"}, + {[]string{"-short", "-p", "1"}, false, "flags without -race"}, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + out := detectRace(tc.input) + if out != tc.expect { + t.Errorf("got: %t, expect: %t", out, tc.expect) + } + }) + } +} + func TestGetTestListFromPkgs(t *testing.T) { if err := os.Chdir("testdata/withtags"); err != nil { wd, _ := os.Getwd() @@ -139,7 +162,7 @@ func TestGetTestListFromPkgs(t *testing.T) { }, }} - got, err := getTestListsFromPkgs([]string{"."}, "-tags=a") + got, err := getTestListsFromPkgs([]string{"."}, "-tags=a", false) if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/run.go b/run.go index 6f8fc16..dec9614 100644 --- a/run.go +++ b/run.go @@ -53,7 +53,7 @@ func run(ctx context.Context, total, idx uint, junitDir string, argv []string, o } } - testLists, err := getTestListsFromPkgs(pkgs, detectTags(testOpts)) + testLists, err := getTestListsFromPkgs(pkgs, detectTags(testOpts), detectRace(testOpts)) if err != nil { return err }