Skip to content

Commit a8d53f7

Browse files
committed
Python3 + Hyper-parameters update
1 parent c717e75 commit a8d53f7

File tree

7 files changed

+405
-413
lines changed

7 files changed

+405
-413
lines changed

DDPG.py

Lines changed: 43 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
import copy
12
import numpy as np
23
import torch
34
import torch.nn as nn
4-
from torch.autograd import Variable
55
import torch.nn.functional as F
6-
import utils
6+
77

88
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
99

@@ -23,11 +23,10 @@ def __init__(self, state_dim, action_dim, max_action):
2323
self.max_action = max_action
2424

2525

26-
def forward(self, x):
27-
x = F.relu(self.l1(x))
28-
x = F.relu(self.l2(x))
29-
x = self.max_action * torch.tanh(self.l3(x))
30-
return x
26+
def forward(self, state):
27+
a = F.relu(self.l1(state))
28+
a = F.relu(self.l2(a))
29+
return self.max_action * torch.tanh(self.l3(a))
3130

3231

3332
class Critic(nn.Module):
@@ -39,79 +38,61 @@ def __init__(self, state_dim, action_dim):
3938
self.l3 = nn.Linear(300, 1)
4039

4140

42-
def forward(self, x, u):
43-
x = F.relu(self.l1(x))
44-
x = F.relu(self.l2(torch.cat([x, u], 1)))
45-
x = self.l3(x)
46-
return x
41+
def forward(self, state, action):
42+
q = F.relu(self.l1(state))
43+
q = F.relu(self.l2(torch.cat([q, action], 1)))
44+
return self.l3(q)
4745

4846

4947
class DDPG(object):
50-
def __init__(self, state_dim, action_dim, max_action):
48+
def __init__(self, state_dim, action_dim, max_action, discount=0.99, tau=0.001):
5149
self.actor = Actor(state_dim, action_dim, max_action).to(device)
52-
self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
53-
self.actor_target.load_state_dict(self.actor.state_dict())
50+
self.actor_target = copy.deepcopy(self.actor)
5451
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=1e-4)
5552

5653
self.critic = Critic(state_dim, action_dim).to(device)
57-
self.critic_target = Critic(state_dim, action_dim).to(device)
58-
self.critic_target.load_state_dict(self.critic.state_dict())
59-
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), weight_decay=1e-2)
54+
self.critic_target = copy.deepcopy(self.critic)
55+
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), weight_decay=1e-2)
56+
57+
self.discount = discount
58+
self.tau = tau
6059

6160

6261
def select_action(self, state):
6362
state = torch.FloatTensor(state.reshape(1, -1)).to(device)
6463
return self.actor(state).cpu().data.numpy().flatten()
6564

6665

67-
def train(self, replay_buffer, iterations, batch_size=64, discount=0.99, tau=0.001):
68-
69-
for it in range(iterations):
70-
71-
# Sample replay buffer
72-
x, y, u, r, d = replay_buffer.sample(batch_size)
73-
state = torch.FloatTensor(x).to(device)
74-
action = torch.FloatTensor(u).to(device)
75-
next_state = torch.FloatTensor(y).to(device)
76-
done = torch.FloatTensor(1 - d).to(device)
77-
reward = torch.FloatTensor(r).to(device)
66+
def train(self, replay_buffer, batch_size=64):
67+
# Sample replay buffer
68+
state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
7869

79-
# Compute the target Q value
80-
target_Q = self.critic_target(next_state, self.actor_target(next_state))
81-
target_Q = reward + (done * discount * target_Q).detach()
70+
# Compute the target Q value
71+
target_Q = self.critic_target(next_state, self.actor_target(next_state))
72+
target_Q = reward + (not_done * self.discount * target_Q).detach()
8273

83-
# Get current Q estimate
84-
current_Q = self.critic(state, action)
74+
# Get current Q estimate
75+
current_Q = self.critic(state, action)
8576

86-
# Compute critic loss
87-
critic_loss = F.mse_loss(current_Q, target_Q)
77+
# Compute critic loss
78+
critic_loss = F.mse_loss(current_Q, target_Q)
8879

89-
# Optimize the critic
90-
self.critic_optimizer.zero_grad()
91-
critic_loss.backward()
92-
self.critic_optimizer.step()
80+
# Optimize the critic
81+
self.critic_optimizer.zero_grad()
82+
critic_loss.backward()
83+
self.critic_optimizer.step()
9384

94-
# Compute actor loss
95-
actor_loss = -self.critic(state, self.actor(state)).mean()
96-
97-
# Optimize the actor
98-
self.actor_optimizer.zero_grad()
99-
actor_loss.backward()
100-
self.actor_optimizer.step()
101-
102-
# Update the frozen target models
103-
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
104-
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
105-
106-
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
107-
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
108-
109-
110-
def save(self, filename, directory):
111-
torch.save(self.actor.state_dict(), '%s/%s_actor.pth' % (directory, filename))
112-
torch.save(self.critic.state_dict(), '%s/%s_critic.pth' % (directory, filename))
85+
# Compute actor loss
86+
actor_loss = -self.critic(state, self.actor(state)).mean()
87+
88+
# Optimize the actor
89+
self.actor_optimizer.zero_grad()
90+
actor_loss.backward()
91+
self.actor_optimizer.step()
11392

93+
# Update the frozen target models
94+
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
95+
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
11496

115-
def load(self, filename, directory):
116-
self.actor.load_state_dict(torch.load('%s/%s_actor.pth' % (directory, filename)))
117-
self.critic.load_state_dict(torch.load('%s/%s_critic.pth' % (directory, filename)))
97+
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
98+
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

OurDDPG.py

Lines changed: 42 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import copy
12
import numpy as np
23
import torch
34
import torch.nn as nn
4-
from torch.autograd import Variable
55
import torch.nn.functional as F
66
import utils
77

8+
89
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
910

1011
# Re-tuned version of Deep Deterministic Policy Gradients (DDPG)
@@ -22,11 +23,10 @@ def __init__(self, state_dim, action_dim, max_action):
2223
self.max_action = max_action
2324

2425

25-
def forward(self, x):
26-
x = F.relu(self.l1(x))
27-
x = F.relu(self.l2(x))
28-
x = self.max_action * torch.tanh(self.l3(x))
29-
return x
26+
def forward(self, state):
27+
a = F.relu(self.l1(state))
28+
a = F.relu(self.l2(a))
29+
return self.max_action * torch.tanh(self.l3(a))
3030

3131

3232
class Critic(nn.Module):
@@ -38,79 +38,61 @@ def __init__(self, state_dim, action_dim):
3838
self.l3 = nn.Linear(300, 1)
3939

4040

41-
def forward(self, x, u):
42-
x = F.relu(self.l1(torch.cat([x, u], 1)))
43-
x = F.relu(self.l2(x))
44-
x = self.l3(x)
45-
return x
41+
def forward(self, state, action):
42+
q = F.relu(self.l1(torch.cat([state, action], 1)))
43+
q = F.relu(self.l2(q))
44+
return self.l3(q)
4645

4746

4847
class DDPG(object):
49-
def __init__(self, state_dim, action_dim, max_action):
48+
def __init__(self, state_dim, action_dim, max_action, discount=0.99, tau=0.005):
5049
self.actor = Actor(state_dim, action_dim, max_action).to(device)
51-
self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
52-
self.actor_target.load_state_dict(self.actor.state_dict())
50+
self.actor_target = copy.deepcopy(self.actor)
5351
self.actor_optimizer = torch.optim.Adam(self.actor.parameters())
5452

5553
self.critic = Critic(state_dim, action_dim).to(device)
56-
self.critic_target = Critic(state_dim, action_dim).to(device)
57-
self.critic_target.load_state_dict(self.critic.state_dict())
54+
self.critic_target = copy.deepcopy(self.critic)
5855
self.critic_optimizer = torch.optim.Adam(self.critic.parameters())
5956

57+
self.discount = discount
58+
self.tau = tau
59+
6060

6161
def select_action(self, state):
6262
state = torch.FloatTensor(state.reshape(1, -1)).to(device)
6363
return self.actor(state).cpu().data.numpy().flatten()
6464

6565

66-
def train(self, replay_buffer, iterations, batch_size=100, discount=0.99, tau=0.005):
67-
68-
for it in range(iterations):
69-
70-
# Sample replay buffer
71-
x, y, u, r, d = replay_buffer.sample(batch_size)
72-
state = torch.FloatTensor(x).to(device)
73-
action = torch.FloatTensor(u).to(device)
74-
next_state = torch.FloatTensor(y).to(device)
75-
done = torch.FloatTensor(1 - d).to(device)
76-
reward = torch.FloatTensor(r).to(device)
66+
def train(self, replay_buffer, batch_size=100):
67+
# Sample replay buffer
68+
state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
7769

78-
# Compute the target Q value
79-
target_Q = self.critic_target(next_state, self.actor_target(next_state))
80-
target_Q = reward + (done * discount * target_Q).detach()
70+
# Compute the target Q value
71+
target_Q = self.critic_target(next_state, self.actor_target(next_state))
72+
target_Q = reward + (not_done * self.discount * target_Q).detach()
8173

