Skip to content

Commit c717e75

Browse files
committed
cleanup + 1e6 max buffer
1 parent 25dfc0a commit c717e75

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

TD3.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(self, state_dim, action_dim, max_action):
7676
self.critic = Critic(state_dim, action_dim).to(device)
7777
self.critic_target = Critic(state_dim, action_dim).to(device)
7878
self.critic_target.load_state_dict(self.critic.state_dict())
79-
self.critic_optimizer = torch.optim.Adam(self.critic.parameters())
79+
self.critic_optimizer = torch.optim.Adam(self.critic.parameters())
8080

8181
self.max_action = max_action
8282

@@ -102,7 +102,6 @@ def train(self, replay_buffer, iterations, batch_size=100, discount=0.99, tau=0.
102102
noise = torch.FloatTensor(u).data.normal_(0, policy_noise).to(device)
103103
noise = noise.clamp(-noise_clip, noise_clip)
104104
next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action)
105-
next_action = next_action.clamp(-self.max_action, self.max_action)
106105

107106
# Compute the target Q value
108107
target_Q1, target_Q2 = self.critic_target(next_state, next_action)

utils.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,21 @@
33
# Code based on:
44
# https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py
55

6-
# Simple replay buffer
6+
# Expects tuples of (state, next_state, action, reward, done)
77
class ReplayBuffer(object):
8-
def __init__(self):
8+
def __init__(self, max_size=1e6):
99
self.storage = []
10+
self.max_size = max_size
11+
self.ptr = 0
1012

11-
# Expects tuples of (state, next_state, action, reward, done)
1213
def add(self, data):
13-
self.storage.append(data)
14+
if len(self.storage) == self.max_size:
15+
self.storage[int(self.ptr)] = data
16+
self.ptr = (self.ptr + 1) % self.max_size
17+
else:
18+
self.storage.append(data)
1419

15-
def sample(self, batch_size=100):
20+
def sample(self, batch_size):
1621
ind = np.random.randint(0, len(self.storage), size=batch_size)
1722
x, y, u, r, d = [], [], [], [], []
1823

0 commit comments

Comments
 (0)