Skip to content

Commit 956866f

Browse files
committed
[fix] catch errors
1 parent d73c15d commit 956866f

File tree

3 files changed

+37
-10
lines changed

3 files changed

+37
-10
lines changed

ml_scheduler/exp/exp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import subprocess
33
import threading
44
from logging import getLogger
5-
from typing import Any, Dict, List, Set
5+
from typing import Any, Dict, List, Optional, Set
66

77
from ..pools.base import BaseAllocator, BaseResources
88
from .runner import BaseRunner
@@ -61,6 +61,8 @@ def run_in_thread(**popen_kwargs):
6161

6262
return stdout
6363

64-
async def report(self, metrics: Dict[str, Any], **kwargs):
64+
async def report(self, metrics: Optional[Dict[str, Any]] = None, **kwargs):
65+
if metrics is None:
66+
metrics = {}
6567
metrics.update(kwargs)
6668
await self.runner._report(self.uuid, metrics)

ml_scheduler/exp/func.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
import inspect
2+
from logging import getLogger
3+
from traceback import format_exc
24
from typing import Any, Tuple
35

46
from .exp import Exp
57
from .runner.csv import CSVRunner
68

9+
logger = getLogger(__name__)
10+
711

812
class ExpFunc:
913

1014
def __init__(self, exp_func) -> None:
1115
self.exp_func = exp_func
12-
self.run_csv = CSVRunner.set(self).run
16+
17+
csvrunner = CSVRunner.set(self)
18+
self.run_csv = csvrunner.run
19+
self.arun_csv = csvrunner.arun
1320

1421
async def __call__(self, exp: Exp, **kwargs) -> Tuple[Exp, Any]:
1522
assert isinstance(exp, Exp)
@@ -20,7 +27,13 @@ async def __call__(self, exp: Exp, **kwargs) -> Tuple[Exp, Any]:
2027
ba = inspect.signature(self.exp_func).bind(exp, **filtered_kwargs)
2128
ba.apply_defaults()
2229

23-
results = await self.exp_func(**ba.arguments)
30+
try:
31+
results = await self.exp_func(**ba.arguments)
32+
except Exception as e:
33+
logger.warning(format_exc())
34+
logger.error(f"Error in {self.exp_func.__name__}: {e}")
35+
results = ""
36+
2437
await exp.cleanup()
2538
return (exp, results)
2639

ml_scheduler/exp/runner/csv.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def run(
2222
read_csv_kwargs: Optional[dict] = None,
2323
uuid_column: str = ":uuid:",
2424
retval_column: Optional[str] = ":retval:",
25+
extra_kwargs: Optional[Dict[str, Any]] = None,
2526
):
2627
"""Run experiments from a csv file
2728
@@ -32,14 +33,16 @@ def run(
3233
read_csv_kwargs (`Optional[dict]`, optional): Additional kwargs passed to `pandas.read_csv`. Defaults to None.
3334
uuid_column (`str`, optional): The column name for the uuid. Defaults to `":uuid:"`.
3435
retval_column (`Optional[str]`, optional): The column name for the return value. None for not saving the return value. Defaults to `":retval:"`.
36+
extra_kwargs (`Optional[Dict[str, Any]]`, optional): Extra kwargs passed to exp_func.
3537
"""
3638
kwargs = {
3739
"csv_path": csv_path,
3840
"continue_cols": continue_cols,
3941
"force_rerun": force_rerun,
40-
"read_csv_kwargs": read_csv_kwargs or {},
42+
"read_csv_kwargs": read_csv_kwargs,
4143
"uuid_column": uuid_column,
4244
"retval_column": retval_column,
45+
"extra_kwargs": extra_kwargs,
4346
}
4447
return asyncio.run(self.arun(**kwargs))
4548

@@ -68,15 +71,22 @@ def submit_from_csv(
6871
if not force_rerun:
6972
rows = False
7073
for col in self.continue_cols:
71-
rows |= df[col].isnull()
72-
added = int(rows.sum())
73-
logger.info(f"Adding {added} tasks ({len(df) - added} skipped).")
74+
if col in df.columns:
75+
rows |= df[col].isnull()
76+
if isinstance(rows, bool):
77+
rows = slice(None)
78+
logger.info(f"Adding {len(df)} tasks.")
79+
else:
80+
added = int(rows.sum())
81+
logger.info(
82+
f"Adding {added} tasks ({len(df) - added} skipped).")
7483
else:
7584
rows = slice(None)
7685
logger.info(f"Adding {len(df)} tasks.")
7786

7887
tasks = [
79-
self.create_task(uuid, **row) for uuid, row in df[rows].iterrows()
88+
self.create_task(uuid, **row, **self.extra_kwargs)
89+
for uuid, row in df[rows].iterrows()
8090
]
8191

8292
return tasks
@@ -112,13 +122,15 @@ async def arun(
112122
read_csv_kwargs: Optional[dict] = None,
113123
uuid_column: str = ":uuid:",
114124
retval_column: Optional[str] = ":retval:",
125+
extra_kwargs: Optional[Dict[str, Any]] = None,
115126
):
116127
"""Async run experiments from a csv file"""
117128

118129
self.csv_path = csv_path
119130
self.continue_cols = continue_cols
120-
self.read_csv_kwargs = read_csv_kwargs
131+
self.read_csv_kwargs = read_csv_kwargs or {}
121132
self.uuid_column = uuid_column
133+
self.extra_kwargs = extra_kwargs or {}
122134

123135
tasks = self.submit_from_csv(force_rerun)
124136

0 commit comments

Comments
 (0)