Skip to content

Commit

Permalink
increase code maintenance + enabled strict usage of run_args, overwri…
Browse files Browse the repository at this point in the history
…tting all other arguments that are provided via neps.run(...) + Enable loading the pipeline_space provided as a dictionary from YAML and the loading of the BaseOptimizer as searcher from YAML
  • Loading branch information
danrgll committed Mar 19, 2024
1 parent a3cd1e6 commit 02ffe98
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 104 deletions.
43 changes: 21 additions & 22 deletions neps/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,50 +219,49 @@ def run(
del searcher_kwargs["budget"]
logger = logging.getLogger("neps")


# if arguments via run_args provided overwrite them
if run_args:
logger.info(
"Loading arguments from 'run_args'. Any arguments directly provided to the "
".run() method will be overwritten by those specified in 'run_args', but only"
" for arguments that are explicitly provided in 'run_args'."
"WARNING: Loading arguments from 'run_args'. Arguments directly provided "
"to neps.run(...) will be not used!"
)

optim_settings = get_run_args_from_yaml(run_args)

# Update each argument based on optimization_settings
run_pipeline = optim_settings.get("run_pipeline", run_pipeline)
root_directory = optim_settings.get("root_directory", root_directory)
pipeline_space = optim_settings.get("pipeline_space", pipeline_space)
# Update each argument based on optimization_settings, if not provided in yaml
# use default value, strict but will change in the future
run_pipeline = optim_settings.get("run_pipeline", None)
root_directory = optim_settings.get("root_directory", None)
pipeline_space = optim_settings.get("pipeline_space", None)
overwrite_working_directory = optim_settings.get(
"overwrite_working_directory", overwrite_working_directory
"overwrite_working_directory", False
)
post_run_summary = optim_settings.get("post_run_summary", post_run_summary)
post_run_summary = optim_settings.get("post_run_summary", False)
development_stage_id = optim_settings.get(
"development_stage_id", development_stage_id
"development_stage_id", None
)
task_id = optim_settings.get("task_id", task_id)
task_id = optim_settings.get("task_id", None)
max_evaluations_total = optim_settings.get(
"max_evaluations_total", max_evaluations_total
"max_evaluations_total", None
)
max_evaluations_per_run = optim_settings.get(
"max_evaluations_per_run", max_evaluations_per_run
"max_evaluations_per_run", None
)
continue_until_max_evaluation_completed = optim_settings.get(
"continue_until_max_evaluation_completed",
continue_until_max_evaluation_completed,
False,
)
max_cost_total = optim_settings.get("max_cost_total", max_cost_total)
ignore_errors = optim_settings.get("ignore_errors", ignore_errors)
max_cost_total = optim_settings.get("max_cost_total", None)
ignore_errors = optim_settings.get("ignore_errors", False)
loss_value_on_error = optim_settings.get(
"loss_value_on_error", loss_value_on_error
"loss_value_on_error", None
)
cost_value_on_error = optim_settings.get(
"cost_value_on_error", cost_value_on_error
"cost_value_on_error", None
)
pre_load_hooks = optim_settings.get("pre_load_hooks", pre_load_hooks)
searcher = optim_settings.get("searcher", searcher)
searcher_path = optim_settings.get("searcher_path", searcher_path)
pre_load_hooks = optim_settings.get("pre_load_hooks", None)
searcher = optim_settings.get("searcher", "default")
searcher_path = optim_settings.get("searcher_path", None)
for key, value in optim_settings.get("searcher_kwargs", {}).items():
searcher_kwargs[key] = value

Expand Down
Loading

0 comments on commit 02ffe98

Please sign in to comment.