@@ -22,6 +22,7 @@ def run(
22
22
read_csv_kwargs : Optional [dict ] = None ,
23
23
uuid_column : str = ":uuid:" ,
24
24
retval_column : Optional [str ] = ":retval:" ,
25
+ extra_kwargs : Optional [Dict [str , Any ]] = None ,
25
26
):
26
27
"""Run experiments from a csv file
27
28
@@ -32,14 +33,16 @@ def run(
32
33
read_csv_kwargs (`Optional[dict]`, optional): Additional kwargs passed to `pandas.read_csv`. Defaults to None.
33
34
uuid_column (`str`, optional): The column name for the uuid. Defaults to `":uuid:"`.
34
35
retval_column (`Optional[str]`, optional): The column name for the return value. None for not saving the return value. Defaults to `":retval:"`.
36
+ extra_kwargs (`Optional[Dict[str, Any]]`, optional): Extra kwargs passed to exp_func.
35
37
"""
36
38
kwargs = {
37
39
"csv_path" : csv_path ,
38
40
"continue_cols" : continue_cols ,
39
41
"force_rerun" : force_rerun ,
40
- "read_csv_kwargs" : read_csv_kwargs or {} ,
42
+ "read_csv_kwargs" : read_csv_kwargs ,
41
43
"uuid_column" : uuid_column ,
42
44
"retval_column" : retval_column ,
45
+ "extra_kwargs" : extra_kwargs ,
43
46
}
44
47
return asyncio .run (self .arun (** kwargs ))
45
48
@@ -68,15 +71,22 @@ def submit_from_csv(
68
71
if not force_rerun :
69
72
rows = False
70
73
for col in self .continue_cols :
71
- rows |= df [col ].isnull ()
72
- added = int (rows .sum ())
73
- logger .info (f"Adding { added } tasks ({ len (df ) - added } skipped)." )
74
+ if col in df .columns :
75
+ rows |= df [col ].isnull ()
76
+ if isinstance (rows , bool ):
77
+ rows = slice (None )
78
+ logger .info (f"Adding { len (df )} tasks." )
79
+ else :
80
+ added = int (rows .sum ())
81
+ logger .info (
82
+ f"Adding { added } tasks ({ len (df ) - added } skipped)." )
74
83
else :
75
84
rows = slice (None )
76
85
logger .info (f"Adding { len (df )} tasks." )
77
86
78
87
tasks = [
79
- self .create_task (uuid , ** row ) for uuid , row in df [rows ].iterrows ()
88
+ self .create_task (uuid , ** row , ** self .extra_kwargs )
89
+ for uuid , row in df [rows ].iterrows ()
80
90
]
81
91
82
92
return tasks
@@ -112,13 +122,15 @@ async def arun(
112
122
read_csv_kwargs : Optional [dict ] = None ,
113
123
uuid_column : str = ":uuid:" ,
114
124
retval_column : Optional [str ] = ":retval:" ,
125
+ extra_kwargs : Optional [Dict [str , Any ]] = None ,
115
126
):
116
127
"""Async run experiments from a csv file"""
117
128
118
129
self .csv_path = csv_path
119
130
self .continue_cols = continue_cols
120
- self .read_csv_kwargs = read_csv_kwargs
131
+ self .read_csv_kwargs = read_csv_kwargs or {}
121
132
self .uuid_column = uuid_column
133
+ self .extra_kwargs = extra_kwargs or {}
122
134
123
135
tasks = self .submit_from_csv (force_rerun )
124
136
0 commit comments