13
13
# limitations under the License.
14
14
15
15
import logging
16
+ from threading import Lock
16
17
from typing import Any , Dict , Optional
17
18
18
19
from nvflare .apis .analytix import AnalyticsDataType
19
20
from nvflare .app_common .abstract .fl_model import FLModel
20
21
21
22
from .api_context import APIContext
22
23
24
+ global_context_lock = Lock ()
23
25
context_dict = {}
24
26
default_context = None
25
27
26
28
29
+ def _get_context (ctx : Optional [APIContext ] = None ):
30
+ if ctx :
31
+ return ctx
32
+ elif default_context :
33
+ return default_context
34
+ else :
35
+ raise RuntimeError ("APIContext is None" )
36
+
37
+
27
38
def init (rank : Optional [str ] = None , config_file : Optional [str ] = None ) -> APIContext :
28
39
"""Initializes NVFlare Client API environment.
29
40
@@ -35,19 +46,20 @@ def init(rank: Optional[str] = None, config_file: Optional[str] = None) -> APICo
35
46
Returns:
36
47
APIContext
37
48
"""
38
- global context_dict
39
- global default_context
40
- local_ctx = context_dict .get ((rank , config_file ))
41
-
42
- if local_ctx is None :
43
- local_ctx = APIContext (rank = rank , config_file = config_file )
44
- context_dict [(rank , config_file )] = local_ctx
45
- default_context = local_ctx
46
- else :
47
- logging .warning (
48
- "Warning: called init() more than once with same parameters." "The subsequence calls are ignored"
49
- )
50
- return local_ctx
49
+ with global_context_lock :
50
+ global context_dict
51
+ global default_context
52
+ local_ctx = context_dict .get ((rank , config_file ))
53
+
54
+ if local_ctx is None :
55
+ local_ctx = APIContext (rank = rank , config_file = config_file )
56
+ context_dict [(rank , config_file )] = local_ctx
57
+ default_context = local_ctx
58
+ else :
59
+ logging .warning (
60
+ "Warning: called init() more than once with same parameters." "The subsequence calls are ignored"
61
+ )
62
+ return local_ctx
51
63
52
64
53
65
def receive (timeout : Optional [float ] = None , ctx : Optional [APIContext ] = None ) -> Optional [FLModel ]:
@@ -56,8 +68,7 @@ def receive(timeout: Optional[float] = None, ctx: Optional[APIContext] = None) -
56
68
Returns:
57
69
An FLModel received.
58
70
"""
59
- global default_context
60
- local_ctx = ctx if ctx else default_context
71
+ local_ctx = _get_context (ctx )
61
72
return local_ctx .api .receive (timeout )
62
73
63
74
@@ -70,8 +81,7 @@ def send(model: FLModel, clear_cache: bool = True, ctx: Optional[APIContext] = N
70
81
"""
71
82
if not isinstance (model , FLModel ):
72
83
raise TypeError ("model needs to be an instance of FLModel" )
73
- global default_context
74
- local_ctx = ctx if ctx else default_context
84
+ local_ctx = _get_context (ctx )
75
85
return local_ctx .api .send (model , clear_cache )
76
86
77
87
@@ -88,8 +98,7 @@ def system_info(ctx: Optional[APIContext] = None) -> Dict:
88
98
A dict of system information.
89
99
90
100
"""
91
- global default_context
92
- local_ctx = ctx if ctx else default_context
101
+ local_ctx = _get_context (ctx )
93
102
return local_ctx .api .system_info ()
94
103
95
104
@@ -99,8 +108,7 @@ def get_config(ctx: Optional[APIContext] = None) -> Dict:
99
108
Returns:
100
109
A dict of the configuration used in Client API.
101
110
"""
102
- global default_context
103
- local_ctx = ctx if ctx else default_context
111
+ local_ctx = _get_context (ctx )
104
112
return local_ctx .api .get_config ()
105
113
106
114
@@ -110,8 +118,7 @@ def get_job_id(ctx: Optional[APIContext] = None) -> str:
110
118
Returns:
111
119
The current job id.
112
120
"""
113
- global default_context
114
- local_ctx = ctx if ctx else default_context
121
+ local_ctx = _get_context (ctx )
115
122
return local_ctx .api .get_job_id ()
116
123
117
124
@@ -121,8 +128,7 @@ def get_site_name(ctx: Optional[APIContext] = None) -> str:
121
128
Returns:
122
129
The site name of this client.
123
130
"""
124
- global default_context
125
- local_ctx = ctx if ctx else default_context
131
+ local_ctx = _get_context (ctx )
126
132
return local_ctx .api .get_site_name ()
127
133
128
134
@@ -132,8 +138,7 @@ def get_task_name(ctx: Optional[APIContext] = None) -> str:
132
138
Returns:
133
139
The task name.
134
140
"""
135
- global default_context
136
- local_ctx = ctx if ctx else default_context
141
+ local_ctx = _get_context (ctx )
137
142
return local_ctx .api .get_task_name ()
138
143
139
144
@@ -143,8 +148,7 @@ def is_running(ctx: Optional[APIContext] = None) -> bool:
143
148
Returns:
144
149
True, if the system is up and running. False, otherwise.
145
150
"""
146
- global default_context
147
- local_ctx = ctx if ctx else default_context
151
+ local_ctx = _get_context (ctx )
148
152
return local_ctx .api .is_running ()
149
153
150
154
@@ -154,8 +158,7 @@ def is_train(ctx: Optional[APIContext] = None) -> bool:
154
158
Returns:
155
159
True, if the current task is a training task. False, otherwise.
156
160
"""
157
- global default_context
158
- local_ctx = ctx if ctx else default_context
161
+ local_ctx = _get_context (ctx )
159
162
return local_ctx .api .is_train ()
160
163
161
164
@@ -165,8 +168,7 @@ def is_evaluate(ctx: Optional[APIContext] = None) -> bool:
165
168
Returns:
166
169
True, if the current task is an evaluate task. False, otherwise.
167
170
"""
168
- global default_context
169
- local_ctx = ctx if ctx else default_context
171
+ local_ctx = _get_context (ctx )
170
172
return local_ctx .api .is_evaluate ()
171
173
172
174
@@ -176,8 +178,7 @@ def is_submit_model(ctx: Optional[APIContext] = None) -> bool:
176
178
Returns:
177
179
True, if the current task is a submit_model. False, otherwise.
178
180
"""
179
- global default_context
180
- local_ctx = ctx if ctx else default_context
181
+ local_ctx = _get_context (ctx )
181
182
return local_ctx .api .is_submit_model ()
182
183
183
184
@@ -195,20 +196,17 @@ def log(key: str, value: Any, data_type: AnalyticsDataType, ctx: Optional[APICon
195
196
Returns:
196
197
whether the key value pair is logged successfully
197
198
"""
198
- global default_context
199
- local_ctx = ctx if ctx else default_context
199
+ local_ctx = _get_context (ctx )
200
200
return local_ctx .api .log (key , value , data_type , ** kwargs )
201
201
202
202
203
203
def clear (ctx : Optional [APIContext ] = None ):
204
204
"""Clears the cache."""
205
- global default_context
206
- local_ctx = ctx if ctx else default_context
205
+ local_ctx = _get_context (ctx )
207
206
return local_ctx .api .clear ()
208
207
209
208
210
209
def shutdown (ctx : Optional [APIContext ] = None ):
211
210
"""Releases all threads and resources used by the API and stops operation."""
212
- global default_context
213
- local_ctx = ctx if ctx else default_context
211
+ local_ctx = _get_context (ctx )
214
212
return local_ctx .api .shutdown ()
0 commit comments