Skip to content

Commit

Permalink
Merge pull request #48 from DIAGNijmegen/fix-47
Browse files Browse the repository at this point in the history
Simplify multi-processing error handling
  • Loading branch information
pkcakeout committed Aug 12, 2024
2 parents 43408bb + 5baa601 commit 4666136
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ FROM --platform=linux/amd64 pytorch/pytorch
# Use a 'large' base container to show-case how to load pytorch and use the GPU (when enabled)

# Ensures that Python output to stdout/stderr is not buffered: prevents missing information when terminating
ENV PYTHONUNBUFFERED 1
ENV PYTHONUNBUFFERED=1

RUN groupadd -r user && useradd -m --no-log-init -r -g user user
USER user
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ formatted_build_info=$(date -d "$build_timestamp" +"%Y%m%d_%H%M%S")
# Set the output filename with timestamp and build information
output_filename="${SCRIPT_DIR}/${container_tag}_${formatted_build_info}.tar.gz"

printenv

# Save the Docker container and gzip it
docker save "$container_tag" | gzip -c > "$output_filename"

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
FROM --platform=linux/amd64 docker.io/library/python:3.11-slim

# Ensures that Python output to stdout/stderr is not buffered: prevents missing information when terminating
ENV PYTHONUNBUFFERED 1
ENV PYTHONUNBUFFERED=1

RUN groupadd -r user && useradd -m --no-log-init -r -g user user
USER user
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def process(job):
report += pformat(job)
report += "\n"

# Firstly, find the location of the results
# Firstly, find the location of the results
{% for ci in cookiecutter.phase.algorithm_outputs %}
{%- set py_slug = ci.slug | replace("-", "_") -%}
location_{{ py_slug }} = get_file_location(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import multiprocessing
import os
import signal
from concurrent.futures import ProcessPoolExecutor, as_completed
from multiprocessing import Manager, Process

Expand Down Expand Up @@ -44,6 +43,8 @@ def run_prediction_processing(*, fn, predictions):
- if any child process is terminated, all prediction processing will abort
- after prediction processing is done, all child processes are terminated
Note that the results are returned in completing order.
Parameters
----------
fn : function
Expand Down Expand Up @@ -99,30 +100,9 @@ def _start_pool_worker(fn, predictions, max_workers, results, errors):


def _pool_worker(*, fn, predictions, max_workers, results, errors):
terminating_child_processes = False

caught_exception = False
with ProcessPoolExecutor(max_workers=max_workers) as executor:
try:

def handle_error(error, prediction="Unknown"):
executor.shutdown(wait=False, cancel_futures=True)
errors.append((prediction, error))

nonlocal terminating_child_processes
terminating_child_processes = True
_terminate_child_processes()

def sigchld_handler(*_, **__):
if not terminating_child_processes:
handle_error(
RuntimeError(
"Child process was terminated unexpectedly"
)
)

# Register the SIGCHLD handler
signal.signal(signal.SIGCHLD, sigchld_handler)

# Submit the processing tasks of the predictions
futures = [
executor.submit(fn, prediction) for prediction in predictions
Expand All @@ -137,28 +117,39 @@ def sigchld_handler(*_, **__):
result = future.result()
results.append(result)
except Exception as e:
handle_error(e, prediction=future_to_predictions[future])
errors.append((future_to_predictions[future], e))

if not caught_exception: # Hard stop
caught_exception = True

executor.shutdown(wait=False, cancel_futures=True)
_terminate_child_processes()
finally:
terminating_child_processes = True
# Be aggresive in cleaning up any left-over processes
_terminate_child_processes()


def _terminate_child_processes():
current_process = psutil.Process(os.getpid())
children = current_process.children(recursive=True)
process = psutil.Process(os.getpid())
children = process.children(recursive=True)
for child in children:
try:
child.terminate()
except psutil.NoSuchProcess:
pass # Not a problem

# Wait for processes to terminate
gone, still_alive = psutil.wait_procs(children, timeout=5)
_, still_alive = psutil.wait_procs(children, timeout=5)

# Forcefully kill any remaining processes
for p in still_alive:
print(f"Forcefully killing child process {p.pid}")
try:
p.kill()
except psutil.NoSuchProcess:
pass # That is fine

# Finally, prevent zombies by waiting for all child processes
try:
os.waitpid(-1, 0)
except ChildProcessError:
pass # No child processes, that if fine
46 changes: 24 additions & 22 deletions tests/test_evaluate_helpers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
import signal
import sys
import time
from functools import partial
from multiprocessing import Process
from unittest import mock

import psutil
import pytest

# Do some creating path hacking to be able to import the helpers
Expand All @@ -27,10 +27,12 @@

# Some of the test below, if things go wrong, can potentially deadlock.
# So we set a maximum runtime
pytestmark = pytest.mark.timeout(5)
pytestmark = pytest.mark.timeout(4)


def working_process(p):
if p == "prediction1":
time.sleep(2)
return f"{p} result"


Expand All @@ -55,14 +57,18 @@ def forever_process(*_):
time.sleep(1)


def send_signals_to_process(process, signal_to_send, interval):
while True:
try:
os.kill(process.pid, signal_to_send)
except ProcessLookupError:
# Race conditions sometimes have this try and send a signal even though
# the process is already terminated
pass
def stop_children(process, interval):
stopped = False
while not stopped:
process = psutil.Process(process.pid)
children = process.children(recursive=True)
if children:
for child in children:
try:
child.kill()
except psutil.NoSuchProcess:
pass # Not a problem
stopped = True
time.sleep(interval)


Expand All @@ -71,7 +77,7 @@ def test_prediction_processing():
result = run_prediction_processing(
fn=working_process, predictions=predictions
)
assert ["prediction1 result", "prediction2 result"] == result
assert {"prediction1 result", "prediction2 result"} == set(result)


def test_prediction_processing_error():
Expand Down Expand Up @@ -103,18 +109,14 @@ def test_prediction_processing_killing_of_child_processes():
def test_prediction_processing_catching_killing_of_child_processes():
predictions = ["prediction1", "prediction2"]

child_terminator = None
child_stopper = None

# Set up the fake child murder scene
def add_child_terminator(*args, **kwargs):
process = _start_pool_worker(*args, **kwargs)
nonlocal child_terminator
child_terminator = Process(
target=partial(
send_signals_to_process, process, signal.SIGCHLD, 0.5
)
)
child_terminator.start() # Hasta la vista, baby
nonlocal child_stopper
child_stopper = Process(target=partial(stop_children, process, 0.5))
child_stopper.start() # Hasta la vista, baby
return process

try:
Expand All @@ -124,9 +126,9 @@ def add_child_terminator(*args, **kwargs):
fn=forever_process, predictions=predictions
)
finally:
if child_terminator:
child_terminator.terminate()
if child_stopper:
child_stopper.terminate()

assert "Child process was terminated unexpectedly" in str(
assert "A process in the process pool was terminated abruptly" in str(
excinfo.value.error
)

0 comments on commit 4666136

Please sign in to comment.