Skip to content

Commit

Permalink
Merge pull request #38 from VectorInstitute/dbe/fix_apfl_with_frozen_…
Browse files Browse the repository at this point in the history
…params

Fixing the APFL implementation that occurs if there are frozen layers in a model
  • Loading branch information
emersodb committed Jul 13, 2023
2 parents a37c159 + da34c7f commit 57a86f6
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
4 changes: 2 additions & 2 deletions examples/apfl_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from flwr.common.typing import Config

from examples.models.cnn_model import MnistNet
from examples.models.cnn_model import MnistNetWithBnAndFrozen
from fl4health.clients.apfl_client import ApflClient
from fl4health.model_bases.apfl_base import APFLModule
from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger
Expand All @@ -26,7 +26,7 @@ def __init__(

def setup_client(self, config: Config) -> None:
batch_size = self.narrow_config_type(config, "batch_size", int)
self.model: APFLModule = APFLModule(MnistNet()).to(self.device)
self.model: APFLModule = APFLModule(MnistNetWithBnAndFrozen()).to(self.device)
self.criterion = torch.nn.CrossEntropyLoss()
self.local_optimizer = torch.optim.AdamW(self.model.local_model.parameters(), lr=0.01)
self.global_optimizer = torch.optim.AdamW(self.model.global_model.parameters(), lr=0.01)
Expand Down
4 changes: 2 additions & 2 deletions examples/apfl_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from flwr.common.typing import Config, Metrics, Parameters
from flwr.server.strategy import FedAvg

from examples.models.cnn_model import MnistNet
from examples.models.cnn_model import MnistNetWithBnAndFrozen
from examples.simple_metric_aggregation import metric_aggregation, normalize_metrics
from fl4health.model_bases.apfl_base import APFLModule
from fl4health.utils.config import load_config
Expand All @@ -30,7 +30,7 @@ def evaluate_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]
def get_initial_model_parameters() -> Parameters:
# Initializing the model parameters on the server side.
# Currently uses the Pytorch default initialization for the model parameters.
initial_model = APFLModule(MnistNet())
initial_model = APFLModule(MnistNetWithBnAndFrozen())
return ndarrays_to_parameters([val.cpu().numpy() for _, val in initial_model.state_dict().items()])


Expand Down
10 changes: 9 additions & 1 deletion fl4health/model_bases/apfl_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,17 @@ def update_alpha(self) -> None:
# https://github.com/MLOPTPSU/FedTorch/blob
# /ab8068dbc96804a5c1a8b898fd115175cfebfe75/fedtorch/comms/utils/flow_utils.py#L240

# Need to filter out frozen parameters, as they have no grad object
local_parameters = [
local_params for local_params in self.local_model.parameters() if local_params.requires_grad
]
global_parameters = [
global_params for global_params in self.global_model.parameters() if global_params.requires_grad
]

# Accumulate gradient of alpha across layers
grad_alpha: float = 0.0
for local_p, global_p in zip(self.local_model.parameters(), self.global_model.parameters()):
for local_p, global_p in zip(local_parameters, global_parameters):
local_grad = local_p.grad
global_grad = global_p.grad
assert local_grad is not None and global_grad is not None
Expand Down

0 comments on commit 57a86f6

Please sign in to comment.