Skip to content

Commit

Permalink
Prevent multiple shutdowns happening at the same time
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisvanrun committed Jul 31, 2024
1 parent 1305914 commit 1165dac
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,21 @@ def _start_pool_worker(fn, predictions, max_workers, results, errors):

def _pool_worker(*, fn, predictions, max_workers, results, errors):
terminating_child_processes = False

executor_shutting_down = False
with ProcessPoolExecutor(max_workers=max_workers) as executor:
try:

def handle_error(error, prediction="Unknown"):
executor.shutdown(wait=False, cancel_futures=True)
nonlocal terminating_child_processes
if terminating_child_processes:
return

nonlocal executor_shutting_down
if not executor_shutting_down:
executor_shutting_down = True
executor.shutdown(wait=False, cancel_futures=True)
errors.append((prediction, error))

nonlocal terminating_child_processes
terminating_child_processes = True
_terminate_child_processes()

Expand Down Expand Up @@ -139,13 +145,14 @@ def sigchld_handler(*_, **__):
except Exception as e:
handle_error(e, prediction=future_to_predictions[future])
finally:
terminating_child_processes = True
_terminate_child_processes()
if not terminating_child_processes:
terminating_child_processes = True
_terminate_child_processes()


def _terminate_child_processes():
current_process = psutil.Process(os.getpid())
children = current_process.children(recursive=True)
def _terminate_child_processes(pid=None):
process = psutil.Process(pid or os.getpid())
children = process.children(recursive=True)
for child in children:
try:
child.terminate()
Expand All @@ -157,7 +164,6 @@ def _terminate_child_processes():

# Forcefully kill any remaining processes
for p in still_alive:
print(f"Forcefully killing child process {p.pid}")
try:
p.kill()
except psutil.NoSuchProcess:
Expand Down
18 changes: 14 additions & 4 deletions tests/test_evaluate_helpers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
import signal
import sys
import time
from functools import partial
from multiprocessing import Process
from unittest import mock

import psutil
import pytest

# Do some creating path hacking to be able to import the helpers
Expand Down Expand Up @@ -66,6 +66,18 @@ def send_signals_to_process(process, signal_to_send, interval):
time.sleep(interval)


def terminate_children(process, interval):
while True:
process = psutil.Process(process.pid)
children = process.children(recursive=True)
for child in children:
try:
child.kill()
except psutil.NoSuchProcess:
pass # Not a problem
time.sleep(interval)


def test_prediction_processing():
predictions = ["prediction1", "prediction2"]
result = run_prediction_processing(
Expand Down Expand Up @@ -110,9 +122,7 @@ def add_child_terminator(*args, **kwargs):
process = _start_pool_worker(*args, **kwargs)
nonlocal child_terminator
child_terminator = Process(
target=partial(
send_signals_to_process, process, signal.SIGCHLD, 0.5
)
target=partial(terminate_children, process, 0.5)
)
child_terminator.start() # Hasta la vista, baby
return process
Expand Down

0 comments on commit 1165dac

Please sign in to comment.