Skip to content

Commit

Permalink
Refactor output folder
Browse files Browse the repository at this point in the history
  • Loading branch information
EdenWuyifan committed Apr 26, 2024
1 parent dce07c6 commit 19e4f8b
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 123 deletions.
29 changes: 13 additions & 16 deletions alpha_automl/pipeline_search/agent_environment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import random

import gymnasium as gym
import numpy as np
from gymnasium.spaces import Box, Dict, Discrete
Expand Down Expand Up @@ -40,7 +41,6 @@ def __init__(self, config: EnvContext):
self.action_offsets = self.generate_action_offsets()
self.action_space = Discrete(self.max_actions)


def reset(self, *, seed=None, options=None):
# init number of steps
self.num_steps = 0
Expand All @@ -54,11 +54,11 @@ def reset(self, *, seed=None, options=None):

def step(self, action):
curr_step = self.step_stack.pop()
offseted_action = self.action_offsets[curr_step]+action
offseted_action = self.action_offsets[curr_step] + action
valid_action_size = self.action_spaces[curr_step]
# Check the action is illegal
valid_moves = self.game.getValidMoves(self.board)
if action >= valid_action_size or valid_moves[offseted_action-1] != 1:
if action >= valid_action_size or valid_moves[offseted_action - 1] != 1:
return (
{"board": np.array(self.board).astype(np.uint8)},
-1,
Expand All @@ -78,19 +78,19 @@ def step(self, action):
False,
{},
)
if non_terminals_moves[0] != "E" and non_terminals_moves[0].upper() == non_terminals_moves[0]:
if (
non_terminals_moves[0] != "E"
and non_terminals_moves[0].upper() == non_terminals_moves[0]
):
self.step_stack.extend(non_terminals_moves[::-1])


# update number of steps
self.num_steps += 1

# update board with new action
# print(f"action: {action}\n board: {self.board}")
self.board = self.game.getNextState(self.board, offseted_action-1)
self.board = self.game.getNextState(self.board, offseted_action - 1)

if self.num_steps > 9:
logger.debug(f"[YFW]================={self.board[self.game.m:]}")
# reward: win(1) - pipeline score, not end(0) - 0, bad(2) - 0
reward = 0
game_end = self.game.getGameEnded(self.board)
Expand Down Expand Up @@ -118,16 +118,14 @@ def step(self, action):
# else:
# split_move = move_string.split("->")
# non_terminals_moves = move_string.split("->")[1].strip().split(" ")

# if split_move[0].strip() == "ENSEMBLER":
# if "E" in non_terminals_moves:
# rewards = 5
# else:
# rewards = 5 - len(non_terminals_moves)
# else:
# rewards = random.uniform(0, 1)



# done & truncated
truncated = self.num_steps >= 20
Expand All @@ -153,21 +151,20 @@ def generate_action_spaces(self):
action_spaces = {}
for action in self.game.grammar["RULES"].values():
move_type, non_terminals_moves = self.extract_action_details(action)

if move_type not in action_spaces:
action_spaces[move_type] = 1
else:
action_spaces[move_type] += 1

return action_spaces

def generate_action_offsets(self):
action_offsets = {}
for action in self.game.grammar["RULES"].values():
move_type, non_terminals_moves = self.extract_action_details(action)

if move_type not in action_offsets:
action_offsets[move_type] = action

return action_offsets

77 changes: 49 additions & 28 deletions alpha_automl/pipeline_search/agent_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,21 @@
import logging
import os
import time
import json
from datetime import datetime

import ray
from alpha_automl.pipeline_search.agent_environment import AutoMLEnv
from ray.rllib.policy import Policy
from ray.rllib.utils.checkpoints import get_checkpoint_info
from ray.tune.logger import pretty_print
from ray.tune.registry import get_trainable_cls
from ray import tune

from alpha_automl.pipeline_search.agent_environment import AutoMLEnv

logger = logging.getLogger(__name__)


