Skip to content

Commit 4dffc96

Browse files
authored
Merge branch 'main' into add_lightning_DLI
2 parents bfec6bb + 4897865 commit 4dffc96

File tree

15 files changed

+1203
-5
lines changed

15 files changed

+1203
-5
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import os
17+
18+
from src.kaplan_meier_wf import KM
19+
from src.kaplan_meier_wf_he import KM_HE
20+
21+
from nvflare import FedJob
22+
from nvflare.job_config.script_runner import ScriptRunner
23+
24+
25+
def main():
26+
args = define_parser()
27+
# Default paths
28+
data_root = "/tmp/nvflare/dataset/km_data"
29+
he_context_path = "/tmp/nvflare/he_context/he_context_client.txt"
30+
31+
# Set the script and config
32+
if args.encryption:
33+
job_name = "KM_HE"
34+
train_script = "src/kaplan_meier_train_he.py"
35+
script_args = f"--data_root {data_root} --he_context_path {he_context_path}"
36+
else:
37+
job_name = "KM"
38+
train_script = "src/kaplan_meier_train.py"
39+
script_args = f"--data_root {data_root}"
40+
41+
# Set the number of clients and threads
42+
num_clients = args.num_clients
43+
if args.num_threads:
44+
num_threads = args.num_threads
45+
else:
46+
num_threads = num_clients
47+
48+
# Set the output workspace and job directories
49+
workspace_dir = os.path.join(args.workspace_dir, job_name)
50+
job_dir = args.job_dir
51+
52+
# Create the FedJob
53+
job = FedJob(name=job_name, min_clients=num_clients)
54+
55+
# Define the KM controller workflow and send to server
56+
if args.encryption:
57+
controller = KM_HE(min_clients=num_clients, he_context_path=he_context_path)
58+
else:
59+
controller = KM(min_clients=num_clients)
60+
job.to_server(controller)
61+
62+
# Define the ScriptRunner and send to all clients
63+
runner = ScriptRunner(
64+
script=train_script,
65+
script_args=script_args,
66+
params_exchange_format="raw",
67+
launch_external_process=False,
68+
)
69+
job.to_clients(runner, tasks=["train"])
70+
71+
# Export the job
72+
print("job_dir=", job_dir)
73+
job.export_job(job_dir)
74+
75+
# Run the job
76+
print("workspace_dir=", workspace_dir)
77+
print("num_threads=", num_threads)
78+
job.simulator_run(workspace_dir, n_clients=num_clients, threads=num_threads)
79+
80+
81+
def define_parser():
82+
parser = argparse.ArgumentParser()
83+
parser.add_argument(
84+
"--workspace_dir",
85+
type=str,
86+
default="/tmp/nvflare/jobs/km/workdir",
87+
help="work directory, default to '/tmp/nvflare/jobs/km/workdir'",
88+
)
89+
parser.add_argument(
90+
"--job_dir",
91+
type=str,
92+
default="/tmp/nvflare/jobs/km/jobdir",
93+
help="directory for job export, default to '/tmp/nvflare/jobs/km/jobdir'",
94+
)
95+
parser.add_argument(
96+
"--encryption",
97+
action=argparse.BooleanOptionalAction,
98+
help="whether to enable encryption, default to False",
99+
)
100+
parser.add_argument(
101+
"--num_clients",
102+
type=int,
103+
default=5,
104+
help="number of clients to simulate, default to 5",
105+
)
106+
parser.add_argument(
107+
"--num_threads",
108+
type=int,
109+
help="number of threads to use for FL simulation, default to the number of clients if not specified",
110+
)
111+
return parser.parse_args()
112+
113+
114+
if __name__ == "__main__":
115+
main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
lifelines
2+
tenseal
3+
scikit-survival
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import json
17+
import os
18+
19+
import matplotlib.pyplot as plt
20+
import numpy as np
21+
import pandas as pd
22+
from lifelines import KaplanMeierFitter
23+
from lifelines.utils import survival_table_from_events
24+
25+
# (1) import nvflare client API
26+
import nvflare.client as flare
27+
from nvflare.app_common.abstract.fl_model import FLModel, ParamsType
28+
29+
30+
# Client code
31+
def details_save(kmf):
32+
# Get the survival function at all observed time points
33+
survival_function_at_all_times = kmf.survival_function_
34+
# Get the timeline (time points)
35+
timeline = survival_function_at_all_times.index.values
36+
# Get the KM estimate
37+
km_estimate = survival_function_at_all_times["KM_estimate"].values
38+
# Get the event count at each time point
39+
event_count = kmf.event_table.iloc[:, 0].values # Assuming the first column is the observed events
40+
# Get the survival rate at each time point (using the 1st column of the survival function)
41+
survival_rate = 1 - survival_function_at_all_times.iloc[:, 0].values
42+
# Return the results
43+
results = {
44+
"timeline": timeline.tolist(),
45+
"km_estimate": km_estimate.tolist(),
46+
"event_count": event_count.tolist(),
47+
"survival_rate": survival_rate.tolist(),
48+
}
49+
file_path = os.path.join(os.getcwd(), "km_global.json")
50+
print(f"save the details of KM analysis result to {file_path} \n")
51+
with open(file_path, "w") as json_file:
52+
json.dump(results, json_file, indent=4)
53+
54+
55+
def plot_and_save(kmf):
56+
# Plot and save the Kaplan-Meier survival curve
57+
plt.figure()
58+
plt.title("Federated")
59+
kmf.plot_survival_function()
60+
plt.ylim(0, 1)
61+
plt.ylabel("prob")
62+
plt.xlabel("time")
63+
plt.legend("", frameon=False)
64+
plt.tight_layout()
65+
file_path = os.path.join(os.getcwd(), "km_curve_fl.png")
66+
print(f"save the curve plot to {file_path} \n")
67+
plt.savefig(file_path)
68+
69+
70+
def main():
71+
parser = argparse.ArgumentParser(description="KM analysis")
72+
parser.add_argument("--data_root", type=str, help="Root path for data files")
73+
args = parser.parse_args()
74+
75+
flare.init()
76+
77+
site_name = flare.get_site_name()
78+
print(f"Kaplan-meier analysis for {site_name}")
79+
80+
# get local data
81+
data_path = os.path.join(args.data_root, site_name + ".csv")
82+
data = pd.read_csv(data_path)
83+
event_local = data["event"]
84+
time_local = data["time"]
85+
86+
while flare.is_running():
87+
# receives global message from NVFlare
88+
global_msg = flare.receive()
89+
curr_round = global_msg.current_round
90+
print(f"current_round={curr_round}")
91+
92+
if curr_round == 1:
93+
# First round:
94+
# Empty payload from server, send local histogram
95+
# Convert local data to histogram
96+
event_table = survival_table_from_events(time_local, event_local)
97+
hist_idx = event_table.index.values.astype(int)
98+
hist_obs = {}
99+
hist_cen = {}
100+
for idx in range(max(hist_idx)):
101+
hist_obs[idx] = 0
102+
hist_cen[idx] = 0
103+
# Assign values
104+
idx = event_table.index.values.astype(int)
105+
observed = event_table["observed"].to_numpy()
106+
censored = event_table["censored"].to_numpy()
107+
for i in range(len(idx)):
108+
hist_obs[idx[i]] = observed[i]
109+
hist_cen[idx[i]] = censored[i]
110+
# Send histograms to server
111+
response = FLModel(params={"hist_obs": hist_obs, "hist_cen": hist_cen}, params_type=ParamsType.FULL)
112+
flare.send(response)
113+
114+
elif curr_round == 2:
115+
# Get global histograms
116+
hist_obs_global = global_msg.params["hist_obs_global"]
117+
hist_cen_global = global_msg.params["hist_cen_global"]
118+
# Unfold histogram to event list
119+
time_unfold = []
120+
event_unfold = []
121+
for i in hist_obs_global.keys():
122+
for j in range(hist_obs_global[i]):
123+
time_unfold.append(i)
124+
event_unfold.append(True)
125+
for k in range(hist_cen_global[i]):
126+
time_unfold.append(i)
127+
event_unfold.append(False)
128+
time_unfold = np.array(time_unfold)
129+
event_unfold = np.array(event_unfold)
130+
131+
# Perform Kaplan-Meier analysis on global aggregated information
132+
# Create a Kaplan-Meier estimator
133+
kmf = KaplanMeierFitter()
134+
135+
# Fit the model
136+
kmf.fit(durations=time_unfold, event_observed=event_unfold)
137+
138+
# Plot and save the KM curve
139+
plot_and_save(kmf)
140+
141+
# Save details of the KM result to a json file
142+
details_save(kmf)
143+
144+
# Send a simple response to server
145+
response = FLModel(params={}, params_type=ParamsType.FULL)
146+
flare.send(response)
147+
148+
print(f"finish send for {site_name}, complete")
149+
150+
151+
if __name__ == "__main__":
152+
main()

0 commit comments

Comments
 (0)