-
Notifications
You must be signed in to change notification settings - Fork 183
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into add_lightning_DLI
- Loading branch information
Showing
15 changed files
with
1,203 additions
and
5 deletions.
There are no files selected for viewing
Binary file added
BIN
+28 KB
...convert_survival_analysis_to_federated_learning/code/figs/km_curve_baseline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+16.4 KB
...2.3.3_convert_survival_analysis_to_federated_learning/code/figs/km_curve_fl.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+16.6 KB
....3_convert_survival_analysis_to_federated_learning/code/figs/km_curve_fl_he.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
115 changes: 115 additions & 0 deletions
115
..._federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/km_job.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import os | ||
|
||
from src.kaplan_meier_wf import KM | ||
from src.kaplan_meier_wf_he import KM_HE | ||
|
||
from nvflare import FedJob | ||
from nvflare.job_config.script_runner import ScriptRunner | ||
|
||
|
||
def main(): | ||
args = define_parser() | ||
# Default paths | ||
data_root = "/tmp/nvflare/dataset/km_data" | ||
he_context_path = "/tmp/nvflare/he_context/he_context_client.txt" | ||
|
||
# Set the script and config | ||
if args.encryption: | ||
job_name = "KM_HE" | ||
train_script = "src/kaplan_meier_train_he.py" | ||
script_args = f"--data_root {data_root} --he_context_path {he_context_path}" | ||
else: | ||
job_name = "KM" | ||
train_script = "src/kaplan_meier_train.py" | ||
script_args = f"--data_root {data_root}" | ||
|
||
# Set the number of clients and threads | ||
num_clients = args.num_clients | ||
if args.num_threads: | ||
num_threads = args.num_threads | ||
else: | ||
num_threads = num_clients | ||
|
||
# Set the output workspace and job directories | ||
workspace_dir = os.path.join(args.workspace_dir, job_name) | ||
job_dir = args.job_dir | ||
|
||
# Create the FedJob | ||
job = FedJob(name=job_name, min_clients=num_clients) | ||
|
||
# Define the KM controller workflow and send to server | ||
if args.encryption: | ||
controller = KM_HE(min_clients=num_clients, he_context_path=he_context_path) | ||
else: | ||
controller = KM(min_clients=num_clients) | ||
job.to_server(controller) | ||
|
||
# Define the ScriptRunner and send to all clients | ||
runner = ScriptRunner( | ||
script=train_script, | ||
script_args=script_args, | ||
params_exchange_format="raw", | ||
launch_external_process=False, | ||
) | ||
job.to_clients(runner, tasks=["train"]) | ||
|
||
# Export the job | ||
print("job_dir=", job_dir) | ||
job.export_job(job_dir) | ||
|
||
# Run the job | ||
print("workspace_dir=", workspace_dir) | ||
print("num_threads=", num_threads) | ||
job.simulator_run(workspace_dir, n_clients=num_clients, threads=num_threads) | ||
|
||
|
||
def define_parser(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--workspace_dir", | ||
type=str, | ||
default="/tmp/nvflare/jobs/km/workdir", | ||
help="work directory, default to '/tmp/nvflare/jobs/km/workdir'", | ||
) | ||
parser.add_argument( | ||
"--job_dir", | ||
type=str, | ||
default="/tmp/nvflare/jobs/km/jobdir", | ||
help="directory for job export, default to '/tmp/nvflare/jobs/km/jobdir'", | ||
) | ||
parser.add_argument( | ||
"--encryption", | ||
action=argparse.BooleanOptionalAction, | ||
help="whether to enable encryption, default to False", | ||
) | ||
parser.add_argument( | ||
"--num_clients", | ||
type=int, | ||
default=5, | ||
help="number of clients to simulate, default to 5", | ||
) | ||
parser.add_argument( | ||
"--num_threads", | ||
type=int, | ||
help="number of threads to use for FL simulation, default to the number of clients if not specified", | ||
) | ||
return parser.parse_args() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
3 changes: 3 additions & 0 deletions
3
...ted_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
lifelines | ||
tenseal | ||
scikit-survival |
152 changes: 152 additions & 0 deletions
152
...ing/02.3.3_convert_survival_analysis_to_federated_learning/code/src/kaplan_meier_train.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import json | ||
import os | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import pandas as pd | ||
from lifelines import KaplanMeierFitter | ||
from lifelines.utils import survival_table_from_events | ||
|
||
# (1) import nvflare client API | ||
import nvflare.client as flare | ||
from nvflare.app_common.abstract.fl_model import FLModel, ParamsType | ||
|
||
|
||
# Client code | ||
def details_save(kmf): | ||
# Get the survival function at all observed time points | ||
survival_function_at_all_times = kmf.survival_function_ | ||
# Get the timeline (time points) | ||
timeline = survival_function_at_all_times.index.values | ||
# Get the KM estimate | ||
km_estimate = survival_function_at_all_times["KM_estimate"].values | ||
# Get the event count at each time point | ||
event_count = kmf.event_table.iloc[:, 0].values # Assuming the first column is the observed events | ||
# Get the survival rate at each time point (using the 1st column of the survival function) | ||
survival_rate = 1 - survival_function_at_all_times.iloc[:, 0].values | ||
# Return the results | ||
results = { | ||
"timeline": timeline.tolist(), | ||
"km_estimate": km_estimate.tolist(), | ||
"event_count": event_count.tolist(), | ||
"survival_rate": survival_rate.tolist(), | ||
} | ||
file_path = os.path.join(os.getcwd(), "km_global.json") | ||
print(f"save the details of KM analysis result to {file_path} \n") | ||
with open(file_path, "w") as json_file: | ||
json.dump(results, json_file, indent=4) | ||
|
||
|
||
def plot_and_save(kmf): | ||
# Plot and save the Kaplan-Meier survival curve | ||
plt.figure() | ||
plt.title("Federated") | ||
kmf.plot_survival_function() | ||
plt.ylim(0, 1) | ||
plt.ylabel("prob") | ||
plt.xlabel("time") | ||
plt.legend("", frameon=False) | ||
plt.tight_layout() | ||
file_path = os.path.join(os.getcwd(), "km_curve_fl.png") | ||
print(f"save the curve plot to {file_path} \n") | ||
plt.savefig(file_path) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="KM analysis") | ||
parser.add_argument("--data_root", type=str, help="Root path for data files") | ||
args = parser.parse_args() | ||
|
||
flare.init() | ||
|
||
site_name = flare.get_site_name() | ||
print(f"Kaplan-meier analysis for {site_name}") | ||
|
||
# get local data | ||
data_path = os.path.join(args.data_root, site_name + ".csv") | ||
data = pd.read_csv(data_path) | ||
event_local = data["event"] | ||
time_local = data["time"] | ||
|
||
while flare.is_running(): | ||
# receives global message from NVFlare | ||
global_msg = flare.receive() | ||
curr_round = global_msg.current_round | ||
print(f"current_round={curr_round}") | ||
|
||
if curr_round == 1: | ||
# First round: | ||
# Empty payload from server, send local histogram | ||
# Convert local data to histogram | ||
event_table = survival_table_from_events(time_local, event_local) | ||
hist_idx = event_table.index.values.astype(int) | ||
hist_obs = {} | ||
hist_cen = {} | ||
for idx in range(max(hist_idx)): | ||
hist_obs[idx] = 0 | ||
hist_cen[idx] = 0 | ||
# Assign values | ||
idx = event_table.index.values.astype(int) | ||
observed = event_table["observed"].to_numpy() | ||
censored = event_table["censored"].to_numpy() | ||
for i in range(len(idx)): | ||
hist_obs[idx[i]] = observed[i] | ||
hist_cen[idx[i]] = censored[i] | ||
# Send histograms to server | ||
response = FLModel(params={"hist_obs": hist_obs, "hist_cen": hist_cen}, params_type=ParamsType.FULL) | ||
flare.send(response) | ||
|
||
elif curr_round == 2: | ||
# Get global histograms | ||
hist_obs_global = global_msg.params["hist_obs_global"] | ||
hist_cen_global = global_msg.params["hist_cen_global"] | ||
# Unfold histogram to event list | ||
time_unfold = [] | ||
event_unfold = [] | ||
for i in hist_obs_global.keys(): | ||
for j in range(hist_obs_global[i]): | ||
time_unfold.append(i) | ||
event_unfold.append(True) | ||
for k in range(hist_cen_global[i]): | ||
time_unfold.append(i) | ||
event_unfold.append(False) | ||
time_unfold = np.array(time_unfold) | ||
event_unfold = np.array(event_unfold) | ||
|
||
# Perform Kaplan-Meier analysis on global aggregated information | ||
# Create a Kaplan-Meier estimator | ||
kmf = KaplanMeierFitter() | ||
|
||
# Fit the model | ||
kmf.fit(durations=time_unfold, event_observed=event_unfold) | ||
|
||
# Plot and save the KM curve | ||
plot_and_save(kmf) | ||
|
||
# Save details of the KM result to a json file | ||
details_save(kmf) | ||
|
||
# Send a simple response to server | ||
response = FLModel(params={}, params_type=ParamsType.FULL) | ||
flare.send(response) | ||
|
||
print(f"finish send for {site_name}, complete") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.