Skip to content

Commit

Permalink
Merge pull request #35 from VectorInstitute/dbe/test_fenda_apfl_exchange
Browse files Browse the repository at this point in the history
Adding Tests for the APFL and FENDA layer exchange flow.
  • Loading branch information
emersodb committed Jul 13, 2023
2 parents 57a86f6 + c09ca88 commit 5513c2d
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 29 deletions.
6 changes: 3 additions & 3 deletions examples/models/fenda_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn
import torch.nn.functional as F

from fl4health.model_bases.fenda_base import FendaGlobalModule, FendaHeadModule, FendaJoinMode, FendaLocalModule
from fl4health.model_bases.fenda_base import FendaHeadModule, FendaJoinMode


class FendaClassifier(FendaHeadModule):
Expand All @@ -21,7 +21,7 @@ def head_forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
return x


class LocalCnn(FendaLocalModule):
class LocalCnn(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
Expand All @@ -37,7 +37,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class GlobalCnn(FendaGlobalModule):
class GlobalCnn(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
Expand Down
23 changes: 2 additions & 21 deletions fl4health/model_bases/fenda_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,6 @@ class FendaJoinMode(Enum):
SUM = "SUM"


class FendaGlobalModule(nn.Module, ABC):
def __init__(self) -> None:
super().__init__()

def get_layer_names(self) -> List[str]:
# This function supplies the names of the layers to be exchanged with the central server during FL training
# NOTE: By default, global FENDA modules will return all layer names to be exchanged. This behavior can be
# modified by overriding this function
return list(self.state_dict().keys())


class FendaLocalModule(nn.Module):
def __init__(self) -> None:
super().__init__()


class FendaHeadModule(nn.Module, ABC):
def __init__(self, mode: FendaJoinMode) -> None:
super().__init__()
Expand All @@ -50,17 +34,14 @@ def forward(self, local_tensor: torch.Tensor, global_tensor: torch.Tensor) -> to


class FendaModel(nn.Module):
def __init__(
self, local_module: FendaLocalModule, global_module: FendaGlobalModule, model_head: FendaHeadModule
) -> None:
def __init__(self, local_module: nn.Module, global_module: nn.Module, model_head: FendaHeadModule) -> None:
super().__init__()
self.local_module = local_module
self.global_module = global_module
self.model_head = model_head

def layers_to_exchange(self) -> List[str]:
# NOTE: that the prepending string must match the name of the global module variable
return [f"global_module.{layer_name}" for layer_name in self.global_module.get_layer_names()]
return [layer_name for layer_name in self.state_dict().keys() if layer_name.startswith("global_module.")]

def forward(self, input: torch.Tensor) -> torch.Tensor:
local_output = self.local_module.forward(input)
Expand Down
7 changes: 2 additions & 5 deletions fl4health/parameter_exchange/layer_exchanger.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@ def __init__(self, layers_to_transfer: List[str]) -> None:

def apply_layer_filter(self, model: nn.Module) -> NDArrays:
# NOTE: Filtering layers only works if each client exchanges exactly the same layers
return [
layer_parameters.cpu().numpy()
for layer_name, layer_parameters in model.state_dict().items()
if layer_name in self.layers_to_transfer
]
model_state_dict = model.state_dict()
return [model_state_dict[layer_to_transfer].cpu().numpy() for layer_to_transfer in self.layers_to_transfer]

def push_parameters(
self, model: nn.Module, initial_model: Optional[nn.Module] = None, config: Optional[Config] = None
Expand Down
Empty file.
65 changes: 65 additions & 0 deletions tests/parameter_exchange/test_apfl_exchange.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from fl4health.model_bases.apfl_base import APFLModule
from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger


class ToyModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 2, 2)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(2 * 4 * 4, 3)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = x.view(-1, 2 * 4 * 4)
x = F.relu(self.fc1(x))
return x


def test_apfl_layer_exchange() -> None:
model = APFLModule(ToyModel())
apfl_layers_to_exchange = sorted(model.layers_to_exchange())
assert apfl_layers_to_exchange == [
"global_model.conv1.bias",
"global_model.conv1.weight",
"global_model.fc1.bias",
"global_model.fc1.weight",
]
parameter_exchanger = FixedLayerExchanger(apfl_layers_to_exchange)
parameters_to_exchange = parameter_exchanger.push_parameters(model)
# 4 layers are expected, as weight and bias are separate for conv1 and fc1
assert len(parameters_to_exchange) == 4
model_state_dict = model.state_dict()
for layer_name, layer_parameters in zip(apfl_layers_to_exchange, parameters_to_exchange):
assert np.array_equal(layer_parameters, model_state_dict[layer_name])

# Insert the weights back into another model
model_2 = APFLModule(ToyModel())
parameter_exchanger.pull_parameters(parameters_to_exchange, model_2)
for layer_name, layer_parameters in model_2.state_dict().items():
if layer_name in apfl_layers_to_exchange:
assert np.array_equal(layer_parameters, model_state_dict[layer_name])

input = torch.ones((3, 1, 10, 10))
# APFL returns a dictionary of tensors. In the case of personal predictions, it produces a convex combination of
# the dual toy model outputs, which have dimension 3 under the key personal and a prediction from the local model
# under the key local
apfl_output_dict = model(input, personal=True)
assert "local" in apfl_output_dict
personal_shape = apfl_output_dict["personal"].shape
# Batch size
assert personal_shape[0] == 3
# Output size
assert personal_shape[1] == 3
# Make sure that the APFL module still correctly functions when making predictions using only the global model. It
# should produce a dictionary with key "global"
global_shape = model(input, personal=False)["global"].shape
# Batch size
assert global_shape[0] == 3
# Output size
assert global_shape[1] == 3
89 changes: 89 additions & 0 deletions tests/parameter_exchange/test_fenda_exchange.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from fl4health.model_bases.fenda_base import FendaHeadModule, FendaJoinMode, FendaModel
from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger


class FendaTestClassifier(FendaHeadModule):
def __init__(self, input_size: int, join_mode: FendaJoinMode) -> None:
super().__init__(join_mode)
self.fc1 = nn.Linear(input_size, 2)

def local_global_concat(self, local_tensor: torch.Tensor, global_tensor: torch.Tensor) -> torch.Tensor:
# Assuming tensors are "batch first" so join column-wise
return torch.concat([local_tensor, global_tensor], dim=1)

def head_forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
x = self.fc1(input_tensor)
return x


class LocalFendaTest(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 2, 2)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(2 * 4 * 4, 3)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = x.view(-1, 2 * 4 * 4)
x = F.relu(self.fc1(x))
return x


class GlobalFendaTest(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 2, 2)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(2 * 4 * 4, 3)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = x.view(-1, 2 * 4 * 4)
x = F.relu(self.fc1(x))
return x


def test_fenda_join_and_layer_exchange() -> None:
model = FendaModel(LocalFendaTest(), GlobalFendaTest(), FendaTestClassifier(6, FendaJoinMode.CONCATENATE))
fenda_layers_to_exchange = sorted(model.layers_to_exchange())
assert fenda_layers_to_exchange == [
"global_module.conv1.bias",
"global_module.conv1.weight",
"global_module.fc1.bias",
"global_module.fc1.weight",
]
parameter_exchanger = FixedLayerExchanger(fenda_layers_to_exchange)
parameters_to_exchange = parameter_exchanger.push_parameters(model)
# 4 layers are expected, as weight and bias are separate for conv1 and fc1
assert len(parameters_to_exchange) == 4
model_state_dict = model.state_dict()
for layer_name, layer_parameters in zip(fenda_layers_to_exchange, parameters_to_exchange):
assert np.array_equal(layer_parameters, model_state_dict[layer_name])

# Insert the weights back into another model
model_2 = FendaModel(LocalFendaTest(), GlobalFendaTest(), FendaTestClassifier(6, FendaJoinMode.CONCATENATE))
parameter_exchanger.pull_parameters(parameters_to_exchange, model_2)
for layer_name, layer_parameters in model_2.state_dict().items():
if layer_name in fenda_layers_to_exchange:
assert np.array_equal(layer_parameters, model_state_dict[layer_name])

input = torch.ones((3, 1, 10, 10))
# Test that concatenation produces the right output dimension
output_shape = model(input).shape
# Batch size
assert output_shape[0] == 3
# Output size
assert output_shape[1] == 2
# Test that summing produces the right output dimension
model = FendaModel(LocalFendaTest(), GlobalFendaTest(), FendaTestClassifier(3, FendaJoinMode.SUM))
output_shape = model(input).shape
# Batch size
assert output_shape[0] == 3
# Output size
assert output_shape[1] == 2
Empty file added tests/server/__init__.py
Empty file.

0 comments on commit 5513c2d

Please sign in to comment.