Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh committed Feb 7, 2025
1 parent 2ca2ace commit 675f658
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 51 deletions.
92 changes: 51 additions & 41 deletions nvflare/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,40 @@
# limitations under the License.

import logging
from threading import Lock
from typing import Any, Dict, Optional

from nvflare.apis.analytix import AnalyticsDataType
from nvflare.app_common.abstract.fl_model import FLModel

from .api_context import APIContext

global_context_lock = Lock()
context_dict = {}
default_context = None


def get_context(ctx: Optional[APIContext] = None) -> APIContext:
"""Gets an APIContext.
Args:
ctx (Optional[APIContext]): The context to use,
if None means use default context. Defaults to None.
Raises:
RuntimeError: if can't get a valid APIContext.
Returns:
An APIContext.
"""
if ctx:
return ctx
elif default_context:
return default_context
else:
raise RuntimeError("APIContext is None")


def init(rank: Optional[str] = None, config_file: Optional[str] = None) -> APIContext:
"""Initializes NVFlare Client API environment.
Expand All @@ -35,19 +58,20 @@ def init(rank: Optional[str] = None, config_file: Optional[str] = None) -> APICo
Returns:
APIContext
"""
global context_dict
global default_context
local_ctx = context_dict.get((rank, config_file))

if local_ctx is None:
local_ctx = APIContext(rank=rank, config_file=config_file)
context_dict[(rank, config_file)] = local_ctx
default_context = local_ctx
else:
logging.warning(
"Warning: called init() more than once with same parameters." "The subsequence calls are ignored"
)
return local_ctx
with global_context_lock:
global context_dict
global default_context
local_ctx = context_dict.get((rank, config_file))

if local_ctx is None:
local_ctx = APIContext(rank=rank, config_file=config_file)
context_dict[(rank, config_file)] = local_ctx
default_context = local_ctx
else:
logging.warning(
"Warning: called init() more than once with same parameters." "The subsequence calls are ignored"
)
return local_ctx


def receive(timeout: Optional[float] = None, ctx: Optional[APIContext] = None) -> Optional[FLModel]:
Expand All @@ -56,8 +80,7 @@ def receive(timeout: Optional[float] = None, ctx: Optional[APIContext] = None) -
Returns:
An FLModel received.
"""
global default_context
local_ctx = ctx if ctx else default_context
local_ctx = get_context(ctx)
return local_ctx.api.receive(timeout)


Expand All @@ -70,8 +93,7 @@ def send(model: FLModel, clear_cache: bool = True, ctx: Optional[APIContext] = N
"""
if not isinstance(model, FLModel):
raise TypeError("model needs to be an instance of FLModel")
global default_context
local_ctx = ctx if ctx else default_context
local_ctx = get_context(ctx)
return local_ctx.api.send(model, clear_cache)


Expand All @@ -88,8 +110,7 @@ def system_info(ctx: Optional[APIContext] = None) -> Dict:
A dict of system information.
"""
global default_context
local_ctx = ctx if ctx else default_context
local_ctx = get_context(ctx)
return local_ctx.api.system_info()


Expand All @@ -99,8 +120,7 @@ def get_config(ctx: Optional[APIContext] = None) -> Dict:
Returns:
A dict of the configuration used in Client API.
"""
global default_context
local_ctx = ctx if ctx else default_context
local_ctx = get_context(ctx)
return local_ctx.api.get_config()


Expand All @@ -110,8 +130,7 @@ def get_job_id(ctx: Optional[APIContext] = None) -> str:
Returns:
The current job id.
"""
global default_context
local_ctx = ctx if ctx else default_context
local_ctx = get_context(ctx)
return local_ctx.api.get_job_id()


Expand All @@ -121,8 +140,7 @@ def get_site_name(ctx: Optional[APIContext] = None) -> str:
Returns:
The site name of this client.
"""
global default_context
local_ctx = ctx if ctx else default_context
local_ctx = get_context(ctx)
return local_ctx.api.get_site_name()


Expand All @@ -132,8 +150,7 @@ def get_task_name(ctx: Optional[APIContext] = None) -> str:
Returns:
The task name.
"""
global default_context
local_ctx = ctx if ctx else default_context
local_ctx = get_context(ctx)
return local_ctx.api.get_task_name()


Expand All @@ -143,8 +160,7 @@ def is_running(ctx: Optional[APIContext] = None) -> bool:
Returns:
True, if the system is up and running. False, otherwise.
"""
global default_context
local_ctx = ctx if ctx else default_context
local_ctx = get_context(ctx)
return local_ctx.api.is_running()


