Skip to content

Commit

Permalink
fix provided arguments check, special case searcher-pipeline_space
Browse files Browse the repository at this point in the history
  • Loading branch information
danrgll committed Mar 18, 2024
1 parent 54eca9c commit ec7889b
Showing 1 changed file with 50 additions and 42 deletions.
92 changes: 50 additions & 42 deletions neps/utils/run_args_from_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,55 @@
import logging
import sys
from collections.abc import Callable
from optimizers import BaseOptimizer

import yaml

logger = logging.getLogger("neps")

# Define the allowed parameters based on the arguments of neps.run(), that have a
# simple value like a string or int type
# [searcher_kwargs, run_pipeline, preload_hooks] require special handling due to their
# values necessitating distinct treatment, that's why they are not listed
EXPECTED_PARAMETERS = [
"pipeline_space",
"root_directory",
"max_evaluations_total",
"max_cost_total",
"overwrite_working_directory",
"post_run_summary",
"development_stage_id",
"task_id",
"max_evaluations_per_run",
"continue_until_max_evaluation_completed",
"loss_value_on_error",
"cost_value_on_error",
"ignore_errors",
"searcher",
"searcher_path",
]

# Mapping parameter names to their allowed types
# [task_id, development_stage_id, pre_load_hooks] require special handling of type,
# that's why they are not listed
ALLOWED_TYPES = {
"run_pipeline": Callable,
"root_directory": str,
"pipeline_space": str,
"max_evaluations_total": int,
"max_cost_total": (int, float),
"overwrite_working_directory": bool,
"post_run_summary": bool,
"max_evaluations_per_run": int,
"continue_until_max_evaluation_completed": bool,
"loss_value_on_error": float,
"cost_value_on_error": float,
"ignore_errors": bool,
"searcher": str,
"searcher_path": str,
"searcher_kwargs": dict,
}


def get_run_args_from_yaml(path):
"""
Expand All @@ -30,35 +74,16 @@ def get_run_args_from_yaml(path):
# Load the YAML configuration file
config = config_loader(path)

# Define the allowed parameters based on the arguments of neps.run()
allowed_parameters = [
"pipeline_space",
"root_directory",
"max_evaluations_total",
"max_cost_total",
"overwrite_working_directory",
"post_run_summary",
"development_stage_id",
"task_id",
"max_evaluations_per_run",
"continue_until_max_evaluation_completed",
"loss_value_on_error",
"cost_value_on_error",
"ignore_errors",
"searcher",
"searcher_path",
]

# Initialize an empty dictionary to hold the extracted settings
settings = {}

# Flatten yaml file and ignore hierarchical structure, only consider parameters(keys)
# with a value
# with an explicit value
flat_config, special_configs = extract_leaf_keys(config)

# Check if just neps arguments are provided
# Check if flatten dict just contains neps arguments
for parameter, value in flat_config.items():
if parameter in allowed_parameters:
if parameter in EXPECTED_PARAMETERS:
settings[parameter] = value
else:
raise KeyError(f"Parameter '{parameter}' is not an argument of neps.run().")
Expand Down Expand Up @@ -243,24 +268,7 @@ def check_run_args(settings):
Raises:
TypeError: For mismatched setting value types.
"""
# Mapping parameter names to their allowed types
allowed_types = {
"run_pipeline": Callable,
"root_directory": str,
"pipeline_space": str,
"max_evaluations_total": int,
"max_cost_total": (int, float),
"overwrite_working_directory": bool,
"post_run_summary": bool,
"max_evaluations_per_run": int,
"continue_until_max_evaluation_completed": bool,
"loss_value_on_error": float,
"cost_value_on_error": float,
"ignore_errors": bool,
"searcher": str,
"searcher_path": str,
"searcher_kwargs": dict,
}

for param, value in settings.items():
if param == "development_stage_id" or param == "task_id":
# this argument can be Any
Expand All @@ -270,7 +278,7 @@ def check_run_args(settings):
if not all(callable(item) for item in value):
raise TypeError("All items in 'pre_load_hooks' must be callable.")
else:
expected_type = allowed_types[param]
expected_type = ALLOWED_TYPES[param]
if not isinstance(value, expected_type):
raise TypeError(
f"Parameter '{param}' expects a value of type {expected_type}, got "
Expand Down Expand Up @@ -305,7 +313,7 @@ def check_essential_arguments(
if not root_directory:
raise ValueError("'root_directory' is required but was not provided.")
if not pipeline_space:
if searcher is not Callable:
if not isinstance(searcher, BaseOptimizer):
raise ValueError("'pipeline_space' is required but was not provided.")
if not max_evaluation_total and not max_cost_total:
raise ValueError(
Expand Down

0 comments on commit ec7889b

Please sign in to comment.