2
2
import logging
3
3
import sys
4
4
from collections .abc import Callable
5
+ from optimizers import BaseOptimizer
5
6
6
7
import yaml
7
8
8
9
logger = logging .getLogger ("neps" )
9
10
11
+ # Define the allowed parameters based on the arguments of neps.run(), that have a
12
+ # simple value like a string or int type
13
+ # [searcher_kwargs, run_pipeline, preload_hooks] require special handling due to their
14
+ # values necessitating distinct treatment, that's why they are not listed
15
+ EXPECTED_PARAMETERS = [
16
+ "pipeline_space" ,
17
+ "root_directory" ,
18
+ "max_evaluations_total" ,
19
+ "max_cost_total" ,
20
+ "overwrite_working_directory" ,
21
+ "post_run_summary" ,
22
+ "development_stage_id" ,
23
+ "task_id" ,
24
+ "max_evaluations_per_run" ,
25
+ "continue_until_max_evaluation_completed" ,
26
+ "loss_value_on_error" ,
27
+ "cost_value_on_error" ,
28
+ "ignore_errors" ,
29
+ "searcher" ,
30
+ "searcher_path" ,
31
+ ]
32
+
33
+ # Mapping parameter names to their allowed types
34
+ # [task_id, development_stage_id, pre_load_hooks] require special handling of type,
35
+ # that's why they are not listed
36
+ ALLOWED_TYPES = {
37
+ "run_pipeline" : Callable ,
38
+ "root_directory" : str ,
39
+ "pipeline_space" : str ,
40
+ "max_evaluations_total" : int ,
41
+ "max_cost_total" : (int , float ),
42
+ "overwrite_working_directory" : bool ,
43
+ "post_run_summary" : bool ,
44
+ "max_evaluations_per_run" : int ,
45
+ "continue_until_max_evaluation_completed" : bool ,
46
+ "loss_value_on_error" : float ,
47
+ "cost_value_on_error" : float ,
48
+ "ignore_errors" : bool ,
49
+ "searcher" : str ,
50
+ "searcher_path" : str ,
51
+ "searcher_kwargs" : dict ,
52
+ }
53
+
10
54
11
55
def get_run_args_from_yaml (path ):
12
56
"""
@@ -30,35 +74,16 @@ def get_run_args_from_yaml(path):
30
74
# Load the YAML configuration file
31
75
config = config_loader (path )
32
76
33
- # Define the allowed parameters based on the arguments of neps.run()
34
- allowed_parameters = [
35
- "pipeline_space" ,
36
- "root_directory" ,
37
- "max_evaluations_total" ,
38
- "max_cost_total" ,
39
- "overwrite_working_directory" ,
40
- "post_run_summary" ,
41
- "development_stage_id" ,
42
- "task_id" ,
43
- "max_evaluations_per_run" ,
44
- "continue_until_max_evaluation_completed" ,
45
- "loss_value_on_error" ,
46
- "cost_value_on_error" ,
47
- "ignore_errors" ,
48
- "searcher" ,
49
- "searcher_path" ,
50
- ]
51
-
52
77
# Initialize an empty dictionary to hold the extracted settings
53
78
settings = {}
54
79
55
80
# Flatten yaml file and ignore hierarchical structure, only consider parameters(keys)
56
- # with a value
81
+ # with an explicit value
57
82
flat_config , special_configs = extract_leaf_keys (config )
58
83
59
- # Check if just neps arguments are provided
84
+ # Check if flatten dict just contains neps arguments
60
85
for parameter , value in flat_config .items ():
61
- if parameter in allowed_parameters :
86
+ if parameter in EXPECTED_PARAMETERS :
62
87
settings [parameter ] = value
63
88
else :
64
89
raise KeyError (f"Parameter '{ parameter } ' is not an argument of neps.run()." )
@@ -243,24 +268,7 @@ def check_run_args(settings):
243
268
Raises:
244
269
TypeError: For mismatched setting value types.
245
270
"""
246
- # Mapping parameter names to their allowed types
247
- allowed_types = {
248
- "run_pipeline" : Callable ,
249
- "root_directory" : str ,
250
- "pipeline_space" : str ,
251
- "max_evaluations_total" : int ,
252
- "max_cost_total" : (int , float ),
253
- "overwrite_working_directory" : bool ,
254
- "post_run_summary" : bool ,
255
- "max_evaluations_per_run" : int ,
256
- "continue_until_max_evaluation_completed" : bool ,
257
- "loss_value_on_error" : float ,
258
- "cost_value_on_error" : float ,
259
- "ignore_errors" : bool ,
260
- "searcher" : str ,
261
- "searcher_path" : str ,
262
- "searcher_kwargs" : dict ,
263
- }
271
+
264
272
for param , value in settings .items ():
265
273
if param == "development_stage_id" or param == "task_id" :
266
274
# this argument can be Any
@@ -270,7 +278,7 @@ def check_run_args(settings):
270
278
if not all (callable (item ) for item in value ):
271
279
raise TypeError ("All items in 'pre_load_hooks' must be callable." )
272
280
else :
273
- expected_type = allowed_types [param ]
281
+ expected_type = ALLOWED_TYPES [param ]
274
282
if not isinstance (value , expected_type ):
275
283
raise TypeError (
276
284
f"Parameter '{ param } ' expects a value of type { expected_type } , got "
@@ -305,7 +313,7 @@ def check_essential_arguments(
305
313
if not root_directory :
306
314
raise ValueError ("'root_directory' is required but was not provided." )
307
315
if not pipeline_space :
308
- if searcher is not Callable :
316
+ if not isinstance ( searcher , BaseOptimizer ) :
309
317
raise ValueError ("'pipeline_space' is required but was not provided." )
310
318
if not max_evaluation_total and not max_cost_total :
311
319
raise ValueError (
0 commit comments