Skip to content

Commit 3d446fa

Browse files
committed
Cohort-parallel Federated Learning
0 parents  commit 3d446fa

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

81 files changed

+8514
-0
lines changed

.gitignore

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
.idea/**
2+
.DS_Store
3+
data/**
4+
5+
__pycache__/
6+
*.py[cod]
7+
simulations/data/**
8+
simulations/dfl/data/**
9+
simulations/dl/data/**
10+
simulations/gl/data/**
11+
12+
scripts/data/**
13+
*.log
14+
*.txt

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "pyipv8"]
2+
path = pyipv8
3+
url = https://github.com/devos50/py-ipv8

README.md

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Cohort-Parallel Federated Learning (CPFL)
2+
Repository for the source code of our paper *[Harnessing Increased Client Participation with Cohort-Parallel Federated Learning](https://arxiv.org/pdf/2405.15644)* published at [The Workshop on Machine Learning and System 2025](https://euromlsys.eu/#).
3+
4+
## Abstract
5+
6+
Federated learning (FL) is a machine learning approach where nodes collaboratively train a global model.
7+
As more nodes participate in a round of FL, the effectiveness of individual model updates by nodes also diminishes.
8+
In this study, we increase the effectiveness of client updates by dividing the network into smaller partitions, or _cohorts_.
9+
We introduce Cohort-Parallel Federated Learning (CPFL): a novel learning approach where each cohort independently trains a global model using FL, until convergence, and the produced models by each cohort are then unified using knowledge distillation.
10+
The insight behind CPFL is that smaller, isolated networks converge quicker than in a one-network setting where all nodes participate.
11+
Through exhaustive experiments involving realistic traces and non-IID data distributions on the CIFAR-10 and FEMNIST image classification tasks, we investigate the balance between the number of cohorts, model accuracy, training time, and compute resources.
12+
Compared to traditional FL, CPFL with four cohorts, non-IID data distribution, and CIFAR-10 yields a 1.9x reduction in train time and a 1.3x reduction in resource usage, with a minimal drop in test accuracy.
13+
14+
## Installation
15+
16+
Start by cloning the repository recursively (since CPFL depends on the PyIPv8 networking library):
17+
18+
```
19+
git clone [email protected]:sacs-epfl/cpfl.git --recursive
20+
```
21+
22+
Install the required dependencies (preferably in a virtual environment to avoid conflicts with existing libraries):
23+
24+
```
25+
pip install -r requirements.txt
26+
```
27+
28+
In our paper, we evaluate CPFL using the CIFAR-10 and FEMNIST datasets.
29+
For CIFAR-10 we use `torchvision`. The FEMNIST dataset has to be downloaded manually and we refer the reader to the [decentralizepy framework](https://github.com/sacs-epfl/decentralizepy) that uses the same dataset.
30+
31+
## Running CPFL
32+
33+
Training with CPFL can be done by invoking the following scripts from the root of the repository:
34+
35+
```
36+
# Running with the CIFAR-10 dataset
37+
bash scripts/cohorts/run_e2e_cifar10.sh <number_of_cohorts> <seed> <alpha> <peers>
38+
39+
# Running with the FEMNIST dataset
40+
bash scripts/cohorts/run_2e2_femnist.sh <number_of_cohorts> <seed>
41+
```
42+
43+
We refer to the respective bash scripts for more configuration options, such as the number of local steps, the number of participants, and other learning parameters.
44+
45+
The script first splits the data across participants and participants across cohorts. These assignments are used during the distillation process.
46+
Then during FL training, each cohort will periodically checkpoint the current global model, as well as checkpointing the current best model (based on the loss obtained with a validation testset). The output of this experiment can be found in a separate folder in the `data` directory.
47+
48+
After training, the checkpointed models can be distilled into a single model using the following command:
49+
50+
```
51+
python3 scripts/distill.py $PWD/data n_200_cifar10_dirichlet0.100000_sd24082_ct10_dfl cifar10 stl10 --cohort-file cohorts/cohorts_cifar10_n200_c10.txt --public-data-dir <path_to_public_data> --learning-rate 0.001 --momentum 0.9 --partitioner dirichlet --alpha 0.1 --weighting-scheme label --check-teachers-accuracy > output_distill.log 2>&1
52+
```
53+
54+
The above command invokes the `distill.py` script that scans the models in the `n_200_cifar10_dirichlet0.100000_sd24082_ct10_dfl` directory (created by the previous experiment) and merges them.
55+
The command also requires the path to the cohort information file created during the previous steps.
56+
The `distill.py` script automatically determines the attained accuracies of the obtained model after distillation.
57+
58+
## Reference
59+
60+
If you find our work useful, you can cite us as follows:
61+
62+
```
63+
@inproceedings{dhasade2025cpfl,
64+
title={Harnessing Increased Client Participation with Cohort-Parallel Federated Learning},
65+
author={Dhasade, Akash and Kermarrec, Anne-Marie and Nguyen, Tuan-Ahn and Pires, Rafael and de Vos, Martijn},
66+
booktitle={Proceedings of the 5th Workshop on Machine Learning and Systems},
67+
year={2025}
68+
}
69+
```

cpfl/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""
2+
Contains code related to the DFL framework.
3+
"""

cpfl/core/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from enum import Enum
2+
3+
4+
class TransmissionMethod(Enum):
5+
EVA = 0
6+
7+
8+
class NodeMembershipChange(Enum):
9+
JOIN = 0
10+
LEAVE = 1

cpfl/core/community.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import asyncio
2+
import time
3+
from asyncio import Future, ensure_future
4+
from binascii import unhexlify, hexlify
5+
from typing import Optional, Callable, Dict, List
6+
7+
from cpfl.core import TransmissionMethod
8+
from cpfl.core.model_manager import ModelManager
9+
from cpfl.core.peer_manager import PeerManager
10+
from cpfl.core.session_settings import SessionSettings
11+
from cpfl.util.eva.protocol import EVAProtocol
12+
from cpfl.util.eva.result import TransferResult
13+
14+
from ipv8.community import Community
15+
from ipv8.requestcache import RequestCache
16+
from ipv8.types import Peer
17+
18+
19+
class LearningCommunity(Community):
20+
community_id = unhexlify('d5889074c1e4c60423cdb6e9307ba0ca5695ead7')
21+
22+
def __init__(self, *args, **kwargs):
23+
super().__init__(*args, **kwargs)
24+
self.request_cache = RequestCache()
25+
self.my_id = self.my_peer.public_key.key_to_bin()
26+
self.round_complete_callback: Optional[Callable] = None
27+
self.aggregate_complete_callback: Optional[Callable] = None
28+
29+
self.peers_list: List[Peer] = []
30+
31+
# Settings
32+
self.settings: Optional[SessionSettings] = None
33+
34+
# State
35+
self.is_active = False
36+
self.did_setup = False
37+
self.shutting_down = False
38+
39+
# Components
40+
self.peer_manager: PeerManager = PeerManager(self.my_id, 100000)
41+
self.model_manager: Optional[ModelManager] = None # Initialized when the process is setup
42+
43+
# Model exchange parameters
44+
self.eva = EVAProtocol(self, self.on_receive, self.on_send_complete, self.on_error)
45+
46+
# Availability traces
47+
self.traces: Optional[Dict] = None
48+
self.traces_count: int = 0
49+
50+
self.logger.info("The %s started with peer ID: %s", self.__class__.__name__,
51+
self.peer_manager.get_my_short_id())
52+
53+
def start(self):
54+
"""
55+
Start to participate in the training process.
56+
"""
57+
assert self.did_setup, "Process has not been setup - call setup() first"
58+
self.is_active = True
59+
60+
def set_traces(self, traces: Dict) -> None:
61+
self.traces = traces
62+
events: int = 0
63+
64+
# Schedule the join/leave events
65+
for active_timestamp in self.traces["active"]:
66+
if active_timestamp == 0:
67+
continue # We assume peers will be online at t=0
68+
69+
self.register_anonymous_task("join", self.go_online, delay=active_timestamp)
70+
events += 1
71+
72+
for inactive_timestamp in self.traces["inactive"]:
73+
self.register_anonymous_task("leave", self.go_offline, delay=inactive_timestamp)
74+
events += 1
75+
76+
self.logger.info("Scheduled %d join/leave events for peer %s (trace length in sec: %d)", events,
77+
self.peer_manager.get_my_short_id(), traces["finish_time"])
78+
79+
# Schedule the next call to set_traces
80+
self.register_task("reapply-trace-%s-%d" % (self.peer_manager.get_my_short_id(), self.traces_count),
81+
self.set_traces, self.traces, delay=self.traces["finish_time"])
82+
self.traces_count += 1
83+
84+
def go_online(self):
85+
self.is_active = True
86+
cur_time = asyncio.get_event_loop().time() if self.settings.is_simulation else time.time()
87+
self.logger.info("Participant %s comes online (t=%d)", self.peer_manager.get_my_short_id(), cur_time)
88+
89+
def go_offline(self, graceful: bool = True):
90+
self.is_active = False
91+
cur_time = asyncio.get_event_loop().time() if self.settings.is_simulation else time.time()
92+
self.logger.info("Participant %s will go offline (t=%d)", self.peer_manager.get_my_short_id(), cur_time)
93+
94+
def setup(self, settings: SessionSettings):
95+
self.settings = settings
96+
for participant in settings.participants:
97+
self.peer_manager.add_peer(unhexlify(participant))
98+
99+
# Initialize the model
100+
participant_index = settings.all_participants.index(hexlify(self.my_id).decode())
101+
self.model_manager = ModelManager(None, settings, participant_index)
102+
103+
# Setup the model transmission
104+
if self.settings.transmission_method == TransmissionMethod.EVA:
105+
self.logger.info("Setting up EVA protocol")
106+
self.eva.settings.block_size = settings.eva_block_size
107+
self.eva.settings.max_simultaneous_transfers = settings.eva_max_simultaneous_transfers
108+
else:
109+
raise RuntimeError("Unsupported transmission method %s", self.settings.transmission_method)
110+
111+
self.did_setup = True
112+
113+
def get_peers(self):
114+
if self.peers_list:
115+
return self.peers_list
116+
return super().get_peers()
117+
118+
def get_peer_by_pk(self, target_pk: bytes):
119+
peers = list(self.get_peers())
120+
for peer in peers:
121+
if peer.public_key.key_to_bin() == target_pk:
122+
return peer
123+
return None
124+
125+
def on_eva_send_done(self, future: Future, peer: Peer, serialized_response: bytes, binary_data: bytes, start_time: float):
126+
if future.cancelled(): # Do not reschedule if the future was cancelled
127+
return
128+
129+
if future.exception():
130+
peer_id = self.peer_manager.get_short_id(peer.public_key.key_to_bin())
131+
self.logger.warning("Transfer to participant %s failed, scheduling it again (Exception: %s)",
132+
peer_id, future.exception())
133+
# The transfer failed - try it again after some delay
134+
ensure_future(asyncio.sleep(self.settings.model_send_delay)).add_done_callback(
135+
lambda _: self.schedule_eva_send_model(peer, serialized_response, binary_data, start_time))
136+
else:
137+
# The transfer seems to be completed - record the transfer time
138+
end_time = asyncio.get_event_loop().time() if self.settings.is_simulation else time.time()
139+
140+
def schedule_eva_send_model(self, peer: Peer, serialized_response: bytes, binary_data: bytes, start_time: float) -> Future:
141+
# Schedule the transfer
142+
future = ensure_future(self.eva.send_binary(peer, serialized_response, binary_data))
143+
future.add_done_callback(lambda f: self.on_eva_send_done(f, peer, serialized_response, binary_data, start_time))
144+
return future
145+
146+
async def on_receive(self, result: TransferResult):
147+
raise NotImplementedError()
148+
149+
async def on_send_complete(self, result: TransferResult):
150+
peer_id = self.peer_manager.get_short_id(result.peer.public_key.key_to_bin())
151+
my_peer_id = self.peer_manager.get_my_short_id()
152+
self.logger.info(f'Outgoing transfer {my_peer_id} -> {peer_id} has completed: {result.info.decode()}')
153+
154+
async def on_error(self, peer, exception):
155+
self.logger.error(f'An error has occurred in transfer to peer {peer}: {exception}')
156+
157+
async def unload(self):
158+
self.shutting_down = True
159+
await self.request_cache.shutdown()
160+
await super().unload()

0 commit comments

Comments
 (0)