Expand All @@ -154,8 +170,7 @@ def is_train(ctx: Optional[APIContext] = None) -> bool:
Returns:
True, if the current task is a training task. False, otherwise.
"""
global default_context
local_ctx = ctx if ctx else default_context
local_ctx = get_context(ctx)
return local_ctx.api.is_train()


Expand All @@ -165,8 +180,7 @@ def is_evaluate(ctx: Optional[APIContext] = None) -> bool:
Returns:
True, if the current task is an evaluate task. False, otherwise.
"""
global default_context
local_ctx = ctx if ctx else default_context
local_ctx = get_context(ctx)
return local_ctx.api.is_evaluate()


Expand All @@ -176,8 +190,7 @@ def is_submit_model(ctx: Optional[APIContext] = None) -> bool:
Returns:
True, if the current task is a submit_model. False, otherwise.
"""
global default_context
local_ctx = ctx if ctx else default_context
local_ctx = get_context(ctx)
return local_ctx.api.is_submit_model()


Expand All @@ -195,20 +208,17 @@ def log(key: str, value: Any, data_type: AnalyticsDataType, ctx: Optional[APICon
Returns:
whether the key value pair is logged successfully
"""
global default_context
local_ctx = ctx if ctx else default_context
local_ctx = get_context(ctx)
return local_ctx.api.log(key, value, data_type, **kwargs)


def clear(ctx: Optional[APIContext] = None):
"""Clears the cache."""
global default_context
local_ctx = ctx if ctx else default_context
local_ctx = get_context(ctx)
return local_ctx.api.clear()


def shutdown(ctx: Optional[APIContext] = None):
"""Releases all threads and resources used by the API and stops operation."""
global default_context
local_ctx = ctx if ctx else default_context
local_ctx = get_context(ctx)
return local_ctx.api.shutdown()
6 changes: 5 additions & 1 deletion nvflare/client/api_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from .api_spec import CLIENT_API_KEY, CLIENT_API_TYPE_KEY, APISpec
from .ex_process.api import ExProcessClientAPI
from .in_process.api import InProcessClientAPI

DEFAULT_CONFIG = f"config/{CLIENT_API_CONFIG}"
data_bus = DataBus()
Expand Down Expand Up @@ -54,6 +55,9 @@ def __exit__(self, exc_type, exc_val, exc_tb):
def _create_client_api(self, api_type: ClientAPIType) -> APISpec:
"""Creates a new client_api based on the provided API type."""
if api_type == ClientAPIType.IN_PROCESS_API:
return data_bus.get_data(CLIENT_API_KEY)
api = data_bus.get_data(CLIENT_API_KEY)
if not isinstance(api, InProcessClientAPI):
raise RuntimeError(f"api {api} is not a valid InProcessClientAPI")
return api
else:
return ExProcessClientAPI(config_file=self.config_file)
6 changes: 4 additions & 2 deletions nvflare/client/ex_process/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(self, config_file: str):
self.logger = get_obj_logger(self)
self.receive_called = False
self.config_file = config_file
self.flare_agent = None

def get_model_registry(self) -> ModelRegistry:
"""Gets the ModelRegistry."""
Expand Down Expand Up @@ -133,6 +134,7 @@ def init(self, rank: Optional[str] = None):
flare_agent.start()

self.model_registry = ModelRegistry(client_config, rank, flare_agent)
self.flare_agent = flare_agent
except Exception as e:
self.logger.error(f"flare.init failed: {e}")
raise e
Expand Down Expand Up @@ -216,5 +218,5 @@ def clear(self):
self.receive_called = False

def shutdown(self):
model_registry = self.get_model_registry()
model_registry.shutdown()
if self.flare_agent:
self.flare_agent.stop()
1 change: 1 addition & 0 deletions nvflare/client/in_process/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,4 +237,5 @@ def __continue_job(self) -> bool:

def shutdown(self):
self.stop = True
self.event_manager.fire_event(TOPIC_STOP)
self.stop_reason = "API shutdown called."
4 changes: 0 additions & 4 deletions nvflare/client/task_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,3 @@ def clear(self) -> None:

def __str__(self):
return f"{self.__class__.__name__}(config: {self.config.get_config()})"

def shutdown(self):
if self.flare_agent:
self.flare_agent.stop()
5 changes: 2 additions & 3 deletions nvflare/client/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@

# flake8: noqa
from .api import default_context as default_context
from .api import log
from .api import log, get_context
from .api_context import APIContext


class _BaseWriter:
def __init__(self, ctx: Optional[APIContext] = None):
global default_context
self.ctx = ctx if ctx else default_context
self.ctx = get_context(ctx)


class SummaryWriter(_BaseWriter):
Expand Down

0 comments on commit 675f658

Please sign in to comment.