Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Implement self-play for two-player zero-sum games #103

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions stoix/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ class OnlineAndTarget(NamedTuple):
target: FrozenDict


class OnlineTargetOpponent(NamedTuple):
online: FrozenDict
target: FrozenDict
opponent: FrozenDict


StoixState = TypeVar(
"StoixState",
)
Expand Down
20 changes: 20 additions & 0 deletions stoix/configs/arch/anakin_debug.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# --- Anakin config ---
architecture_name: anakin_debug
# --- Training ---
seed: 42 # RNG seed.
update_batch_size: 1 # Number of vectorised gradient updates per device.
total_num_envs: 4 # Total Number of vectorised environments across all devices and batched_updates. Needs to be divisible by n_devices*update_batch_size.
total_timesteps: 1e4 # Set the total environment steps.
# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value.
num_updates: ~ # Number of updates

# --- Evaluation ---
evaluation_greedy:
False # Evaluate the policy greedily. If True the policy will select
# an action which corresponds to the greatest logit. If false, the policy will sample
# from the logits.
num_eval_episodes: 8 # Number of episodes to evaluate per evaluation.
num_evaluation: 10 # Number of evenly spaced evaluations to perform during training.
absolute_metric:
True # Whether the absolute metric should be computed. For more details
# on the absolute metric please see: https://arxiv.org/abs/2209.10485
11 changes: 11 additions & 0 deletions stoix/configs/default/anakin/default_ff_dqn_selfplay.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
defaults:
- logger: base_logger
- arch: anakin_debug
- system: ff_dqn
- network: mlp_dqn
- env: pgx/chess
- _self_

hydra:
searchpath:
- file://stoix/configs
168 changes: 168 additions & 0 deletions stoix/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,135 @@ def evaluator_fn(trained_params: FrozenDict, key: chex.PRNGKey) -> EvaluationOut
return evaluator_fn


def get_selfplay_evaluator_fn(
env: Environment,
act_fn: ActFn,
config: DictConfig,
log_solve_rate: bool = False,
eval_multiplier: int = 1,
) -> EvalFn:
"""Get the evaluator function for feedforward networks.

Args:
env (Environment): An environment instance for evaluation.
act_fn (callable): The act_fn that returns the action taken by the agent.
config (dict): Experiment configuration.
eval_multiplier (int): A scalar that will increase the number of evaluation
episodes by a fixed factor. The reason for the increase is to enable the
computation of the `absolute metric` which is a metric computed and the end
of training by rolling out the policy which obtained the greatest evaluation
performance during training for 10 times more episodes than were used at a
single evaluation step.
"""

def eval_one_episode(
params: FrozenDict, opponent_params: FrozenDict, init_eval_state: EvalState
) -> Dict:
"""Evaluate one episode. It is vectorized over the number of evaluation episodes."""

def _env_step(eval_state: EvalState) -> EvalState:
"""Step the environment."""
# PRNG keys.
key, env_state, last_timestep, step_count, episode_return = eval_state

# Select action.
key, policy_key = jax.random.split(key)

def select_actions(
params: FrozenDict,
opponent_params: FrozenDict,
obs: chex.Array,
current_player: int,
policy_key: chex.PRNGKey,
) -> chex.Array:
"""
Samples an action from the agent or the opponent depending on
`env_state.current_player`.
"""
# player 0: opponent, player 1: agent
current_player_params = jax.lax.select(current_player, params, opponent_params)
actor_policy = act_fn(current_player_params, obs)

return actor_policy.sample(seed=policy_key)

action = select_actions(
params,
opponent_params,
jax.tree_util.tree_map(lambda x: x[jnp.newaxis, ...], last_timestep.observation),
env_state.env_state.current_player,
policy_key,
)

# Step environment.
env_state, timestep = env.step(env_state, action.squeeze())

# Log episode metrics.
episode_return += timestep.reward
step_count += 1
eval_state = EvalState(key, env_state, timestep, step_count, episode_return)
return eval_state

def not_done(carry: Tuple) -> bool:
"""Check if the episode is done."""
timestep = carry[2]
is_not_done: bool = ~timestep.last()
return is_not_done

