Skip to content

Commit

Permalink
Add content for experiment tracking for DLI (#3203)
Browse files Browse the repository at this point in the history
Add content for experiment tracking for DLI.

### Description

Add content for experiment tracking for DLI.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Quick tests passed locally by running `./runtest.sh`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated.
  • Loading branch information
nvkevlu authored Feb 10, 2025
1 parent 7f03432 commit b40e8a1
Show file tree
Hide file tree
Showing 21 changed files with 1,547 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
Expand All @@ -752,9 +752,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 5
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"\n",
" * [customize client logics](../01.4_customize_client_training/customize_client_training.ipynb)\n",
"\n",
"5. **Tracking the trainig metrics** \n",
"5. **Tracking the training metrics** \n",
"\n",
" * [experiment tracking](../01.5_experiment_tracking/experiment_tracking.ipynb )\n",
"\n",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This Dirichlet sampling strategy for creating a heterogeneous partition is adopted
# from FedMA (https://github.com/IBM/FedMA).

# MIT License

# Copyright (c) 2020 International Business Machines

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import argparse

import torchvision.datasets as datasets

# default dataset path
CIFAR10_ROOT = "/tmp/nvflare/data/cifar10"


def define_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, default=CIFAR10_ROOT, nargs="?")
args = parser.parse_args()
return args


def main(args):
datasets.CIFAR10(root=args.dataset_path, train=True, download=True)
datasets.CIFAR10(root=args.dataset_path, train=False, download=True)


if __name__ == "__main__":
main(define_parser())
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from src.network import SimpleNetwork

from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob
from nvflare.job_config.script_runner import ScriptRunner

if __name__ == "__main__":
n_clients = 5
num_rounds = 20

train_script = "src/client.py"

job = FedAvgJob(name="fedavg", n_clients=n_clients, num_rounds=num_rounds, initial_model=SimpleNetwork())

# Add clients
for i in range(n_clients):
executor = ScriptRunner(
script=train_script, script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}"
)
job.to(executor, f"site-{i + 1}")

job.simulator_run(workspace="/tmp/nvflare/jobs/workdir", log_config="./log_config.json")
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
{
"version": 1,
"disable_existing_loggers": false,
"formatters": {
"baseFormatter": {
"()": "nvflare.fuel.utils.log_utils.BaseFormatter",
"fmt": "%(asctime)s - %(name)s - %(levelname)s - %(fl_ctx)s - %(message)s"
},
"colorFormatter": {
"()": "nvflare.fuel.utils.log_utils.ColorFormatter",
"fmt": "%(asctime)s - %(levelname)s - %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S"
},
"jsonFormatter": {
"()": "nvflare.fuel.utils.log_utils.JsonFormatter",
"fmt": "%(asctime)s - %(identity)s - %(name)s - %(fullName)s - %(levelname)s - %(fl_ctx)s - %(message)s"
}
},
"filters": {
"FLFilter": {
"()": "nvflare.fuel.utils.log_utils.LoggerNameFilter",
"logger_names": ["custom", "nvflare.app_common", "nvflare.app_opt"]
}
},
"handlers": {
"consoleHandler": {
"class": "logging.StreamHandler",
"level": "INFO",
"formatter": "colorFormatter",
"filters": ["FLFilter"],
"stream": "ext://sys.stdout"
},
"logFileHandler": {
"class": "logging.handlers.RotatingFileHandler",
"level": "DEBUG",
"formatter": "baseFormatter",
"filename": "log.txt",
"mode": "a",
"maxBytes": 20971520,
"backupCount": 10
},
"errorFileHandler": {
"class": "logging.handlers.RotatingFileHandler",
"level": "ERROR",
"formatter": "baseFormatter",
"filename": "log_error.txt",
"mode": "a",
"maxBytes": 20971520,
"backupCount": 10
},
"jsonFileHandler": {
"class": "logging.handlers.RotatingFileHandler",
"level": "DEBUG",
"formatter": "jsonFormatter",
"filename": "log.json",
"mode": "a",
"maxBytes": 20971520,
"backupCount": 10
},
"FLFileHandler": {
"class": "logging.handlers.RotatingFileHandler",
"level": "DEBUG",
"formatter": "baseFormatter",
"filters": ["FLFilter"],
"filename": "log_fl.txt",
"mode": "a",
"maxBytes": 20971520,
"backupCount": 10,
"delay": true
}
},
"loggers": {
"root": {
"level": "INFO",
"handlers": ["consoleHandler", "logFileHandler", "errorFileHandler", "jsonFileHandler", "FLFileHandler"]
}
}
}









Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
torch
torchvision
tensorboard
Loading

0 comments on commit b40e8a1

Please sign in to comment.