Skip to content

Commit ec7889b

Browse files
committed
fix provided arguments check, special case searcher-pipeline_space
1 parent 54eca9c commit ec7889b

File tree

1 file changed

+50
-42
lines changed

1 file changed

+50
-42
lines changed

neps/utils/run_args_from_yaml.py

Lines changed: 50 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,55 @@
22
import logging
33
import sys
44
from collections.abc import Callable
5+
from optimizers import BaseOptimizer
56

67
import yaml
78

89
logger = logging.getLogger("neps")
910

11+
# Define the allowed parameters based on the arguments of neps.run(), that have a
12+
# simple value like a string or int type
13+
# [searcher_kwargs, run_pipeline, preload_hooks] require special handling due to their
14+
# values necessitating distinct treatment, that's why they are not listed
15+
EXPECTED_PARAMETERS = [
16+
"pipeline_space",
17+
"root_directory",
18+
"max_evaluations_total",
19+
"max_cost_total",
20+
"overwrite_working_directory",
21+
"post_run_summary",
22+
"development_stage_id",
23+
"task_id",
24+
"max_evaluations_per_run",
25+
"continue_until_max_evaluation_completed",
26+
"loss_value_on_error",
27+
"cost_value_on_error",
28+
"ignore_errors",
29+
"searcher",
30+
"searcher_path",
31+
]
32+
33+
# Mapping parameter names to their allowed types
34+
# [task_id, development_stage_id, pre_load_hooks] require special handling of type,
35+
# that's why they are not listed
36+
ALLOWED_TYPES = {
37+
"run_pipeline": Callable,
38+
"root_directory": str,
39+
"pipeline_space": str,
40+
"max_evaluations_total": int,
41+
"max_cost_total": (int, float),
42+
"overwrite_working_directory": bool,
43+
"post_run_summary": bool,
44+
"max_evaluations_per_run": int,
45+
"continue_until_max_evaluation_completed": bool,
46+
"loss_value_on_error": float,
47+
"cost_value_on_error": float,
48+
"ignore_errors": bool,
49+
"searcher": str,
50+
"searcher_path": str,
51+
"searcher_kwargs": dict,
52+
}
53+
1054

1155
def get_run_args_from_yaml(path):
1256
"""
@@ -30,35 +74,16 @@ def get_run_args_from_yaml(path):
3074
# Load the YAML configuration file
3175
config = config_loader(path)
3276

33-
# Define the allowed parameters based on the arguments of neps.run()
34-
allowed_parameters = [
35-
"pipeline_space",
36-
"root_directory",
37-
"max_evaluations_total",
38-
"max_cost_total",
39-
"overwrite_working_directory",
40-
"post_run_summary",
41-
"development_stage_id",
42-
"task_id",
43-
"max_evaluations_per_run",
44-
"continue_until_max_evaluation_completed",
45-
"loss_value_on_error",
46-
"cost_value_on_error",
47-
"ignore_errors",
48-
"searcher",
49-
"searcher_path",
50-
]
51-
5277
# Initialize an empty dictionary to hold the extracted settings
5378
settings = {}
5479

5580
# Flatten yaml file and ignore hierarchical structure, only consider parameters(keys)
56-
# with a value
81+
# with an explicit value
5782
flat_config, special_configs = extract_leaf_keys(config)
5883

59-
# Check if just neps arguments are provided
84+
# Check if flatten dict just contains neps arguments
6085
for parameter, value in flat_config.items():
61-
if parameter in allowed_parameters:
86+
if parameter in EXPECTED_PARAMETERS:
6287
settings[parameter] = value
6388
else:
6489
raise KeyError(f"Parameter '{parameter}' is not an argument of neps.run().")
@@ -243,24 +268,7 @@ def check_run_args(settings):
243268
Raises:
244269
TypeError: For mismatched setting value types.
245270
"""
246-
# Mapping parameter names to their allowed types
247-
allowed_types = {
248-
"run_pipeline": Callable,
249-
"root_directory": str,
250-
"pipeline_space": str,
251-
"max_evaluations_total": int,
252-
"max_cost_total": (int, float),
253-
"overwrite_working_directory": bool,
254-
"post_run_summary": bool,
255-
"max_evaluations_per_run": int,
256-
"continue_until_max_evaluation_completed": bool,
257-
"loss_value_on_error": float,
258-
"cost_value_on_error": float,
259-
"ignore_errors": bool,
260-
"searcher": str,
261-
"searcher_path": str,
262-
"searcher_kwargs": dict,
263-
}
271+
264272
for param, value in settings.items():
265273
if param == "development_stage_id" or param == "task_id":
266274
# this argument can be Any
@@ -270,7 +278,7 @@ def check_run_args(settings):
270278
if not all(callable(item) for item in value):
271279
raise TypeError("All items in 'pre_load_hooks' must be callable.")
272280
else:
273-
expected_type = allowed_types[param]
281+
expected_type = ALLOWED_TYPES[param]
274282
if not isinstance(value, expected_type):
275283
raise TypeError(
276284
f"Parameter '{param}' expects a value of type {expected_type}, got "
@@ -305,7 +313,7 @@ def check_essential_arguments(
305313
if not root_directory:
306314
raise ValueError("'root_directory' is required but was not provided.")
307315
if not pipeline_space:
308-
if searcher is not Callable:
316+
if not isinstance(searcher, BaseOptimizer):
309317
raise ValueError("'pipeline_space' is required but was not provided.")
310318
if not max_evaluation_total and not max_cost_total:
311319
raise ValueError(

0 commit comments

Comments
 (0)