final_state = jax.lax.while_loop(not_done, _env_step, init_eval_state)

eval_metrics = {
"episode_return": final_state.episode_return,
"episode_length": final_state.step_count,
}
# Log solve episode if solve rate is required.
if log_solve_rate:
eval_metrics["solve_episode"] = jnp.all(
final_state.episode_return >= config.env.solved_return_threshold
).astype(int)

return eval_metrics

def evaluator_fn(
trained_params: FrozenDict, opponent_params: FrozenDict, key: chex.PRNGKey
) -> EvaluationOutput[EvalState]:
"""Evaluator function."""

# Initialise environment states and timesteps.
n_devices = len(jax.devices())

eval_batch = (config.arch.num_eval_episodes // n_devices) * eval_multiplier

key, *env_keys = jax.random.split(key, eval_batch + 1)
env_states, timesteps = jax.vmap(env.reset)(
jnp.stack(env_keys),
)
# Split keys for each core.
key, *step_keys = jax.random.split(key, eval_batch + 1)
# Add dimension to pmap over.
step_keys = jnp.stack(step_keys).reshape(eval_batch, -1)

eval_state = EvalState(
key=step_keys,
env_state=env_states,
timestep=timesteps,
step_count=jnp.zeros((eval_batch, 1)),
episode_return=jnp.zeros_like(timesteps.reward),
)

eval_metrics = jax.vmap(
eval_one_episode,
in_axes=(None, 0),
axis_name="eval_batch",
)(trained_params, opponent_params, eval_state)

return EvaluationOutput(
learner_state=eval_state,
episode_metrics=eval_metrics,
)

return evaluator_fn


def get_rnn_evaluator_fn(
env: Environment,
rec_act_fn: RecActFn,
Expand Down Expand Up @@ -355,6 +484,45 @@ def evaluator_setup(
return evaluator, absolute_metric_evaluator, (trained_params, eval_keys)


def selfplay_evaluator_setup(
eval_env: Environment,
key_e: chex.PRNGKey,
eval_act_fn: ActFn,
params: FrozenDict,
opponent_params: FrozenDict,
config: DictConfig,
) -> Tuple[EvalFn, EvalFn, Tuple[FrozenDict, chex.Array]]:
"""Initialise evaluator_fn."""
# Get available TPU cores.
n_devices = len(jax.devices())
# Check if solve rate is required for evaluation.
if hasattr(config.env, "solved_return_threshold"):
log_solve_rate = True
else:
log_solve_rate = False
# Vmap it over number of agents and create evaluator_fn.

evaluator = get_selfplay_evaluator_fn(eval_env, eval_act_fn, config, log_solve_rate) # type: ignore
absolute_metric_evaluator = get_selfplay_evaluator_fn(
eval_env,
eval_act_fn, # type: ignore
config,
log_solve_rate,
10,
)

evaluator = jax.pmap(evaluator, axis_name="device")
absolute_metric_evaluator = jax.pmap(absolute_metric_evaluator, axis_name="device")

# Broadcast trained params to cores and split keys for each core.
trained_params = unreplicate_batch_dim(params)
opponent_params = unreplicate_batch_dim(opponent_params)
key_e, *eval_keys = jax.random.split(key_e, n_devices + 1)
eval_keys = jnp.stack(eval_keys).reshape(n_devices, -1)

return evaluator, absolute_metric_evaluator, (trained_params, opponent_params, eval_keys)


def get_sebulba_eval_fn(
env_factory: EnvFactory,
act_fn: ActFn,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
LearnerFn,
LogEnvState,
OffPolicyLearnerState,
OnlineAndTarget,
OnlineTargetOpponent,
)
from stoix.evaluator import evaluator_setup, get_distribution_act_fn
from stoix.evaluator import get_distribution_act_fn, selfplay_evaluator_setup
from stoix.networks.base import FeedForwardActor as Actor
from stoix.systems.q_learning.dqn_types import Transition
from stoix.utils import make_env as environments
Expand Down Expand Up @@ -155,7 +155,9 @@ def _q_loss_fn(
q_t = q_apply_fn(target_q_params, transitions.next_obs).preferences

# Cast and clip rewards.
discount = 1.0 - transitions.done.astype(jnp.float32)
discount = (
1.0 - transitions.done.astype(jnp.float32)
) * -1 # reverse the discount to obtain zero-sum returns
d_t = (discount * config.system.gamma).astype(jnp.float32)
r_t = jnp.clip(
transitions.reward, -config.system.max_abs_reward, config.system.max_abs_reward
Expand Down Expand Up @@ -186,6 +188,9 @@ def _q_loss_fn(
transition_sample = buffer_sample_fn(buffer_state, sample_key)
transitions: Transition = transition_sample.experience

# The opponent params are the online params with a 1-step lag
opponent_params = copy.deepcopy(params.online)

# CALCULATE Q LOSS
q_grad_fn = jax.grad(_q_loss_fn, has_aux=True)
q_grads, q_loss_info = q_grad_fn(
Expand All @@ -208,7 +213,9 @@ def _q_loss_fn(
new_target_q_params = optax.incremental_update(
q_new_online_params, params.target, config.system.tau
)
q_new_params = OnlineAndTarget(q_new_online_params, new_target_q_params)
q_new_params = OnlineTargetOpponent(
q_new_online_params, new_target_q_params, opponent_params
)

# PACK NEW PARAMS AND OPTIMISER STATE
new_params = q_new_params
Expand Down Expand Up @@ -271,7 +278,7 @@ def learner_setup(
config.system.action_dim = action_dim

# PRNG keys.
key, q_net_key = keys
key, q_net_key, opponent_key = keys

# Define networks and optimiser.
q_network_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
Expand Down Expand Up @@ -305,7 +312,10 @@ def learner_setup(
q_target_params = q_online_params
q_opt_state = q_optim.init(q_online_params)

params = OnlineAndTarget(q_online_params, q_target_params)
# Initialise opponent parameters
opponent_q_params = q_network.init(opponent_key, init_x)

params = OnlineTargetOpponent(q_online_params, q_target_params, opponent_q_params)
opt_states = q_opt_state

q_network_apply_fn = q_network.apply
Expand Down Expand Up @@ -378,7 +388,7 @@ def reshape_states(x: chex.Array) -> chex.Array:
**config.logger.checkpointing.load_args, # Other checkpoint args
)
# Restore the learner state from the checkpoint
restored_params, _ = loaded_checkpoint.restore_params(TParams=OnlineAndTarget)
restored_params, _ = loaded_checkpoint.restore_params(TParams=OnlineTargetOpponent)
# Update the params
params = restored_params

Expand Down Expand Up @@ -433,17 +443,26 @@ def run_experiment(_config: DictConfig) -> float:
env, eval_env = environments.make(config=config)

# PRNG keys.
key, key_e, q_net_key = jax.random.split(jax.random.PRNGKey(config.arch.seed), num=3)
key, key_e, opponent_key, q_net_key = jax.random.split(
jax.random.PRNGKey(config.arch.seed), num=4
)

# Setup learner.
learn, eval_q_network, learner_state = learner_setup(env, (key, q_net_key), config)
learn, eval_q_network, learner_state = learner_setup(
env, (key, q_net_key, opponent_key), config
)

# Setup evaluator.
evaluator, absolute_metric_evaluator, (trained_params, eval_keys) = evaluator_setup(
(
evaluator,
absolute_metric_evaluator,
(trained_params, opponent_eval_params, eval_keys),
) = selfplay_evaluator_setup( # TODO: setup selfplay evaluator
eval_env=eval_env,
key_e=key_e,
eval_act_fn=get_distribution_act_fn(config, eval_q_network.apply),
params=learner_state.params.online,
opponent_params=learner_state.params.opponent,
config=config,
)

Expand Down Expand Up @@ -555,8 +574,9 @@ def run_experiment(_config: DictConfig) -> float:


@hydra.main(
# TODO: change back to anakin after testing
config_path="../../configs/default/anakin",
config_name="default_ff_dqn.yaml",
config_name="default_ff_dqn_selfplay.yaml",
version_base="1.2",
)
def hydra_entry_point(cfg: DictConfig) -> float:
Expand Down
Loading