diff --git a/src/py/flwr/common/exit_handlers.py b/src/py/flwr/common/exit_handlers.py index 7670fd58c059..ae722a211340 100644 --- a/src/py/flwr/common/exit_handlers.py +++ b/src/py/flwr/common/exit_handlers.py @@ -19,6 +19,7 @@ from threading import Thread from types import FrameType from typing import Callable, Optional +from uuid import uuid4 from grpc import Server @@ -30,6 +31,7 @@ signal.SIGINT: ExitCode.GRACEFUL_EXIT_SIGINT, signal.SIGTERM: ExitCode.GRACEFUL_EXIT_SIGTERM, } +_handlers: dict[str, Callable[[], None]] = {} # SIGQUIT is not available on Windows if hasattr(signal, "SIGQUIT"): @@ -38,6 +40,7 @@ def register_exit_handlers( event_type: EventType, + handlers: Optional[list[Callable[[], None]]] = None, exit_message: Optional[str] = None, grpc_servers: Optional[list[Server]] = None, bckg_threads: Optional[list[Thread]] = None, @@ -48,6 +51,8 @@ def register_exit_handlers( ---------- event_type : EventType The telemetry event that should be logged before exit. + handlers : Optional[List[Callable[[], None]]] (default: None) + An optional list of handlers to be called before exiting. exit_message : Optional[str] (default: None) The message to be logged before exiting. grpc_servers: Optional[List[Server]] (default: None) @@ -68,6 +73,9 @@ def graceful_exit_handler(signalnum: int, _frame: FrameType) -> None: # Reset to default handler signal.signal(signalnum, default_handlers[signalnum]) # type: ignore + for handler in _handlers.values(): + handler() + if grpc_servers is not None: for grpc_server in grpc_servers: grpc_server.stop(grace=1) @@ -83,7 +91,28 @@ def graceful_exit_handler(signalnum: int, _frame: FrameType) -> None: event_type=event_type, ) + # Register exit handlers + if handlers: + for handler in handlers: + _handlers[str(uuid4())] = handler + # Register signal handlers for sig in SIGNAL_TO_EXIT_CODE: default_handler = signal.signal(sig, graceful_exit_handler) # type: ignore default_handlers[sig] = default_handler # type: ignore + + +def add_exit_handler(handler: Callable[[], None], name: Optional[str] = None) -> None: + """Add an exit handler.""" + if name is None: + name = str(uuid4()) + + _handlers[name] = handler + + +def remove_exit_handler(name: str) -> None: + """Remove an exit handler.""" + if name in _handlers: + del _handlers[name] + else: + raise KeyError(f"Handler with name '{name}' not found.") diff --git a/src/py/flwr/common/exit_handlers_test.py b/src/py/flwr/common/exit_handlers_test.py new file mode 100644 index 000000000000..89b0b854192c --- /dev/null +++ b/src/py/flwr/common/exit_handlers_test.py @@ -0,0 +1,91 @@ +# Copyright 2025 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for exit handler utils.""" + + +import os +import signal +import unittest +from unittest.mock import Mock, patch + +from .exit_handlers import ( + _handlers, + add_exit_handler, + register_exit_handlers, + remove_exit_handler, +) +from .telemetry import EventType + + +class TestExitHandlers(unittest.TestCase): + """Tests for exit handler utils.""" + + def setUp(self) -> None: + """Clear all exit handlers before each test.""" + _handlers.clear() + + @patch("sys.exit") + def test_register_exit_handlers(self, mock_sys_exit: Mock) -> None: + """Test register_exit_handlers.""" + # Prepare + handlers = [Mock(), Mock()] + register_exit_handlers(EventType.PING, handlers=handlers) # type: ignore + + # Execute + os.kill(os.getpid(), signal.SIGTERM) + + # Assert + for handler in handlers: + handler.assert_called() + mock_sys_exit.assert_called() + self.assertEqual(list(_handlers.values()), handlers) + + def test_add_exit_handler(self) -> None: + """Test add_exit_handler.""" + # Prepare + handler = Mock() + + # Execute + add_exit_handler(handler, "mock_handler") + + # Assert + self.assertIn("mock_handler", _handlers) + self.assertEqual(_handlers["mock_handler"], handler) + + def test_remove_exit_handler(self) -> None: + """Test remove_exit_handler.""" + # Prepare + handler = Mock() + add_exit_handler(handler, "mock_handler") + + # Execute + remove_exit_handler("mock_handler") + + # Assert + self.assertNotIn("mock_handler", _handlers) + + def test_remove_exit_handler_not_found(self) -> None: + """Test remove_exit_handler with invalid name.""" + # Prepare + handler = Mock() + add_exit_handler(handler, "mock_handler") + + # Execute + with self.assertRaises(KeyError): + remove_exit_handler("non_existent_handler") + + # Assert + self.assertIn("mock_handler", _handlers) + self.assertEqual(_handlers["mock_handler"], handler)