Skip to content

Commit

Permalink
Making small changes to make new mkmmd clients compatible with the ne…
Browse files Browse the repository at this point in the history
…w adaptive clients.
  • Loading branch information
emersodb committed Sep 25, 2024
1 parent 960d39c commit 859c34b
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 34 deletions.
8 changes: 2 additions & 6 deletions fl4health/clients/mkmmd_clients/ditto_mkmmd_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def __init__(
device: torch.device,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
checkpointer: Optional[ClientCheckpointModule] = None,
lam: float = 1.0,
mkmmd_loss_weight: float = 10.0,
feature_extraction_layers: Optional[Sequence[str]] = None,
feature_l2_norm_weight: float = 0.0,
Expand All @@ -45,7 +44,6 @@ def __init__(
checkpointer (Optional[ClientCheckpointModule], optional): Checkpointer module defining when and how to
do checkpointing during client-side training. No checkpointing is done if not provided. Defaults to
None.
lam (float, optional): weight applied to the Ditto drift loss. Defaults to 1.0.
mkmmd_loss_weight (float, optional): weight applied to the MK-MMD loss. Defaults to 10.0.
feature_extraction_layers (Optional[Sequence[str]], optional): List of layers from which to extract
and flatten features. Defaults to None.
Expand All @@ -62,7 +60,6 @@ def __init__(
device=device,
loss_meter_type=loss_meter_type,
checkpointer=checkpointer,
lam=lam,
)
self.mkmmd_loss_weight = mkmmd_loss_weight
if self.mkmmd_loss_weight == 0:
Expand All @@ -87,7 +84,7 @@ def __init__(
self.flatten_feature_extraction_layers = {layer: True for layer in feature_extraction_layers}
else:
self.flatten_feature_extraction_layers = {}
self.mkmmd_losses = {}
self.mkmmd_losses: Dict[str, MkMmdLoss] = {}
for layer in self.flatten_feature_extraction_layers.keys():
self.mkmmd_losses[layer] = MkMmdLoss(
device=self.device, minimize_type_two_error=True, normalize_features=True, layer_name=layer
Expand All @@ -110,8 +107,7 @@ def update_before_train(self, current_server_round: int) -> None:
# Register hooks to extract features from the local model if not already registered. As hooks have
# to be removed to checkpoint the model, so we check if they need to be re-registered each time.
self.local_feature_extractor._maybe_register_hooks()
# Clone and freeze the initial weights GLOBAL MODEL. These are used to form the Ditto local
# update penalty term.
# Clone and freeze the initial GLOBAL MODEL. This is used to extract features for the MkMMD constraints
self.initial_global_model = self.clone_and_freeze_model(self.global_model)
self.initial_global_feature_extractor = FeatureExtractorBuffer(
model=self.initial_global_model,
Expand Down
5 changes: 1 addition & 4 deletions fl4health/clients/mkmmd_clients/mr_mtl_mkmmd_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def __init__(
device: torch.device,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
checkpointer: Optional[ClientCheckpointModule] = None,
lam: float = 1.0,
mkmmd_loss_weight: float = 10.0,
feature_extraction_layers: Optional[Sequence[str]] = None,
feature_l2_norm_weight: float = 0.0,
Expand All @@ -45,7 +44,6 @@ def __init__(
checkpointer (Optional[ClientCheckpointModule], optional): Checkpointer module defining when and how to
do checkpointing during client-side training. No checkpointing is done if not provided. Defaults to
None.
lam (float, optional): weight applied to the MR-MTL drift loss. Defaults to 1.0.
mkmmd_loss_weight (float, optional): weight applied to the MK-MMD loss. Defaults to 10.0.
feature_extraction_layers (Optional[Sequence[str]], optional): List of layers from which to extract
and flatten features. Defaults to None.
Expand All @@ -62,7 +60,6 @@ def __init__(
device=device,
loss_meter_type=loss_meter_type,
checkpointer=checkpointer,
lam=lam,
)
self.mkmmd_loss_weight = mkmmd_loss_weight
if self.mkmmd_loss_weight == 0:
Expand All @@ -87,7 +84,7 @@ def __init__(
self.flatten_feature_extraction_layers = {layer: True for layer in feature_extraction_layers}
else:
self.flatten_feature_extraction_layers = {}
self.mkmmd_losses = {}
self.mkmmd_losses: Dict[str, MkMmdLoss] = {}
for layer in self.flatten_feature_extraction_layers.keys():
self.mkmmd_losses[layer] = MkMmdLoss(
device=self.device, minimize_type_two_error=True, normalize_features=True, layer_name=layer
Expand Down
7 changes: 0 additions & 7 deletions research/flamby/fed_isic2019/ditto_mkmmd/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def __init__(
device: torch.device,
client_number: int,
learning_rate: float,
lam: float = 0,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
mkmmd_loss_weight: float = 10,
feature_l2_norm_weight: float = 1,
Expand All @@ -50,7 +49,6 @@ def __init__(
device=device,
loss_meter_type=loss_meter_type,
checkpointer=checkpointer,
lam=lam,
mkmmd_loss_weight=mkmmd_loss_weight,
feature_extraction_layers=FED_ISIC2019_BASELINE_LAYERS[-1 * mkmmd_loss_depth :],
feature_l2_norm_weight=feature_l2_norm_weight,
Expand Down Expand Up @@ -124,9 +122,6 @@ def get_criterion(self, config: Config) -> _Loss:
parser.add_argument(
"--learning_rate", action="store", type=float, help="Learning rate for local optimization", default=LR
)
parser.add_argument(
"--lam", action="store", type=float, help="Ditto loss weight for local model training", default=0.01
)
parser.add_argument(
"--seed",
action="store",
Expand Down Expand Up @@ -170,7 +165,6 @@ def get_criterion(self, config: Config) -> _Loss:
log(INFO, f"Device to be used: {DEVICE}")
log(INFO, f"Server Address: {args.server_address}")
log(INFO, f"Learning Rate: {args.learning_rate}")
log(INFO, f"Lambda: {args.lam}")
log(INFO, f"Mu: {args.mu}")
log(INFO, f"Feature L2 Norm Weight: {args.l2}")
log(INFO, f"MKMMD Loss Depth: {args.mkmmd_loss_depth}")
Expand All @@ -189,7 +183,6 @@ def get_criterion(self, config: Config) -> _Loss:
device=DEVICE,
client_number=args.client_number,
learning_rate=args.learning_rate,
lam=args.lam,
checkpointer=checkpointer,
feature_l2_norm_weight=args.l2,
mkmmd_loss_depth=args.mkmmd_loss_depth,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ do
--config_path ${SERVER_CONFIG_PATH} \
--server_address ${SERVER_ADDRESS} \
--seed ${SEED} \
--lam ${L2_VALUES}
> ${SERVER_OUTPUT_FILE} 2>&1 &

# Sleep for 20 seconds to allow the server to come up.
Expand All @@ -154,7 +155,6 @@ do
--run_name ${RUN_NAME} \
--client_number ${c} \
--learning_rate ${CLIENT_LR} \
--lam ${LAM_VALUE} \
--mu ${MU_VALUE} \
--l2 ${L2_VALUE} \
--mkmmd_loss_depth ${MKMMD_LOSS_DEPTH} \
Expand Down
13 changes: 9 additions & 4 deletions research/flamby/fed_isic2019/ditto_mkmmd/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from flamby.datasets.fed_isic2019 import Baseline
from flwr.common.logger import log
from flwr.server.client_manager import SimpleClientManager
from flwr.server.strategy import FedAvg

from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint
from fl4health.utils.config import load_config
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from fl4health.utils.parameter_extraction import get_all_model_parameters
Expand All @@ -17,7 +17,7 @@
from research.flamby.utils import fit_config, summarize_model_info


def main(config: Dict[str, Any], server_address: str) -> None:
def main(config: Dict[str, Any], server_address: str, lam: float) -> None:
# This function will be used to produce a config that is sent to each client to initialize their own environment
fit_config_fn = partial(
fit_config,
Expand All @@ -30,7 +30,7 @@ def main(config: Dict[str, Any], server_address: str) -> None:
summarize_model_info(model)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
strategy = FedAvgWithAdaptiveConstraint(
min_fit_clients=config["n_clients"],
min_evaluate_clients=config["n_clients"],
# Server waits for min_available_clients before starting FL rounds
Expand All @@ -41,6 +41,7 @@ def main(config: Dict[str, Any], server_address: str) -> None:
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=get_all_model_parameters(model),
initial_loss_weight=lam,
)

server = PersonalServer(client_manager, strategy)
Expand Down Expand Up @@ -81,12 +82,16 @@ def main(config: Dict[str, Any], server_address: str) -> None:
help="Seed for the random number generators across python, torch, and numpy",
required=False,
)
parser.add_argument(
"--lam", action="store", type=float, help="Ditto loss weight for local model training", default=0.01
)
args = parser.parse_args()

config = load_config(args.config_path)
log(INFO, f"Server Address: {args.server_address}")
log(INFO, f"Lambda: {args.lam}")

# Set the random seed for reproducibility
set_all_random_seeds(args.seed)

main(config, args.server_address)
main(config, args.server_address, args.lam)
7 changes: 0 additions & 7 deletions research/flamby/fed_isic2019/mr_mtl_mkmmd/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def __init__(
device: torch.device,
client_number: int,
learning_rate: float,
lam: float = 0,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
mkmmd_loss_weight: float = 10,
feature_l2_norm_weight: float = 1,
Expand All @@ -50,7 +49,6 @@ def __init__(
device=device,
loss_meter_type=loss_meter_type,
checkpointer=checkpointer,
lam=lam,
mkmmd_loss_weight=mkmmd_loss_weight,
feature_extraction_layers=FED_ISIC2019_BASELINE_LAYERS[-1 * mkmmd_loss_depth :],
feature_l2_norm_weight=feature_l2_norm_weight,
Expand Down Expand Up @@ -120,9 +118,6 @@ def get_criterion(self, config: Config) -> _Loss:
parser.add_argument(
"--learning_rate", action="store", type=float, help="Learning rate for local optimization", default=LR
)
parser.add_argument(
"--lam", action="store", type=float, help="MR-MTL loss weight for local model training", default=0.01
)
parser.add_argument(
"--seed",
action="store",
Expand Down Expand Up @@ -166,7 +161,6 @@ def get_criterion(self, config: Config) -> _Loss:
log(INFO, f"Device to be used: {DEVICE}")
log(INFO, f"Server Address: {args.server_address}")
log(INFO, f"Learning Rate: {args.learning_rate}")
log(INFO, f"Lambda: {args.lam}")
log(INFO, f"Mu: {args.mu}")
log(INFO, f"Feature L2 Norm Weight: {args.l2}")
log(INFO, f"MKMMD Loss Depth: {args.mkmmd_loss_depth}")
Expand All @@ -185,7 +179,6 @@ def get_criterion(self, config: Config) -> _Loss:
device=DEVICE,
client_number=args.client_number,
learning_rate=args.learning_rate,
lam=args.lam,
checkpointer=checkpointer,
feature_l2_norm_weight=args.l2,
mkmmd_loss_depth=args.mkmmd_loss_depth,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ do
--config_path ${SERVER_CONFIG_PATH} \
--server_address ${SERVER_ADDRESS} \
--seed ${SEED} \
--lam ${LAM_VALUE} \
> ${SERVER_OUTPUT_FILE} 2>&1 &

# Sleep for 20 seconds to allow the server to come up.
Expand All @@ -154,7 +155,6 @@ do
--run_name ${RUN_NAME} \
--client_number ${c} \
--learning_rate ${CLIENT_LR} \
--lam ${LAM_VALUE} \
--mu ${MU_VALUE} \
--l2 ${L2_VALUE} \
--mkmmd_loss_depth ${MKMMD_LOSS_DEPTH} \
Expand Down
13 changes: 9 additions & 4 deletions research/flamby/fed_isic2019/mr_mtl_mkmmd/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from flamby.datasets.fed_isic2019 import Baseline
from flwr.common.logger import log
from flwr.server.client_manager import SimpleClientManager
from flwr.server.strategy import FedAvg

from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint
from fl4health.utils.config import load_config
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from fl4health.utils.parameter_extraction import get_all_model_parameters
Expand All @@ -17,7 +17,7 @@
from research.flamby.utils import fit_config, summarize_model_info


def main(config: Dict[str, Any], server_address: str) -> None:
def main(config: Dict[str, Any], server_address: str, lam: float) -> None:
# This function will be used to produce a config that is sent to each client to initialize their own environment
fit_config_fn = partial(
fit_config,
Expand All @@ -30,7 +30,7 @@ def main(config: Dict[str, Any], server_address: str) -> None:
summarize_model_info(model)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
strategy = FedAvgWithAdaptiveConstraint(
min_fit_clients=config["n_clients"],
min_evaluate_clients=config["n_clients"],
# Server waits for min_available_clients before starting FL rounds
Expand All @@ -41,6 +41,7 @@ def main(config: Dict[str, Any], server_address: str) -> None:
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=get_all_model_parameters(model),
initial_loss_weight=lam,
)

server = PersonalServer(client_manager, strategy)
Expand Down Expand Up @@ -81,12 +82,16 @@ def main(config: Dict[str, Any], server_address: str) -> None:
help="Seed for the random number generators across python, torch, and numpy",
required=False,
)
parser.add_argument(
"--lam", action="store", type=float, help="MR-MTL loss weight for local model training", default=0.01
)
args = parser.parse_args()

config = load_config(args.config_path)
log(INFO, f"Server Address: {args.server_address}")
log(INFO, f"Lambda: {args.lam}")

# Set the random seed for reproducibility
set_all_random_seeds(args.seed)

main(config, args.server_address)
main(config, args.server_address, args.lam)

0 comments on commit 859c34b

Please sign in to comment.