Skip to content

Commit

Permalink
fix: make FlattenObservationWrapper also flatten next_obs (#115)
Browse files Browse the repository at this point in the history
* fix: flattenObservationWrapper to flatten next_obs as well.
  • Loading branch information
JesseSilverberg committed Sep 12, 2024
1 parent 1889659 commit 6125d36
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions stoix/wrappers/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,34 @@ def __init__(self, env: Environment) -> None:
obs_shape = self._env.observation_spec().agent_view.shape
self._obs_shape = (np.prod(obs_shape),)

def _flatten(self, obs: Observation) -> Array:
agent_view = obs.agent_view.astype(jnp.float32)
return agent_view.reshape(self._obs_shape)

def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:
state, timestep = self._env.reset(key)
agent_view = timestep.observation.agent_view.astype(jnp.float32)
agent_view = agent_view.reshape(self._obs_shape)
agent_view = self._flatten(timestep.observation)
timestep = timestep.replace(
observation=timestep.observation._replace(agent_view=agent_view),
)
if "next_obs" in timestep.extras:
agent_view = self._flatten(timestep.extras["next_obs"])
timestep.extras["next_obs"] = timestep.extras["next_obs"]._replace(
agent_view=agent_view
)
return state, timestep

def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]:
state, timestep = self._env.step(state, action)
agent_view = timestep.observation.agent_view.astype(jnp.float32)
agent_view = agent_view.reshape(self._obs_shape)
agent_view = self._flatten(timestep.observation)
timestep = timestep.replace(
observation=timestep.observation._replace(agent_view=agent_view),
)
if "next_obs" in timestep.extras:
agent_view = self._flatten(timestep.extras["next_obs"])
timestep.extras["next_obs"] = timestep.extras["next_obs"]._replace(
agent_view=agent_view
)
return state, timestep

def observation_spec(self) -> Spec:
Expand Down

0 comments on commit 6125d36

Please sign in to comment.