Skip to content

Commit

Permalink
implementation of run_args, loading and checking settings from yaml +…
Browse files Browse the repository at this point in the history
… tests
  • Loading branch information
danrgll committed Mar 11, 2024
1 parent 2ab7003 commit 17fe070
Show file tree
Hide file tree
Showing 17 changed files with 1,107 additions and 3 deletions.
65 changes: 62 additions & 3 deletions neps/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Callable, Iterable, Literal

import ConfigSpace as CS
from utils.run_args_from_yaml import check_essential_arguments, get_run_args_from_yaml

from .metahyper import instance_from_map, metahyper_run
from .optimizers import BaseOptimizer, SearcherMapping
Expand Down Expand Up @@ -96,15 +97,16 @@ def write_loss_and_config(file_handle, loss_, config_id_, config_):


def run(
run_pipeline: Callable,
root_directory: str | Path,
run_pipeline: Callable = None,
root_directory: str | Path = None,
pipeline_space: (
dict[str, Parameter | CS.ConfigurationSpace]
| str
| Path
| CS.ConfigurationSpace
| None
) = None,
run_args: str | Path = None,
overwrite_working_directory: bool = False,
post_run_summary: bool = False,
development_stage_id=None,
Expand Down Expand Up @@ -146,6 +148,8 @@ def run(
pipeline_space: The search space to minimize over.
root_directory: The directory to save progress to. This is also used to
synchronize multiple calls to run(.) for parallelization.
run_args: An option for providing the optimization settings e.g.
max_evaluation_total in a YAML file.
overwrite_working_directory: If true, delete the working directory at the start of
the run. This is, e.g., useful when debugging a run_pipeline function.
post_run_summary: If True, creates a csv file after each worker is done,
Expand Down Expand Up @@ -213,11 +217,66 @@ def run(
)
max_cost_total = searcher_kwargs["budget"]
del searcher_kwargs["budget"]
logger = logging.getLogger("neps")

# if arguments via run_args provided overwrite them
if run_args:
logger.info(
"Loading arguments from 'run_args'. Any arguments directly provided to the "
".run() method will be overwritten by those specified in 'run_args', but only"
" for arguments that are explicitly provided in 'run_args'."
)

optim_settings = get_run_args_from_yaml(run_args)

# Update each argument based on optimization_settings
run_pipeline = optim_settings.get("run_pipeline", run_pipeline)
root_directory = optim_settings.get("root_directory", root_directory)
pipeline_space = optim_settings.get("pipeline_space", pipeline_space)
overwrite_working_directory = optim_settings.get(
"overwrite_working_directory", overwrite_working_directory
)
post_run_summary = optim_settings.get("post_run_summary", post_run_summary)
development_stage_id = optim_settings.get(
"development_stage_id", development_stage_id
)
task_id = optim_settings.get("task_id", task_id)
max_evaluations_total = optim_settings.get(
"max_evaluations_total", max_evaluations_total
)
max_evaluations_per_run = optim_settings.get(
"max_evaluations_per_run", max_evaluations_per_run
)
continue_until_max_evaluation_completed = optim_settings.get(
"continue_until_max_evaluation_completed",
continue_until_max_evaluation_completed,
)
max_cost_total = optim_settings.get("max_cost_total", max_cost_total)
ignore_errors = optim_settings.get("ignore_errors", ignore_errors)
loss_value_on_error = optim_settings.get(
"loss_value_on_error", loss_value_on_error
)
cost_value_on_error = optim_settings.get(
"cost_value_on_error", cost_value_on_error
)
pre_load_hooks = optim_settings.get("pre_load_hooks", pre_load_hooks)
searcher = optim_settings.get("searcher", searcher)
searcher_path = optim_settings.get("searcher_path", searcher_path)
for key, value in optim_settings.get("searcher_kwargs", {}).items():
searcher_kwargs[key] = value

# check if necessary arguments are provided.
check_essential_arguments(
run_pipeline,
root_directory,
pipeline_space,
max_cost_total,
max_evaluations_total,
)

if pre_load_hooks is None:
pre_load_hooks = []

logger = logging.getLogger("neps")
logger.info(f"Starting neps.run using root directory {root_directory}")

# Used to create the yaml holding information about the searcher.
Expand Down
Loading

0 comments on commit 17fe070

Please sign in to comment.