Skip to content

Commit

Permalink
Merge pull request #246 from VectorInstitute/nnunet-server-checkpointing
Browse files Browse the repository at this point in the history
Nnunet server checkpointing
  • Loading branch information
jewelltaylor committed Sep 25, 2024
2 parents 26b1ae9 + 34d216b commit b753a96
Show file tree
Hide file tree
Showing 85 changed files with 827 additions and 310 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ settings.json
**/datasets/skin_cancer/ISIC_2019/**
**/datasets/skin_cancer/Derm7pt/**
**/datasets/nnunet/**
**/datasets/nnunet_raw/**
**/datasets/nnunet_preprocessed/**

# logs

Expand Down
8 changes: 4 additions & 4 deletions examples/ae_examples/cvae_dim_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from examples.models.mnist_model import MnistNet
from fl4health.clients.basic_client import BasicClient
from fl4health.preprocessing.autoencoders.dim_reduction import CvaeFixedConditionProcessor
from fl4health.utils.config import narrow_config_type
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.metrics import Accuracy, Metric
from fl4health.utils.random import set_all_random_seeds
Expand All @@ -27,8 +27,8 @@ def __init__(self, data_path: Path, metrics: Sequence[Metric], DEVICE: torch.dev
self.condition = condition

def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
batch_size = narrow_config_type(config, "batch_size", int)
cvae_model_path = Path(narrow_config_type(config, "cvae_model_path", str))
batch_size = narrow_dict_type(config, "batch_size", int)
cvae_model_path = Path(narrow_dict_type(config, "cvae_model_path", str))
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=100)
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(torch.flatten)])
# CvaeFixedConditionProcessor is added to the data transform pipeline to encode the data samples
Expand All @@ -47,7 +47,7 @@ def get_optimizer(self, config: Config) -> Optimizer:
return torch.optim.Adam(self.model.parameters(), lr=0.001)

def get_model(self, config: Config) -> nn.Module:
latent_dim = narrow_config_type(config, "latent_dim", int)
latent_dim = narrow_dict_type(config, "latent_dim", int)
# Dimensionality reduction reduces the size of inputs to the size of cat(mu, logvar).
return MnistNet(latent_dim * 2).to(self.device)

Expand Down
9 changes: 8 additions & 1 deletion examples/ae_examples/cvae_dim_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,14 @@ def main(config: Dict[str, Any]) -> None:
initial_parameters=get_all_model_parameters(model),
)

server = FlServerWithCheckpointing(SimpleClientManager(), model, parameter_exchanger, None, strategy, checkpointer)
server = FlServerWithCheckpointing(
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)

fl.server.start_server(
server=server,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from fl4health.clients.basic_client import BasicClient
from fl4health.model_bases.autoencoders_base import ConditionalVae
from fl4health.preprocessing.autoencoders.loss import VaeLoss
from fl4health.utils.config import narrow_config_type
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.dataset_converter import AutoEncoderDatasetConverter
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.metrics import Metric
Expand Down Expand Up @@ -57,7 +57,7 @@ def setup_client(self, config: Config) -> None:
self.model.unpack_input_condition = self.autoencoder_converter.get_unpacking_function()

def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
batch_size = narrow_config_type(config, "batch_size", int)
batch_size = narrow_dict_type(config, "batch_size", int)
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=100)
# To make sure pixels stay in the range [0.0, 1.0].
transform = transforms.Compose([transforms.ToTensor()])
Expand All @@ -75,14 +75,14 @@ def get_criterion(self, config: Config) -> _Loss:
# The base_loss is the loss function used for comparing the original and generated image pixels.
# We are using MSE loss to calculate the difference between the reconstructed and original images.
base_loss = torch.nn.MSELoss(reduction="sum")
latent_dim = narrow_config_type(config, "latent_dim", int)
latent_dim = narrow_dict_type(config, "latent_dim", int)
return VaeLoss(latent_dim, base_loss)

def get_optimizer(self, config: Config) -> Optimizer:
return torch.optim.Adam(self.model.parameters(), lr=0.001)

def get_model(self, config: Config) -> nn.Module:
latent_dim = narrow_config_type(config, "latent_dim", int)
latent_dim = narrow_dict_type(config, "latent_dim", int)
encoder = ConvConditionalEncoder(latent_dim=latent_dim)
decoder = ConvConditionalDecoder(latent_dim=latent_dim)
return ConditionalVae(encoder=encoder, decoder=decoder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,14 @@ def main(config: Dict[str, Any]) -> None:
initial_parameters=get_all_model_parameters(model),
)

server = FlServerWithCheckpointing(SimpleClientManager(), model, parameter_exchanger, None, strategy, checkpointer)
server = FlServerWithCheckpointing(
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)

fl.server.start_server(
server=server,
Expand Down
8 changes: 4 additions & 4 deletions examples/ae_examples/cvae_examples/mlp_cvae_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from fl4health.clients.basic_client import BasicClient
from fl4health.model_bases.autoencoders_base import ConditionalVae
from fl4health.preprocessing.autoencoders.loss import VaeLoss
from fl4health.utils.config import narrow_config_type
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.dataset_converter import AutoEncoderDatasetConverter
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.metrics import Metric
Expand Down Expand Up @@ -45,7 +45,7 @@ def setup_client(self, config: Config) -> None:
self.model.unpack_input_condition = self.autoencoder_converter.get_unpacking_function()

def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
batch_size = narrow_config_type(config, "batch_size", int)
batch_size = narrow_dict_type(config, "batch_size", int)
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=100)
# ToTensor transform is used to make sure pixels stay in the range [0.0, 1.0].
# Flattening the image data to match the input shape of the model.
Expand All @@ -65,14 +65,14 @@ def get_criterion(self, config: Config) -> _Loss:
# The base_loss is the loss function used for comparing the original and generated image pixels.
# We are using MSE loss to calculate the difference between the reconstructed and original images.
base_loss = torch.nn.MSELoss(reduction="sum")
latent_dim = narrow_config_type(config, "latent_dim", int)
latent_dim = narrow_dict_type(config, "latent_dim", int)
return VaeLoss(latent_dim, base_loss)

def get_optimizer(self, config: Config) -> Optimizer:
return torch.optim.Adam(self.model.parameters(), lr=0.001)

def get_model(self, config: Config) -> nn.Module:
latent_dim = narrow_config_type(config, "latent_dim", int)
latent_dim = narrow_dict_type(config, "latent_dim", int)
# The input/output size is the flattened MNIST image size.
encoder = MnistConditionalEncoder(input_size=784, latent_dim=latent_dim)
decoder = MnistConditionalDecoder(latent_dim=latent_dim, output_size=784)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,14 @@ def main(config: Dict[str, Any]) -> None:
initial_parameters=get_all_model_parameters(model),
)

server = FlServerWithCheckpointing(SimpleClientManager(), model, parameter_exchanger, None, strategy, checkpointer)
server = FlServerWithCheckpointing(
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)

fl.server.start_server(
server=server,
Expand Down
8 changes: 4 additions & 4 deletions examples/ae_examples/fedprox_vae_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
from fl4health.clients.fed_prox_client import FedProxClient
from fl4health.model_bases.autoencoders_base import VariationalAe
from fl4health.preprocessing.autoencoders.loss import VaeLoss
from fl4health.utils.config import narrow_config_type
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.dataset_converter import AutoEncoderDatasetConverter
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.sampler import DirichletLabelBasedSampler


class VaeFedProxClient(FedProxClient):
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
batch_size = narrow_config_type(config, "batch_size", int)
batch_size = narrow_dict_type(config, "batch_size", int)
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=100)
# Flattening the input images to use an MLP-based variational autoencoder.
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(torch.flatten)])
Expand All @@ -42,14 +42,14 @@ def get_criterion(self, config: Config) -> _Loss:
# The base_loss is the loss function used for comparing the original and generated image pixels.
# We are using MSE loss to calculate the difference between the reconstructed and original images.
base_loss = torch.nn.MSELoss(reduction="sum")
latent_dim = narrow_config_type(config, "latent_dim", int)
latent_dim = narrow_dict_type(config, "latent_dim", int)
return VaeLoss(latent_dim, base_loss)

def get_optimizer(self, config: Config) -> Optimizer:
return torch.optim.Adam(self.model.parameters(), lr=0.001)

def get_model(self, config: Config) -> nn.Module:
latent_dim = narrow_config_type(config, "latent_dim", int)
latent_dim = narrow_dict_type(config, "latent_dim", int)
encoder = MnistVariationalEncoder(input_size=784, latent_dim=latent_dim)
decoder = MnistVariationalDecoder(latent_dim=latent_dim, output_size=784)
return VariationalAe(encoder=encoder, decoder=decoder)
Expand Down
9 changes: 8 additions & 1 deletion examples/ae_examples/fedprox_vae_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,14 @@ def main(config: Dict[str, Any]) -> None:
proximal_weight=config["proximal_weight"],
)

server = FlServerWithCheckpointing(SimpleClientManager(), model, parameter_exchanger, None, strategy, checkpointer)
server = FlServerWithCheckpointing(
client_manager=SimpleClientManager(),
model=model,
parameter_exchanger=parameter_exchanger,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)

fl.server.start_server(
server=server,
Expand Down
4 changes: 2 additions & 2 deletions examples/apfl_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from examples.models.cnn_model import MnistNetWithBnAndFrozen
from fl4health.clients.apfl_client import ApflClient
from fl4health.model_bases.apfl_base import ApflModule
from fl4health.utils.config import narrow_config_type
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.metrics import Accuracy
from fl4health.utils.random import set_all_random_seeds
Expand All @@ -22,7 +22,7 @@

class MnistApflClient(ApflClient):
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
batch_size = narrow_config_type(config, "batch_size", int)
batch_size = narrow_dict_type(config, "batch_size", int)
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75)
train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler)
return train_loader, val_loader
Expand Down
6 changes: 3 additions & 3 deletions examples/basic_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@

from examples.models.cnn_model import Net
from fl4health.clients.basic_client import BasicClient
from fl4health.utils.config import narrow_config_type
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_cifar10_data, load_cifar10_test_data
from fl4health.utils.metrics import Accuracy


class CifarClient(BasicClient):
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
batch_size = narrow_config_type(config, "batch_size", int)
batch_size = narrow_dict_type(config, "batch_size", int)
train_loader, val_loader, _ = load_cifar10_data(self.data_path, batch_size)
return train_loader, val_loader

def get_test_data_loader(self, config: Config) -> Optional[DataLoader]:
batch_size = narrow_config_type(config, "batch_size", int)
batch_size = narrow_dict_type(config, "batch_size", int)
test_loader, _ = load_cifar10_test_data(self.data_path, batch_size)
return test_loader

Expand Down
7 changes: 6 additions & 1 deletion examples/basic_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,12 @@ def main(config: Dict[str, Any]) -> None:
)

server = FlServerWithCheckpointing(
SimpleClientManager(), model, parameter_exchanger, None, strategy, checkpointers
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointers,
)

fl.server.start_server(
Expand Down
4 changes: 2 additions & 2 deletions examples/ditto_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from examples.models.cnn_model import MnistNet
from fl4health.clients.ditto_client import DittoClient
from fl4health.utils.config import narrow_config_type
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.metrics import Accuracy
from fl4health.utils.random import set_all_random_seeds
Expand All @@ -24,7 +24,7 @@
class MnistDittoClient(DittoClient):
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=1)
batch_size = narrow_config_type(config, "batch_size", int)
batch_size = narrow_dict_type(config, "batch_size", int)
train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler)
return train_loader, val_loader

Expand Down
4 changes: 2 additions & 2 deletions examples/docker_basic_example/fl_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from examples.models.cnn_model import Net
from fl4health.clients.basic_client import BasicClient
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.utils.config import narrow_config_type
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_cifar10_data
from fl4health.utils.metrics import Accuracy, Metric

Expand All @@ -22,7 +22,7 @@ def __init__(self, data_path: Path, metrics: Sequence[Metric], device: torch.dev

def setup_client(self, config: Config) -> None:
super().setup_client(config)
batch_size = narrow_config_type(config, "batch_size", int)
batch_size = narrow_dict_type(config, "batch_size", int)
train_loader, validation_loader, num_examples = load_cifar10_data(self.data_path, batch_size)

self.train_loader = train_loader
Expand Down
4 changes: 2 additions & 2 deletions examples/dp_fed_examples/client_level_dp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@

from examples.models.cnn_model import Net
from fl4health.clients.clipping_client import NumpyClippingClient
from fl4health.utils.config import narrow_config_type
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_cifar10_data
from fl4health.utils.metrics import Accuracy


class CifarClient(NumpyClippingClient):
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
batch_size = narrow_config_type(config, "batch_size", int)
batch_size = narrow_dict_type(config, "batch_size", int)
train_loader, val_loader, _ = load_cifar10_data(self.data_path, batch_size)
return train_loader, val_loader

Expand Down
6 changes: 3 additions & 3 deletions examples/dp_fed_examples/client_level_dp_weighted/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from examples.dp_fed_examples.client_level_dp_weighted.data import load_data
from examples.models.logistic_regression import LogisticRegression
from fl4health.clients.clipping_client import NumpyClippingClient
from fl4health.utils.config import narrow_config_type
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.metrics import Accuracy


Expand All @@ -22,8 +22,8 @@ def get_model(self, config: Config) -> nn.Module:
return LogisticRegression(input_dim=31, output_dim=1).to(self.device)

def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
batch_size = narrow_config_type(config, "batch_size", int)
scaler_bytes = narrow_config_type(config, "scaler", bytes)
batch_size = narrow_dict_type(config, "batch_size", int)
scaler_bytes = narrow_dict_type(config, "scaler", bytes)
train_loader, val_loader, _ = load_data(self.data_path, batch_size, scaler_bytes)
return train_loader, val_loader

Expand Down
4 changes: 2 additions & 2 deletions examples/dp_fed_examples/instance_level_dp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
from fl4health.checkpointing.client_module import ClientCheckpointModule
from fl4health.checkpointing.opacus_checkpointer import BestLossOpacusCheckpointer
from fl4health.clients.instance_level_dp_client import InstanceLevelDpClient
from fl4health.utils.config import narrow_config_type
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_cifar10_data
from fl4health.utils.metrics import Accuracy


class CifarClient(InstanceLevelDpClient):
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
batch_size = narrow_config_type(config, "batch_size", int)
batch_size = narrow_dict_type(config, "batch_size", int)
train_loader, val_loader, _ = load_cifar10_data(self.data_path, batch_size)
return train_loader, val_loader

Expand Down
4 changes: 2 additions & 2 deletions examples/dp_scaffold_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@

from examples.models.cnn_model import MnistNet
from fl4health.clients.scaffold_client import DPScaffoldClient
from fl4health.utils.config import narrow_config_type
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.metrics import Accuracy
from fl4health.utils.sampler import DirichletLabelBasedSampler


class MnistDPScaffoldClient(DPScaffoldClient):
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
batch_size = narrow_config_type(config, "batch_size", int)
batch_size = narrow_dict_type(config, "batch_size", int)
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=1.0)
train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler)
return train_loader, val_loader
Expand Down
Loading

0 comments on commit b753a96

Please sign in to comment.