Skip to content

Commit

Permalink
Modfiy FlServerWithCheckpointing model argument to be optional and up…
Browse files Browse the repository at this point in the history
…date references accordingly
  • Loading branch information
jewelltaylor committed Sep 25, 2024
1 parent 94acd93 commit 34d216b
Show file tree
Hide file tree
Showing 14 changed files with 98 additions and 29 deletions.
9 changes: 8 additions & 1 deletion examples/ae_examples/cvae_dim_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,14 @@ def main(config: Dict[str, Any]) -> None:
initial_parameters=get_all_model_parameters(model),
)

server = FlServerWithCheckpointing(SimpleClientManager(), model, parameter_exchanger, None, strategy, checkpointer)
server = FlServerWithCheckpointing(
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)

fl.server.start_server(
server=server,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,14 @@ def main(config: Dict[str, Any]) -> None:
initial_parameters=get_all_model_parameters(model),
)

server = FlServerWithCheckpointing(SimpleClientManager(), model, parameter_exchanger, None, strategy, checkpointer)
server = FlServerWithCheckpointing(
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)

fl.server.start_server(
server=server,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,14 @@ def main(config: Dict[str, Any]) -> None:
initial_parameters=get_all_model_parameters(model),
)

server = FlServerWithCheckpointing(SimpleClientManager(), model, parameter_exchanger, None, strategy, checkpointer)
server = FlServerWithCheckpointing(
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)

fl.server.start_server(
server=server,
Expand Down
9 changes: 8 additions & 1 deletion examples/ae_examples/fedprox_vae_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,14 @@ def main(config: Dict[str, Any]) -> None:
proximal_weight=config["proximal_weight"],
)

server = FlServerWithCheckpointing(SimpleClientManager(), model, parameter_exchanger, None, strategy, checkpointer)
server = FlServerWithCheckpointing(
client_manager=SimpleClientManager(),
model=model,
parameter_exchanger=parameter_exchanger,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)

fl.server.start_server(
server=server,
Expand Down
7 changes: 6 additions & 1 deletion examples/basic_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,12 @@ def main(config: Dict[str, Any]) -> None:
)

server = FlServerWithCheckpointing(
SimpleClientManager(), model, parameter_exchanger, None, strategy, checkpointers
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointers,
)

fl.server.start_server(
Expand Down
9 changes: 8 additions & 1 deletion examples/fedpca_examples/dim_reduction/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,14 @@ def main(config: Dict[str, Any]) -> None:
initial_parameters=get_all_model_parameters(model),
)

server = FlServerWithCheckpointing(SimpleClientManager(), model, parameter_exchanger, None, strategy, checkpointer)
server = FlServerWithCheckpointing(
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)

fl.server.start_server(
server=server,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,14 @@ def main(config: Dict[str, Any]) -> None:
initial_parameters=get_all_model_parameters(model),
)

server = FlServerWithCheckpointing(SimpleClientManager(), model, parameter_exchanger, None, strategy, checkpointer)
server = FlServerWithCheckpointing(
client_manager=SimpleClientManager(),
model=model,
parameter_exchanger=parameter_exchanger,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)

fl.server.start_server(
server=server,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,14 @@ def main(config: Dict[str, Any]) -> None:
initial_parameters=get_all_model_parameters(model),
)

server = FlServerWithCheckpointing(SimpleClientManager(), model, parameter_exchanger, None, strategy, checkpointer)
server = FlServerWithCheckpointing(
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)

