Skip to content

Commit

Permalink
add documentations
Browse files Browse the repository at this point in the history
  • Loading branch information
EdenWuyifan committed Apr 27, 2024
1 parent 9d3e490 commit e44a7b8
Showing 1 changed file with 17 additions and 40 deletions.
57 changes: 17 additions & 40 deletions alpha_automl/pipeline_search/agent_environment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import random

import gymnasium as gym
import numpy as np
Expand All @@ -16,40 +15,40 @@ class AutoMLEnv(gym.Env):
step: take an action and return the next state, reward, done, and info
rewards in detail:
- win:
- CLASSIFICATION: 10 + (pipeline score) ^ 5 * 100
- CLASSIFICATION: 10 + (pipeline score) ^ 2 * 100
- REGRESSION: 10 + (100 / pipeline score)
- not end: 1
- invalid: 10
- bad: -100
- bad: -1
"""

def __init__(self, config: EnvContext):
self.game = config["game"]
self.board = self.game.getInitBoard()
self.step_stack = ["S"]
self.game = config["game"] # PipelineGame
self.board = self.game.getInitBoard() # initial board
self.step_stack = ["S"] # stack for steps
self.metadata = self.board[: self.game.m]
self.observation_space = Dict(
{
"board": Box(
0, 85, shape=(self.game.p + self.game.m,), dtype=np.uint8
), # board
), # Ray env board contains pipeline and metadata
}
)
# self.action_space = Discrete(85) # primitives to choose from
self.max_actions = 24
self.action_spaces = self.generate_action_spaces()
self.action_offsets = self.generate_action_offsets()
self.action_space = Discrete(self.max_actions)
self.max_actions = 24 # max number of actions (depends on the largest step in the grammar, i.e. CLASSIFIER)
self.action_spaces = (
self.generate_action_spaces()
) # number of actions for each step
self.action_offsets = (
self.generate_action_offsets()
) # offset for each step, for translating action to PipelineGame action
self.action_space = Discrete(self.max_actions) # Ray env action space

def reset(self, *, seed=None, options=None):
# init number of steps
self.num_steps = 0

self.step_stack = ["S"]
self.board = self.game.getInitBoard()
self.metadata = self.board[: self.game.m]

# print(f"metadata: {self.metadata}\n board: {self.board}")
return {"board": np.array(self.board).astype(np.uint8)}, {}

def step(self, action):
Expand All @@ -69,7 +68,6 @@ def step(self, action):

# Check the action is out of order
move_type, non_terminals_moves = self.extract_action_details(offseted_action)
# logger.critical(f"offseted_action: {offseted_action} ===> curr_step: {curr_step}")
if move_type != curr_step:
return (
{"board": np.array(self.board).astype(np.uint8)},
Expand All @@ -88,10 +86,9 @@ def step(self, action):
self.num_steps += 1

# update board with new action
# print(f"action: {action}\n board: {self.board}")
self.board = self.game.getNextState(self.board, offseted_action - 1)

# reward: win(1) - pipeline score, not end(0) - 0, bad(2) - 0
# reward: win(1) - pipeline score, not end(0) - 1, bad(2) - -1
reward = 0
game_end = self.game.getGameEnded(self.board)
if game_end == 1: # pipeline score over threshold
Expand All @@ -100,32 +97,12 @@ def step(self, action):
reward = 10 + (100 / self.game.getEvaluation(self.board))
else:
reward = 10 + (self.game.getEvaluation(self.board)) ** 2 * 100
except:
logger.critical(f"[PIPELINE FOUND] Error happened")
except Exception as e:
logger.critical(f"[PIPELINE FOUND] Error happened: {str(e)}")
elif game_end == 2: # finished but invalid
reward = 10
else:
# if move_type == "S":
# reward = 1
# elif move_type == "ENCODERS":
# reward = 1
# else:
# if move_type == "IMPUTER" or move_type == "CATEGORICAL_ENCODER":
# reward = 1
reward = 1
# if move_string.upper() != move_string:
# reward = random.uniform(0, 1)
# else:
# split_move = move_string.split("->")
# non_terminals_moves = move_string.split("->")[1].strip().split(" ")

# if split_move[0].strip() == "ENSEMBLER":
# if "E" in non_terminals_moves:
# rewards = 5
# else:
# rewards = 5 - len(non_terminals_moves)
# else:
# rewards = random.uniform(0, 1)

# done & truncated
truncated = self.num_steps >= 20
Expand Down

0 comments on commit e44a7b8

Please sign in to comment.