Skip to content

Commit

Permalink
Optionally return the inner pid from spawn()
Browse files Browse the repository at this point in the history
bubblewrap does not support forwarding signals yet,
see containers/bubblewrap#586. As a workaround,
we need to make sure we send our signals to the inner process. To
make this work, we create a pipe, pass it through to the subprocess,
and prefix with a bash command that writes its pid to the pipe before
exec-ing the actual command.

The other thing we get from this is that we can register the inner pid
as a scope which makes the systemctl status output for the scopes we
create a lot more useful.
  • Loading branch information
DaanDeMeyer committed Apr 13, 2024
1 parent 6e7bcc7 commit 0880aeb
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 19 deletions.
28 changes: 15 additions & 13 deletions mkosi/qemu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
import random
import shutil
import signal
import socket
import struct
import subprocess
Expand All @@ -38,7 +39,7 @@
from mkosi.log import ARG_DEBUG, die
from mkosi.mounts import finalize_source_mounts
from mkosi.partition import finalize_root, find_partitions
from mkosi.run import SD_LISTEN_FDS_START, AsyncioThread, find_binary, fork_and_wait, run, spawn
from mkosi.run import SD_LISTEN_FDS_START, AsyncioThread, find_binary, fork_and_wait, run, spawn, kill
from mkosi.sandbox import Mount
from mkosi.tree import copy_tree, rmtree
from mkosi.types import PathString
Expand Down Expand Up @@ -274,15 +275,15 @@ def start_swtpm(config: Config) -> Iterator[Path]:
cmdline,
pass_fds=(sock.fileno(),),
sandbox=config.sandbox(mounts=[Mount(state, state)]),
) as proc:
) as (swtpm, innerpid):
allocate_scope(
config,
name=f"mkosi-swtpm-{config.machine_or_name()}",
pid=proc.pid,
pid=innerpid,
description=f"swtpm for {config.machine_or_name()}",
)
yield path
proc.terminate()
kill(swtpm, innerpid, signal.SIGTERM)


def find_virtiofsd(*, tools: Path = Path("/")) -> Optional[Path]:
Expand Down Expand Up @@ -354,15 +355,15 @@ def start_virtiofsd(config: Config, directory: PathString, *, name: str, selinux
mounts=[Mount(directory, directory)],
options=["--uid", "0", "--gid", "0", "--cap-add", "all"],
),
) as proc:
) as (virtiofsd, innerpid):
allocate_scope(
config,
name=f"mkosi-virtiofsd-{name}",
pid=proc.pid,
pid=innerpid,
description=f"virtiofsd for {directory}",
)
yield path
proc.terminate()
kill(virtiofsd, innerpid, signal.SIGTERM)


@contextlib.contextmanager
Expand Down Expand Up @@ -442,15 +443,16 @@ def start_journal_remote(config: Config, sockfd: int) -> Iterator[None]:
# If all logs go into a single file, disable compact mode to allow for journal files exceeding 4G.
env={"SYSTEMD_JOURNAL_COMPACT": "0" if config.forward_journal.suffix == ".journal" else "1"},
foreground=False,
) as proc:
) as (remote, innerpid):
allocate_scope(
config,
name=f"mkosi-journal-remote-{config.machine_or_name()}",
pid=proc.pid,
pid=innerpid,
description=f"mkosi systemd-journal-remote for {config.machine_or_name()}",
)
yield
proc.terminate()
kill(remote, innerpid, signal.SIGTERM)



@contextlib.contextmanager
Expand Down Expand Up @@ -1097,7 +1099,7 @@ def add_virtiofs_mount(
log=False,
foreground=True,
sandbox=config.sandbox(network=True, devices=True, relaxed=True),
) as qemu:
) as (qemu, innerpid):
# We have to close these before we wait for qemu otherwise we'll deadlock as qemu will never exit.
for fd in qemu_device_fds.values():
os.close(fd)
Expand All @@ -1106,10 +1108,10 @@ def add_virtiofs_mount(
allocate_scope(
config,
name=name,
pid=qemu.pid,
pid=innerpid,
description=f"mkosi Virtual Machine {name}",
)
register_machine(config, qemu.pid, fname)
register_machine(config, innerpid, fname)

if qemu.wait() == 0 and (status := int(notifications.get("EXIT_STATUS", 0))):
raise subprocess.CalledProcessError(status, cmdline)
Expand Down
42 changes: 37 additions & 5 deletions mkosi/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def run(
preexec_fn=preexec_fn,
success_exit_status=success_exit_status,
sandbox=sandbox,
) as process:
innerpid=False,
) as (process, _):
out, err = process.communicate(input)

