diff --git a/task_processing/plugins/kubernetes/kubernetes_pod_executor.py b/task_processing/plugins/kubernetes/kubernetes_pod_executor.py index 46ea343..a52e2b0 100644 --- a/task_processing/plugins/kubernetes/kubernetes_pod_executor.py +++ b/task_processing/plugins/kubernetes/kubernetes_pod_executor.py @@ -1,7 +1,8 @@ import logging +import multiprocessing import queue -import threading import time +from multiprocessing import JoinableQueue from queue import Queue from typing import Collection from typing import Optional @@ -51,8 +52,8 @@ logger = logging.getLogger(__name__) -POD_WATCH_THREAD_JOIN_TIMEOUT_S = 1.0 -POD_EVENT_THREAD_JOIN_TIMEOUT_S = 1.0 +POD_WATCH_PROCESS_JOIN_TIMEOUT_S = 1.0 +POD_EVENT_PROCESS_JOIN_TIMEOUT_S = 1.0 QUEUE_GET_TIMEOUT_S = 0.5 SUPPORTED_POD_MODIFIED_EVENT_PHASES = { "Failed", @@ -100,7 +101,7 @@ def __init__( self.stopping = False self.task_metadata: PMap[str, KubernetesTaskMetadata] = pmap() - self.task_metadata_lock = threading.RLock() + self.task_metadata_lock = multiprocessing.RLock() if task_configs: for task_config in task_configs: self._initialize_existing_task(task_config) @@ -110,33 +111,33 @@ def __init__( # and we've opted to not do that processing in the Pod event watcher thread so as to keep # that logic for the threads that operate on them as simple as possible and to make it # possible to cleanly shutdown both of these. - self.pending_events: "Queue[PodEvent]" = Queue() - self.event_queue: "Queue[Event]" = Queue() + self.pending_events: "JoinableQueue[PodEvent]" = JoinableQueue() + self.event_queue: "JoinableQueue[Event]" = JoinableQueue() # TODO(TASKPROC-243): keep track of resourceVersion so that we can continue event processing # from where we left off on restarts - self.pod_event_watch_threads = [] + self.pod_event_watch_processes = [] self.watches = [] for kube_client in [self.kube_client] + self.watcher_kube_clients: watch = kube_watch.Watch() - pod_event_watch_thread = threading.Thread( + pod_event_watch_process = multiprocessing.Process( target=self._pod_event_watch_loop, args=(kube_client, watch), - # ideally this wouldn't be a daemon thread, but a watch.Watch() only checks + # ideally this wouldn't be a daemon process, but a watch.Watch() only checks # if it should stop after receiving an event - and it's possible that we # have periods with no events so instead we'll attempt to stop the watch # and then join() with a small timeout to make sure that, if we shutdown - # with the thread alive, we did not drop any events + # with the process alive, we did not drop any events daemon=True, ) - pod_event_watch_thread.start() - self.pod_event_watch_threads.append(pod_event_watch_thread) + pod_event_watch_process.start() + self.pod_event_watch_processes.append(pod_event_watch_process) self.watches.append(watch) - self.pending_event_processing_thread = threading.Thread( + self.pending_event_processing_process = multiprocessing.Process( target=self._pending_event_processing_loop, ) - self.pending_event_processing_thread.start() + self.pending_event_processing_process.start() def _initialize_existing_task(self, task_config: KubernetesTaskConfig) -> None: """Generates task_metadata in UNKNOWN state for an existing KubernetesTaskConfig. @@ -427,9 +428,18 @@ def _pending_event_processing_loop(self) -> None: """ logger.debug("Starting Pod event processing.") event = None - while not self.stopping or not self.pending_events.empty(): + while True: try: event = self.pending_events.get(timeout=QUEUE_GET_TIMEOUT_S) + if event["type"] == "STOP": + logger.debug("Received a STOP event - stopping processing.") + try: + self.pending_events.task_done() + except ValueError: + logger.error( + "task_done() called on pending events queue too many times!" + ) + break self._process_pod_event(event) except queue.Empty: logger.debug( @@ -699,33 +709,49 @@ def kill(self, task_id: str) -> bool: return terminated def stop(self) -> None: - logger.debug("Preparing to stop all KubernetesPodExecutor threads.") + logger.debug("Preparing to stop all KubernetesPodExecutor processes.") self.stopping = True - logger.debug("Signaling Pod event Watch to stop streaming events...") # make sure that we've stopped watching for events before calling join() - otherwise, # join() will block until we hit the configured timeout (or forever with no timeout). for watch in self.watches: watch.stop() + + # Add a STOP event to the queue below after stopping the watch to ensure + # no events will be added after the STOP event + stop_event = PodEvent(type="STOP", object=None, raw_object={}) + self.pending_events.put(stop_event) + # timeout arbitrarily chosen - we mostly just want to make sure that we have a small # grace period to flush the current event to the pending_events queue as well as - # any other clean-up - it's possible that after this join() the thread is still alive + # any other clean-up - it's possible that after this join() the process is still alive # but in that case we can be reasonably sure that we're not dropping any data. - for pod_event_watch_thread in self.pod_event_watch_threads: - pod_event_watch_thread.join(timeout=POD_WATCH_THREAD_JOIN_TIMEOUT_S) + for pod_event_watch_process in self.pod_event_watch_processes: + pod_event_watch_process.join(timeout=POD_WATCH_PROCESS_JOIN_TIMEOUT_S) logger.debug("Waiting for all pending PodEvents to be processed...") # once we've stopped updating the pending events queue, we then wait until we're done # processing any events we've received - this will wait until task_done() has been # called for every item placed in this queue + # since we stopped the watch above, we don't expect any more events to be added to the queue + # this ensure that we're not stuck due to the stop event, if it wasn't processed by the _pending_event_processing_loop loop + if ( + self.pending_events.qsize() == 1 + and self.pending_events.get(timeout=QUEUE_GET_TIMEOUT_S)["type"] == "STOP" + ): + try: + self.pending_events.task_done() + except ValueError: + logger.error( + "task_done() called on pending events queue too many times!" + ) self.pending_events.join() logger.debug("All pending PodEvents have been processed.") # and then give ourselves time to do any post-stop cleanup - self.pending_event_processing_thread.join( - timeout=POD_EVENT_THREAD_JOIN_TIMEOUT_S + self.pending_event_processing_process.join( + timeout=POD_EVENT_PROCESS_JOIN_TIMEOUT_S ) - logger.debug("Done stopping KubernetesPodExecutor!") - def get_event_queue(self) -> "Queue[Event]": + def get_event_queue(self) -> "JoinableQueue[Event]": return self.event_queue diff --git a/task_processing/plugins/kubernetes/types.py b/task_processing/plugins/kubernetes/types.py index 8437a7e..3bf7e3e 100644 --- a/task_processing/plugins/kubernetes/types.py +++ b/task_processing/plugins/kubernetes/types.py @@ -81,7 +81,8 @@ class NodeAffinity(TypedDict): class PodEvent(TypedDict): - # there are only 3 possible types for Pod events: ADDED, DELETED, MODIFIED + # there are only 4 possible types for Pod events: ADDED, DELETED, MODIFIED or STOP + # STOP is a custom type that we use to signal STOP to all KubernetesPodExecutor processes # XXX: this should be typed as Literal["ADDED", "DELETED", "MODIFIED"] once we drop support # for older Python versions type: str diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index a2e136b..facc2be 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,3 +1,4 @@ +import multiprocessing import threading import mock @@ -14,3 +15,9 @@ def mock_sleep(): def mock_Thread(): with mock.patch.object(threading, "Thread") as mock_Thread: yield mock_Thread + + +@pytest.fixture +def mock_Process(): + with mock.patch.object(multiprocessing, "Process") as mock_Process: + yield mock_Process diff --git a/tests/unit/plugins/kubernetes/kubernetes_pod_executor_test.py b/tests/unit/plugins/kubernetes/kubernetes_pod_executor_test.py index 96356c5..01e5bfb 100644 --- a/tests/unit/plugins/kubernetes/kubernetes_pod_executor_test.py +++ b/tests/unit/plugins/kubernetes/kubernetes_pod_executor_test.py @@ -12,6 +12,7 @@ from kubernetes.client import V1Pod from kubernetes.client import V1PodSecurityContext from kubernetes.client import V1PodSpec +from kubernetes.client import V1PodStatus from kubernetes.client import V1ProjectedVolumeSource from kubernetes.client import V1ResourceRequirements from kubernetes.client import V1SecurityContext @@ -38,7 +39,7 @@ @pytest.fixture -def k8s_executor(mock_Thread): +def k8s_executor(mock_Process): with mock.patch( "task_processing.plugins.kubernetes.kube_client.kube_config.load_kube_config", autospec=True, @@ -53,7 +54,7 @@ def k8s_executor(mock_Thread): @pytest.fixture -def k8s_executor_with_watcher_clusters(mock_Thread): +def k8s_executor_with_watcher_clusters(mock_Process): with mock.patch( "task_processing.plugins.kubernetes.kube_client.kube_config.load_kube_config", autospec=True, @@ -87,7 +88,7 @@ def mock_task_configs(): @pytest.fixture -def k8s_executor_with_tasks(mock_Thread, mock_task_configs): +def k8s_executor_with_tasks(mock_Process, mock_task_configs): with mock.patch( "task_processing.plugins.kubernetes.kube_client.kube_config.load_kube_config", autospec=True, @@ -105,13 +106,13 @@ def k8s_executor_with_tasks(mock_Thread, mock_task_configs): def test_init_watch_setup(k8s_executor): - assert len(k8s_executor.watches) == len(k8s_executor.pod_event_watch_threads) == 1 + assert len(k8s_executor.watches) == len(k8s_executor.pod_event_watch_processes) == 1 def test_init_watch_setup_multicluster(k8s_executor_with_watcher_clusters): assert ( len(k8s_executor_with_watcher_clusters.watches) - == len(k8s_executor_with_watcher_clusters.pod_event_watch_threads) + == len(k8s_executor_with_watcher_clusters.pod_event_watch_processes) == 2 ) @@ -697,15 +698,18 @@ def test_process_event_enqueues_task_processing_events_pending_to_running(k8s_ex mock_pod.metadata.name = "test.1234" mock_pod.status.phase = "Running" mock_pod.spec.node_name = "node-1-2-3-4" + task_config = KubernetesTaskConfig( + image="test", command="test", uuid="uuid", name="pod--name" + ) mock_event = PodEvent( type="MODIFIED", object=mock_pod, - raw_object=mock.Mock(), + raw_object={}, ) k8s_executor.task_metadata = pmap( { mock_pod.metadata.name: KubernetesTaskMetadata( - task_config=mock.Mock(spec=KubernetesTaskConfig), + task_config=task_config, task_state=KubernetesTaskState.TASK_PENDING, task_state_history=v(), ) @@ -736,15 +740,18 @@ def test_process_event_enqueues_task_processing_events_running_to_terminal( mock_pod.metadata.name = "test.1234" mock_pod.status.phase = phase mock_pod.spec.node_name = "node-1-2-3-4" + task_config = KubernetesTaskConfig( + image="test", command="test", uuid="uuid", name="pod--name" + ) mock_event = PodEvent( type="MODIFIED", object=mock_pod, - raw_object=mock.Mock(), + raw_object={}, ) k8s_executor.task_metadata = pmap( { mock_pod.metadata.name: KubernetesTaskMetadata( - task_config=mock.Mock(spec=KubernetesTaskConfig), + task_config=task_config, task_state=KubernetesTaskState.TASK_RUNNING, task_state_history=v(), ) @@ -779,7 +786,7 @@ def test_process_event_enqueues_task_processing_events_no_state_transition( mock_event = PodEvent( type="MODIFIED", object=mock_pod, - raw_object=mock.Mock(), + raw_object={}, ) k8s_executor.task_metadata = pmap( { @@ -807,15 +814,28 @@ def test_process_event_enqueues_task_processing_events_no_state_transition( def test_pending_event_processing_loop_processes_remaining_events_after_stop( k8s_executor, ): + # Create a V1Pod object to use for testing multiprocess instead of mock.Mock() as + # it is not pickleable + test_pod = V1Pod( + metadata=V1ObjectMeta( + name="test-pod", + namespace="task_processing_tests", + ) + ) k8s_executor.pending_events.put( PodEvent( type="ADDED", - object=mock.Mock(), - raw_object=mock.Mock(), + object=test_pod, + raw_object={}, + ) + ) + k8s_executor.pending_events.put( + PodEvent( + type="STOP", + object=None, + raw_object={}, ) ) - k8s_executor.stopping = True - with mock.patch.object( k8s_executor, "_process_pod_event", @@ -835,15 +855,18 @@ def test_process_event_enqueues_task_processing_events_deleted( mock_pod.status.phase = "Running" mock_pod.status.host_ip = "1.2.3.4" mock_pod.spec.node_name = "kubenode" + task_config = KubernetesTaskConfig( + image="test", command="test", uuid="uuid", name="pod--name" + ) mock_event = PodEvent( type="DELETED", object=mock_pod, - raw_object=mock.Mock(), + raw_object={}, ) k8s_executor.task_metadata = pmap( { mock_pod.metadata.name: KubernetesTaskMetadata( - task_config=mock.Mock(spec=KubernetesTaskConfig), + task_config=task_config, task_state=KubernetesTaskState.TASK_RUNNING, task_state_history=v(), ) @@ -870,14 +893,13 @@ def test_initial_task_metadata(k8s_executor_with_tasks): def test_reconcile_missing_pod( k8s_executor, ): - task_config = mock.Mock(spec=KubernetesTaskConfig) - task_config.pod_name = "pod--name.uuid" - task_config.name = "job-name" - + task_config = KubernetesTaskConfig( + image="test", command="test", uuid="uuid", name="pod--name" + ) k8s_executor.task_metadata = pmap( { task_config.pod_name: KubernetesTaskMetadata( - task_config=mock.Mock(spec=KubernetesTaskConfig), + task_config=task_config, task_state=KubernetesTaskState.TASK_UNKNOWN, task_state_history=v(), ) @@ -899,14 +921,13 @@ def test_reconcile_missing_pod( def test_reconcile_multicluster( k8s_executor_with_watcher_clusters, ): - task_config = mock.Mock(spec=KubernetesTaskConfig) - task_config.pod_name = "pod--name.uuid" - task_config.name = "job-name" - + task_config = KubernetesTaskConfig( + image="test", command="test", uuid="uuid", name="pod--name" + ) k8s_executor_with_watcher_clusters.task_metadata = pmap( { task_config.pod_name: KubernetesTaskMetadata( - task_config=mock.Mock(spec=KubernetesTaskConfig), + task_config=task_config, task_state=KubernetesTaskState.TASK_UNKNOWN, task_state_history=v(), ) @@ -968,10 +989,9 @@ def test_reconcile_existing_pods(k8s_executor, mock_task_configs): def test_reconcile_api_error( k8s_executor, ): - task_config = mock.Mock(spec=KubernetesTaskConfig) - task_config.pod_name = "pod--name.uuid" - task_config.name = "job-name" - + task_config = KubernetesTaskConfig( + image="test", command="test", uuid="uuid", name="pod--name" + ) with mock.patch.object( k8s_executor, "kube_client", autospec=True ) as mock_kube_client: