Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Allow registering exit handlers #4927

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/py/flwr/common/exit_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from threading import Thread
from types import FrameType
from typing import Callable, Optional
from uuid import uuid4

from grpc import Server

Expand All @@ -30,6 +31,7 @@
signal.SIGINT: ExitCode.GRACEFUL_EXIT_SIGINT,
signal.SIGTERM: ExitCode.GRACEFUL_EXIT_SIGTERM,
}
_handlers: dict[str, Callable[[], None]] = {}

# SIGQUIT is not available on Windows
if hasattr(signal, "SIGQUIT"):
Expand All @@ -38,6 +40,7 @@

def register_exit_handlers(
event_type: EventType,
handlers: Optional[list[Callable[[], None]]] = None,
exit_message: Optional[str] = None,
grpc_servers: Optional[list[Server]] = None,
bckg_threads: Optional[list[Thread]] = None,
Expand All @@ -48,6 +51,8 @@ def register_exit_handlers(
----------
event_type : EventType
The telemetry event that should be logged before exit.
handlers : Optional[List[Callable[[], None]]] (default: None)
An optional list of handlers to be called before exiting.
exit_message : Optional[str] (default: None)
The message to be logged before exiting.
grpc_servers: Optional[List[Server]] (default: None)
Expand All @@ -68,6 +73,9 @@ def graceful_exit_handler(signalnum: int, _frame: FrameType) -> None:
# Reset to default handler
signal.signal(signalnum, default_handlers[signalnum]) # type: ignore

for handler in _handlers.values():
handler()

if grpc_servers is not None:
for grpc_server in grpc_servers:
grpc_server.stop(grace=1)
Expand All @@ -83,7 +91,28 @@ def graceful_exit_handler(signalnum: int, _frame: FrameType) -> None:
event_type=event_type,
)

# Register exit handlers
if handlers:
for handler in handlers:
_handlers[str(uuid4())] = handler

# Register signal handlers
for sig in SIGNAL_TO_EXIT_CODE:
default_handler = signal.signal(sig, graceful_exit_handler) # type: ignore
default_handlers[sig] = default_handler # type: ignore


def add_exit_handler(handler: Callable[[], None], name: Optional[str] = None) -> None:
"""Add an exit handler."""
if name is None:
name = str(uuid4())

_handlers[name] = handler


def remove_exit_handler(name: str) -> None:
"""Remove an exit handler."""
if name in _handlers:
del _handlers[name]
else:
raise KeyError(f"Handler with name '{name}' not found.")
91 changes: 91 additions & 0 deletions src/py/flwr/common/exit_handlers_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for exit handler utils."""


import os
import signal
import unittest
from unittest.mock import Mock, patch

from .exit_handlers import (
_handlers,
add_exit_handler,
register_exit_handlers,
remove_exit_handler,
)
from .telemetry import EventType


class TestExitHandlers(unittest.TestCase):
"""Tests for exit handler utils."""

def setUp(self) -> None:
"""Clear all exit handlers before each test."""
_handlers.clear()

@patch("sys.exit")
def test_register_exit_handlers(self, mock_sys_exit: Mock) -> None:
"""Test register_exit_handlers."""
# Prepare
handlers = [Mock(), Mock()]
register_exit_handlers(EventType.PING, handlers=handlers) # type: ignore

# Execute
os.kill(os.getpid(), signal.SIGTERM)

# Assert
for handler in handlers:
handler.assert_called()
mock_sys_exit.assert_called()
self.assertEqual(list(_handlers.values()), handlers)

def test_add_exit_handler(self) -> None:
"""Test add_exit_handler."""
# Prepare
handler = Mock()

# Execute
add_exit_handler(handler, "mock_handler")

# Assert
self.assertIn("mock_handler", _handlers)
self.assertEqual(_handlers["mock_handler"], handler)

def test_remove_exit_handler(self) -> None:
"""Test remove_exit_handler."""
# Prepare
handler = Mock()
add_exit_handler(handler, "mock_handler")

# Execute
remove_exit_handler("mock_handler")

# Assert
self.assertNotIn("mock_handler", _handlers)

def test_remove_exit_handler_not_found(self) -> None:
"""Test remove_exit_handler with invalid name."""
# Prepare
handler = Mock()
add_exit_handler(handler, "mock_handler")

# Execute
with self.assertRaises(KeyError):
remove_exit_handler("non_existent_handler")

# Assert
self.assertIn("mock_handler", _handlers)
self.assertEqual(_handlers["mock_handler"], handler)
Loading