return CompletedProcess(cmdline, process.returncode, out, err)
Expand All @@ -182,7 +183,8 @@ def spawn(
preexec_fn: Optional[Callable[[], None]] = None,
success_exit_status: Sequence[int] = (0,),
sandbox: AbstractContextManager[Sequence[PathString]] = contextlib.nullcontext([]),
) -> Iterator[Popen]:
innerpid: bool = True,
) -> Iterator[tuple[Popen, int]]:
assert sorted(set(pass_fds)) == list(pass_fds)

cmdline = [os.fspath(x) for x in cmdline]
Expand Down Expand Up @@ -271,6 +273,16 @@ def preexec() -> None:
# command.
prefix += ["sh", "-c", f"LISTEN_FDS={len(pass_fds)} LISTEN_PID=$$ exec $0 \"$@\""]

if prefix and innerpid:
r, w = os.pipe2(os.O_CLOEXEC)
q = fcntl.fcntl(w, fcntl.F_DUPFD_CLOEXEC, len(pass_fds) + 1)
os.close(w)
w = q
# dash doesn't support working with file descriptors higher than 9 so make sure we use bash.
prefix += ["bash", "-c", f"echo $$ >&{w} && exec {w}>&- && exec $0 \"$@\""]
else:
r, w = (None, None)

try:
with subprocess.Popen(
prefix + cmdline,
Expand All @@ -282,15 +294,24 @@ def preexec() -> None:
group=group,
# pass_fds only comes into effect after python has invoked the preexec function, so we make sure that
# pass_fds contains the file descriptors to keep open after we've done our transformation in preexec().
pass_fds=[SD_LISTEN_FDS_START + i for i in range(len(pass_fds))],
pass_fds=[SD_LISTEN_FDS_START + i for i in range(len(pass_fds))] + ([w] if w else []),
env=env,
cwd=cwd,
preexec_fn=preexec,
) as proc:
if w:
os.close(w)
pid = proc.pid
try:
yield proc
if r:
with open(r) as f:
s = f.read()
if s:
pid = int(s)

yield proc, pid
except BaseException:
proc.terminate()
kill(proc, pid, signal.SIGTERM)
raise
finally:
returncode = proc.wait()
Expand Down Expand Up @@ -339,6 +360,17 @@ def find_binary(*names: PathString, root: Path = Path("/")) -> Optional[Path]:
return None


def kill(process: Popen, innerpid: int, signal: int) -> None:
process.poll()
if process.returncode is not None:
return

try:
os.kill(innerpid, signal)
except ProcessLookupError:
pass


class AsyncioThread(threading.Thread):
"""
The default threading.Thread() is not interruptable, so we make our own version by using the concurrency
Expand Down
6 changes: 5 additions & 1 deletion mkosi/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,11 @@ def become_root() -> None:
# execute using flock so they don't execute before they can get a lock on the same temporary file, then we
# unshare the user namespace and finally we unlock the temporary file, which allows the newuidmap and newgidmap
# processes to execute. we then wait for the processes to finish before continuing.
with flock(lock) as fd, spawn(newuidmap) as uidmap, spawn(newgidmap) as gidmap:
with (
flock(lock) as fd,
spawn(newuidmap, innerpid=False) as (uidmap, _),
spawn(newgidmap, innerpid=False) as (gidmap, _)
):
unshare(CLONE_NEWUSER)
fcntl.flock(fd, fcntl.LOCK_UN)
uidmap.wait()
Expand Down

0 comments on commit 0880aeb

Please sign in to comment.