82-
# Get current Q estimate
83-
current_Q = self.critic(state, action)
74+
# Get current Q estimate
75+
current_Q = self.critic(state, action)
8476

85-
# Compute critic loss
86-
critic_loss = F.mse_loss(current_Q, target_Q)
77+
# Compute critic loss
78+
critic_loss = F.mse_loss(current_Q, target_Q)
8779

88-
# Optimize the critic
89-
self.critic_optimizer.zero_grad()
90-
critic_loss.backward()
91-
self.critic_optimizer.step()
80+
# Optimize the critic
81+
self.critic_optimizer.zero_grad()
82+
critic_loss.backward()
83+
self.critic_optimizer.step()
9284

93-
# Compute actor loss
94-
actor_loss = -self.critic(state, self.actor(state)).mean()
95-
96-
# Optimize the actor
97-
self.actor_optimizer.zero_grad()
98-
actor_loss.backward()
99-
self.actor_optimizer.step()
100-
101-
# Update the frozen target models
102-
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
103-
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
104-
105-
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
106-
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
107-
108-
109-
def save(self, filename, directory):
110-
torch.save(self.actor.state_dict(), '%s/%s_actor.pth' % (directory, filename))
111-
torch.save(self.critic.state_dict(), '%s/%s_critic.pth' % (directory, filename))
85+
# Compute actor loss
86+
actor_loss = -self.critic(state, self.actor(state)).mean()
87+
88+
# Optimize the actor
89+
self.actor_optimizer.zero_grad()
90+
actor_loss.backward()
91+
self.actor_optimizer.step()
11292

93+
# Update the frozen target models
94+
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
95+
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
11396

114-
def load(self, filename, directory):
115-
self.actor.load_state_dict(torch.load('%s/%s_actor.pth' % (directory, filename)))
116-
self.critic.load_state_dict(torch.load('%s/%s_critic.pth' % (directory, filename)))
97+
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
98+
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

README.md

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,37 @@
33
PyTorch implementation of Twin Delayed Deep Deterministic Policy Gradients (TD3). If you use our code or data please cite the [paper](https://arxiv.org/abs/1802.09477).
44

55
Method is tested on [MuJoCo](http://www.mujoco.org/) continuous control tasks in [OpenAI gym](https://github.com/openai/gym).
6-
Networks are trained using [PyTorch 0.4](https://github.com/pytorch/pytorch) and Python 2.7.
6+
Networks are trained using [PyTorch 1.2](https://github.com/pytorch/pytorch) and Python 3.7.
77

88
### Usage
9-
The paper results can be reproduced exactly by running:
9+
The paper results can be reproduced by running:
1010
```
1111
./experiments.sh
1212
```
1313
Experiments on single environments can be run by calling:
1414
```
15-
python2 main.py --env HalfCheetah-v1
15+
python main.py --env HalfCheetah-v2
1616
```
1717

18-
Hyper-parameters can be modified with different arguments to main.py. We include an implementation of DDPG (DDPG.py) for easy comparison of hyper-parameters with TD3, this is not the implementation of "Our DDPG" as used in the paper (see OurDDPG.py).
18+
Hyper-parameters can be modified with different arguments to main.py. We include an implementation of DDPG (DDPG.py), which is not used in the paper, for easy comparison of hyper-parameters with TD3. This is not the implementation of "Our DDPG" as used in the paper (see OurDDPG.py).
1919

2020
Algorithms which TD3 compares against (PPO, TRPO, ACKTR, DDPG) can be found at [OpenAI baselines repository](https://github.com/openai/baselines).
2121

2222
### Results
23+
Code is no longer exactly representative of the code used in the paper. Minor adjustments to hyperparamters, etc, to improve performance. Learning curves are still the original results found in the paper.
24+
2325
Learning curves found in the paper are found under /learning_curves. Each learning curve are formatted as NumPy arrays of 201 evaluations (201,), where each evaluation corresponds to the average total reward from running the policy for 10 episodes with no exploration. The first evaluation is the randomly initialized policy network (unused in the paper). Evaluations are peformed every 5000 time steps, over a total of 1 million time steps.
2426

2527
Numerical results can be found in the paper, or from the learning curves. Video of the learned agent can be found [here](https://youtu.be/x33Vw-6vzso).
28+
29+
### Bibtex
30+
31+
'''
32+
@inproceedings{fujimoto2018addressing,
33+
title={Addressing Function Approximation Error in Actor-Critic Methods},
34+
author={Fujimoto, Scott and Hoof, Herke and Meger, David},
35+
booktitle={International Conference on Machine Learning},
36+
pages={1582--1591},
37+
year={2018}
38+
}
39+
'''

0 commit comments

Comments
 (0)