Skip to content

Commit

Permalink
Feat/Add Sebulba (#105)
Browse files Browse the repository at this point in the history
* feat: add sebulba ppo system and reorganise for future sebulba systems
  • Loading branch information
EdanToledo committed Aug 18, 2024
1 parent 1a86c09 commit 5977cce
Show file tree
Hide file tree
Showing 85 changed files with 2,188 additions and 124 deletions.
44 changes: 35 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,20 @@

## Welcome to Stoix! 🏛️

Stoix provides simplified code for quickly iterating on ideas in single-agent reinforcement learning with useful implementations of popular single-agent RL algorithms in JAX allowing for easy parallelisation across devices with JAX's `pmap`. All implementations are fully compilable with JAX's `jit` thus making training and environment execution very fast. However, this requires environments written in JAX. Algorithms and their default hyperparameters have not been hyper-optimised for any specific environment and are useful as a starting point for research and/or for initial baselines.
Stoix provides simplified code for quickly iterating on ideas in single-agent reinforcement learning with useful implementations of popular single-agent RL algorithms in JAX allowing for easy parallelisation across devices with JAX's `pmap`. All implementations are fully compiled with JAX's `jit` thus making training and environment execution very fast. However, this does require environments written in JAX. For environments not written in JAX, Stoix offers Sebulba systems (see below). Algorithms and their default hyperparameters have not been hyper-optimised for any specific environment and are useful as a starting point for research and/or for initial baselines.

To join us in these efforts, please feel free to reach out, raise issues or read our [contribution guidelines](#contributing-) (or just star 🌟 to stay up to date with the latest developments)!

Stoix is fully in JAX with substantial speed improvement compared to other popular libraries. We currently provide native support for the [Jumanji][jumanji] environment API and wrappers for popular JAX-based RL environments.
Stoix is fully in JAX with substantial speed improvement compared to other popular libraries. We currently provide native support for the [Jumanji][jumanji] environment API and wrappers for popular RL environments.

## System Design Paradigms
Stoix offers two primary system design paradigms (Podracer Architectures) to cater to different research and deployment needs:

- **Anakin:** Traditional Stoix implementations are fully end-to-end compiled with JAX, focusing on speed and simplicity with native JAX environments. This design paradigm is ideal for setups where all components, including environments, can be optimized using JAX, leveraging the full power of JAX's pmap and jit. For an illustration of the Anakin architecture, see this [figure](docs/images/anakin_arch.jpg) from the [Mava](mava) technical report.

- **Sebulba:** The Sebulba system introduces flexibility by allowing different devices to be assigned specifically for learning and acting. In this setup, acting devices serve as inference servers for multiple parallel environments, which can be written in any framework, not just JAX. This enables Stoix to be used with a broader range of environments while still benefiting from JAX's speed. For an illustration of the Sebulba architecture, see this [animation](docs/images/sebulba_arch.gif) from the [InstaDeep Sebulba implementation](https://github.com/instadeepai/sebulba/).

Not all implementations have both Anakin and Sebulba implementations but effort has gone into making the two implementations as similar as possible to allow easy conversion.

## Code Philosophy 🧘

Expand All @@ -47,9 +56,11 @@ The current code in Stoix was initially **largely** taken and subsequently adapt
### Stoix TLDR
1. **Algorithms:** Stoix offers easily hackable, single-file implementations of popular algorithms in pure JAX. You can vectorize algorithm training on a single device using `vmap` as well as distribute training across multiple devices with `pmap` (or both). Multi-host support (i.e., vmap/pmap over multiple devices **and** machines) is coming soon! All implementations include checkpointing to save and resume parameters and training runs.

2. **Hydra Config System:** Leverage the Hydra configuration system for efficient and consistent management of experiments, network architectures, and environments. Hydra facilitates the easy addition of new hyperparameters and supports multi-runs and Optuna hyperparameter optimization. No more need to create large bash scripts to run a series of experiments with differing hyperparameters, network architectures or environments.
2. **System Designs:** Choose between Anakin systems for fully JAX-optimized workflows or Sebulba systems for flexibility with non-JAX environments.

3. **Hydra Config System:** Leverage the Hydra configuration system for efficient and consistent management of experiments, network architectures, and environments. Hydra facilitates the easy addition of new hyperparameters and supports multi-runs and Optuna hyperparameter optimization. No more need to create large bash scripts to run a series of experiments with differing hyperparameters, network architectures or environments.

3. **Advanced Logging:** Stoix features advanced and configurable logging, ready for output to the terminal, TensorBoard, and other ML tracking dashboards (WandB and Neptune). It also supports logging experiments in JSON format ready for statistical tests and generating RLiable plots (see the plotting notebook). This enables statistically confident comparisons of algorithms natively.
4. **Advanced Logging:** Stoix features advanced and configurable logging, ready for output to the terminal, TensorBoard, and other ML tracking dashboards (WandB and Neptune). It also supports logging experiments in JSON format ready for statistical tests and generating RLiable plots (see the plotting notebook). This enables statistically confident comparisons of algorithms natively.

Stoix currently offers the following building blocks for Single-Agent RL research:

Expand Down Expand Up @@ -78,14 +89,17 @@ Stoix currently offers the following building blocks for Single-Agent RL researc
- **Sampled Alpha/Mu-Zero** - [Paper](https://arxiv.org/abs/2104.06303)

### Environment Wrappers 🍬
Stoix offers wrappers for [Gymnax][gymnax], [Jumanji][jumanji], [Brax][brax], [XMinigrid][xminigrid], [Craftax][craftax], [POPJym][popjym], [Navix][navix] and even [JAXMarl][jaxmarl] (although using Centralised Controllers).
Stoix offers wrappers for:

- **JAX environments:** [Gymnax][gymnax], [Jumanji][jumanji], [Brax][brax], [XMinigrid][xminigrid], [Craftax][craftax], [POPJym][popjym], [Navix][navix] and even [JAXMarl][jaxmarl] (although using Centralised Controllers).
- **Non-JAX environments:** [Envpool][envpool] and [Gymnasium][gymnasium].

### Statistically Robust Evaluation 🧪
Stoix natively supports logging to json files which adhere to the standard suggested by [Gorsane et al. (2022)][toward_standard_eval]. This enables easy downstream experiment plotting and aggregation using the tools found in the [MARL-eval][marl_eval] library.

## Performance and Speed 🚀

As the code in Stoix (at the time of creation) was in essence a port of [Mava][mava], for further speed comparisons we point to their repo. Additionally, we refer to the PureJaxRL blog post [here](https://chrislu.page/blog/meta-disco/) where the speed benefits of end-to-end JAX systems are discussed.
As the code in Stoix (at the time of creation) was in essence a port of [Mava][mava], for further speed comparisons we point to their repo. Additionally, we refer to the PureJaxRL blog post [here](https://chrislu.page/blog/meta-disco/) where the speed benefits of end-to-end JAX systems are discussed. Lastly, we point to the Podracer architectures paper [here][anakin_paper] where these ideas were first discussed and benchmarked.

Below we provide some plots illustrating that Stoix performs equally to that of [PureJaxRL][purejaxrl] but with the added benefit of the code being already set up for `pmap` distribution over devices as well as the other features provided (algorithm implementations, logging, config system, etc).
<p align="center">
Expand Down Expand Up @@ -118,14 +132,22 @@ we advise users to explicitly install the correct JAX version (see the [official

To get started with training your first Stoix system, simply run one of the system files. e.g.,

For an Anakin system:

```bash
python stoix/systems/ppo/anakin/ff_ppo.py
```

or for a Sebulba system:

```bash
python stoix/systems/ppo/ff_ppo.py
python stoix/systems/ppo/sebulba/ff_ppo.py arch=sebulba env=envpool/pong network=visual_resnet
```

Stoix makes use of Hydra for config management. In order to see our default system configs please see the `stoix/configs/` directory. A benefit of Hydra is that configs can either be set in config yaml files or overwritten from the terminal on the fly. For an example of running a system on the CartPole environment, the above code can simply be adapted as follows:
Stoix makes use of Hydra for config management. In order to see our default system configs please see the `stoix/configs/` directory. A benefit of Hydra is that configs can either be set in config yaml files or overwritten from the terminal on the fly. For an example of running a system on the CartPole environment and changing any hyperparameters, the above code can simply be adapted as follows:

```bash
python stoix/systems/ppo/ff_ppo.py env=gymnax/cartpole
python stoix/systems/ppo/ff_ppo.py env=gymnax/cartpole system.rollout_length=32 system.decay_learning_rates=True
```

Additionally, certain implementations such as Dueling DQN are decided by the network architecture but the underlying algorithm stays the same. For example, if you wanted to run Dueling DQN you would simply do:
Expand All @@ -146,6 +168,8 @@ python stoix/systems/q_learning/ff_c51.py network=mlp_dueling_c51

2. Due to the way Stoix is set up, you are not guaranteed to run for exactly the number of timesteps you set. A warning is given at the beginning of a run on the actual number of timesteps that will be run. This value will always be less than or equal to the specified sample budget. To get the exact number of transitions to run, ensure that the number of timesteps is divisible by the rollout length * total_num_envs and additionally ensure that the number of evaluations spaced out throughout training perfectly divide the number of updates to be performed. To see the exact calculation, see the file total_timestep_checker.py. This will give an indication of how the actual number of timesteps is calculated and how you can easily set it up to run the exact amount you desire. Its relatively trivial to do so but it is important to keep in mind.

3. Optimising the performance and speed for Sebulba systems can be a little tricky as you need to balance the pipeline size, the number of actor threads, etc so keep this in mind when applying an algorithm to a new problem.

## Contributing 🤝

Please read our [contributing docs](docs/CONTRIBUTING.md) for details on how to submit pull requests, our Contributor License Agreement and community guidelines.
Expand Down Expand Up @@ -217,5 +241,7 @@ We would like to thank the authors and developers of [Mava](mava) as this was es
[craftax]: https://github.com/MichaelTMatthews/Craftax
[popjym]: https://github.com/FLAIROx/popjym
[navix]: https://github.com/epignatelli/navix
[envpool]: https://github.com/sail-sg/envpool/
[gymnasium]: https://github.com/Farama-Foundation/Gymnasium

Disclaimer: This is not an official InstaDeep product nor is any of the work putforward associated with InstaDeep in any official capacity.
Binary file added docs/images/anakin_arch.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/sebulba_arch.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ chex
colorama
craftax
distrax @ git+https://github.com/google-deepmind/distrax # distrax release doesn't support jax > 0.4.13
envpool
flashbax @ git+https://github.com/instadeepai/flashbax
flax
gymnasium
gymnax>=0.0.6
huggingface_hub
hydra-core==1.3.2
Expand Down
25 changes: 21 additions & 4 deletions stoix/base_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Tuple, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Tuple, TypeVar

import chex
from distrax import DistributionLike
Expand Down Expand Up @@ -36,7 +36,7 @@ class Observation(NamedTuple):

agent_view: chex.Array # (num_obs_features,)
action_mask: chex.Array # (num_actions,)
step_count: chex.Array # (,)
step_count: Optional[chex.Array] = None # (,)


class ObservationGlobalState(NamedTuple):
Expand Down Expand Up @@ -106,8 +106,18 @@ class ActorCriticHiddenStates(NamedTuple):
critic_hidden_state: HiddenState


class LearnerState(NamedTuple):
"""State of the learner."""
class CoreLearnerState(NamedTuple):
"""Base state of the learner. Can be used for both on-policy and off-policy learners.
Mainly used for sebulba systems since we dont store env state."""

params: Parameters
opt_states: OptStates
key: chex.PRNGKey
timestep: TimeStep


class OnPolicyLearnerState(NamedTuple):
"""State of the learner. Used for on-policy learners."""

params: Parameters
opt_states: OptStates
Expand Down Expand Up @@ -146,6 +156,9 @@ class OnlineAndTarget(NamedTuple):
StoixState = TypeVar(
"StoixState",
)
StoixTransition = TypeVar(
"StoixTransition",
)


class ExperimentOutput(NamedTuple, Generic[StoixState]):
Expand All @@ -158,6 +171,7 @@ class ExperimentOutput(NamedTuple, Generic[StoixState]):

RNNObservation: TypeAlias = Tuple[Observation, Done]
LearnerFn = Callable[[StoixState], ExperimentOutput[StoixState]]
SebulbaLearnerFn = Callable[[StoixState, StoixTransition], ExperimentOutput[StoixState]]
EvalFn = Callable[[FrozenDict, chex.PRNGKey], ExperimentOutput[StoixState]]

ActorApply = Callable[..., DistributionLike]
Expand All @@ -174,3 +188,6 @@ class ExperimentOutput(NamedTuple, Generic[StoixState]):
[FrozenDict, HiddenState, RNNObservation, chex.PRNGKey], Tuple[HiddenState, chex.Array]
]
RecCriticApply = Callable[[FrozenDict, HiddenState, RNNObservation], Tuple[HiddenState, Value]]


EnvFactory = Callable[[int], Any]
2 changes: 1 addition & 1 deletion stoix/configs/arch/anakin.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# --- Anakin config ---

architecture_name: anakin
# --- Training ---
seed: 42 # RNG seed.
update_batch_size: 1 # Number of vectorised gradient updates per device.
Expand Down
29 changes: 29 additions & 0 deletions stoix/configs/arch/sebulba.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# --- Sebulba config ---
architecture_name : sebulba
# --- Training ---
seed: 42 # RNG seed.
total_num_envs: 1024 # Total Number of vectorised environments across all actors. Needs to be divisible by the number of actor devices and actors per device.
total_timesteps: 1e7 # Set the total environment steps.
# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value.
num_updates: ~ # Number of updates

# Define the number of actors per device and which devices to use.
actor:
device_ids: [0,1] # Define which devices to use for the actors.
actor_per_device: 2 # number of different threads per actor device.

# Define which devices to use for the learner.
learner:
device_ids: [2,3] # Define which devices to use for the learner.

# Size of the queue for the pipeline where actors push data and the learner pulls data.
pipeline_queue_size: 10

# --- Evaluation ---
evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select
# an action which corresponds to the greatest logit. If false, the policy will sample
# from the logits.
num_eval_episodes: 128 # Number of episodes to evaluate per evaluation.
num_evaluation: 20 # Number of evenly spaced evaluations to perform during training.
absolute_metric: True # Whether the absolute metric should be computed. For more details
# on the absolute metric please see: https://arxiv.org/abs/2209.10485
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ defaults:
- network: mlp
- env: gymnax/cartpole
- _self_

hydra:
searchpath:
- file://stoix/configs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ defaults:
- network: mlp_continuous
- env: brax/ant
- _self_

hydra:
searchpath:
- file://stoix/configs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ defaults:
- network: mlp
- env: gymnax/cartpole
- _self_

hydra:
searchpath:
- file://stoix/configs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ defaults:
- network: mlp_c51
- env: gymnax/cartpole
- _self_

hydra:
searchpath:
- file://stoix/configs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ defaults:
- network: mlp_d4pg
- env: brax/ant
- _self_

hydra:
searchpath:
- file://stoix/configs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ defaults:
- network: mlp_ddpg
- env: brax/ant
- _self_

hydra:
searchpath:
- file://stoix/configs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ defaults:
- network: mlp_dqn
- env: gymnax/cartpole
- _self_

hydra:
searchpath:
- file://stoix/configs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ defaults:
- network: mlp_continuous
- env: brax/ant
- _self_

hydra:
searchpath:
- file://stoix/configs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ defaults:
- network: mlp_dqn
- env: gymnax/cartpole
- _self_

hydra:
searchpath:
- file://stoix/configs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ defaults:
- network: mlp_dqn
- env: gymnax/cartpole
- _self_

hydra:
searchpath:
- file://stoix/configs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ defaults:
- network: mlp_dqn
- env: gymnax/cartpole
- _self_

hydra:
searchpath:
- file://stoix/configs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ defaults:
- network: mlp_mpo
- env: gymnax/cartpole
- _self_

hydra:
searchpath:
- file://stoix/configs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ defaults:
- network: mlp_mpo_continuous
- env: brax/ant
- _self_

hydra:
searchpath:
- file://stoix/configs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ defaults:
- network: muzero
- env: gymnax/cartpole
- _self_

hydra:
searchpath:
- file://stoix/configs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ defaults:
- network: mlp
- env: gymnax/cartpole
- _self_

hydra:
searchpath:
- file://stoix/configs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ defaults:
- network: mlp_continuous
- env: brax/ant
- _self_

hydra:
searchpath:
- file://stoix/configs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ defaults:
- network: mlp_qr_dqn
- env: gymnax/cartpole
- _self_

hydra:
searchpath:
- file://stoix/configs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ defaults:
- network: mlp_noisy_dueling_c51
- env: gymnax/cartpole
- _self_

hydra:
searchpath:
- file://stoix/configs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ defaults:
- network: mlp
- env: gymnax/cartpole
- _self_

hydra:
searchpath:
- file://stoix/configs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ defaults:
- network: mlp_continuous
- env: gymnax/pendulum
- _self_

hydra:
searchpath:
- file://stoix/configs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ defaults:
- network: mlp_sac
- env: brax/ant
- _self_

hydra:
searchpath:
- file://stoix/configs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ defaults:
- network: mlp_continuous
- env: brax/ant
- _self_

hydra:
searchpath:
- file://stoix/configs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ defaults:
- network: sampled_muzero
- env: gymnax/pendulum
- _self_

hydra:
searchpath:
- file://stoix/configs
Loading

0 comments on commit 5977cce

Please sign in to comment.