diff --git a/purejaxrl/wrappers.py b/purejaxrl/wrappers.py index cd28111..bc8a994 100644 --- a/purejaxrl/wrappers.py +++ b/purejaxrl/wrappers.py @@ -320,3 +320,46 @@ def step(self, key, state, action, params=None): env_state=env_state, ) return obs, state, reward / jnp.sqrt(state.var + 1e-8), done, info + +class MaskedObservationWrapper(GymnaxWrapper): + """Mask parts observations of the environment.""" + + def __init__(self, env: environment.Environment,config: dict): + super().__init__(env) + self.config = config # {'obs_idx':[],'mu':0.0,'sigma':0.1} + def observation_space(self, params) -> spaces.Box: + assert isinstance( + self._env.observation_space(params), spaces.Box + ), "Only Box spaces are supported for now." + return spaces.Box( + low=self._env.observation_space(params).low, + high=self._env.observation_space(params).high, + shape=(np.prod((len(self.config['obs_idx'])),),), + dtype=self._env.observation_space(params).dtype, + ) + + @partial(jax.jit, static_argnums=(0,)) + def reset( + self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None + ) -> Tuple[chex.Array, environment.EnvState]: + obs, state = self._env.reset(key, params) + obs = jnp.reshape(obs, (-1,)) + obs = jnp.take(obs,jnp.array(self.config['obs_idx'])) + noise = (self.config['mu'] + self.config['mu'] * jax.random.normal(key, shape=obs.shape)) + return obs+noise, state + + @partial(jax.jit, static_argnums=(0,)) + def step( + self, + key: chex.PRNGKey, + state: environment.EnvState, + action: Union[int, float], + params: Optional[environment.EnvParams] = None, + ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]: + obs, state, reward, done, info = self._env.step( + key, state, action, params + ) + obs = jnp.reshape(obs, (-1,)) + obs = jnp.take(obs,jnp.array(self.config['obs_idx'])) + noise = (self.config['mu'] + self.config['mu'] * jax.random.normal(key, shape=obs.shape)) + return obs+noise, state, reward, done, info \ No newline at end of file