generated from VectorInstitute/aieng-template
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #35 from VectorInstitute/dbe/test_fenda_apfl_exchange
Adding Tests for the APFL and FENDA layer exchange flow.
- Loading branch information
Showing
7 changed files
with
161 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from fl4health.model_bases.apfl_base import APFLModule | ||
from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger | ||
|
||
|
||
class ToyModel(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.conv1 = nn.Conv2d(1, 2, 2) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.fc1 = nn.Linear(2 * 4 * 4, 3) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = x.view(-1, 2 * 4 * 4) | ||
x = F.relu(self.fc1(x)) | ||
return x | ||
|
||
|
||
def test_apfl_layer_exchange() -> None: | ||
model = APFLModule(ToyModel()) | ||
apfl_layers_to_exchange = sorted(model.layers_to_exchange()) | ||
assert apfl_layers_to_exchange == [ | ||
"global_model.conv1.bias", | ||
"global_model.conv1.weight", | ||
"global_model.fc1.bias", | ||
"global_model.fc1.weight", | ||
] | ||
parameter_exchanger = FixedLayerExchanger(apfl_layers_to_exchange) | ||
parameters_to_exchange = parameter_exchanger.push_parameters(model) | ||
# 4 layers are expected, as weight and bias are separate for conv1 and fc1 | ||
assert len(parameters_to_exchange) == 4 | ||
model_state_dict = model.state_dict() | ||
for layer_name, layer_parameters in zip(apfl_layers_to_exchange, parameters_to_exchange): | ||
assert np.array_equal(layer_parameters, model_state_dict[layer_name]) | ||
|
||
# Insert the weights back into another model | ||
model_2 = APFLModule(ToyModel()) | ||
parameter_exchanger.pull_parameters(parameters_to_exchange, model_2) | ||
for layer_name, layer_parameters in model_2.state_dict().items(): | ||
if layer_name in apfl_layers_to_exchange: | ||
assert np.array_equal(layer_parameters, model_state_dict[layer_name]) | ||
|
||
input = torch.ones((3, 1, 10, 10)) | ||
# APFL returns a dictionary of tensors. In the case of personal predictions, it produces a convex combination of | ||
# the dual toy model outputs, which have dimension 3 under the key personal and a prediction from the local model | ||
# under the key local | ||
apfl_output_dict = model(input, personal=True) | ||
assert "local" in apfl_output_dict | ||
personal_shape = apfl_output_dict["personal"].shape | ||
# Batch size | ||
assert personal_shape[0] == 3 | ||
# Output size | ||
assert personal_shape[1] == 3 | ||
# Make sure that the APFL module still correctly functions when making predictions using only the global model. It | ||
# should produce a dictionary with key "global" | ||
global_shape = model(input, personal=False)["global"].shape | ||
# Batch size | ||
assert global_shape[0] == 3 | ||
# Output size | ||
assert global_shape[1] == 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from fl4health.model_bases.fenda_base import FendaHeadModule, FendaJoinMode, FendaModel | ||
from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger | ||
|
||
|
||
class FendaTestClassifier(FendaHeadModule): | ||
def __init__(self, input_size: int, join_mode: FendaJoinMode) -> None: | ||
super().__init__(join_mode) | ||
self.fc1 = nn.Linear(input_size, 2) | ||
|
||
def local_global_concat(self, local_tensor: torch.Tensor, global_tensor: torch.Tensor) -> torch.Tensor: | ||
# Assuming tensors are "batch first" so join column-wise | ||
return torch.concat([local_tensor, global_tensor], dim=1) | ||
|
||
def head_forward(self, input_tensor: torch.Tensor) -> torch.Tensor: | ||
x = self.fc1(input_tensor) | ||
return x | ||
|
||
|
||
class LocalFendaTest(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.conv1 = nn.Conv2d(1, 2, 2) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.fc1 = nn.Linear(2 * 4 * 4, 3) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = x.view(-1, 2 * 4 * 4) | ||
x = F.relu(self.fc1(x)) | ||
return x | ||
|
||
|
||
class GlobalFendaTest(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.conv1 = nn.Conv2d(1, 2, 2) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.fc1 = nn.Linear(2 * 4 * 4, 3) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = x.view(-1, 2 * 4 * 4) | ||
x = F.relu(self.fc1(x)) | ||
return x | ||
|
||
|
||
def test_fenda_join_and_layer_exchange() -> None: | ||
model = FendaModel(LocalFendaTest(), GlobalFendaTest(), FendaTestClassifier(6, FendaJoinMode.CONCATENATE)) | ||
fenda_layers_to_exchange = sorted(model.layers_to_exchange()) | ||
assert fenda_layers_to_exchange == [ | ||
"global_module.conv1.bias", | ||
"global_module.conv1.weight", | ||
"global_module.fc1.bias", | ||
"global_module.fc1.weight", | ||
] | ||
parameter_exchanger = FixedLayerExchanger(fenda_layers_to_exchange) | ||
parameters_to_exchange = parameter_exchanger.push_parameters(model) | ||
# 4 layers are expected, as weight and bias are separate for conv1 and fc1 | ||
assert len(parameters_to_exchange) == 4 | ||
model_state_dict = model.state_dict() | ||
for layer_name, layer_parameters in zip(fenda_layers_to_exchange, parameters_to_exchange): | ||
assert np.array_equal(layer_parameters, model_state_dict[layer_name]) | ||
|
||
# Insert the weights back into another model | ||
model_2 = FendaModel(LocalFendaTest(), GlobalFendaTest(), FendaTestClassifier(6, FendaJoinMode.CONCATENATE)) | ||
parameter_exchanger.pull_parameters(parameters_to_exchange, model_2) | ||
for layer_name, layer_parameters in model_2.state_dict().items(): | ||
if layer_name in fenda_layers_to_exchange: | ||
assert np.array_equal(layer_parameters, model_state_dict[layer_name]) | ||
|
||
input = torch.ones((3, 1, 10, 10)) | ||
# Test that concatenation produces the right output dimension | ||
output_shape = model(input).shape | ||
# Batch size | ||
assert output_shape[0] == 3 | ||
# Output size | ||
assert output_shape[1] == 2 | ||
# Test that summing produces the right output dimension | ||
model = FendaModel(LocalFendaTest(), GlobalFendaTest(), FendaTestClassifier(3, FendaJoinMode.SUM)) | ||
output_shape = model(input).shape | ||
# Batch size | ||
assert output_shape[0] == 3 | ||
# Output size | ||
assert output_shape[1] == 2 |
Empty file.