def pipeline_search_rllib(game, time_bound, checkpoint_load_folder, checkpoint_save_folder):
def pipeline_search_rllib(
game, time_bound, checkpoint_load_folder, checkpoint_save_folder
):
"""
Search for pipelines using Rllib
"""
Expand All @@ -31,7 +30,6 @@ def pipeline_search_rllib(game, time_bound, checkpoint_load_folder, checkpoint_s

# train model
train_rllib_model(algo, time_bound, checkpoint_load_folder, checkpoint_save_folder)
save_rllib_checkpoint(algo, checkpoint_save_folder)
logger.debug("[RlLib] Done")
ray.shutdown()

Expand Down Expand Up @@ -63,21 +61,21 @@ def load_rllib_checkpoint(game, checkpoint_load_folder, num_rollout_workers):
logger.debug("[RlLib] Create Config done")

# Checking if the list is empty or not
if [f for f in os.listdir(checkpoint_load_folder) if not f.startswith(".")] == []:
if contain_checkpoints(checkpoint_load_folder):
logger.debug("[RlLib] Cannot read RlLib checkpoint, create a new one.")
return config.build()
else:
algo = config.build()
weights = load_rllib_policy_weights(checkpoint_load_folder)

algo.set_weights(weights)
# Restore the old state.
# algo.restore(load_folder)
# checkpoint_info = get_checkpoint_info(load_folder)
return algo


def train_rllib_model(algo, time_bound, load_folder, checkpoint_save_folder):
def train_rllib_model(algo, time_bound, checkpoint_load_folder, checkpoint_save_folder):
timeout = time.time() + time_bound
result = algo.train()
last_best = result["episode_reward_mean"]
Expand All @@ -92,9 +90,12 @@ def train_rllib_model(algo, time_bound, load_folder, checkpoint_save_folder):
):
logger.debug(f"[RlLib] Train Timeout")
break

if [f for f in os.listdir(load_folder) if not f.startswith(".")] != []:
weights = load_rllib_policy_weights()

if contain_checkpoints(checkpoint_save_folder):
weights = load_rllib_policy_weights(checkpoint_save_folder)
algo.set_weights(weights)
elif contain_checkpoints(checkpoint_load_folder):
weights = load_rllib_policy_weights(checkpoint_load_folder)
algo.set_weights(weights)
result = algo.train()
logger.debug(pretty_print(result))
Expand All @@ -108,16 +109,17 @@ def train_rllib_model(algo, time_bound, load_folder, checkpoint_save_folder):
algo.stop()


def load_rllib_policy_weights(checkpoint_load_folder):
def load_rllib_policy_weights(checkpoint_folder):
logger.debug(f"[RlLib] Synchronizing model weights...")
policy = Policy.from_checkpoint(checkpoint_load_folder)
policy = policy['default_policy']
policy = Policy.from_checkpoint(checkpoint_folder)
policy = policy["default_policy"]
weights = policy.get_weights()

weights = {'default_policy': weights}
weights = {"default_policy": weights}

return weights


def save_rllib_checkpoint(algo, checkpoint_save_folder):
save_result = algo.save(checkpoint_dir=checkpoint_save_folder)
path_to_checkpoint = save_result.checkpoint.path
Expand All @@ -131,15 +133,14 @@ def dump_result_to_json(primitives, task_start, score, output_folder=None):
output_path = generate_json_path(output_folder)
# Read JSON data from input file
if not os.path.exists(output_path) or os.path.getsize(output_path) == 0:
with open(output_path, 'w') as f:
with open(output_path, "w") as f:
json.dump({}, f)
with open(output_path, 'r') as f:
with open(output_path, "r") as f:
data = json.load(f)



timestamp = str(datetime.now() - task_start)
# strftime("%Y-%m-%d %H:%M:%S")

# Check for duplicate elements
if primitives in data.values():
return
Expand All @@ -152,13 +153,10 @@ def dump_result_to_json(primitives, task_start, score, output_folder=None):

def read_result_to_pipeline(builder, output_folder=None):
output_path = generate_json_path(output_folder)

pipelines = []
# Read JSON data from input file
if (
not os.path.exists(output_path)
or os.path.getsize(output_path) == 0
):
if not os.path.exists(output_path) or os.path.getsize(output_path) == 0:
return []
with open(output_path, "r") as f:
data = json.load(f)
Expand All @@ -168,11 +166,34 @@ def read_result_to_pipeline(builder, output_folder=None):
pipeline = builder.make_pipeline(primitives)
if pipeline:
pipelines.append(pipeline)

return pipelines


def generate_json_path(output_folder=None):
output_path = os.path.join(output_folder, "result.json")

return output_path


def contain_checkpoints(folder_path):
if folder_path is None:
return False

file_list = os.listdir(folder_path)

if [f for f in file_list if not f.startswith(".")] == []:
return False

if (
"algorithm_state.pkl" in file_list
and "policies" in file_list
and "rllib_checkpoint.json" in file_list
):
return True
else:
logger.info(
f"[RlLib] Checkpoint folder {folder_path} does not contain all necessary files, files: {file_list}."
)

return False
Loading

0 comments on commit 19e4f8b

Please sign in to comment.