Skip to content

Commit

Permalink
command,runner: limit environ size when starting command (#654)
Browse files Browse the repository at this point in the history
This change ensures that the size of individual environment variables,
as well as all environment variables (environ) passed to a command, are
within limits.

Relates to #648
  • Loading branch information
adambabik authored Sep 9, 2024
1 parent e781f00 commit 05dbf31
Show file tree
Hide file tree
Showing 22 changed files with 604 additions and 154 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ lint:
-exclude integration/subject/... \
./...
@staticcheck ./...
@gosec -quiet -exclude=G204,G304,G404 -exclude-generated ./...
@gosec -quiet -exclude=G110,G204,G304,G404 -exclude-generated ./...

.PHONY: pre-commit
pre-commit: build wasm test lint
Expand Down
7 changes: 4 additions & 3 deletions internal/cmd/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ import (
"bytes"
"context"
"fmt"
"github.com/pkg/errors"
"github.com/spf13/cobra"
"go.uber.org/zap"
"os"
"path/filepath"
"strconv"
"strings"

"github.com/pkg/errors"
"github.com/spf13/cobra"
"go.uber.org/zap"

"github.com/stateful/runme/v3/internal/runner/client"
"github.com/stateful/runme/v3/internal/tui"
"github.com/stateful/runme/v3/internal/tui/prompt"
Expand Down
48 changes: 46 additions & 2 deletions internal/command/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strings"

"github.com/pkg/errors"
"go.uber.org/zap"

"github.com/stateful/runme/v3/pkg/project"
)
Expand Down Expand Up @@ -37,8 +38,9 @@ type internalCommand interface {

type base struct {
cfg *ProgramConfig
logger *zap.Logger
project *project.Project
runtime runtime
runtime Runtime
session *Session
stdin io.Reader
stdout io.Writer
Expand Down Expand Up @@ -75,12 +77,20 @@ func (c *base) Env() []string {
env := c.runtime.Environ()

if c.project != nil {
projEnv, _ := c.project.LoadEnv()
projEnv, err := c.project.LoadEnv()
if err != nil {
c.logger.Warn("failed to load project env", zap.Error(err))
}
env = append(env, projEnv...)
}

env = append(env, c.session.GetAllEnv()...)
env = append(env, c.cfg.Env...)

if err := c.limitEnviron(env); err != nil {
c.logger.Error("environment size exceeds the limit", zap.Error(err))
}

return env
}

Expand All @@ -104,6 +114,40 @@ func (c *base) ProgramPath() (string, []string, error) {
return c.findDefaultProgram(c.cfg.ProgramName, c.cfg.Arguments)
}

func (c *base) limitEnviron(environ []string) error {
const stdoutPrefix = StoreStdoutEnvName + "="

stdoutEnvIdx := -1
size := 0
for idx, e := range environ {
size += len(e)

if strings.HasPrefix(e, stdoutPrefix) {
stdoutEnvIdx = idx
}
}

if size <= MaxEnvironSizeInBytes {
return nil
}

c.logger.Warn("environment size exceeds the limit", zap.Int("size", size), zap.Int("limit", MaxEnvironSizeInBytes))

if stdoutEnvIdx == -1 {
return errors.New("env is too large; no stdout env to trim")
}

stdoutCap := MaxEnvironSizeInBytes - size + len(environ[stdoutEnvIdx])
if stdoutCap < 0 {
return errors.New("env is too large even if trimming stdout env")
}

key, value := splitEnv(environ[stdoutEnvIdx])
environ[stdoutEnvIdx] = CreateEnv(key, value[len(value)-stdoutCap:])

return nil
}

func (c *base) getEnv(key string) string {
env := c.Env()
for i := len(env) - 1; i >= 0; i-- {
Expand Down
122 changes: 122 additions & 0 deletions internal/command/command_inline_shell_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,21 @@
package command

import (
"bytes"
"context"
"math"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"

"github.com/stateful/runme/v3/internal/command/testdata"
runnerv2 "github.com/stateful/runme/v3/pkg/api/gen/proto/go/runme/runner/v2"
)

Expand Down Expand Up @@ -69,6 +75,122 @@ func TestInlineShellCommand_CollectEnv(t *testing.T) {
})
}

func TestInlineShellCommand_MaxEnvSize(t *testing.T) {
t.Parallel()

sess := NewSession()

envName := "TEST"
envValue := strings.Repeat("a", MaxEnvSizeInBytes-len(envName)-1) // -1 for the "=" sign
err := sess.SetEnv(createEnv(envName, envValue))
require.NoError(t, err)

factory := NewFactory(
WithLogger(zaptest.NewLogger(t)),
WithRuntime(&hostRuntime{}), // stub runtime and do not include environ
)
cfg := &ProgramConfig{
ProgramName: "bash",
Source: &runnerv2.ProgramConfig_Commands{
Commands: &runnerv2.ProgramConfig_CommandList{
Items: []string{
"echo -n $" + envName,
},
},
},
Mode: runnerv2.CommandMode_COMMAND_MODE_FILE,
}
stdout := bytes.NewBuffer(nil)
command, err := factory.Build(cfg, CommandOptions{Session: sess, Stdout: stdout})
require.NoError(t, err)

err = command.Start(context.Background())
require.NoError(t, err)
err = command.Wait()
require.NoError(t, err)

assert.Equal(t, envValue, stdout.String())
}

func TestInlineShellCommand_MaxEnvironSizeInBytes(t *testing.T) {
t.Parallel()

sess := NewSession()

// Set multiple environment variables of [MaxEnvSizeInBytes] length.
// [StoreStdoutEnvName] is also set but it exceeds [MaxEnvironSizeInBytes],
// however, it's allowed to be trimmed so it should not cause an error.
envCount := math.Ceil(float64(MaxEnvironSizeInBytes) / float64(MaxEnvSizeInBytes))
envValue := strings.Repeat("a", MaxEnvSizeInBytes-1) // -1 for the equal sign
for i := 0; i < int(envCount); i++ {
name := "TEST" + strconv.Itoa(i)
value := envValue[:len(envValue)-len(name)]
err := sess.SetEnv(createEnv(name, value))
require.NoError(t, err)
}
err := sess.SetEnv(createEnv(StoreStdoutEnvName, envValue[:len(envValue)-len(StoreStdoutEnvName)]))
require.NoError(t, err)

factory := NewFactory(
WithLogger(zaptest.NewLogger(t)),
WithRuntime(&hostRuntime{}), // stub runtime and do not include environ
)
cfg := &ProgramConfig{
ProgramName: "bash",
Source: &runnerv2.ProgramConfig_Commands{
Commands: &runnerv2.ProgramConfig_CommandList{
Items: []string{
"echo -n $" + StoreStdoutEnvName,
},
},
},
Mode: runnerv2.CommandMode_COMMAND_MODE_FILE,
}
command, err := factory.Build(cfg, CommandOptions{Session: sess})
require.NoError(t, err)

err = command.Start(context.Background())
require.NoError(t, err)
err = command.Wait()
require.NoError(t, err)
}

func TestInlineShellCommand_LargeOutput(t *testing.T) {
t.Parallel()

temp := t.TempDir()
fileName := filepath.Join(temp, "large_output.json")
_, err := testdata.UngzipToFile(testdata.Users1MGzip, fileName)
require.NoError(t, err)

factory := NewFactory(WithLogger(zaptest.NewLogger(t)))
cfg := &ProgramConfig{
ProgramName: "bash",
Source: &runnerv2.ProgramConfig_Commands{
Commands: &runnerv2.ProgramConfig_CommandList{
Items: []string{
"cat " + fileName,
},
},
},
Mode: runnerv2.CommandMode_COMMAND_MODE_INLINE,
}
sess := NewSession()
stdout := bytes.NewBuffer(nil)
command, err := factory.Build(cfg, CommandOptions{Session: sess, Stdout: stdout})
require.NoError(t, err)

err = command.Start(context.Background())
require.NoError(t, err)
err = command.Wait()
require.NoError(t, err)

expected, err := os.ReadFile(fileName)
require.NoError(t, err)
got := stdout.Bytes()
assert.EqualValues(t, expected, got)
}

func testInlineShellCommandCollectEnv(t *testing.T) {
t.Helper()

Expand Down
7 changes: 4 additions & 3 deletions internal/command/command_virtual.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,12 @@ func isNil(val any) bool {

v := reflect.ValueOf(val)

if v.Type().Kind() == reflect.Struct {
switch v.Type().Kind() {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Pointer, reflect.UnsafePointer:
return v.IsNil()
default:
return false
}

return reflect.ValueOf(val).IsNil()
}

// readCloser wraps [io.Reader] into [io.ReadCloser].
Expand Down
10 changes: 10 additions & 0 deletions internal/command/env_shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@ import (
"io"
)

const StoreStdoutEnvName = "__"

func CreateEnv(key, value string) string {
return createEnv(key, value)
}

func createEnv(key, value string) string {
return key + "=" + value
}

func setOnShell(shell io.Writer, prePath, postPath string) error {
var err error

Expand Down
31 changes: 11 additions & 20 deletions internal/command/env_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@ import (
"strings"
"sync"

"github.com/pkg/errors"
"golang.org/x/exp/slices"
)

// func envPairSize(k, v string) int {
// // +2 for the '=' and '\0' separators
// return len(k) + len(v) + 2
// }
const (
MaxEnvSizeInBytes = 128*1024 - 128
MaxEnvironSizeInBytes = MaxEnvSizeInBytes * 8
)

var ErrEnvTooLarge = errors.New("env too large")

type envStore struct {
mu sync.RWMutex
Expand Down Expand Up @@ -38,24 +41,12 @@ func (s *envStore) Get(k string) (string, bool) {
}

func (s *envStore) Set(k, v string) (*envStore, error) {
if len(k)+len(v) > MaxEnvSizeInBytes {
return s, ErrEnvTooLarge
}
s.mu.Lock()
defer s.mu.Unlock()

// environSize := envPairSize(k, v)

// for key, value := range s.items {
// if key == k {
// continue
// }
// environSize += envPairSize(key, value)
// }

// if environSize > MaxEnvironSizeInBytes {
// return s, errors.New("could not set environment variable, environment size limit exceeded")
// }

s.items[k] = v

s.mu.Unlock()
return s, nil
}

Expand Down
9 changes: 0 additions & 9 deletions internal/command/env_store_unix.go

This file was deleted.

7 changes: 0 additions & 7 deletions internal/command/env_store_windows.go

This file was deleted.

17 changes: 13 additions & 4 deletions internal/command/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ func WithProject(proj *project.Project) FactoryOption {
}
}

func WithRuntime(r Runtime) FactoryOption {
return func(f *commandFactory) {
f.runtime = r
}
}

func NewFactory(opts ...FactoryOption) Factory {
f := &commandFactory{}
for _, opt := range opts {
Expand All @@ -80,6 +86,7 @@ type commandFactory struct {
docker *dockerexec.Docker
logger *zap.Logger
project *project.Project
runtime Runtime
}

// Build creates a new command based on the provided [ProgramConfig] and [CommandOptions].
Expand Down Expand Up @@ -185,15 +192,17 @@ func (f *commandFactory) Build(cfg *ProgramConfig, opts CommandOptions) (Command
}

func (f *commandFactory) buildBase(cfg *ProgramConfig, opts CommandOptions) *base {
var runtime runtime
if f.docker != nil {
runtime := f.runtime

if isNil(runtime) && f.docker != nil {
runtime = &dockerRuntime{Docker: f.docker}
} else {
runtime = &hostRuntime{}
} else if isNil(runtime) {
runtime = &hostRuntime{useSystem: true}
}

return &base{
cfg: cfg,
logger: f.getLogger("Base"),
project: f.project,
runtime: runtime,
session: opts.Session,
Expand Down
Loading

0 comments on commit 05dbf31

Please sign in to comment.