Skip to content

Commit

Permalink
add utils and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Feb 11, 2025
1 parent 4f0201a commit d13564f
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 0 deletions.
29 changes: 29 additions & 0 deletions src/py/flwr/common/exit_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from threading import Thread
from types import FrameType
from typing import Callable, Optional
from uuid import uuid4

from grpc import Server

Expand All @@ -30,6 +31,7 @@
signal.SIGINT: ExitCode.GRACEFUL_EXIT_SIGINT,
signal.SIGTERM: ExitCode.GRACEFUL_EXIT_SIGTERM,
}
_handlers: dict[str, Callable[[], None]] = {}

# SIGQUIT is not available on Windows
if hasattr(signal, "SIGQUIT"):
Expand All @@ -38,6 +40,7 @@

def register_exit_handlers(
event_type: EventType,
handlers: Optional[list[Callable[[], None]]] = None,
exit_message: Optional[str] = None,
grpc_servers: Optional[list[Server]] = None,
bckg_threads: Optional[list[Thread]] = None,
Expand All @@ -48,6 +51,8 @@ def register_exit_handlers(
----------
event_type : EventType
The telemetry event that should be logged before exit.
handlers : Optional[List[Callable[[], None]]] (default: None)
An optional list of handlers to be called before exiting.
exit_message : Optional[str] (default: None)
The message to be logged before exiting.
grpc_servers: Optional[List[Server]] (default: None)
Expand All @@ -68,6 +73,9 @@ def graceful_exit_handler(signalnum: int, _frame: FrameType) -> None:
# Reset to default handler
signal.signal(signalnum, default_handlers[signalnum]) # type: ignore

for handler in _handlers.values():
handler()

if grpc_servers is not None:
for grpc_server in grpc_servers:
grpc_server.stop(grace=1)
Expand All @@ -83,7 +91,28 @@ def graceful_exit_handler(signalnum: int, _frame: FrameType) -> None:
event_type=event_type,
)

# Register exit handlers
if handlers:
for handler in handlers:
_handlers[str(uuid4())] = handler

# Register signal handlers
for sig in SIGNAL_TO_EXIT_CODE:
default_handler = signal.signal(sig, graceful_exit_handler) # type: ignore
default_handlers[sig] = default_handler # type: ignore


def add_exit_handler(handler: Callable[[], None], name: Optional[str] = None) -> None:
"""Add an exit handler."""
if name is None:
name = str(uuid4())

_handlers[name] = handler


def remove_exit_handler(name: str) -> None:
"""Remove an exit handler."""
if name in _handlers:
del _handlers[name]
else:
raise KeyError(f"Handler with name '{name}' not found.")
91 changes: 91 additions & 0 deletions src/py/flwr/common/exit_handlers_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for exit handler utils."""


import os
import signal
import unittest
from unittest.mock import Mock, patch

from .exit_handlers import (
_handlers,
add_exit_handler,
register_exit_handlers,
remove_exit_handler,
)
from .telemetry import EventType


class TestExitHandlers(unittest.TestCase):
"""Tests for exit handler utils."""

def setUp(self) -> None:
"""Clear all exit handlers before each test."""
_handlers.clear()

@patch("sys.exit")
def test_register_exit_handlers(self, mock_sys_exit: Mock) -> None:
"""Test register_exit_handlers."""
# Prepare
handlers = [Mock(), Mock()]
register_exit_handlers(EventType.PING, handlers=handlers) # type: ignore

# Execute
os.kill(os.getpid(), signal.SIGTERM)

# Assert
for handler in handlers:
handler.assert_called()
mock_sys_exit.assert_called()
self.assertEqual(list(_handlers.values()), handlers)

def test_add_exit_handler(self) -> None:
"""Test add_exit_handler."""
# Prepare
handler = Mock()

# Execute
add_exit_handler(handler, "mock_handler")

# Assert
self.assertIn("mock_handler", _handlers)
self.assertEqual(_handlers["mock_handler"], handler)

def test_remove_exit_handler(self) -> None:
"""Test remove_exit_handler."""
# Prepare
handler = Mock()
add_exit_handler(handler, "mock_handler")

# Execute
remove_exit_handler("mock_handler")

# Assert
self.assertNotIn("mock_handler", _handlers)

def test_remove_exit_handler_not_found(self) -> None:
"""Test remove_exit_handler with invalid name."""
# Prepare
handler = Mock()
add_exit_handler(handler, "mock_handler")

# Execute
with self.assertRaises(KeyError):
remove_exit_handler("non_existent_handler")

# Assert
self.assertIn("mock_handler", _handlers)
self.assertEqual(_handlers["mock_handler"], handler)

0 comments on commit d13564f

Please sign in to comment.