Skip to content

Commit df38c1c

Browse files
committed
Address review comments
1 parent 2ca2ace commit df38c1c

File tree

5 files changed

+48
-48
lines changed

5 files changed

+48
-48
lines changed

nvflare/client/api.py

Lines changed: 39 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,28 @@
1313
# limitations under the License.
1414

1515
import logging
16+
from threading import Lock
1617
from typing import Any, Dict, Optional
1718

1819
from nvflare.apis.analytix import AnalyticsDataType
1920
from nvflare.app_common.abstract.fl_model import FLModel
2021

2122
from .api_context import APIContext
2223

24+
global_context_lock = Lock()
2325
context_dict = {}
2426
default_context = None
2527

2628

29+
def _get_context(ctx: Optional[APIContext] = None):
30+
if ctx:
31+
return ctx
32+
elif default_context:
33+
return default_context
34+
else:
35+
raise RuntimeError("APIContext is None")
36+
37+
2738
def init(rank: Optional[str] = None, config_file: Optional[str] = None) -> APIContext:
2839
"""Initializes NVFlare Client API environment.
2940
@@ -35,19 +46,20 @@ def init(rank: Optional[str] = None, config_file: Optional[str] = None) -> APICo
3546
Returns:
3647
APIContext
3748
"""
38-
global context_dict
39-
global default_context
40-
local_ctx = context_dict.get((rank, config_file))
41-
42-
if local_ctx is None:
43-
local_ctx = APIContext(rank=rank, config_file=config_file)
44-
context_dict[(rank, config_file)] = local_ctx
45-
default_context = local_ctx
46-
else:
47-
logging.warning(
48-
"Warning: called init() more than once with same parameters." "The subsequence calls are ignored"
49-
)
50-
return local_ctx
49+
with global_context_lock:
50+
global context_dict
51+
global default_context
52+
local_ctx = context_dict.get((rank, config_file))
53+
54+
if local_ctx is None:
55+
local_ctx = APIContext(rank=rank, config_file=config_file)
56+
context_dict[(rank, config_file)] = local_ctx
57+
default_context = local_ctx
58+
else:
59+
logging.warning(
60+
"Warning: called init() more than once with same parameters." "The subsequence calls are ignored"
61+
)
62+
return local_ctx
5163

5264

5365
def receive(timeout: Optional[float] = None, ctx: Optional[APIContext] = None) -> Optional[FLModel]:
@@ -56,8 +68,7 @@ def receive(timeout: Optional[float] = None, ctx: Optional[APIContext] = None) -
5668
Returns:
5769
An FLModel received.
5870
"""
59-
global default_context
60-
local_ctx = ctx if ctx else default_context
71+
local_ctx = _get_context(ctx)
6172
return local_ctx.api.receive(timeout)
6273

6374

@@ -70,8 +81,7 @@ def send(model: FLModel, clear_cache: bool = True, ctx: Optional[APIContext] = N
7081
"""
7182
if not isinstance(model, FLModel):
7283
raise TypeError("model needs to be an instance of FLModel")
73-
global default_context
74-
local_ctx = ctx if ctx else default_context
84+
local_ctx = _get_context(ctx)
7585
return local_ctx.api.send(model, clear_cache)
7686

7787

@@ -88,8 +98,7 @@ def system_info(ctx: Optional[APIContext] = None) -> Dict:
8898
A dict of system information.
8999
90100
"""
91-
global default_context
92-
local_ctx = ctx if ctx else default_context
101+
local_ctx = _get_context(ctx)
93102
return local_ctx.api.system_info()
94103

95104

@@ -99,8 +108,7 @@ def get_config(ctx: Optional[APIContext] = None) -> Dict:
99108
Returns:
100109
A dict of the configuration used in Client API.
101110
"""
102-
global default_context
103-
local_ctx = ctx if ctx else default_context
111+
local_ctx = _get_context(ctx)
104112
return local_ctx.api.get_config()
105113

106114

@@ -110,8 +118,7 @@ def get_job_id(ctx: Optional[APIContext] = None) -> str:
110118
Returns:
111119
The current job id.
112120
"""
113-
global default_context
114-
local_ctx = ctx if ctx else default_context
121+
local_ctx = _get_context(ctx)
115122
return local_ctx.api.get_job_id()
116123

117124

@@ -121,8 +128,7 @@ def get_site_name(ctx: Optional[APIContext] = None) -> str:
121128
Returns:
122129
The site name of this client.
123130
"""
124-
global default_context
125-
local_ctx = ctx if ctx else default_context
131+
local_ctx = _get_context(ctx)
126132
return local_ctx.api.get_site_name()
127133

128134

@@ -132,8 +138,7 @@ def get_task_name(ctx: Optional[APIContext] = None) -> str:
132138
Returns:
133139
The task name.
134140
"""
135-
global default_context
136-
local_ctx = ctx if ctx else default_context
141+
local_ctx = _get_context(ctx)
137142
return local_ctx.api.get_task_name()
138143

139144

@@ -143,8 +148,7 @@ def is_running(ctx: Optional[APIContext] = None) -> bool:
143148
Returns:
144149
True, if the system is up and running. False, otherwise.
145150
"""
146-
global default_context
147-
local_ctx = ctx if ctx else default_context
151+
local_ctx = _get_context(ctx)
148152
return local_ctx.api.is_running()
149153

