Skip to content

Commit 9147627

Browse files
authored
TD3 implementation in pytorch (#1890)
* TD3 Torch (examples, benchmark, test) * Change to Trainer * Update examples
1 parent e32cd06 commit 9147627

File tree

25 files changed

+946
-69
lines changed

25 files changed

+946
-69
lines changed

benchmarks/src/garage_benchmarks/benchmark_algos.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""Benchmarking for algorithms."""
22
# yapf: disable
3-
from garage_benchmarks.experiments.algos import (ddpg_garage_tf,
4-
her_garage_tf,
3+
from garage_benchmarks.experiments.algos import (ddpg_garage_tf, her_garage_tf,
54
ppo_garage_pytorch,
65
ppo_garage_tf,
6+
td3_garage_pytorch,
77
td3_garage_tf,
88
trpo_garage_pytorch,
99
trpo_garage_tf,
@@ -40,7 +40,7 @@ def td3_benchmarks():
4040
td3_env_ids = [
4141
env_id for env_id in MuJoCo1M_ENV_SET if env_id != 'Reacher-v2'
4242
]
43-
43+
iterate_experiments(td3_garage_pytorch, td3_env_ids)
4444
iterate_experiments(td3_garage_tf, td3_env_ids)
4545

4646

benchmarks/src/garage_benchmarks/benchmark_auto.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
# yapf: disable
33
from garage_benchmarks.experiments.algos import (ddpg_garage_tf,
44
ppo_garage_pytorch,
5-
ppo_garage_tf,
6-
td3_garage_tf,
5+
ppo_garage_tf, td3_garage_tf,
76
trpo_garage_pytorch,
87
trpo_garage_tf,
98
vpg_garage_pytorch,

benchmarks/src/garage_benchmarks/experiments/algos/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from garage_benchmarks.experiments.algos.ppo_garage_pytorch import (
55
ppo_garage_pytorch)
66
from garage_benchmarks.experiments.algos.ppo_garage_tf import ppo_garage_tf
7+
from garage_benchmarks.experiments.algos.td3_garage_pytorch import (
8+
td3_garage_pytorch)
79
from garage_benchmarks.experiments.algos.td3_garage_tf import td3_garage_tf
810
from garage_benchmarks.experiments.algos.trpo_garage_pytorch import (
911
trpo_garage_pytorch)
@@ -14,6 +16,6 @@
1416

1517
__all__ = [
1618
'ddpg_garage_tf', 'her_garage_tf', 'ppo_garage_pytorch', 'ppo_garage_tf',
17-
'td3_garage_tf', 'trpo_garage_pytorch', 'trpo_garage_tf',
18-
'vpg_garage_pytorch', 'vpg_garage_tf'
19+
'td3_garage_pytorch', 'td3_garage_tf', 'trpo_garage_pytorch',
20+
'trpo_garage_tf', 'vpg_garage_pytorch', 'vpg_garage_tf'
1921
]
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
"""A regression test for automatic benchmarking garage-Pytorch-TD3."""
2+
import torch
3+
from torch.nn import functional as F
4+
5+
from garage import wrap_experiment
6+
from garage.envs import GymEnv, normalize
7+
from garage.experiment import deterministic
8+
from garage.np.exploration_policies import AddGaussianNoise
9+
from garage.np.policies import UniformRandomPolicy
10+
from garage.replay_buffer import PathBuffer
11+
from garage.torch import prefer_gpu
12+
from garage.torch.algos import TD3
13+
from garage.torch.policies import DeterministicMLPPolicy
14+
from garage.torch.q_functions import ContinuousMLPQFunction
15+
from garage.trainer import TFTrainer
16+
17+
hyper_parameters = {
18+
'policy_lr': 1e-3,
19+
'qf_lr': 1e-3,
20+
'policy_hidden_sizes': [256, 256],
21+
'qf_hidden_sizes': [256, 256],
22+
'n_epochs': 250,
23+
'steps_per_epoch': 40,
24+
'batch_size': 100,
25+
'start_steps': 1000,
26+
'update_after': 1000,
27+
'grad_steps_per_env_step': 50,
28+
'discount': 0.99,
29+
'target_update_tau': 0.005,
30+
'replay_buffer_size': int(1e6),
31+
'sigma': 0.1,
32+
'policy_noise': 0.2,
33+
'policy_noise_clip': 0.5,
34+
'buffer_batch_size': 100,
35+
'min_buffer_size': int(1e4),
36+
}
37+
38+
39+
@wrap_experiment(snapshot_mode='last')
40+
def td3_garage_pytorch(ctxt, env_id, seed):
41+
"""Create garage TensorFlow TD3 model and training.
42+
43+
Args:
44+
ctxt (garage.experiment.ExperimentContext): The experiment
45+
configuration used by Localtrainer to create the
46+
snapshotter.
47+
env_id (str): Environment id of the task.
48+
seed (int): Random positive integer for the trial.
49+
50+
"""
51+
deterministic.set_seed(seed)
52+
53+
with TFTrainer(ctxt) as trainer:
54+
num_timesteps = hyper_parameters['n_epochs'] * hyper_parameters[
55+
'steps_per_epoch'] * hyper_parameters['batch_size']
56+
env = normalize(GymEnv(env_id))
57+
58+
policy = DeterministicMLPPolicy(
59+
env_spec=env.spec,
60+
hidden_sizes=hyper_parameters['policy_hidden_sizes'],
61+
hidden_nonlinearity=F.relu,
62+
output_nonlinearity=torch.tanh)
63+
64+
exploration_policy = AddGaussianNoise(
65+
env.spec,
66+
policy,
67+
total_timesteps=num_timesteps,
68+
max_sigma=hyper_parameters['sigma'],
69+
min_sigma=hyper_parameters['sigma'])
70+
71+
uniform_random_policy = UniformRandomPolicy(env.spec)
72+
73+
qf1 = ContinuousMLPQFunction(
74+
env_spec=env.spec,
75+
hidden_sizes=hyper_parameters['qf_hidden_sizes'],
76+
hidden_nonlinearity=F.relu)
77+
78+
qf2 = ContinuousMLPQFunction(
79+
env_spec=env.spec,
80+
hidden_sizes=hyper_parameters['qf_hidden_sizes'],
81+
hidden_nonlinearity=F.relu)
82+
83+
replay_buffer = PathBuffer(
84+
capacity_in_transitions=hyper_parameters['replay_buffer_size'])
85+
86+
td3 = TD3(env_spec=env.spec,
87+
policy=policy,
88+
qf1=qf1,
89+
qf2=qf2,
90+
exploration_policy=exploration_policy,
91+
uniform_random_policy=uniform_random_policy,
92+
replay_buffer=replay_buffer,
93+
steps_per_epoch=hyper_parameters['steps_per_epoch'],
94+
policy_lr=hyper_parameters['policy_lr'],
95+
qf_lr=hyper_parameters['qf_lr'],
96+
target_update_tau=hyper_parameters['target_update_tau'],
97+
discount=hyper_parameters['discount'],
98+
grad_steps_per_env_step=hyper_parameters[
99+
'grad_steps_per_env_step'],
100+
start_steps=hyper_parameters['start_steps'],
101+
min_buffer_size=hyper_parameters['min_buffer_size'],
102+
buffer_batch_size=hyper_parameters['buffer_batch_size'],
103+
policy_optimizer=torch.optim.Adam,
104+
qf_optimizer=torch.optim.Adam,
105+
policy_noise_clip=hyper_parameters['policy_noise_clip'],
106+
policy_noise=hyper_parameters['policy_noise'])
107+
108+
prefer_gpu()
109+
td3.to()
110+
trainer.setup(td3, env)
111+
trainer.train(n_epochs=hyper_parameters['n_epochs'],
112+
batch_size=hyper_parameters['batch_size'])

benchmarks/src/garage_benchmarks/experiments/algos/td3_garage_tf.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
hyper_parameters = {
1515
'policy_lr': 1e-3,
1616
'qf_lr': 1e-3,
17-
'policy_hidden_sizes': [400, 300],
18-
'qf_hidden_sizes': [400, 300],
19-
'n_epochs': 8,
20-
'steps_per_epoch': 20,
21-
'n_exploration_steps': 250,
22-
'n_train_steps': 1,
17+
'policy_hidden_sizes': [256, 256],
18+
'qf_hidden_sizes': [256, 256],
19+
'n_epochs': 250,
20+
'steps_per_epoch': 40,
21+
'n_exploration_steps': 100,
22+
'n_train_steps': 50,
2323
'discount': 0.99,
2424
'tau': 0.005,
2525
'replay_buffer_size': int(1e6),

examples/torch/mtsac_metaworld_mt10.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def mtsac_metaworld_mt10(ctxt=None, *, seed, _gpu, n_tasks, timesteps):
4343
"""
4444
deterministic.set_seed(seed)
4545
trainer = Trainer(ctxt)
46-
mt10 = metaworld.MT10()
47-
mt10_test = metaworld.MT10()
46+
mt10 = metaworld.MT10() # pylint: disable=no-member
47+
mt10_test = metaworld.MT10() # pylint: disable=no-member
4848

4949
# pylint: disable=missing-return-doc, missing-return-type-doc
5050
def wrap(env, _):

examples/torch/mtsac_metaworld_mt50.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def mtsac_metaworld_mt50(ctxt=None,
5151
"""
5252
deterministic.set_seed(seed)
5353
trainer = Trainer(ctxt)
54-
mt50 = metaworld.MT50()
55-
mt50_test = metaworld.MT50()
54+
mt50 = metaworld.MT50() # pylint: disable=no-member
55+
mt50_test = metaworld.MT50() # pylint: disable=no-member
5656
train_task_sampler = MetaWorldTaskSampler(
5757
mt50,
5858
'train',

examples/torch/td3_halfcheetah.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#!/usr/bin/env python3
2+
"""An example to train TD3 algorithm on InvertedDoublePendulum PyTorch."""
3+
import torch
4+
from torch.nn import functional as F
5+
6+
# from garage.np.exploration_policies import AddGaussianNoise
7+
from garage import wrap_experiment
8+
from garage.envs import GymEnv, normalize
9+
from garage.experiment.deterministic import set_seed
10+
from garage.np.exploration_policies import AddGaussianNoise
11+
from garage.np.policies import UniformRandomPolicy
12+
from garage.replay_buffer import PathBuffer
13+
from garage.torch.algos import TD3
14+
from garage.torch.policies import DeterministicMLPPolicy
15+
from garage.torch.q_functions import ContinuousMLPQFunction
16+
from garage.trainer import Trainer
17+
18+
19+
@wrap_experiment(snapshot_mode='none')
20+
def td3_half_cheetah(ctxt=None, seed=1):
21+
"""Train TD3 with InvertedDoublePendulum-v2 environment.
22+
23+
Args:
24+
ctxt (garage.experiment.ExperimentContext): The experiment
25+
configuration used by LocalRunner to create the snapshotter.
26+
seed (int): Used to seed the random number generator to produce
27+
determinism.
28+
"""
29+
set_seed(seed)
30+
31+
n_epochs = 500
32+
steps_per_epoch = 20
33+
sampler_batch_size = 250
34+
num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size
35+
36+
trainer = Trainer(ctxt)
37+
env = normalize(GymEnv('HalfCheetah-v2'))
38+
39+
policy = DeterministicMLPPolicy(env_spec=env.spec,
40+
hidden_sizes=[256, 256],
41+
hidden_nonlinearity=F.relu,
42+
output_nonlinearity=torch.tanh)
43+
44+
exploration_policy = AddGaussianNoise(env.spec,
45+
policy,
46+
total_timesteps=num_timesteps,
47+
max_sigma=0.1,
48+
min_sigma=0.1)
49+
50+
uniform_random_policy = UniformRandomPolicy(env.spec)
51+
52+
qf1 = ContinuousMLPQFunction(env_spec=env.spec,
53+
hidden_sizes=[256, 256],
54+
hidden_nonlinearity=F.relu)
55+
56+
qf2 = ContinuousMLPQFunction(env_spec=env.spec,
57+
hidden_sizes=[256, 256],
58+
hidden_nonlinearity=F.relu)
59+
60+
replay_buffer = PathBuffer(capacity_in_transitions=int(1e6))
61+
62+
td3 = TD3(env_spec=env.spec,
63+
policy=policy,
64+
qf1=qf1,
65+
qf2=qf2,
66+
replay_buffer=replay_buffer,
67+
policy_optimizer=torch.optim.Adam,
68+
qf_optimizer=torch.optim.Adam,
69+
exploration_policy=exploration_policy,
70+
uniform_random_policy=uniform_random_policy,
71+
target_update_tau=0.005,
72+
discount=0.99,
73+
policy_noise_clip=0.5,
74+
policy_noise=0.2,
75+
policy_lr=1e-3,
76+
qf_lr=1e-3,
77+
steps_per_epoch=40,
78+
start_steps=1000,
79+
grad_steps_per_env_step=50,
80+
min_buffer_size=1000,
81+
buffer_batch_size=100)
82+
83+
trainer.setup(algo=td3, env=env)
84+
trainer.train(n_epochs=750, batch_size=100)
85+
86+
87+
td3_half_cheetah(seed=0)

examples/torch/td3_pendulum.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#!/usr/bin/env python3
2+
"""An example to train TD3 algorithm on InvertedDoublePendulum PyTorch."""
3+
import torch
4+
from torch.nn import functional as F
5+
6+
from garage import wrap_experiment
7+
from garage.envs import GymEnv, normalize
8+
from garage.experiment.deterministic import set_seed
9+
from garage.np.exploration_policies import AddGaussianNoise
10+
from garage.np.policies import UniformRandomPolicy
11+
from garage.replay_buffer import PathBuffer
12+
from garage.torch import prefer_gpu
13+
from garage.torch.algos import TD3
14+
from garage.torch.policies import DeterministicMLPPolicy
15+
from garage.torch.q_functions import ContinuousMLPQFunction
16+
from garage.trainer import Trainer
17+
18+
19+
@wrap_experiment(snapshot_mode='none')
20+
def td3_pendulum(ctxt=None, seed=1):
21+
"""Train TD3 with InvertedDoublePendulum-v2 environment.
22+
23+
Args:
24+
ctxt (garage.experiment.ExperimentContext): The experiment
25+
configuration used by LocalRunner to create the snapshotter.
26+
seed (int): Used to seed the random number generator to produce
27+
determinism.
28+
29+
"""
30+
set_seed(seed)
31+
n_epochs = 750
32+
steps_per_epoch = 40
33+
sampler_batch_size = 100
34+
num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size
35+
36+
trainer = Trainer(ctxt)
37+
env = normalize(GymEnv('InvertedDoublePendulum-v2'))
38+
39+
policy = DeterministicMLPPolicy(env_spec=env.spec,
40+
hidden_sizes=[256, 256],
41+
hidden_nonlinearity=F.relu,
42+
output_nonlinearity=torch.tanh)
43+
44+
exploration_policy = AddGaussianNoise(env.spec,
45+
policy,
46+
total_timesteps=num_timesteps,
47+
max_sigma=0.1,
48+
min_sigma=0.1)
49+
50+
uniform_random_policy = UniformRandomPolicy(env.spec)
51+
52+
qf1 = ContinuousMLPQFunction(env_spec=env.spec,
53+
hidden_sizes=[256, 256],
54+
hidden_nonlinearity=F.relu)
55+
56+
qf2 = ContinuousMLPQFunction(env_spec=env.spec,
57+
hidden_sizes=[256, 256],
58+
hidden_nonlinearity=F.relu)
59+
60+
replay_buffer = PathBuffer(capacity_in_transitions=int(1e6))
61+
62+
td3 = TD3(env_spec=env.spec,
63+
policy=policy,
64+
qf1=qf1,
65+
qf2=qf2,
66+
replay_buffer=replay_buffer,
67+
policy_optimizer=torch.optim.Adam,
68+
qf_optimizer=torch.optim.Adam,
69+
exploration_policy=exploration_policy,
70+
uniform_random_policy=uniform_random_policy,
71+
target_update_tau=0.005,
72+
discount=0.99,
73+
policy_noise_clip=0.5,
74+
policy_noise=0.2,
75+
policy_lr=1e-3,
76+
qf_lr=1e-3,
77+
steps_per_epoch=steps_per_epoch,
78+
start_steps=1000,
79+
grad_steps_per_env_step=1,
80+
min_buffer_size=int(1e4),
81+
buffer_batch_size=100)
82+
83+
prefer_gpu()
84+
td3.to()
85+
trainer.setup(algo=td3, env=env)
86+
trainer.train(n_epochs=n_epochs, batch_size=sampler_batch_size)
87+
88+
89+
td3_pendulum()

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use_parentheses = True
2626
force_sort_within_sections = True
2727
force_alphabetical_sort_within_sections = True
2828
lexicographical = True
29-
multi_line_output = 1
29+
multi_line_output = 0
3030
sections=FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,TESTS,LOCALFOLDER
3131
known_first_party = garage
3232
known_tests = tests, garage_benchmarks

0 commit comments

Comments
 (0)