diff --git a/internal/command/env_collector.go b/internal/command/env_collector.go index d54a04b8..8e6f6800 100644 --- a/internal/command/env_collector.go +++ b/internal/command/env_collector.go @@ -15,6 +15,8 @@ import ( const ( envStartFileName = ".env_start" envEndFileName = ".env_end" + + maxScannerBufferSizeInBytes = 1024 * 1024 * 1024 ) // EnvDumpCommand is a command that dumps the environment variables. @@ -133,7 +135,18 @@ func (c *shellEnvCollector) readEnvFromFile(name string) (result []string, _ err } defer func() { _ = f.Close() }() + fileInfo, err := f.Stat() + if err != nil { + return nil, errors.WithMessagef(err, "failed to get the file info of the env file %q", name) + } + + scannerBufferSizeInBytes := fileInfo.Size() + if scannerBufferSizeInBytes > maxScannerBufferSizeInBytes { + return nil, errors.Errorf("the env file %q is too big: %d bytes", name, scannerBufferSizeInBytes) + } + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 4096), int(scannerBufferSizeInBytes)) // 4096 is taken from bufio as the initial buffer size scanner.Split(splitNull) for scanner.Scan() { diff --git a/internal/project/projectservice/project_service_test.go b/internal/project/projectservice/project_service_test.go index 9a8edeb3..35107677 100644 --- a/internal/project/projectservice/project_service_test.go +++ b/internal/project/projectservice/project_service_test.go @@ -77,7 +77,7 @@ func TestProjectServiceServer_Load(t *testing.T) { }) } -func TestProjectServiceServer_Load_ErrorWhileSending(t *testing.T) { +func TestProjectServiceServer_Load_ClientConnClosed(t *testing.T) { t.Parallel() temp := t.TempDir() @@ -98,8 +98,10 @@ func TestProjectServiceServer_Load_ErrorWhileSending(t *testing.T) { loadClient, err := client.Load(context.Background(), req) require.NoError(t, err) - err = clientConn.Close() - require.NoError(t, err) + errc := make(chan error, 1) + go func() { + errc <- clientConn.Close() + }() for { _, err := loadClient.Recv() @@ -108,6 +110,8 @@ func TestProjectServiceServer_Load_ErrorWhileSending(t *testing.T) { break } } + + require.NoError(t, <-errc) } func collectLoadEventTypes(client projectv1.ProjectService_LoadClient) ([]projectv1.LoadEventType, error) {