Skip to content

Commit bfec6bb

Browse files
committed
add lightning to FL content for DLI
1 parent c0a6e7c commit bfec6bb

File tree

7 files changed

+780
-0
lines changed

7 files changed

+780
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from src.lit_net import LitNet
16+
17+
from nvflare.app_common.workflows.fedavg import FedAvg
18+
from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob
19+
from nvflare.job_config.script_runner import ScriptRunner
20+
21+
if __name__ == "__main__":
22+
n_clients = 5
23+
num_rounds = 2
24+
25+
job = BaseFedJob(
26+
name="cifar10_lightning_fedavg",
27+
initial_model=LitNet(),
28+
)
29+
30+
controller = FedAvg(
31+
num_clients=n_clients,
32+
num_rounds=num_rounds,
33+
)
34+
job.to(controller, "server")
35+
36+
# Add clients
37+
for i in range(n_clients):
38+
runner = ScriptRunner(
39+
script="src/cifar10_lightning_fl.py", script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}"
40+
)
41+
job.to(runner, f"site-{i+1}")
42+
43+
job.export_job("/tmp/nvflare/jobs/job_config")
44+
job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0", log_config="./log_config.json")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
{
2+
"version": 1,
3+
"disable_existing_loggers": false,
4+
"formatters": {
5+
"baseFormatter": {
6+
"()": "nvflare.fuel.utils.log_utils.BaseFormatter",
7+
"fmt": "%(asctime)s - %(name)s - %(levelname)s - %(fl_ctx)s - %(message)s"
8+
},
9+
"colorFormatter": {
10+
"()": "nvflare.fuel.utils.log_utils.ColorFormatter",
11+
"fmt": "%(asctime)s - %(levelname)s - %(message)s",
12+
"datefmt": "%Y-%m-%d %H:%M:%S"
13+
},
14+
"jsonFormatter": {
15+
"()": "nvflare.fuel.utils.log_utils.JsonFormatter",
16+
"fmt": "%(asctime)s - %(identity)s - %(name)s - %(fullName)s - %(levelname)s - %(fl_ctx)s - %(message)s"
17+
}
18+
},
19+
"filters": {
20+
"FLFilter": {
21+
"()": "nvflare.fuel.utils.log_utils.LoggerNameFilter",
22+
"logger_names": ["custom", "nvflare.app_common", "nvflare.app_opt"]
23+
}
24+
},
25+
"handlers": {
26+
"consoleHandler": {
27+
"class": "logging.StreamHandler",
28+
"level": "INFO",
29+
"formatter": "colorFormatter",
30+
"filters": ["FLFilter"],
31+
"stream": "ext://sys.stdout"
32+
},
33+
"logFileHandler": {
34+
"class": "logging.handlers.RotatingFileHandler",
35+
"level": "DEBUG",
36+
"formatter": "baseFormatter",
37+
"filename": "log.txt",
38+
"mode": "a",
39+
"maxBytes": 20971520,
40+
"backupCount": 10
41+
},
42+
"errorFileHandler": {
43+
"class": "logging.handlers.RotatingFileHandler",
44+
"level": "ERROR",
45+
"formatter": "baseFormatter",
46+
"filename": "log_error.txt",
47+
"mode": "a",
48+
"maxBytes": 20971520,
49+
"backupCount": 10
50+
},
51+
"jsonFileHandler": {
52+
"class": "logging.handlers.RotatingFileHandler",
53+
"level": "DEBUG",
54+
"formatter": "jsonFormatter",
55+
"filename": "log.json",
56+
"mode": "a",
57+
"maxBytes": 20971520,
58+
"backupCount": 10
59+
},
60+
"FLFileHandler": {
61+
"class": "logging.handlers.RotatingFileHandler",
62+
"level": "DEBUG",
63+
"formatter": "baseFormatter",
64+
"filters": ["FLFilter"],
65+
"filename": "log_fl.txt",
66+
"mode": "a",
67+
"maxBytes": 20971520,
68+
"backupCount": 10,
69+
"delay": true
70+
}
71+
},
72+
"loggers": {
73+
"root": {
74+
"level": "INFO",
75+
"handlers": ["consoleHandler", "logFileHandler", "errorFileHandler", "jsonFileHandler", "FLFileHandler"]
76+
}
77+
}
78+
}
79+
80+
81+
82+
83+
84+
85+
86+
87+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
nvflare~=2.5.0rc
2+
torch
3+
torchvision
4+
pytorch_lightning
5+
tensorboard
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
import torchvision
17+
import torchvision.transforms as transforms
18+
from lit_net import LitNet
19+
from pytorch_lightning import LightningDataModule, Trainer, seed_everything
20+
from torch.utils.data import DataLoader, random_split
21+
22+
# (1) import nvflare lightning client API
23+
import nvflare.client.lightning as flare
24+
25+
seed_everything(7)
26+
27+
28+
DATASET_PATH = "/tmp/nvflare/data"
29+
BATCH_SIZE = 4
30+
31+
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
32+
33+
34+
class CIFAR10DataModule(LightningDataModule):
35+
def __init__(self, data_dir: str = DATASET_PATH, batch_size: int = BATCH_SIZE):
36+
super().__init__()
37+
self.data_dir = data_dir
38+
self.batch_size = batch_size
39+
40+
def prepare_data(self):
41+
torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=True, transform=transform)
42+
torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=True, transform=transform)
43+
44+
def setup(self, stage: str):
45+
# Assign train/val datasets for use in dataloaders
46+
if stage == "fit" or stage == "validate":
47+
cifar_full = torchvision.datasets.CIFAR10(
48+
root=self.data_dir, train=True, download=False, transform=transform
49+
)
50+
self.cifar_train, self.cifar_val = random_split(cifar_full, [0.8, 0.2])
51+
52+
# Assign test dataset for use in dataloader(s)
53+
if stage == "test" or stage == "predict":
54+
self.cifar_test = torchvision.datasets.CIFAR10(
55+
root=self.data_dir, train=False, download=False, transform=transform
56+
)
57+
58+
def train_dataloader(self):
59+
return DataLoader(self.cifar_train, batch_size=self.batch_size)
60+
61+
def val_dataloader(self):
62+
return DataLoader(self.cifar_val, batch_size=self.batch_size)
63+
64+
def test_dataloader(self):
65+
return DataLoader(self.cifar_test, batch_size=self.batch_size)
66+
67+
def predict_dataloader(self):
68+
return DataLoader(self.cifar_test, batch_size=self.batch_size)
69+
70+
71+
def main():
72+
model = LitNet()
73+
cifar10_dm = CIFAR10DataModule()
74+
trainer = Trainer(max_epochs=1, devices=1, accelerator="gpu" if torch.cuda.is_available() else "cpu")
75+
76+
# (2) patch the lightning trainer
77+
flare.patch(trainer)
78+
79+
while flare.is_running():
80+
# (3) receives FLModel from NVFlare
81+
# Note that we don't need to pass this input_model to trainer
82+
# because after flare.patch the trainer.fit/validate will get the
83+
# global model internally
84+
input_model = flare.receive()
85+
print(f"\n[Current Round={input_model.current_round}, Site = {flare.get_site_name()}]\n")
86+
87+
# (4) evaluate the current global model to allow server-side model selection
88+
print("--- validate global model ---")
89+
trainer.validate(model, datamodule=cifar10_dm)
90+
91+
# perform local training starting with the received global model
92+
print("--- train new model ---")
93+
trainer.fit(model, datamodule=cifar10_dm)
94+
95+
# test local model
96+
print("--- test new model ---")
97+
trainer.test(ckpt_path="best", datamodule=cifar10_dm)
98+
99+
# get predictions
100+
print("--- prediction with new best model ---")
101+
trainer.predict(ckpt_path="best", datamodule=cifar10_dm)
102+
103+
104+
if __name__ == "__main__":
105+
main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any
16+
17+
import torch
18+
import torch.nn as nn
19+
import torch.nn.functional as F
20+
import torch.optim as optim
21+
from pytorch_lightning import LightningModule
22+
from torchmetrics import Accuracy
23+
24+
NUM_CLASSES = 10
25+
criterion = nn.CrossEntropyLoss()
26+
27+
28+
class Net(nn.Module):
29+
def __init__(self):
30+
super().__init__()
31+
self.conv1 = nn.Conv2d(3, 6, 5)
32+
self.pool = nn.MaxPool2d(2, 2)
33+
self.conv2 = nn.Conv2d(6, 16, 5)
34+
self.fc1 = nn.Linear(16 * 5 * 5, 120)
35+
self.fc2 = nn.Linear(120, 84)
36+
self.fc3 = nn.Linear(84, 10)
37+
38+
def forward(self, x):
39+
x = self.pool(F.relu(self.conv1(x)))
40+
x = self.pool(F.relu(self.conv2(x)))
41+
x = torch.flatten(x, 1) # flatten all dimensions except batch
42+
x = F.relu(self.fc1(x))
43+
x = F.relu(self.fc2(x))
44+
x = self.fc3(x)
45+
return x
46+
47+
48+
class LitNet(LightningModule):
49+
def __init__(self):
50+
super().__init__()
51+
self.save_hyperparameters()
52+
self.model = Net()
53+
self.train_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES)
54+
self.valid_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES)
55+
# (optional) pass additional information via self.__fl_meta__
56+
self.__fl_meta__ = {}
57+
58+
def forward(self, x):
59+
out = self.model(x)
60+
return out
61+
62+
def training_step(self, batch, batch_idx):
63+
x, labels = batch
64+
outputs = self(x)
65+
loss = criterion(outputs, labels)
66+
self.train_acc(outputs, labels)
67+
self.log("train_loss", loss)
68+
self.log("train_acc", self.train_acc, on_step=True, on_epoch=False)
69+
return loss
70+
71+
def evaluate(self, batch, stage=None):
72+
x, labels = batch
73+
outputs = self(x)
74+
loss = criterion(outputs, labels)
75+
self.valid_acc(outputs, labels)
76+
77+
if stage:
78+
self.log(f"{stage}_loss", loss)
79+
self.log(f"{stage}_acc", self.valid_acc, on_step=True, on_epoch=True)
80+
return outputs
81+
82+
def validation_step(self, batch, batch_idx):
83+
self.evaluate(batch, "val")
84+
85+
def test_step(self, batch, batch_idx):
86+
self.evaluate(batch, "test")
87+
88+
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
89+
return self.evaluate(batch)
90+
91+
def configure_optimizers(self):
92+
optimizer = optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
93+
return {"optimizer": optimizer}

examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/convert_torch_lightning_to_dl.ipynb

Whitespace-only changes.

0 commit comments

Comments
 (0)