Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

command: fix signalling to command process group #637

Merged
merged 2 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion experimental/runme.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,5 @@ server:

log:
enabled: true
path: "/var/log/runme.log"
path: "/tmp/runme.log"
verbose: true
4 changes: 4 additions & 0 deletions internal/cmd/beta/run_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/stateful/runme/v3/internal/command"
"github.com/stateful/runme/v3/internal/config/autoconfig"
runnerv2alpha1 "github.com/stateful/runme/v3/pkg/api/gen/proto/go/runme/runner/v2alpha1"
"github.com/stateful/runme/v3/pkg/document"
"github.com/stateful/runme/v3/pkg/project"
)
Expand Down Expand Up @@ -115,6 +116,9 @@ func runCodeBlock(
if err != nil {
return err
}

cfg.Mode = runnerv2alpha1.CommandMode_COMMAND_MODE_CLI

cmd, err := factory.Build(cfg, options)
if err != nil {
return err
Expand Down
10 changes: 5 additions & 5 deletions internal/command/command_inline_shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ func (c *inlineShellCommand) Wait() error {
err := c.internalCommand.Wait()

if c.envCollector != nil {
if cErr := c.collectEnv(); cErr != nil {
c.logger.Info("failed to collect the environment", zap.Error(cErr))
if err == nil {
err = cErr
}
c.logger.Info("collecting the environment after the script execution")
cErr := c.collectEnv()
c.logger.Info("collected the environment after the script execution", zap.Error(cErr))
if cErr != nil && err == nil {
err = cErr
}
}

Expand Down
48 changes: 20 additions & 28 deletions internal/command/command_native.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@ import (
"go.uber.org/zap"
)

// SignalToProcessGroup is used in tests to disable sending signals to a process group.
var SignalToProcessGroup = true

type nativeCommand struct {
*base

logger *zap.Logger
disableNewProcessID bool
logger *zap.Logger

cmd *exec.Cmd
}
Expand All @@ -34,21 +32,14 @@ func (c *nativeCommand) Pid() int {
func (c *nativeCommand) Start(ctx context.Context) (err error) {
stdin := c.Stdin()

// TODO(adamb): include explanation why it is needed.
if f, ok := stdin.(*os.File); ok && f != nil {
// Duplicate /dev/stdin.
newStdinFd, err := dup(f.Fd())
if err != nil {
return errors.Wrap(err, "failed to dup stdin")
}
closeOnExec(newStdinFd)

// Setting stdin to the non-block mode fails on the simple "read" command.
// On the other hand, it allows to use SetReadDeadline().
// It turned out it's not needed, but keeping the code here for now.
// if err := syscall.SetNonblock(newStdinFd, true); err != nil {
// return nil, errors.Wrap(err, "failed to set new stdin fd in non-blocking mode")
// }

stdin = os.NewFile(uintptr(newStdinFd), "")
}

Expand All @@ -69,46 +60,47 @@ func (c *nativeCommand) Start(ctx context.Context) (err error) {
c.cmd.Stdout = c.Stdout()
c.cmd.Stderr = c.Stderr()

// Set the process group ID of the program.
// It is helpful to stop the program and its
// children.
// Note that Setsid set in setSysProcAttrCtty()
// already starts a new process group.
// Warning: it does not work with interactive programs
// like "python", hence, it's commented out.
// setSysProcAttrPgid(c.cmd)
if !c.disableNewProcessID {
// Creating a new process group is required to properly replicate a behaviour
// similar to CTRL-C in the terminal, which sends a SIGINT to the whole group.
setSysProcAttrPgid(c.cmd)
}

c.logger.Info("starting a native command", zap.Any("config", redactConfig(c.ProgramConfig())))
c.logger.Info("starting", zap.Any("config", redactConfig(c.ProgramConfig())))
if err := c.cmd.Start(); err != nil {
return errors.WithStack(err)
}
c.logger.Info("a native command started")
c.logger.Info("started")

return nil
}

func (c *nativeCommand) Signal(sig os.Signal) error {
c.logger.Info("stopping the native command with a signal", zap.Stringer("signal", sig))
c.logger.Info("stopping with signal", zap.Stringer("signal", sig))

if SignalToProcessGroup {
if !c.disableNewProcessID {
c.logger.Info("signaling to the process group", zap.Stringer("signal", sig))
// Try to terminate the whole process group. If it fails, fall back to stdlib methods.
err := signalPgid(c.cmd.Process.Pid, sig)
if err == nil {
return nil
}
c.logger.Info("failed to terminate process group; trying Process.Signal()", zap.Error(err))
c.logger.Info("failed to signal the process group; trying regular signaling", zap.Error(err))
}

if err := c.cmd.Process.Signal(sig); err != nil {
c.logger.Info("failed to signal process; trying Process.Kill()", zap.Error(err))
if sig == os.Kill {
return errors.WithStack(err)
}
c.logger.Info("failed to signal the process; trying kill signal", zap.Error(err))
return errors.WithStack(c.cmd.Process.Kill())
}

return nil
}

func (c *nativeCommand) Wait() (err error) {
c.logger.Info("waiting for the native command to finish")
c.logger.Info("waiting for finish")

var stderr []byte
err = errors.WithStack(c.cmd.Wait())
Expand All @@ -119,7 +111,7 @@ func (c *nativeCommand) Wait() (err error) {
}
}

c.logger.Info("the native command finished", zap.Error(err), zap.ByteString("stderr", stderr))
c.logger.Info("finished", zap.Error(err), zap.ByteString("stderr", stderr))

return
}
12 changes: 10 additions & 2 deletions internal/command/command_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,20 @@ import (
"golang.org/x/sys/unix"
)

func setSysProcAttrCtty(cmd *exec.Cmd) {
func setSysProcAttrCtty(cmd *exec.Cmd, tty int) {
if cmd.SysProcAttr == nil {
cmd.SysProcAttr = &syscall.SysProcAttr{}
}
cmd.SysProcAttr.Setsid = true
cmd.SysProcAttr.Ctty = tty
cmd.SysProcAttr.Setctty = true
cmd.SysProcAttr.Setsid = true
}

func setSysProcAttrPgid(cmd *exec.Cmd) {
if cmd.SysProcAttr == nil {
cmd.SysProcAttr = &syscall.SysProcAttr{}
}
cmd.SysProcAttr.Setpgid = true
}

func disableEcho(fd uintptr) error {
Expand Down
6 changes: 0 additions & 6 deletions internal/command/command_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@ import (
"github.com/stateful/runme/v3/pkg/document/identity"
)

func init() {
// Set to false to disable sending signals to process groups in tests.
// This can be turned on if setSysProcAttrPgid() is called in Start().
SignalToProcessGroup = false
}

func TestCommand(t *testing.T) {
testCases := []struct {
name string
Expand Down
59 changes: 34 additions & 25 deletions internal/command/command_virtual.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ func (c *virtualCommand) Start(ctx context.Context) (err error) {
}

if !c.isEchoEnabled {
c.logger.Info("disabling echo")
if err := disableEcho(c.tty.Fd()); err != nil {
return err
}
Expand All @@ -69,6 +70,7 @@ func (c *virtualCommand) Start(ctx context.Context) (err error) {
if err != nil {
return err
}
c.logger.Info("detected program path and arguments", zap.String("program", program), zap.Strings("args", args))

c.cmd = exec.CommandContext(
ctx,
Expand All @@ -81,22 +83,28 @@ func (c *virtualCommand) Start(ctx context.Context) (err error) {
c.cmd.Stdout = c.tty
c.cmd.Stderr = c.tty

setSysProcAttrCtty(c.cmd)
// Create a new session and set the controlling terminal to tty.
// The new process group is created automatically so that sending
// a signal to the command will affect the whole group.
// 3 is because stdin, stdout, stderr + i-th element in ExtraFiles.
setSysProcAttrCtty(c.cmd, 3)
c.cmd.ExtraFiles = []*os.File{c.tty}

c.logger.Info("starting a virtual command", zap.Any("config", redactConfig(c.ProgramConfig())))
c.logger.Info("starting", zap.Any("config", redactConfig(c.ProgramConfig())))
if err := c.cmd.Start(); err != nil {
return errors.WithStack(err)
}
c.logger.Info("started")

if !isNil(c.stdin) {
c.wg.Add(1)
go func() {
defer c.wg.Done()
n, err := io.Copy(c.pty, c.stdin)
c.logger.Info("finished copying from stdin to pty", zap.Error(err), zap.Int64("count", n))
if err != nil {
c.setErr(errors.WithStack(err))
}
c.logger.Info("copied from stdin to pty", zap.Error(err), zap.Int64("count", n))
}()
}

Expand All @@ -112,54 +120,59 @@ func (c *virtualCommand) Start(ctx context.Context) (err error) {
// a master pseudo-terminal which no longer has an open slave.
// See https://github.com/creack/pty/issues/21.
if errors.Is(err, syscall.EIO) {
c.logger.Debug("failed to copy from pty to stdout; handled EIO")
c.logger.Info("failed to copy from pty to stdout; handled EIO")
return
}
if errors.Is(err, os.ErrClosed) {
c.logger.Debug("failed to copy from pty to stdout; handled ErrClosed")
c.logger.Info("failed to copy from pty to stdout; handled ErrClosed")
return
}

c.logger.Info("failed to copy from pty to stdout", zap.Error(err))

c.setErr(errors.WithStack(err))
} else {
c.logger.Debug("finished copying from pty to stdout", zap.Int64("count", n))
}

c.logger.Info("copied from pty to stdout", zap.Int64("count", n))
}()
}

c.logger.Info("a virtual command started")

return nil
}

func (c *virtualCommand) Signal(sig os.Signal) error {
c.logger.Info("stopping the virtual command with signal", zap.String("signal", sig.String()))
c.logger.Info("stopping with signal", zap.String("signal", sig.String()))

// Try to terminate the whole process group. If it fails, fall back to stdlib methods.
if err := signalPgid(c.cmd.Process.Pid, sig); err != nil {
c.logger.Info("failed to terminate process group; trying Process.Signal()", zap.Error(err))
if err := c.cmd.Process.Signal(sig); err != nil {
c.logger.Info("failed to signal process; trying Process.Kill()", zap.Error(err))
return errors.WithStack(c.cmd.Process.Kill())
err := signalPgid(c.cmd.Process.Pid, sig)
if err == nil {
return nil
}

c.logger.Info("failed to signal the process group; trying regular signaling", zap.Error(err))

if err := c.cmd.Process.Signal(sig); err != nil {
if sig == os.Kill {
return errors.WithStack(err)
}
c.logger.Info("failed to signal the process; trying kill signal", zap.Error(err))
return errors.WithStack(c.cmd.Process.Kill())
}

return nil
}

func (c *virtualCommand) Wait() (err error) {
c.logger.Info("waiting for the virtual command to finish")
c.logger.Info("waiting for finish")
err = errors.WithStack(c.cmd.Wait())
c.logger.Info("the virtual command finished", zap.Error(err))
c.logger.Info("finished", zap.Error(err))

errIO := c.closeIO()
c.logger.Info("closed IO of the virtual command", zap.Error(errIO))
c.logger.Info("closed IO", zap.Error(errIO))
if err == nil && errIO != nil {
err = errIO
}

c.logger.Info("waiting IO goroutines")
c.wg.Wait()
c.logger.Info("finished waiting for IO goroutines")

c.mu.Lock()
if err == nil && c.err != nil {
Expand Down Expand Up @@ -192,10 +205,6 @@ func (c *virtualCommand) closeIO() (err error) {
err = multierr.Append(err, errors.WithMessage(errClose, "failed to close tty"))
}

// if err := c.pty.Close(); err != nil {
// return errors.WithMessage(err, "failed to close pty")
// }

return
}

Expand Down
2 changes: 1 addition & 1 deletion internal/command/command_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/pkg/errors"
)

func setSysProcAttrCtty(cmd *exec.Cmd) {}
func setSysProcAttrCtty(cmd *exec.Cmd, tty int) {}

func setSysProcAttrPgid(cmd *exec.Cmd) {}

Expand Down
Loading