Skip to content

Commit

Permalink
Merge branch 'main' into add_lightning_DLI
Browse files Browse the repository at this point in the history
  • Loading branch information
nvkevlu authored Feb 6, 2025
2 parents bfec6bb + 4897865 commit 4dffc96
Show file tree
Hide file tree
Showing 15 changed files with 1,203 additions and 5 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

from src.kaplan_meier_wf import KM
from src.kaplan_meier_wf_he import KM_HE

from nvflare import FedJob
from nvflare.job_config.script_runner import ScriptRunner


def main():
args = define_parser()
# Default paths
data_root = "/tmp/nvflare/dataset/km_data"
he_context_path = "/tmp/nvflare/he_context/he_context_client.txt"

# Set the script and config
if args.encryption:
job_name = "KM_HE"
train_script = "src/kaplan_meier_train_he.py"
script_args = f"--data_root {data_root} --he_context_path {he_context_path}"
else:
job_name = "KM"
train_script = "src/kaplan_meier_train.py"
script_args = f"--data_root {data_root}"

# Set the number of clients and threads
num_clients = args.num_clients
if args.num_threads:
num_threads = args.num_threads
else:
num_threads = num_clients

# Set the output workspace and job directories
workspace_dir = os.path.join(args.workspace_dir, job_name)
job_dir = args.job_dir

# Create the FedJob
job = FedJob(name=job_name, min_clients=num_clients)

# Define the KM controller workflow and send to server
if args.encryption:
controller = KM_HE(min_clients=num_clients, he_context_path=he_context_path)
else:
controller = KM(min_clients=num_clients)
job.to_server(controller)

# Define the ScriptRunner and send to all clients
runner = ScriptRunner(
script=train_script,
script_args=script_args,
params_exchange_format="raw",
launch_external_process=False,
)
job.to_clients(runner, tasks=["train"])

# Export the job
print("job_dir=", job_dir)
job.export_job(job_dir)

# Run the job
print("workspace_dir=", workspace_dir)
print("num_threads=", num_threads)
job.simulator_run(workspace_dir, n_clients=num_clients, threads=num_threads)


def define_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--workspace_dir",
type=str,
default="/tmp/nvflare/jobs/km/workdir",
help="work directory, default to '/tmp/nvflare/jobs/km/workdir'",
)
parser.add_argument(
"--job_dir",
type=str,
default="/tmp/nvflare/jobs/km/jobdir",
help="directory for job export, default to '/tmp/nvflare/jobs/km/jobdir'",
)
parser.add_argument(
"--encryption",
action=argparse.BooleanOptionalAction,
help="whether to enable encryption, default to False",
)
parser.add_argument(
"--num_clients",
type=int,
default=5,
help="number of clients to simulate, default to 5",
)
parser.add_argument(
"--num_threads",
type=int,
help="number of threads to use for FL simulation, default to the number of clients if not specified",
)
return parser.parse_args()


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
lifelines
tenseal
scikit-survival
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from lifelines import KaplanMeierFitter
from lifelines.utils import survival_table_from_events

# (1) import nvflare client API
import nvflare.client as flare
from nvflare.app_common.abstract.fl_model import FLModel, ParamsType


# Client code
def details_save(kmf):
# Get the survival function at all observed time points
survival_function_at_all_times = kmf.survival_function_
# Get the timeline (time points)
timeline = survival_function_at_all_times.index.values
# Get the KM estimate
km_estimate = survival_function_at_all_times["KM_estimate"].values
# Get the event count at each time point
event_count = kmf.event_table.iloc[:, 0].values # Assuming the first column is the observed events
# Get the survival rate at each time point (using the 1st column of the survival function)
survival_rate = 1 - survival_function_at_all_times.iloc[:, 0].values
# Return the results
results = {
"timeline": timeline.tolist(),
"km_estimate": km_estimate.tolist(),
"event_count": event_count.tolist(),
"survival_rate": survival_rate.tolist(),
}
file_path = os.path.join(os.getcwd(), "km_global.json")
print(f"save the details of KM analysis result to {file_path} \n")
with open(file_path, "w") as json_file:
json.dump(results, json_file, indent=4)


def plot_and_save(kmf):
# Plot and save the Kaplan-Meier survival curve
plt.figure()
plt.title("Federated")
kmf.plot_survival_function()
plt.ylim(0, 1)
plt.ylabel("prob")
plt.xlabel("time")
plt.legend("", frameon=False)
plt.tight_layout()
file_path = os.path.join(os.getcwd(), "km_curve_fl.png")
print(f"save the curve plot to {file_path} \n")
plt.savefig(file_path)


def main():
parser = argparse.ArgumentParser(description="KM analysis")
parser.add_argument("--data_root", type=str, help="Root path for data files")
args = parser.parse_args()

flare.init()

site_name = flare.get_site_name()
print(f"Kaplan-meier analysis for {site_name}")

# get local data
data_path = os.path.join(args.data_root, site_name + ".csv")
data = pd.read_csv(data_path)
event_local = data["event"]
time_local = data["time"]

while flare.is_running():
# receives global message from NVFlare
global_msg = flare.receive()
curr_round = global_msg.current_round
print(f"current_round={curr_round}")

if curr_round == 1:
# First round:
# Empty payload from server, send local histogram
# Convert local data to histogram
event_table = survival_table_from_events(time_local, event_local)
hist_idx = event_table.index.values.astype(int)
hist_obs = {}
hist_cen = {}
for idx in range(max(hist_idx)):
hist_obs[idx] = 0
hist_cen[idx] = 0
# Assign values
idx = event_table.index.values.astype(int)
observed = event_table["observed"].to_numpy()
censored = event_table["censored"].to_numpy()
for i in range(len(idx)):
hist_obs[idx[i]] = observed[i]
hist_cen[idx[i]] = censored[i]
# Send histograms to server
response = FLModel(params={"hist_obs": hist_obs, "hist_cen": hist_cen}, params_type=ParamsType.FULL)
flare.send(response)

elif curr_round == 2:
# Get global histograms
hist_obs_global = global_msg.params["hist_obs_global"]
hist_cen_global = global_msg.params["hist_cen_global"]
# Unfold histogram to event list
time_unfold = []
event_unfold = []
for i in hist_obs_global.keys():
for j in range(hist_obs_global[i]):
time_unfold.append(i)
event_unfold.append(True)
for k in range(hist_cen_global[i]):
time_unfold.append(i)
event_unfold.append(False)
time_unfold = np.array(time_unfold)
event_unfold = np.array(event_unfold)

# Perform Kaplan-Meier analysis on global aggregated information
# Create a Kaplan-Meier estimator
kmf = KaplanMeierFitter()

# Fit the model
kmf.fit(durations=time_unfold, event_observed=event_unfold)

# Plot and save the KM curve
plot_and_save(kmf)

# Save details of the KM result to a json file
details_save(kmf)

# Send a simple response to server
response = FLModel(params={}, params_type=ParamsType.FULL)
flare.send(response)

print(f"finish send for {site_name}, complete")


if __name__ == "__main__":
main()
Loading

0 comments on commit 4dffc96

Please sign in to comment.