Skip to content

Commit

Permalink
Fix loading checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
roquelopez committed Apr 29, 2024
1 parent e9c86b4 commit a6871be
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 15 deletions.
2 changes: 1 addition & 1 deletion alpha_automl/automl_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _search_pipelines(self, automl_hyperparams):
found_pipelines += 1
yield {'pipeline': alphaautoml_pipeline, 'message': 'SCORED'}
except:
logger.info(f'Pipeline scoring error!')
logger.debug(f'Pipeline scoring error!')
continue

logger.debug(f'Found {found_pipelines} pipelines')
Expand Down
25 changes: 11 additions & 14 deletions alpha_automl/pipeline_search/agent_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,20 @@
logger = logging.getLogger(__name__)


def pipeline_search_rllib(
game, time_bound, checkpoint_load_folder, checkpoint_save_folder
):
def pipeline_search_rllib(game, time_bound, checkpoint_load_folder, checkpoint_save_folder):
"""
Search for pipelines using Rllib
"""
ray.init(local_mode=True, num_cpus=8)
num_cpus = int(ray.available_resources()["CPU"])
logger.debug("[RlLib] Ready")

# load checkpoint or create a new one
algo = load_rllib_checkpoint(game, checkpoint_load_folder, num_rollout_workers=7)
logger.debug("[RlLib] Create Algo object done")
logger.debug("Create Algo object done")

# train model
train_rllib_model(algo, time_bound, checkpoint_load_folder, checkpoint_save_folder)
logger.debug("[RlLib] Done")
logger.debug("Training done")
ray.shutdown()


Expand Down Expand Up @@ -58,11 +55,11 @@ def load_rllib_checkpoint(game, checkpoint_load_folder, num_rollout_workers):
)
config.lr = 1e-5
config.simple_optimizer = True
logger.debug("[RlLib] Create Config done")
logger.debug("Create Config done")

# Checking if the list is empty or not
if contain_checkpoints(checkpoint_load_folder):
logger.debug("[RlLib] Cannot read RlLib checkpoint, create a new one.")
if not contain_checkpoints(checkpoint_load_folder):
logger.debug("Cannot read checkpoint, create a new one.")
return config.build()
else:
algo = config.build()
Expand All @@ -88,7 +85,7 @@ def train_rllib_model(algo, time_bound, checkpoint_load_folder, checkpoint_save_
or (best_unchanged_iter >= 600 and result["episode_reward_mean"] >= 0)
# or result["episode_reward_mean"] >= 70
):
logger.debug(f"[RlLib] Train Timeout")
logger.debug(f"Training timeout reached")
break

if contain_checkpoints(checkpoint_save_folder):
Expand All @@ -110,7 +107,7 @@ def train_rllib_model(algo, time_bound, checkpoint_load_folder, checkpoint_save_


def load_rllib_policy_weights(checkpoint_folder):
logger.debug(f"[RlLib] Synchronizing model weights...")
logger.debug(f"Synchronizing model weights...")
policy = Policy.from_checkpoint(checkpoint_folder)
policy = policy["default_policy"]
weights = policy.get_weights()
Expand All @@ -125,7 +122,7 @@ def save_rllib_checkpoint(algo, checkpoint_save_folder):
path_to_checkpoint = save_result.checkpoint.path

logger.debug(
f"[RlLib] An Algorithm checkpoint has been created inside directory: '{path_to_checkpoint}'."
f"An Algorithm checkpoint has been created inside directory: '{path_to_checkpoint}'."
)


Expand Down Expand Up @@ -192,8 +189,8 @@ def contain_checkpoints(folder_path):
):
return True
else:
logger.info(
f"[RlLib] Checkpoint folder {folder_path} does not contain all necessary files, files: {file_list}."
logger.debug(
f"Checkpoint folder {folder_path} does not contain all necessary files, files: {file_list}."
)

return False

0 comments on commit a6871be

Please sign in to comment.