Skip to content

feat(framework) Allow registering exit handlers #4927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions framework/py/flwr/common/exit_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
signal.SIGINT: ExitCode.GRACEFUL_EXIT_SIGINT,
signal.SIGTERM: ExitCode.GRACEFUL_EXIT_SIGTERM,
}
registered_exit_handlers: list[Callable[[], None]] = []

# SIGQUIT is not available on Windows
if hasattr(signal, "SIGQUIT"):
Expand All @@ -41,6 +42,7 @@ def register_exit_handlers(
exit_message: Optional[str] = None,
grpc_servers: Optional[list[Server]] = None,
bckg_threads: Optional[list[Thread]] = None,
exit_handlers: Optional[list[Callable[[], None]]] = None,
) -> None:
"""Register exit handlers for `SIGINT`, `SIGTERM` and `SIGQUIT` signals.

Expand All @@ -56,8 +58,12 @@ def register_exit_handlers(
bckg_threads: Optional[List[Thread]] (default: None)
An optional list of threads that need to be gracefully
terminated before exiting.
exit_handlers: Optional[List[Callable[[], None]]] (default: None)
An optional list of exit handlers to be called before exiting.
Additional exit handlers can be added using `add_exit_handler`.
"""
default_handlers: dict[int, Callable[[int, FrameType], None]] = {}
registered_exit_handlers.extend(exit_handlers or [])

def graceful_exit_handler(signalnum: int, _frame: FrameType) -> None:
"""Exit handler to be registered with `signal.signal`.
Expand All @@ -68,6 +74,9 @@ def graceful_exit_handler(signalnum: int, _frame: FrameType) -> None:
# Reset to default handler
signal.signal(signalnum, default_handlers[signalnum]) # type: ignore

for handler in registered_exit_handlers:
handler()

if grpc_servers is not None:
for grpc_server in grpc_servers:
grpc_server.stop(grace=1)
Expand All @@ -87,3 +96,24 @@ def graceful_exit_handler(signalnum: int, _frame: FrameType) -> None:
for sig in SIGNAL_TO_EXIT_CODE:
default_handler = signal.signal(sig, graceful_exit_handler) # type: ignore
default_handlers[sig] = default_handler # type: ignore


def add_exit_handler(exit_handler: Callable[[], None]) -> None:
"""Add an exit handler to be called on graceful exit.

This function allows you to register additional exit handlers
that will be executed when the application exits gracefully,
if `register_exit_handlers` was called.

Parameters
----------
exit_handler : Callable[[], None]
A callable that takes no arguments and performs cleanup or
other actions before the application exits.

Notes
-----
This method is not thread-safe, and it allows you to add the
same exit handler multiple times.
"""
registered_exit_handlers.append(exit_handler)
64 changes: 64 additions & 0 deletions framework/py/flwr/common/exit_handlers_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for exit handler utils."""


import os
import signal
import unittest
from unittest.mock import Mock, patch

from .exit_handlers import (
add_exit_handler,
register_exit_handlers,
registered_exit_handlers,
)
from .telemetry import EventType


class TestExitHandlers(unittest.TestCase):
"""Tests for exit handler utils."""

def setUp(self) -> None:
"""Clear all exit handlers before each test."""
registered_exit_handlers.clear()

@patch("sys.exit")
def test_register_exit_handlers(self, mock_sys_exit: Mock) -> None:
"""Test register_exit_handlers."""
# Prepare
handlers = [Mock(), Mock(), Mock()]
register_exit_handlers(EventType.PING, exit_handlers=handlers[:-1]) # type: ignore
add_exit_handler(handlers[-1])

# Execute
os.kill(os.getpid(), signal.SIGTERM)

# Assert
for handler in handlers:
handler.assert_called()
mock_sys_exit.assert_called()
self.assertEqual(registered_exit_handlers, handlers)

def test_add_exit_handler(self) -> None:
"""Test add_exit_handler."""
# Prepare
handler = Mock()

# Execute
add_exit_handler(handler)

# Assert
self.assertIn(handler, registered_exit_handlers)
Loading