Skip to content

Commit 0f6788b

Browse files
committed
Refactor client utilities and add experiment heartbeat
Moved client utility functions and models to a new utils.py file, including the safe_request decorator. Introduced a client heartbeat mechanism to keep experiments active, with integration in CloudPyCallback. Updated imports and improved type hinting for better maintainability.
1 parent 28a09dd commit 0f6788b

File tree

7 files changed

+98
-50
lines changed

7 files changed

+98
-50
lines changed

.idea/SwanLab.iml

Lines changed: 0 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/dictionaries/project.xml

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

swanlab/core_python/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@
1010
# FIXME 存在循环引用,我们需要更优雅的代码结构
1111
# from . import auth
1212
# from . import uploader
13-
from .client import Client, create_client, reset_client, get_client, create_session
13+
from .client import *
1414
from .utils import timer

swanlab/core_python/api/experiment.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,29 @@
55
@description: 定义实验相关的后端API接口
66
"""
77

8-
from typing import Literal
8+
from typing import Literal, TYPE_CHECKING
99

10-
from swanlab.core_python.client import Client
10+
if TYPE_CHECKING:
11+
from swanlab.core_python.client import Client
12+
13+
14+
def send_experiment_heartbeat(
15+
client: "Client",
16+
*,
17+
cuid: str,
18+
flag_id: str,
19+
):
20+
"""
21+
发送实验心跳,保持实验处于活跃状态
22+
:param client: 已登录的客户端实例
23+
:param cuid: 实验唯一标识符
24+
:param flag_id: 实验标记ID
25+
"""
26+
client.post(f"/house/experiments/{cuid}/heartbeat", {"flagId": flag_id})
1127

1228

1329
def update_experiment_state(
14-
client: Client,
30+
client: "Client",
1531
*,
1632
username: str,
1733
projname: str,

swanlab/core_python/client/__init__.py

Lines changed: 21 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,15 @@
1010
from typing import Optional, Tuple, Dict, Union, List, AnyStr
1111

1212
import requests
13-
from urllib3.exceptions import (
14-
MaxRetryError,
15-
TimeoutError,
16-
NewConnectionError,
17-
ConnectionError,
18-
ReadTimeoutError,
19-
ConnectTimeoutError,
20-
)
21-
22-
from swanlab.error import NetworkError, ApiError
13+
14+
from swanlab.error import ApiError
2315
from swanlab.log import swanlog
2416
from swanlab.package import get_package_version
25-
from .model import ProjectInfo, ExperimentInfo
2617
from .session import create_session
18+
from .utils import safe_request, ProjectInfo, ExperimentInfo
2719
from .. import auth
20+
from ..api.experiment import send_experiment_heartbeat
21+
from ..utils import timer
2822
from ...env import utc_time
2923

3024

@@ -114,11 +108,11 @@ def expname(self):
114108
return self.__exp.name
115109

116110
@property
117-
def web_proj_url(self):
111+
def web_proj_url(self) -> str:
118112
return f"{self.__login_info.web_host}/@{self.groupname}/{self.projname}"
119113

120114
@property
121-
def web_exp_url(self):
115+
def web_exp_url(self) -> str:
122116
return f"{self.web_proj_url}/runs/{self.exp_id}"
123117

124118
# ---------------------------------- http方法 ----------------------------------
@@ -374,41 +368,31 @@ def reset_client():
374368
client = None
375369

376370

377-
def safe_request(func):
371+
def create_client_heartbeat():
378372
"""
379-
在一些接口中我们不希望线程奔溃,而是返回一个错误对象
373+
创建客户端心跳定时器,保持实验处于活跃状态
374+
:return: 心跳定时器实例
380375
"""
376+
cl = get_client()
381377

382-
def wrapper(*args, **kwargs) -> Tuple[Optional[Union[dict, str]], Optional[Exception]]:
378+
# TODO 目前保证乡下兼容,如果报错也不提示用户,后续使用safe_request装饰器
379+
# func = safe_request(func=send_experiment_heartbeat)
380+
def func(c: Client, *, cuid: str, flag_id: str):
383381
try:
384-
# 在装饰器中调用被装饰的异步函数
385-
result = func(*args, **kwargs)
386-
return result, None
387-
except requests.exceptions.Timeout:
388-
return None, NetworkError()
389-
except requests.exceptions.ConnectionError:
390-
return None, NetworkError()
391-
# Catch urllib3 specific errors
392-
except (
393-
MaxRetryError,
394-
TimeoutError,
395-
NewConnectionError,
396-
ConnectionError,
397-
ReadTimeoutError,
398-
ConnectTimeoutError,
399-
):
400-
return None, NetworkError()
401-
except Exception as e:
402-
return None, e
403-
404-
return wrapper
382+
send_experiment_heartbeat(c, cuid=cuid, flag_id=flag_id)
383+
except ApiError as e:
384+
swanlog.debug("Failed to send heartbeat: " + str(e))
385+
386+
task = lambda: func(cl, cuid=cl.exp.cuid, flag_id=cl.exp.flag_id)
387+
return timer.Timer(task, interval=10, immediate=True).run()
405388

406389

407390
__all__ = [
408391
"get_client",
409392
"reset_client",
410393
"create_session",
411394
"create_client",
395+
"create_client_heartbeat",
412396
"safe_request",
413397
"decode_response",
414398
"Client",
Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,53 @@
11
"""
22
@author: cunyue
3-
@file: model.py
4-
@time: 2025/6/16 14:55
5-
@description: 实验、项目元信息
3+
@file: utils.py
4+
@time: 2025/12/31 13:29
5+
@description: 客户端工具函数
66
"""
77

8-
from typing import Optional
8+
from typing import Tuple, Optional, Union
9+
10+
import requests
11+
from urllib3.exceptions import (
12+
MaxRetryError,
13+
TimeoutError,
14+
NewConnectionError,
15+
ConnectionError,
16+
ReadTimeoutError,
17+
ConnectTimeoutError,
18+
)
19+
20+
from swanlab.error import NetworkError
21+
22+
23+
def safe_request(func):
24+
"""
25+
在一些接口中我们不希望线程奔溃,而是返回一个错误对象
26+
"""
27+
28+
def wrapper(*args, **kwargs) -> Tuple[Optional[Union[dict, str]], Optional[Exception]]:
29+
try:
30+
# 在装饰器中调用被装饰的异步函数
31+
result = func(*args, **kwargs)
32+
return result, None
33+
except requests.exceptions.Timeout:
34+
return None, NetworkError()
35+
except requests.exceptions.ConnectionError:
36+
return None, NetworkError()
37+
# Catch urllib3 specific errors
38+
except (
39+
MaxRetryError,
40+
TimeoutError,
41+
NewConnectionError,
42+
ConnectionError,
43+
ReadTimeoutError,
44+
ConnectTimeoutError,
45+
):
46+
return None, NetworkError()
47+
except Exception as e:
48+
return None, e
49+
50+
return wrapper
951

1052

1153
class ProjectInfo:

swanlab/data/callbacker/cloud.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ..run import get_run
2828
from ...core_python import *
2929
from ...core_python.api.experiment import update_experiment_state
30+
from ...core_python.utils.timer import Timer
3031
from ...log.type import LogData
3132

3233

@@ -36,6 +37,7 @@ class CloudPyCallback(SwanLabRunCallback):
3637
def __init__(self):
3738
super().__init__()
3839
self.executor = ThreadPoolExecutor(max_workers=1)
40+
self.heartbeat: Optional[Timer] = None
3941

4042
def __str__(self):
4143
return "SwanLabCloudPyCallback"
@@ -62,12 +64,15 @@ def _converter_summarise_metric():
6264
pass
6365

6466
def on_init(self, *args, **kwargs):
65-
_ = self._create_client()
67+
self._create_client()
6668
# 检测是否有最新的版本
6769
U.check_latest_version()
70+
# 挂载项目、实验
6871
with Status("Creating experiment...", spinner="dots"):
6972
with Mounter() as mounter:
7073
mounter.execute()
74+
# 创建客户端心跳
75+
self.heartbeat = create_client_heartbeat()
7176

7277
def _terminal_handler(self, log_data: LogData):
7378
self.porter.trace_log(log_data)
@@ -100,6 +105,10 @@ def on_metric_create(self, metric_info: MetricInfo, *args, **kwargs):
100105
self.porter.trace_metric(metric_info)
101106

102107
def on_stop(self, error: str = None, *args, **kwargs):
108+
# 删除心跳
109+
self.heartbeat.cancel()
110+
self.heartbeat.join()
111+
# 删除终端代理和系统回调
103112
success = get_run().success
104113
# FIXME 等合并 swankit 以后优化一下 interrupt 的传递问题
105114
interrupt = kwargs.get("interrupt", False)

0 commit comments

Comments
 (0)