Skip to content

Commit

Permalink
chore: refactor types
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Aug 22, 2024
1 parent b5efc72 commit cef6b11
Show file tree
Hide file tree
Showing 34 changed files with 207 additions and 156 deletions.
23 changes: 19 additions & 4 deletions stoix/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,18 +161,33 @@ class OnlineAndTarget(NamedTuple):
)


class ExperimentOutput(NamedTuple, Generic[StoixState]):
class SebulbaExperimentOutput(NamedTuple, Generic[StoixState]):
"""Experiment output."""

learner_state: StoixState
train_metrics: Dict[str, chex.Array]


class AnakinExperimentOutput(NamedTuple, Generic[StoixState]):
"""Experiment output."""

learner_state: StoixState
episode_metrics: Dict[str, chex.Array]
train_metrics: Dict[str, chex.Array]


class EvaluationOutput(NamedTuple, Generic[StoixState]):
"""Evaluation output."""

learner_state: StoixState
episode_metrics: Dict[str, chex.Array]


RNNObservation: TypeAlias = Tuple[Observation, Done]
LearnerFn = Callable[[StoixState], ExperimentOutput[StoixState]]
SebulbaLearnerFn = Callable[[StoixState, StoixTransition], ExperimentOutput[StoixState]]
EvalFn = Callable[[FrozenDict, chex.PRNGKey], ExperimentOutput[StoixState]]
LearnerFn = Callable[[StoixState], AnakinExperimentOutput[StoixState]]
SebulbaLearnerFn = Callable[[StoixState, StoixTransition], SebulbaExperimentOutput[StoixState]]
EvalFn = Callable[[FrozenDict, chex.PRNGKey], EvaluationOutput[StoixState]]
SebulbaEvalFn = Callable[[FrozenDict, chex.PRNGKey], Dict[str, chex.Array]]

ActorApply = Callable[..., DistributionLike]

Expand Down
39 changes: 7 additions & 32 deletions stoix/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import time
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union

import chex
import flax.linen as nn
Expand All @@ -18,11 +18,12 @@
EnvFactory,
EvalFn,
EvalState,
ExperimentOutput,
EvaluationOutput,
RecActFn,
RecActorApply,
RNNEvalState,
RNNObservation,
SebulbaEvalFn,
)
from stoix.utils.jax_utils import unreplicate_batch_dim

Expand Down Expand Up @@ -133,7 +134,7 @@ def not_done(carry: Tuple) -> bool:

return eval_metrics

def evaluator_fn(trained_params: FrozenDict, key: chex.PRNGKey) -> ExperimentOutput[EvalState]:
def evaluator_fn(trained_params: FrozenDict, key: chex.PRNGKey) -> EvaluationOutput[EvalState]:
"""Evaluator function."""

# Initialise environment states and timesteps.
Expand Down Expand Up @@ -164,10 +165,9 @@ def evaluator_fn(trained_params: FrozenDict, key: chex.PRNGKey) -> ExperimentOut
axis_name="eval_batch",
)(trained_params, eval_state)

return ExperimentOutput(
return EvaluationOutput(
learner_state=eval_state,
episode_metrics=eval_metrics,
train_metrics={},
)

return evaluator_fn
Expand Down Expand Up @@ -248,7 +248,7 @@ def not_done(carry: Tuple) -> bool:

def evaluator_fn(
trained_params: FrozenDict, key: chex.PRNGKey
) -> ExperimentOutput[RNNEvalState]:
) -> EvaluationOutput[RNNEvalState]:
"""Evaluator function."""

# Initialise environment states and timesteps.
Expand Down Expand Up @@ -289,10 +289,9 @@ def evaluator_fn(
axis_name="eval_batch",
)(trained_params, eval_state)