150154

@@ -154,8 +158,7 @@ def is_train(ctx: Optional[APIContext] = None) -> bool:
154158
Returns:
155159
True, if the current task is a training task. False, otherwise.
156160
"""
157-
global default_context
158-
local_ctx = ctx if ctx else default_context
161+
local_ctx = _get_context(ctx)
159162
return local_ctx.api.is_train()
160163

161164

@@ -165,8 +168,7 @@ def is_evaluate(ctx: Optional[APIContext] = None) -> bool:
165168
Returns:
166169
True, if the current task is an evaluate task. False, otherwise.
167170
"""
168-
global default_context
169-
local_ctx = ctx if ctx else default_context
171+
local_ctx = _get_context(ctx)
170172
return local_ctx.api.is_evaluate()
171173

172174

@@ -176,8 +178,7 @@ def is_submit_model(ctx: Optional[APIContext] = None) -> bool:
176178
Returns:
177179
True, if the current task is a submit_model. False, otherwise.
178180
"""
179-
global default_context
180-
local_ctx = ctx if ctx else default_context
181+
local_ctx = _get_context(ctx)
181182
return local_ctx.api.is_submit_model()
182183

183184

@@ -195,20 +196,17 @@ def log(key: str, value: Any, data_type: AnalyticsDataType, ctx: Optional[APICon
195196
Returns:
196197
whether the key value pair is logged successfully
197198
"""
198-
global default_context
199-
local_ctx = ctx if ctx else default_context
199+
local_ctx = _get_context(ctx)
200200
return local_ctx.api.log(key, value, data_type, **kwargs)
201201

202202

203203
def clear(ctx: Optional[APIContext] = None):
204204
"""Clears the cache."""
205-
global default_context
206-
local_ctx = ctx if ctx else default_context
205+
local_ctx = _get_context(ctx)
207206
return local_ctx.api.clear()
208207

209208

210209
def shutdown(ctx: Optional[APIContext] = None):
211210
"""Releases all threads and resources used by the API and stops operation."""
212-
global default_context
213-
local_ctx = ctx if ctx else default_context
211+
local_ctx = _get_context(ctx)
214212
return local_ctx.api.shutdown()

nvflare/client/api_context.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from .api_spec import CLIENT_API_KEY, CLIENT_API_TYPE_KEY, APISpec
2424
from .ex_process.api import ExProcessClientAPI
25+
from .in_process.api import InProcessClientAPI
2526

2627
DEFAULT_CONFIG = f"config/{CLIENT_API_CONFIG}"
2728
data_bus = DataBus()
@@ -54,6 +55,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
5455
def _create_client_api(self, api_type: ClientAPIType) -> APISpec:
5556
"""Creates a new client_api based on the provided API type."""
5657
if api_type == ClientAPIType.IN_PROCESS_API:
57-
return data_bus.get_data(CLIENT_API_KEY)
58+
api = data_bus.get_data(CLIENT_API_KEY)
59+
if not isinstance(api, InProcessClientAPI):
60+
raise RuntimeError(f"api {api} is not a valid InProcessClientAPI")
5861
else:
5962
return ExProcessClientAPI(config_file=self.config_file)

nvflare/client/ex_process/api.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(self, config_file: str):
7878
self.logger = get_obj_logger(self)
7979
self.receive_called = False
8080
self.config_file = config_file
81+
self.flare_agent = None
8182

8283
def get_model_registry(self) -> ModelRegistry:
8384
"""Gets the ModelRegistry."""
@@ -133,6 +134,7 @@ def init(self, rank: Optional[str] = None):
133134
flare_agent.start()
134135

135136
self.model_registry = ModelRegistry(client_config, rank, flare_agent)
137+
self.flare_agent = flare_agent
136138
except Exception as e:
137139
self.logger.error(f"flare.init failed: {e}")
138140
raise e
@@ -216,5 +218,5 @@ def clear(self):
216218
self.receive_called = False
217219

218220
def shutdown(self):
219-
model_registry = self.get_model_registry()
220-
model_registry.shutdown()
221+
if self.flare_agent:
222+
self.flare_agent.stop()

nvflare/client/in_process/api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,4 +237,5 @@ def __continue_job(self) -> bool:
237237

238238
def shutdown(self):
239239
self.stop = True
240+
self.event_manager.fire_event(TOPIC_STOP)
240241
self.stop_reason = "API shutdown called."

nvflare/client/task_registry.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,3 @@ def clear(self) -> None:
113113

114114
def __str__(self):
115115
return f"{self.__class__.__name__}(config: {self.config.get_config()})"
116-
117-
def shutdown(self):
118-
if self.flare_agent:
119-
self.flare_agent.stop()

0 commit comments

Comments
 (0)