Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Why resample SDE noise matrices in PPO optimzation? #1929

Closed
4 tasks done
brn-dev opened this issue May 19, 2024 · 4 comments · Fixed by #1933
Closed
4 tasks done

[Question] Why resample SDE noise matrices in PPO optimzation? #1929

brn-dev opened this issue May 19, 2024 · 4 comments · Fixed by #1933
Assignees
Labels
question Further information is requested

Comments

@brn-dev
Copy link
Contributor

brn-dev commented May 19, 2024

❓ Question

I'm currently implementing a personal RL library and use SB3 as inspiration. I have recently implemented SDE and I'm confused about line 214/215 in your implementation of PPO. Here, the SDE noise matrices are resampled before each PPO update but as far as I have seen, this doesn't do anything since the noise matrices are not used during the PPO updates, only during exploration. Let me elaborate:

In ppo train, we reset the noise matrices and then call evaluate_actions afterwards:

if self.use_sde:
self.policy.reset_noise(self.batch_size)
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)

reset_noise, which just calls sample_weights, does not cause any state changes except for the exploration matrices

def sample_weights(self, log_std: th.Tensor, batch_size: int = 1) -> None:
"""
Sample weights for the noise exploration matrix,
using a centered Gaussian distribution.
:param log_std:
:param batch_size:
"""
std = self.get_std(log_std)
self.weights_dist = Normal(th.zeros_like(std), std)
# Reparametrization trick to pass gradients
self.exploration_mat = self.weights_dist.rsample()
# Pre-compute matrices in case of parallel exploration
self.exploration_matrices = self.weights_dist.rsample((batch_size,))

The exploration matrices are only used in get_noise which is in turn only used in sample

def get_noise(self, latent_sde: th.Tensor) -> th.Tensor:
latent_sde = latent_sde if self.learn_features else latent_sde.detach()
# Default case: only one exploration matrix
if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices):
return th.mm(latent_sde, self.exploration_mat)
# Use batch matrix multiplication for efficient computation
# (batch_size, n_features) -> (batch_size, 1, n_features)
latent_sde = latent_sde.unsqueeze(dim=1)
# (batch_size, 1, n_actions)
noise = th.bmm(latent_sde, self.exploration_matrices)
return noise.squeeze(dim=1)

def sample(self) -> th.Tensor:
noise = self.get_noise(self._latent_sde)
actions = self.distribution.mean + noise
if self.bijector is not None:
return self.bijector.forward(actions)
return actions

However, evaluate_actions does not call sample()

def evaluate_actions(self, obs: PyTorchObs, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
"""
Evaluate actions according to the current policy,
given the observations.
:param obs: Observation
:param actions: Actions
:return: estimated value, log likelihood of taking those actions
and entropy of the action distribution.
"""
# Preprocess the observation if needed
features = self.extract_features(obs)
if self.share_features_extractor:
latent_pi, latent_vf = self.mlp_extractor(features)
else:
pi_features, vf_features = features
latent_pi = self.mlp_extractor.forward_actor(pi_features)
latent_vf = self.mlp_extractor.forward_critic(vf_features)
distribution = self._get_action_dist_from_latent(latent_pi)
log_prob = distribution.log_prob(actions)
values = self.value_net(latent_vf)
entropy = distribution.entropy()
return values, log_prob, entropy

So in conclusion, the reset_noise call in the ppo train function is useless isn't it?

Checklist

@brn-dev brn-dev added the question Further information is requested label May 19, 2024
@araffin araffin self-assigned this May 21, 2024
@araffin
Copy link
Member

araffin commented May 21, 2024

Hello,
thanks for the question.
You should not need indeed to re-sample the noise, but if you comment it out, it will throw an error if I remember correctly.
Back then (4 years ago), I didn't have too much time to investigate why, and re-sampling was a good enough solution.
So, if you comment it out and it works now, I would be happy to receive a PR =)
if it doesn't work and you find out why, I'm also happy to hear the answer ;)

@brn-dev
Copy link
Contributor Author

brn-dev commented May 22, 2024

After removing the reset_noise call, I tested it on the half-cheetah env and it worked without any errors. PR: #1933

image
image

@araffin
Copy link
Member

araffin commented May 23, 2024

After removing the reset_noise call, I tested it on the half-cheetah env

looking at the std value, it doesn't seem you were using gSDE, but I could try with python3 -m rl_zoo3.train --algo ppo --env HalfCheetahBulletEnv-v0 -n 20000 --seed 2 -param n_epochs:5 (bullet env) and it does work indeed.

@brn-dev
Copy link
Contributor Author

brn-dev commented May 23, 2024

image

The currently published version does give me a std of a little less than 1.0 consitently with sde enabled

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

# Parallel environments
vec_env = make_vec_env("HalfCheetah-v4", n_envs=4)

model = PPO("MlpPolicy", vec_env, use_sde=True, sde_sample_freq=10, verbose=2)
model.learn(total_timesteps=250000)
model.save("ppo_cartpole")

del model # remove to demonstrate saving and loading

model = PPO.load("ppo_cartpole")

obs = vec_env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = vec_env.step(action)
    vec_env.render("human")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants