-
Notifications
You must be signed in to change notification settings - Fork 319
Description
Hi. I'm currently developing some other algorithm using TRPO. I found that the training keeps failing after certain epochs because of CUDA out of memory error. At first I thought this was my fault, but I figured out that example code that is given in trpo_pendulum.py also leaks memory too. I guess this happens on policy optimization step when calculating Hessian, but I'm not sure. This is example code that I modified from trpo_pendulum.py I used for gpu utilization.
#!/usr/bin/env python3
"""This is an example to train a task with TRPO algorithm (PyTorch).
Here it runs InvertedDoublePendulum-v2 environment with 100 iterations.
"""
import torch
from garage import wrap_experiment
from garage.envs import GymEnv
from garage.experiment.deterministic import set_seed
from garage.sampler import LocalSampler
from garage.torch.algos import TRPO
from garage.torch.policies import GaussianMLPPolicy
from garage.torch.value_functions import GaussianMLPValueFunction
from garage.trainer import Trainer
from garage.torch import set_gpu_mode
set_gpu_mode(False)
torch.set_num_threads(1)
if torch.cuda.is_available():
set_gpu_mode(True)
torch.multiprocessing.set_start_method('spawn')
@wrap_experiment
def trpo_pendulum(ctxt=None, seed=1):
"""Train TRPO with InvertedDoublePendulum-v2 environment.
Args:
ctxt (garage.experiment.ExperimentContext): The experiment
configuration used by Trainer to create the snapshotter.
seed (int): Used to seed the random number generator to produce
determinism.
"""
set_seed(seed)
env = GymEnv('InvertedDoublePendulum-v2')
trainer = Trainer(ctxt)
policy = GaussianMLPPolicy(env.spec,
hidden_sizes=[32, 32],
hidden_nonlinearity=torch.tanh,
output_nonlinearity=None)
value_function = GaussianMLPValueFunction(env_spec=env.spec,
hidden_sizes=(32, 32),
hidden_nonlinearity=torch.tanh,
output_nonlinearity=None)
sampler = LocalSampler(agents=policy,
envs=env,
max_episode_length=env.spec.max_episode_length)
algo = TRPO(env_spec=env.spec,
policy=policy,
value_function=value_function,
sampler=sampler,
discount=0.99,
center_adv=False)
if torch.cuda.is_available():
algo.to()
trainer.setup(algo, env)
trainer.train(n_epochs=100, batch_size=1024, plot=True)
trpo_pendulum(seed=1)
And this is added function to() in torch/algos/trpo.py;
def to(self, device=None):
"""Put all the networks within the model on device.
Args:
device (str): ID of GPU or CPU.
"""
from garage.torch import global_device
if device is None:
device = global_device()
logger.log('Using device: ' + str(device))
self.policy = self.policy.to(device)
self._old_policy = self._old_policy.to(device)
self._value_function = self._value_function.to(device)
And I modified the tensors of _train_once of torch/algos/vpg.py to use gpu like following ;
def _train_once(self, itr, eps):
"""Train the algorithm once.
Args:
itr (int): Iteration number.
eps (EpisodeBatch): A batch of collected paths.
Returns:
numpy.float64: Calculated mean value of undiscounted returns.
"""
obs = torch.Tensor(eps.padded_observations).to(global_device())
rewards = torch.Tensor(eps.padded_rewards).to(global_device())
returns = torch.Tensor(
np.stack([
discount_cumsum(reward, self.discount)
for reward in eps.padded_rewards
])).to(global_device())
valids = eps.lengths
with torch.no_grad():
baselines = self._value_function(obs)
if self._maximum_entropy:
policy_entropies = self._compute_policy_entropy(obs)
rewards += self._policy_ent_coeff * policy_entropies
obs_flat = torch.Tensor(eps.observations).to(global_device())
actions_flat = torch.Tensor(eps.actions).to(global_device())
rewards_flat = torch.Tensor(eps.rewards).to(global_device())
returns_flat = torch.cat(filter_valids(returns, valids))
advs_flat = self._compute_advantage(rewards, valids, baselines)
with torch.no_grad():
policy_loss_before = self._compute_loss_with_adv(
obs_flat, actions_flat, rewards_flat, advs_flat)
vf_loss_before = self._value_function.compute_loss(
obs_flat, returns_flat)
kl_before = self._compute_kl_constraint(obs)
self._train(obs_flat, actions_flat, rewards_flat, returns_flat,
advs_flat)
with torch.no_grad():
policy_loss_after = self._compute_loss_with_adv(
obs_flat, actions_flat, rewards_flat, advs_flat)
vf_loss_after = self._value_function.compute_loss(
obs_flat, returns_flat)
kl_after = self._compute_kl_constraint(obs)
policy_entropy = self._compute_policy_entropy(obs)
with tabular.prefix(self.policy.name):
tabular.record('/LossBefore', policy_loss_before.item())
tabular.record('/LossAfter', policy_loss_after.item())
tabular.record('/dLoss',
(policy_loss_before - policy_loss_after).item())
tabular.record('/KLBefore', kl_before.item())
tabular.record('/KL', kl_after.item())
tabular.record('/Entropy', policy_entropy.mean().item())
with tabular.prefix(self._value_function.name):
tabular.record('/LossBefore', vf_loss_before.item())
tabular.record('/LossAfter', vf_loss_after.item())
tabular.record('/dLoss',
vf_loss_before.item() - vf_loss_after.item())
self._old_policy.load_state_dict(self.policy.state_dict())
undiscounted_returns = log_performance(itr,
eps,
discount=self._discount)
return np.mean(undiscounted_returns)
Then I checked the memory of each epoch in function train of torch/algos/vpg.py like this;
def train(self, trainer):
"""Obtain samplers and start actual training for each epoch.
Args:
trainer (Trainer): Gives the algorithm the access to
:method:`~Trainer.step_epochs()`, which provides services
such as snapshotting and sampler control.
Returns:
float: The average return in last epoch cycle.
"""
last_return = None
for _ in trainer.step_epochs():
for _ in range(self._n_samples):
eps = trainer.obtain_episodes(trainer.step_itr)
last_return = self._train_once(trainer.step_itr, eps)
trainer.step_itr += 1
print(torch.cuda.memory_allocated())
And the result of the printing kept gradually increasing like 66560 -> 74752 -> 82944 -> 99328 -> 107520 -> 115712 ... for each epoch.
If I'm doing wrong, please give me a help ! Thanks :)