diff --git a/nvflare/client/api.py b/nvflare/client/api.py index 9b7c2a9028..f6322cdae5 100644 --- a/nvflare/client/api.py +++ b/nvflare/client/api.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from threading import Lock from typing import Any, Dict, Optional from nvflare.apis.analytix import AnalyticsDataType @@ -20,10 +21,32 @@ from .api_context import APIContext +global_context_lock = Lock() context_dict = {} default_context = None +def get_context(ctx: Optional[APIContext] = None) -> APIContext: + """Gets an APIContext. + + Args: + ctx (Optional[APIContext]): The context to use, + if None means use default context. Defaults to None. + + Raises: + RuntimeError: if can't get a valid APIContext. + + Returns: + An APIContext. + """ + if ctx: + return ctx + elif default_context: + return default_context + else: + raise RuntimeError("APIContext is None") + + def init(rank: Optional[str] = None, config_file: Optional[str] = None) -> APIContext: """Initializes NVFlare Client API environment. @@ -35,19 +58,20 @@ def init(rank: Optional[str] = None, config_file: Optional[str] = None) -> APICo Returns: APIContext """ - global context_dict - global default_context - local_ctx = context_dict.get((rank, config_file)) - - if local_ctx is None: - local_ctx = APIContext(rank=rank, config_file=config_file) - context_dict[(rank, config_file)] = local_ctx - default_context = local_ctx - else: - logging.warning( - "Warning: called init() more than once with same parameters." "The subsequence calls are ignored" - ) - return local_ctx + with global_context_lock: + global context_dict + global default_context + local_ctx = context_dict.get((rank, config_file)) + + if local_ctx is None: + local_ctx = APIContext(rank=rank, config_file=config_file) + context_dict[(rank, config_file)] = local_ctx + default_context = local_ctx + else: + logging.warning( + "Warning: called init() more than once with same parameters." "The subsequence calls are ignored" + ) + return local_ctx def receive(timeout: Optional[float] = None, ctx: Optional[APIContext] = None) -> Optional[FLModel]: @@ -56,8 +80,7 @@ def receive(timeout: Optional[float] = None, ctx: Optional[APIContext] = None) - Returns: An FLModel received. """ - global default_context - local_ctx = ctx if ctx else default_context + local_ctx = get_context(ctx) return local_ctx.api.receive(timeout) @@ -70,8 +93,7 @@ def send(model: FLModel, clear_cache: bool = True, ctx: Optional[APIContext] = N """ if not isinstance(model, FLModel): raise TypeError("model needs to be an instance of FLModel") - global default_context - local_ctx = ctx if ctx else default_context + local_ctx = get_context(ctx) return local_ctx.api.send(model, clear_cache) @@ -88,8 +110,7 @@ def system_info(ctx: Optional[APIContext] = None) -> Dict: A dict of system information. """ - global default_context - local_ctx = ctx if ctx else default_context + local_ctx = get_context(ctx) return local_ctx.api.system_info() @@ -99,8 +120,7 @@ def get_config(ctx: Optional[APIContext] = None) -> Dict: Returns: A dict of the configuration used in Client API. """ - global default_context - local_ctx = ctx if ctx else default_context + local_ctx = get_context(ctx) return local_ctx.api.get_config() @@ -110,8 +130,7 @@ def get_job_id(ctx: Optional[APIContext] = None) -> str: Returns: The current job id. """ - global default_context - local_ctx = ctx if ctx else default_context + local_ctx = get_context(ctx) return local_ctx.api.get_job_id() @@ -121,8 +140,7 @@ def get_site_name(ctx: Optional[APIContext] = None) -> str: Returns: The site name of this client. """ - global default_context - local_ctx = ctx if ctx else default_context + local_ctx = get_context(ctx) return local_ctx.api.get_site_name() @@ -132,8 +150,7 @@ def get_task_name(ctx: Optional[APIContext] = None) -> str: Returns: The task name. """ - global default_context - local_ctx = ctx if ctx else default_context + local_ctx = get_context(ctx) return local_ctx.api.get_task_name() @@ -143,8 +160,7 @@ def is_running(ctx: Optional[APIContext] = None) -> bool: Returns: True, if the system is up and running. False, otherwise. """ - global default_context - local_ctx = ctx if ctx else default_context + local_ctx = get_context(ctx) return local_ctx.api.is_running() @@ -154,8 +170,7 @@ def is_train(ctx: Optional[APIContext] = None) -> bool: Returns: True, if the current task is a training task. False, otherwise. """ - global default_context - local_ctx = ctx if ctx else default_context + local_ctx = get_context(ctx) return local_ctx.api.is_train() @@ -165,8 +180,7 @@ def is_evaluate(ctx: Optional[APIContext] = None) -> bool: Returns: True, if the current task is an evaluate task. False, otherwise. """ - global default_context - local_ctx = ctx if ctx else default_context + local_ctx = get_context(ctx) return local_ctx.api.is_evaluate() @@ -176,8 +190,7 @@ def is_submit_model(ctx: Optional[APIContext] = None) -> bool: Returns: True, if the current task is a submit_model. False, otherwise. """ - global default_context - local_ctx = ctx if ctx else default_context + local_ctx = get_context(ctx) return local_ctx.api.is_submit_model() @@ -195,20 +208,17 @@ def log(key: str, value: Any, data_type: AnalyticsDataType, ctx: Optional[APICon Returns: whether the key value pair is logged successfully """ - global default_context - local_ctx = ctx if ctx else default_context + local_ctx = get_context(ctx) return local_ctx.api.log(key, value, data_type, **kwargs) def clear(ctx: Optional[APIContext] = None): """Clears the cache.""" - global default_context - local_ctx = ctx if ctx else default_context + local_ctx = get_context(ctx) return local_ctx.api.clear() def shutdown(ctx: Optional[APIContext] = None): """Releases all threads and resources used by the API and stops operation.""" - global default_context - local_ctx = ctx if ctx else default_context + local_ctx = get_context(ctx) return local_ctx.api.shutdown() diff --git a/nvflare/client/api_context.py b/nvflare/client/api_context.py index d576b322d7..f16cd31cfd 100644 --- a/nvflare/client/api_context.py +++ b/nvflare/client/api_context.py @@ -22,6 +22,7 @@ from .api_spec import CLIENT_API_KEY, CLIENT_API_TYPE_KEY, APISpec from .ex_process.api import ExProcessClientAPI +from .in_process.api import InProcessClientAPI DEFAULT_CONFIG = f"config/{CLIENT_API_CONFIG}" data_bus = DataBus() @@ -54,6 +55,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): def _create_client_api(self, api_type: ClientAPIType) -> APISpec: """Creates a new client_api based on the provided API type.""" if api_type == ClientAPIType.IN_PROCESS_API: - return data_bus.get_data(CLIENT_API_KEY) + api = data_bus.get_data(CLIENT_API_KEY) + if not isinstance(api, InProcessClientAPI): + raise RuntimeError(f"api {api} is not a valid InProcessClientAPI") + return api else: return ExProcessClientAPI(config_file=self.config_file) diff --git a/nvflare/client/ex_process/api.py b/nvflare/client/ex_process/api.py index 7f83a44abd..d5d6eceb05 100644 --- a/nvflare/client/ex_process/api.py +++ b/nvflare/client/ex_process/api.py @@ -78,6 +78,7 @@ def __init__(self, config_file: str): self.logger = get_obj_logger(self) self.receive_called = False self.config_file = config_file + self.flare_agent = None def get_model_registry(self) -> ModelRegistry: """Gets the ModelRegistry.""" @@ -133,6 +134,7 @@ def init(self, rank: Optional[str] = None): flare_agent.start() self.model_registry = ModelRegistry(client_config, rank, flare_agent) + self.flare_agent = flare_agent except Exception as e: self.logger.error(f"flare.init failed: {e}") raise e @@ -216,5 +218,5 @@ def clear(self): self.receive_called = False def shutdown(self): - model_registry = self.get_model_registry() - model_registry.shutdown() + if self.flare_agent: + self.flare_agent.stop() diff --git a/nvflare/client/in_process/api.py b/nvflare/client/in_process/api.py index 0014d2a662..5993b61af9 100644 --- a/nvflare/client/in_process/api.py +++ b/nvflare/client/in_process/api.py @@ -237,4 +237,5 @@ def __continue_job(self) -> bool: def shutdown(self): self.stop = True + self.event_manager.fire_event(TOPIC_STOP) self.stop_reason = "API shutdown called." diff --git a/nvflare/client/task_registry.py b/nvflare/client/task_registry.py index bcfeefdaca..743fb53ae5 100644 --- a/nvflare/client/task_registry.py +++ b/nvflare/client/task_registry.py @@ -113,7 +113,3 @@ def clear(self) -> None: def __str__(self): return f"{self.__class__.__name__}(config: {self.config.get_config()})" - - def shutdown(self): - if self.flare_agent: - self.flare_agent.stop() diff --git a/nvflare/client/tracking.py b/nvflare/client/tracking.py index 51e506d0d1..5b91565739 100644 --- a/nvflare/client/tracking.py +++ b/nvflare/client/tracking.py @@ -19,14 +19,13 @@ # flake8: noqa from .api import default_context as default_context -from .api import log +from .api import log, get_context from .api_context import APIContext class _BaseWriter: def __init__(self, ctx: Optional[APIContext] = None): - global default_context - self.ctx = ctx if ctx else default_context + self.ctx = get_context(ctx) class SummaryWriter(_BaseWriter):