fl.server.start_server(
server=server,
Expand Down
20 changes: 11 additions & 9 deletions fl4health/server/base_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,8 @@ class FlServerWithCheckpointing(FlServer, Generic[ExchangerType]):
def __init__(
self,
client_manager: ClientManager,
model: nn.Module,
parameter_exchanger: ExchangerType,
model: Optional[nn.Module] = None,
wandb_reporter: Optional[ServerWandBReporter] = None,
strategy: Optional[Strategy] = None,
checkpointer: Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]] = None,
Expand All @@ -347,21 +347,19 @@ def __init__(
Args:
client_manager (ClientManager): Determines the mechanism by which clients are sampled by the server, if
they are to be sampled at all.
model (nn.Module): This is the torch model to be hydrated by the _hydrate_model_for_checkpointing function
parameter_exchanger (ExchangerType): This is the parameter exchanger to be used to hydrate the model.
model (Optional[nn.Module]): This is the torch model to be hydrated
by the _hydrate_model_for_checkpointing function. Defaults to None.
strategy (Optional[Strategy], optional): The aggregation strategy to be used by the server to handle
client updates and other information potentially sent by the participating clients. If None the
strategy is FedAvg as set by the flwr Server.
wandb_reporter (Optional[ServerWandBReporter], optional): To be provided if the server is to log
information and results to a Weights and Biases account. If None is provided, no logging occurs.
Defaults to None.
checkpointer (Optional[Union[TorchCheckpointer, Sequence
[TorchCheckpointer]]], optional): To be provided if the server
should perform server side checkpointing based on some
criteria. If none, then no server-side checkpointing is
performed. Multiple checkpointers can also be passed in a
sequence to checkpoint based on multiple criteria. Defaults to
None.
checkpointer (Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]], optional):
To be provided if the server should perform server side checkpointing
based on some criteria. If none, then no server-side checkpointing is performed. Multiple checkpointers
can also be passed in a sequence to checkpoint based on multiple criteria. Defaults to None.
metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics
intermediate_server_state_dir (Path): A directory to store and load checkpoints from for the server
during an FL experiment.
Expand Down Expand Up @@ -392,6 +390,10 @@ def __init__(
self.history: History

def _hydrate_model_for_checkpointing(self) -> nn.Module:
assert (
self.server_model is not None
), "Model hydration has been called but no server_model is defined to hydrate"

model_ndarrays = parameters_to_ndarrays(self.parameters)
self.parameter_exchanger.pull_parameters(model_ndarrays, self.server_model)
return self.server_model
Expand Down
6 changes: 5 additions & 1 deletion research/flamby/flamby_servers/fedprox_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ class FedProxServer(FlServerWithCheckpointing[ParameterExchangerWithPacking]):
def __init__(
self,
client_manager: ClientManager,
model: nn.Module,
model: Optional[nn.Module] = None,
strategy: Optional[Strategy] = None,
checkpointer: Optional[TorchCheckpointer] = None,
) -> None:
assert model is not None
# To help with model rehydration
parameter_exchanger = ParameterExchangerWithPacking(ParameterPackerFedProx())
super().__init__(
Expand All @@ -30,6 +31,9 @@ def __init__(
)

def _hydrate_model_for_checkpointing(self) -> nn.Module:
assert (
self.server_model is not None
), "Model hydration has been called but no server_model is defined to hydrate"
# Overriding the standard hydration method to account for the unpacking
packed_parameters = parameters_to_ndarrays(self.parameters)
# Don't need the extra fedprox variable for checkpointing.
Expand Down
4 changes: 2 additions & 2 deletions research/flamby/flamby_servers/full_exchange_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@ class FullExchangeServer(FlServerWithCheckpointing):
def __init__(
self,
client_manager: ClientManager,
model: nn.Module,
model: Optional[nn.Module] = None,
strategy: Optional[Strategy] = None,
checkpointer: Optional[TorchCheckpointer] = None,
) -> None:
# To help with model rehydration
parameter_exchanger = FullParameterExchanger()
super().__init__(
client_manager=client_manager,
model=model,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
checkpointer=checkpointer,
)
8 changes: 6 additions & 2 deletions research/flamby/flamby_servers/scaffold_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,26 @@ class ScaffoldServer(FlServerWithCheckpointing[ParameterExchangerWithPacking]):
def __init__(
self,
client_manager: ClientManager,
model: nn.Module,
model: Optional[nn.Module] = None,
strategy: Optional[Strategy] = None,
checkpointer: Optional[TorchCheckpointer] = None,
) -> None:
assert model is not None
# To help with model rehydration
model_size = len(model.state_dict())
parameter_exchanger = ParameterExchangerWithPacking(ParameterPackerWithControlVariates(model_size))
super().__init__(
client_manager=client_manager,
model=model,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
checkpointer=checkpointer,
)

def _hydrate_model_for_checkpointing(self) -> nn.Module:
assert (
self.server_model is not None
), "Model hydration has been called but no server_model is defined to hydrate"
packed_parameters = parameters_to_ndarrays(self.parameters)
# Don't need the control variates for checkpointing.
model_ndarrays, _ = self.parameter_exchanger.unpack_parameters(packed_parameters)
Expand Down
7 changes: 6 additions & 1 deletion tests/server/test_base_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,12 @@ def test_fl_server_with_checkpointing(tmp_path: Path) -> None:
parameter_exchanger = FullParameterExchanger()

server = FlServerWithCheckpointing(
PoissonSamplingClientManager(), initial_model, parameter_exchanger, None, None, checkpointer
client_manager=PoissonSamplingClientManager(),
model=initial_model,
parameter_exchanger=parameter_exchanger,
wandb_reporter=None,
strategy=None,
checkpointer=checkpointer,
)
# Parameters after aggregation (i.e. the updated server-side model)
server.parameters = ndarrays_to_parameters(parameter_exchanger.push_parameters(updated_model))
Expand Down
12 changes: 6 additions & 6 deletions tests/smoke_tests/load_from_checkpoint_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ def main(config: Dict[str, Any], intermediate_server_state_dir: str, server_name
)

server = FlServerWithCheckpointing(
SimpleClientManager(),
model,
parameter_exchanger,
None,
strategy,
checkpointers,
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointers,
intermediate_server_state_dir=Path(intermediate_server_state_dir),
server_name=server_name,
)
Expand Down

0 comments on commit 34d216b

Please sign in to comment.