Skip to content

Memory leak on trpo algorithm(pytorch) #2342

@MinchoU

Description

@MinchoU

Hi. I'm currently developing some other algorithm using TRPO. I found that the training keeps failing after certain epochs because of CUDA out of memory error. At first I thought this was my fault, but I figured out that example code that is given in trpo_pendulum.py also leaks memory too. I guess this happens on policy optimization step when calculating Hessian, but I'm not sure. This is example code that I modified from trpo_pendulum.py I used for gpu utilization.

#!/usr/bin/env python3
"""This is an example to train a task with TRPO algorithm (PyTorch).

Here it runs InvertedDoublePendulum-v2 environment with 100 iterations.
"""
import torch

from garage import wrap_experiment
from garage.envs import GymEnv
from garage.experiment.deterministic import set_seed
from garage.sampler import LocalSampler
from garage.torch.algos import TRPO
from garage.torch.policies import GaussianMLPPolicy
from garage.torch.value_functions import GaussianMLPValueFunction
from garage.trainer import Trainer

from garage.torch import set_gpu_mode
set_gpu_mode(False)
torch.set_num_threads(1)
if torch.cuda.is_available():
    set_gpu_mode(True)
torch.multiprocessing.set_start_method('spawn')

@wrap_experiment
def trpo_pendulum(ctxt=None, seed=1):
    """Train TRPO with InvertedDoublePendulum-v2 environment.

    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by Trainer to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.

    """
    set_seed(seed)
    env = GymEnv('InvertedDoublePendulum-v2')

    trainer = Trainer(ctxt)

    policy = GaussianMLPPolicy(env.spec,
                               hidden_sizes=[32, 32],
                               hidden_nonlinearity=torch.tanh,
                               output_nonlinearity=None)

    value_function = GaussianMLPValueFunction(env_spec=env.spec,
                                              hidden_sizes=(32, 32),
                                              hidden_nonlinearity=torch.tanh,
                                              output_nonlinearity=None)

    sampler = LocalSampler(agents=policy,
                           envs=env,
                           max_episode_length=env.spec.max_episode_length)

    algo = TRPO(env_spec=env.spec,
                policy=policy,
                value_function=value_function,
                sampler=sampler,
                discount=0.99,
                center_adv=False)
    if torch.cuda.is_available():
        algo.to()
    trainer.setup(algo, env)
    trainer.train(n_epochs=100, batch_size=1024, plot=True)


trpo_pendulum(seed=1)

And this is added function to() in torch/algos/trpo.py;

def to(self, device=None):
        """Put all the networks within the model on device.

        Args:
            device (str): ID of GPU or CPU.
        
        """
        from garage.torch import global_device

        if device is None:
            device = global_device()
        logger.log('Using device: ' + str(device))
        self.policy = self.policy.to(device)
        self._old_policy = self._old_policy.to(device)
        self._value_function = self._value_function.to(device)

And I modified the tensors of _train_once of torch/algos/vpg.py to use gpu like following ;

def _train_once(self, itr, eps):
        """Train the algorithm once.

        Args:
            itr (int): Iteration number.
            eps (EpisodeBatch): A batch of collected paths.

        Returns:
            numpy.float64: Calculated mean value of undiscounted returns.

        """
        obs = torch.Tensor(eps.padded_observations).to(global_device())
        rewards = torch.Tensor(eps.padded_rewards).to(global_device())
        returns = torch.Tensor(
            np.stack([
                discount_cumsum(reward, self.discount)
                for reward in eps.padded_rewards
            ])).to(global_device())
        valids = eps.lengths
        with torch.no_grad():
            baselines = self._value_function(obs)

        if self._maximum_entropy:
            policy_entropies = self._compute_policy_entropy(obs)
            rewards += self._policy_ent_coeff * policy_entropies

        obs_flat = torch.Tensor(eps.observations).to(global_device())
        actions_flat = torch.Tensor(eps.actions).to(global_device())
        rewards_flat = torch.Tensor(eps.rewards).to(global_device())
        returns_flat = torch.cat(filter_valids(returns, valids))
        advs_flat = self._compute_advantage(rewards, valids, baselines)

        with torch.no_grad():
            policy_loss_before = self._compute_loss_with_adv(
                obs_flat, actions_flat, rewards_flat, advs_flat)
            vf_loss_before = self._value_function.compute_loss(
                obs_flat, returns_flat)
            kl_before = self._compute_kl_constraint(obs)

        self._train(obs_flat, actions_flat, rewards_flat, returns_flat,
                    advs_flat)

        with torch.no_grad():
            policy_loss_after = self._compute_loss_with_adv(
                obs_flat, actions_flat, rewards_flat, advs_flat)
            vf_loss_after = self._value_function.compute_loss(
                obs_flat, returns_flat)
            kl_after = self._compute_kl_constraint(obs)
            policy_entropy = self._compute_policy_entropy(obs)

        with tabular.prefix(self.policy.name):
            tabular.record('/LossBefore', policy_loss_before.item())
            tabular.record('/LossAfter', policy_loss_after.item())
            tabular.record('/dLoss',
                           (policy_loss_before - policy_loss_after).item())
            tabular.record('/KLBefore', kl_before.item())
            tabular.record('/KL', kl_after.item())
            tabular.record('/Entropy', policy_entropy.mean().item())

        with tabular.prefix(self._value_function.name):
            tabular.record('/LossBefore', vf_loss_before.item())
            tabular.record('/LossAfter', vf_loss_after.item())
            tabular.record('/dLoss',
                           vf_loss_before.item() - vf_loss_after.item())

        self._old_policy.load_state_dict(self.policy.state_dict())

        undiscounted_returns = log_performance(itr,
                                               eps,
                                               discount=self._discount)
        return np.mean(undiscounted_returns)

Then I checked the memory of each epoch in function train of torch/algos/vpg.py like this;

def train(self, trainer):
        """Obtain samplers and start actual training for each epoch.

        Args:
            trainer (Trainer): Gives the algorithm the access to
                :method:`~Trainer.step_epochs()`, which provides services
                such as snapshotting and sampler control.

        Returns:
            float: The average return in last epoch cycle.

        """
        last_return = None

        for _ in trainer.step_epochs():
            for _ in range(self._n_samples):
                eps = trainer.obtain_episodes(trainer.step_itr)
                last_return = self._train_once(trainer.step_itr, eps)
                trainer.step_itr += 1
                print(torch.cuda.memory_allocated())

And the result of the printing kept gradually increasing like 66560 -> 74752 -> 82944 -> 99328 -> 107520 -> 115712 ... for each epoch.

If I'm doing wrong, please give me a help ! Thanks :)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions