From 8d2aaa458971cba97c3bfec1b0380322e024b514 Mon Sep 17 00:00:00 2001 From: Tom Proctor Date: Mon, 12 Feb 2024 19:19:55 +0000 Subject: [PATCH] Add test for stdout scanner race with runner.Wait() (#300) --- client_test.go | 88 +++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 84 insertions(+), 4 deletions(-) diff --git a/client_test.go b/client_test.go index c51a371..a94fc8a 100644 --- a/client_test.go +++ b/client_test.go @@ -28,10 +28,12 @@ import ( func TestClient(t *testing.T) { process := helperProcess("mock") + logger := &trackingLogger{Logger: hclog.Default()} c := NewClient(&ClientConfig{ Cmd: process, HandshakeConfig: testHandshake, Plugins: testPluginMap, + Logger: logger, }) defer c.Kill() @@ -61,6 +63,9 @@ func TestClient(t *testing.T) { if !c.killed() { t.Fatal("Client should have failed") } + + // One error for connection refused, one for plugin exited. + assertLines(t, logger.errorLogs, 2) } // This tests a bug where Kill would start @@ -112,19 +117,19 @@ func TestClient_killStart(t *testing.T) { } func TestClient_testCleanup(t *testing.T) { - // Create a temporary dir to store the result file - td := t.TempDir() - defer os.RemoveAll(td) + t.Parallel() // Create a path that the helper process will write on cleanup - path := filepath.Join(td, "output") + path := filepath.Join(t.TempDir(), "output") // Test the cleanup process := helperProcess("cleanup", path) + logger := &trackingLogger{Logger: hclog.Default()} c := NewClient(&ClientConfig{ Cmd: process, HandshakeConfig: testHandshake, Plugins: testPluginMap, + Logger: logger, }) // Grab the client so the process starts @@ -140,6 +145,61 @@ func TestClient_testCleanup(t *testing.T) { if _, err := os.Stat(path); err != nil { t.Fatalf("err: %s", err) } + + assertLines(t, logger.errorLogs, 0) +} + +func TestClient_noStdoutScannerRace(t *testing.T) { + t.Parallel() + + process := helperProcess("test-grpc") + logger := &trackingLogger{Logger: hclog.Default()} + c := NewClient(&ClientConfig{ + RunnerFunc: func(l hclog.Logger, cmd *exec.Cmd, tmpDir string) (runner.Runner, error) { + process.Env = append(process.Env, cmd.Env...) + concreteRunner, err := cmdrunner.NewCmdRunner(l, process) + if err != nil { + return nil, err + } + // Inject a delay before calling .Read() method on the command's + // stdout reader. This ensures that if there is a race between the + // stdout scanner loop reading stdout and runner.Wait() closing + // stdout, .Wait() will win and trigger a scanner error in the logs. + return &delayedStdoutCmdRunner{concreteRunner}, nil + }, + HandshakeConfig: testHandshake, + Plugins: testGRPCPluginMap, + AllowedProtocols: []Protocol{ProtocolGRPC}, + Logger: logger, + }) + + // Grab the client so the process starts + if _, err := c.Client(); err != nil { + c.Kill() + t.Fatalf("err: %s", err) + } + + // Kill it gracefully + c.Kill() + + assertLines(t, logger.errorLogs, 0) +} + +type delayedStdoutCmdRunner struct { + *cmdrunner.CmdRunner +} + +func (m *delayedStdoutCmdRunner) Stdout() io.ReadCloser { + return &delayedReader{m.CmdRunner.Stdout()} +} + +type delayedReader struct { + io.ReadCloser +} + +func (d *delayedReader) Read(p []byte) (n int, err error) { + time.Sleep(100 * time.Millisecond) + return d.ReadCloser.Read(p) } func TestClient_testInterface(t *testing.T) { @@ -1563,3 +1623,23 @@ func TestClient_logStderrParseJSON(t *testing.T) { } } } + +type trackingLogger struct { + hclog.Logger + errorLogs []string +} + +func (l *trackingLogger) Error(msg string, args ...interface{}) { + l.errorLogs = append(l.errorLogs, fmt.Sprintf("%s: %v", msg, args)) + l.Logger.Error(msg, args...) +} + +func assertLines(t *testing.T, lines []string, expected int) { + t.Helper() + if len(lines) != expected { + t.Errorf("expected %d, got %d", expected, len(lines)) + for _, log := range lines { + t.Error(log) + } + } +}