Skip to content

Commit bc98e62

Browse files
committed
Camera-ready cleanup
1 parent 00ae5cd commit bc98e62

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+145
-10
lines changed

DDPG.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ def train(self, replay_buffer, iterations, batch_size=64, discount=0.99, tau=0.0
9494

9595
# Q target = reward + discount * Q(next_state, pi(next_state))
9696
target_Q = self.critic_target(next_state, self.actor_target(next_state))
97-
target_Q.volatile = False
9897
target_Q = reward + (done * discount * target_Q)
98+
target_Q.volatile = False
9999

100100
# Get current Q estimate
101101
current_Q = self.critic(state, action)
@@ -120,7 +120,7 @@ def train(self, replay_buffer, iterations, batch_size=64, discount=0.99, tau=0.0
120120
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
121121
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
122122

123-
for param, target_param, in zip(self.actor.parameters(), self.actor_target.parameters()):
123+
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
124124
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
125125

126126

OurDDPG.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn as nn
4+
from torch.autograd import Variable
5+
import torch.nn.functional as F
6+
7+
import utils
8+
9+
10+
# Re-tuned version of Deep Deterministic Policy Gradients (DDPG)
11+
# Paper: https://arxiv.org/abs/1509.02971
12+
13+
14+
def var(tensor, volatile=False):
15+
if torch.cuda.is_available():
16+
return Variable(tensor, volatile=volatile).cuda()
17+
else:
18+
return Variable(tensor, volatile=volatile)
19+
20+
21+
class Actor(nn.Module):
22+
def __init__(self, state_dim, action_dim, max_action):
23+
super(Actor, self).__init__()
24+
25+
self.l1 = nn.Linear(state_dim, 400)
26+
self.l2 = nn.Linear(400, 300)
27+
self.l3 = nn.Linear(300, action_dim)
28+
29+
self.max_action = max_action
30+
31+
32+
def forward(self, x):
33+
x = F.relu(self.l1(x))
34+
x = F.relu(self.l2(x))
35+
x = self.max_action * F.tanh(self.l3(x))
36+
return x
37+
38+
39+
class Critic(nn.Module):
40+
def __init__(self, state_dim, action_dim):
41+
super(Critic, self).__init__()
42+
43+
self.l1 = nn.Linear(state_dim + action_dim, 400)
44+
self.l2 = nn.Linear(400, 300)
45+
self.l3 = nn.Linear(300, 1)
46+
47+
48+
def forward(self, x, u):
49+
x = F.relu(self.l1(torch.cat([x, u], 1)))
50+
x = F.relu(self.l2(x))
51+
x = self.l3(x)
52+
return x
53+
54+
55+
class DDPG(object):
56+
def __init__(self, state_dim, action_dim, max_action):
57+
self.actor = Actor(state_dim, action_dim, max_action)
58+
self.actor_target = Actor(state_dim, action_dim, max_action)
59+
self.actor_target.load_state_dict(self.actor.state_dict())
60+
self.actor_optimizer = torch.optim.Adam(self.actor.parameters())
61+
62+
self.critic = Critic(state_dim, action_dim)
63+
self.critic_target = Critic(state_dim, action_dim)
64+
self.critic_target.load_state_dict(self.critic.state_dict())
65+
self.critic_optimizer = torch.optim.Adam(self.critic.parameters())
66+
67+
if torch.cuda.is_available():
68+
self.actor = self.actor.cuda()
69+
self.actor_target = self.actor_target.cuda()
70+
self.critic = self.critic.cuda()
71+
self.critic_target = self.critic_target.cuda()
72+
73+
self.criterion = nn.MSELoss()
74+
self.state_dim = state_dim
75+
76+
77+
def select_action(self, state):
78+
state = var(torch.FloatTensor(state.reshape(-1, self.state_dim)), volatile=True)
79+
return self.actor(state).cpu().data.numpy().flatten()
80+
81+
82+
def train(self, replay_buffer, iterations, batch_size=100, discount=0.99, tau=0.005):
83+
84+
for it in range(iterations):
85+
86+
# Sample replay buffer
87+
x, y, u, r, d = replay_buffer.sample(batch_size)
88+
state = var(torch.FloatTensor(x))
89+
action = var(torch.FloatTensor(u))
90+
next_state = var(torch.FloatTensor(y), volatile=True)
91+
done = var(torch.FloatTensor(1 - d))
92+
reward = var(torch.FloatTensor(r))
93+
94+
# Q target = reward + discount * Q(next_state, pi(next_state))
95+
target_Q = self.critic_target(next_state, self.actor_target(next_state))
96+
target_Q = reward + (done * discount * target_Q)
97+
target_Q.volatile = False
98+
99+
# Get current Q estimate
100+
current_Q = self.critic(state, action)
101+
102+
# Compute critic loss
103+
critic_loss = self.criterion(current_Q, target_Q)
104+
105+
# Optimize the critic
106+
self.critic_optimizer.zero_grad()
107+
critic_loss.backward()
108+
self.critic_optimizer.step()
109+
110+
# Compute actor loss
111+
actor_loss = -self.critic(state, self.actor(state)).mean()
112+
113+
# Optimize the actor
114+
self.actor_optimizer.zero_grad()
115+
actor_loss.backward()
116+
self.actor_optimizer.step()
117+
118+
# Update the frozen target models
119+
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
120+
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
121+
122+
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
123+
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
124+
125+
126+
def save(self, filename, directory):
127+
torch.save(self.actor.state_dict(), '%s/%s_actor.pth' % (directory, filename))
128+
torch.save(self.critic.state_dict(), '%s/%s_critic.pth' % (directory, filename))
129+
130+
131+
def load(self, filename, directory):
132+
self.actor.load_state_dict(torch.load('%s/%s_actor.pth' % (directory, filename)))
133+
self.critic.load_state_dict(torch.load('%s/%s_critic.pth' % (directory, filename)))

TD3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,11 @@ def train(self, replay_buffer, iterations, batch_size=100, discount=0.99, tau=0.
108108
next_action = self.actor_target(next_state) + var(torch.FloatTensor(noise))
109109
next_action = next_action.clamp(-self.max_action, self.max_action)
110110

111-
# Q target = reward + discount * min(Qi(next_state, pi(next_state)))
111+
# Q target = reward + discount * min_i(Qi(next_state, pi(next_state)))
112112
target_Q1, target_Q2 = self.critic_target(next_state, next_action)
113-
target_Q = torch.min(torch.cat([target_Q1, target_Q2], 1), 1)[0].view(-1, 1)
114-
target_Q.volatile = False
113+
target_Q = torch.min(target_Q1, target_Q2)
115114
target_Q = reward + (done * discount * target_Q)
115+
target_Q.volatile = False
116116

117117
# Get current Q estimates
118118
current_Q1, current_Q2 = self.critic(state, action)

learning_curves/Ant/TD3_Ant-v1_0.npy

0 Bytes
Binary file not shown.

learning_curves/Ant/TD3_Ant-v1_1.npy

0 Bytes
Binary file not shown.

learning_curves/Ant/TD3_Ant-v1_2.npy

0 Bytes
Binary file not shown.

learning_curves/Ant/TD3_Ant-v1_3.npy

0 Bytes
Binary file not shown.

learning_curves/Ant/TD3_Ant-v1_4.npy

0 Bytes
Binary file not shown.

learning_curves/Ant/TD3_Ant-v1_5.npy

0 Bytes
Binary file not shown.

learning_curves/Ant/TD3_Ant-v1_6.npy

0 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)