return ExperimentOutput(
return EvaluationOutput(
learner_state=eval_state,
episode_metrics=eval_metrics,
train_metrics={},
)

return evaluator_fn
Expand Down Expand Up @@ -356,11 +355,6 @@ def evaluator_setup(
return evaluator, absolute_metric_evaluator, (trained_params, eval_keys)


##### THIS IS TEMPORARY

SebulbaEvalFn = Callable[[FrozenDict, chex.PRNGKey], Dict[str, chex.Array]]


def get_sebulba_eval_fn(
env_factory: EnvFactory,
act_fn: ActFn,
Expand All @@ -369,18 +363,7 @@ def get_sebulba_eval_fn(
device: jax.Device,
eval_multiplier: float = 1.0,
) -> Tuple[SebulbaEvalFn, Any]:
"""Creates a function that can be used to evaluate agents on a given environment.

Args:
----
env: an environment that conforms to the mava environment spec.
act_fn: a function that takes in params, timestep, key and optionally a state
and returns actions and optionally a state (see `EvalActFn`).
config: the system config.
np_rng: a numpy random number generator.
eval_multiplier: a scalar that will increase the number of evaluation episodes
by a fixed factor.
"""
eval_episodes = config.arch.num_eval_episodes * eval_multiplier

# We calculate here the number of parallel envs we can run in parallel.
Expand All @@ -405,14 +388,6 @@ def get_sebulba_eval_fn(
print(f"{Fore.YELLOW}{Style.BRIGHT}{msg}{Style.RESET_ALL}")

def eval_fn(params: FrozenDict, key: chex.PRNGKey) -> Dict:
"""Evaluates the given params on an environment and returns relevant metrics.
Metrics are collected by the `RecordEpisodeMetrics` wrapper: episode return and length,
also win rate for environments that support it.
Returns: Dict[str, Array] - dictionary of metric name to metric values for each episode.
"""

def _run_episodes(key: chex.PRNGKey) -> Tuple[chex.PRNGKey, Dict]:
"""Simulates `num_envs` episodes."""
with jax.default_device(device):
Expand Down
6 changes: 3 additions & 3 deletions stoix/systems/awr/ff_awr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
ActorApply,
ActorCriticOptStates,
ActorCriticParams,
AnakinExperimentOutput,
CriticApply,
ExperimentOutput,
LearnerFn,
LogEnvState,
)
Expand Down Expand Up @@ -323,7 +323,7 @@ def _actor_loss_fn(
metric = traj_batch.info
return learner_state, (metric, loss_info)

def learner_fn(learner_state: AWRLearnerState) -> ExperimentOutput[AWRLearnerState]:
def learner_fn(learner_state: AWRLearnerState) -> AnakinExperimentOutput[AWRLearnerState]:
"""Learner function.
This function represents the learner, it updates the network parameters
Expand All @@ -336,7 +336,7 @@ def learner_fn(learner_state: AWRLearnerState) -> ExperimentOutput[AWRLearnerSta
learner_state, (episode_info, loss_info) = jax.lax.scan(
batched_update_step, learner_state, None, config.arch.num_updates_per_eval
)
return ExperimentOutput(
return AnakinExperimentOutput(
learner_state=learner_state,
episode_metrics=episode_info,
train_metrics=loss_info,
Expand Down
6 changes: 3 additions & 3 deletions stoix/systems/awr/ff_awr_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
ActorApply,
ActorCriticOptStates,
ActorCriticParams,
AnakinExperimentOutput,
CriticApply,
ExperimentOutput,
LearnerFn,
LogEnvState,
)
Expand Down Expand Up @@ -323,7 +323,7 @@ def _actor_loss_fn(
metric = traj_batch.info
return learner_state, (metric, loss_info)

def learner_fn(learner_state: AWRLearnerState) -> ExperimentOutput[AWRLearnerState]:
def learner_fn(learner_state: AWRLearnerState) -> AnakinExperimentOutput[AWRLearnerState]:
"""Learner function.
This function represents the learner, it updates the network parameters
Expand All @@ -336,7 +336,7 @@ def learner_fn(learner_state: AWRLearnerState) -> ExperimentOutput[AWRLearnerSta
learner_state, (episode_info, loss_info) = jax.lax.scan(
batched_update_step, learner_state, None, config.arch.num_updates_per_eval
)
return ExperimentOutput(
return AnakinExperimentOutput(
learner_state=learner_state,
episode_metrics=episode_info,
train_metrics=loss_info,
Expand Down
8 changes: 5 additions & 3 deletions stoix/systems/ddpg/ff_d4pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

from stoix.base_types import (
ActorApply,
AnakinExperimentOutput,
ContinuousQApply,
ExperimentOutput,
LearnerFn,
LogEnvState,
Observation,
Expand Down Expand Up @@ -347,7 +347,9 @@ def _actor_loss_fn(
metric = traj_batch.info
return learner_state, (metric, loss_info)

def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPolicyLearnerState]:
def learner_fn(
learner_state: OffPolicyLearnerState,
) -> AnakinExperimentOutput[OffPolicyLearnerState]:
"""Learner function.
This function represents the learner, it updates the network parameters
Expand All @@ -360,7 +362,7 @@ def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPoli
learner_state, (episode_info, loss_info) = jax.lax.scan(
batched_update_step, learner_state, None, config.arch.num_updates_per_eval
)
return ExperimentOutput(
return AnakinExperimentOutput(
learner_state=learner_state,
episode_metrics=episode_info,
train_metrics=loss_info,
Expand Down
8 changes: 5 additions & 3 deletions stoix/systems/ddpg/ff_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

from stoix.base_types import (
ActorApply,
AnakinExperimentOutput,
ContinuousQApply,
ExperimentOutput,
LearnerFn,
LogEnvState,
Observation,
Expand Down Expand Up @@ -309,7 +309,9 @@ def _actor_loss_fn(
metric = traj_batch.info
return learner_state, (metric, loss_info)

def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPolicyLearnerState]:
def learner_fn(
learner_state: OffPolicyLearnerState,
) -> AnakinExperimentOutput[OffPolicyLearnerState]:
"""Learner function.
This function represents the learner, it updates the network parameters
Expand All @@ -323,7 +325,7 @@ def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPoli
learner_state, (episode_info, loss_info) = jax.lax.scan(
batched_update_step, learner_state, None, config.arch.num_updates_per_eval
)
return ExperimentOutput(
return AnakinExperimentOutput(
learner_state=learner_state,
episode_metrics=episode_info,
train_metrics=loss_info,
Expand Down
8 changes: 5 additions & 3 deletions stoix/systems/ddpg/ff_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

from stoix.base_types import (
ActorApply,
AnakinExperimentOutput,
ContinuousQApply,
ExperimentOutput,
LearnerFn,
LogEnvState,
Observation,
Expand Down Expand Up @@ -327,7 +327,9 @@ def _actor_loss_fn(
metric = traj_batch.info
return learner_state, (metric, loss_info)

def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPolicyLearnerState]:
def learner_fn(
learner_state: OffPolicyLearnerState,
) -> AnakinExperimentOutput[OffPolicyLearnerState]:
"""Learner function.
This function represents the learner, it updates the network parameters
Expand All @@ -341,7 +343,7 @@ def learner_fn(learner_state: OffPolicyLearnerState) -> ExperimentOutput[OffPoli
learner_state, (episode_info, loss_info) = jax.lax.scan(
batched_update_step, learner_state, None, config.arch.num_updates_per_eval
)
return ExperimentOutput(
return AnakinExperimentOutput(
learner_state=learner_state,
episode_metrics=episode_info,
train_metrics=loss_info,
Expand Down
6 changes: 3 additions & 3 deletions stoix/systems/mpo/ff_mpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

from stoix.base_types import (
ActorApply,
AnakinExperimentOutput,
ContinuousQApply,
ExperimentOutput,
LearnerFn,
LogEnvState,
OnlineAndTarget,
Expand Down Expand Up @@ -406,7 +406,7 @@ def _q_loss_fn(
metric = traj_batch.info
return learner_state, (metric, loss_info)

def learner_fn(learner_state: MPOLearnerState) -> ExperimentOutput[MPOLearnerState]:
def learner_fn(learner_state: MPOLearnerState) -> AnakinExperimentOutput[MPOLearnerState]:
"""Learner function.
This function represents the learner, it updates the network parameters
Expand All @@ -419,7 +419,7 @@ def learner_fn(learner_state: MPOLearnerState) -> ExperimentOutput[MPOLearnerSta
learner_state, (episode_info, loss_info) = jax.lax.scan(
batched_update_step, learner_state, None, config.arch.num_updates_per_eval
)
return ExperimentOutput(
return AnakinExperimentOutput(
learner_state=learner_state,
episode_metrics=episode_info,
train_metrics=loss_info,
Expand Down
6 changes: 3 additions & 3 deletions stoix/systems/mpo/ff_mpo_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

from stoix.base_types import (
ActorApply,
AnakinExperimentOutput,
ContinuousQApply,
ExperimentOutput,
LearnerFn,
LogEnvState,
OnlineAndTarget,
Expand Down Expand Up @@ -422,7 +422,7 @@ def _q_loss_fn(
metric = traj_batch.info
return learner_state, (metric, loss_info)

def learner_fn(learner_state: MPOLearnerState) -> ExperimentOutput[MPOLearnerState]:
def learner_fn(learner_state: MPOLearnerState) -> AnakinExperimentOutput[MPOLearnerState]:
"""Learner function.
This function represents the learner, it updates the network parameters
Expand All @@ -435,7 +435,7 @@ def learner_fn(learner_state: MPOLearnerState) -> ExperimentOutput[MPOLearnerSta
learner_state, (episode_info, loss_info) = jax.lax.scan(
batched_update_step, learner_state, None, config.arch.num_updates_per_eval
)
return ExperimentOutput(
return AnakinExperimentOutput(
learner_state=learner_state,
episode_metrics=episode_info,
train_metrics=loss_info,
Expand Down
6 changes: 3 additions & 3 deletions stoix/systems/mpo/ff_vmpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

from stoix.base_types import (
ActorApply,
AnakinExperimentOutput,
CriticApply,
ExperimentOutput,
LearnerFn,
OnlineAndTarget,
)
Expand Down Expand Up @@ -305,7 +305,7 @@ def _critic_loss_fn(
metric = traj_batch.info
return learner_state, (metric, loss_info)

def learner_fn(learner_state: VMPOLearnerState) -> ExperimentOutput[VMPOLearnerState]:
def learner_fn(learner_state: VMPOLearnerState) -> AnakinExperimentOutput[VMPOLearnerState]:
"""Learner function.
This function represents the learner, it updates the network parameters
Expand All @@ -318,7 +318,7 @@ def learner_fn(learner_state: VMPOLearnerState) -> ExperimentOutput[VMPOLearnerS
learner_state, (episode_info, loss_info) = jax.lax.scan(
batched_update_step, learner_state, None, config.arch.num_updates_per_eval
)
return ExperimentOutput(
return AnakinExperimentOutput(
learner_state=learner_state,
episode_metrics=episode_info,
train_metrics=loss_info,
Expand Down
Loading

0 comments on commit cef6b11

Please sign in to comment.