Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
qlan3 committed Jul 2, 2023
1 parent 0df0ea0 commit dfc8e92
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
29 changes: 26 additions & 3 deletions agents/RNNIndentity.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ def __init__(self, cfg):
cfg["optimizer"]["name"], cfg["optimizer"]["kwargs"], None
)
# Load data
self.batches = self.load_data(self.cfg["seq_len"], self.cfg["datapath"])
self.seed = random.PRNGKey(self.cfg["seed"])
if len(self.cfg["datapath"]) > 0:
self.batches = self.load_data(self.cfg["seq_len"], self.cfg["datapath"])
else:
self.batches = self.load_uniform_data(self.cfg["seq_len"])

def create_meta_net(self):
self.cfg["meta_net"]["mlp_dims"] = tuple(self.cfg["meta_net"]["mlp_dims"])
Expand All @@ -68,6 +72,25 @@ def load_data(self, seq_len, datapath):
batches = np.array(batches)
return jax.device_put(batches)

def load_uniform_data(self, seq_len):
self.batch_size, len, max_steps = 395, 3000, 500
self.num_batch = len // seq_len
# Generate random data in [-1,1]
xs = []
for i in range(len // max_steps):
seed, self.seed = random.split(self.seed)
x = random.uniform(seed, (self.batch_size, max_steps), minval=i-1, maxval=i+1)
xs.append(x)
xs = jnp.concatenate(xs, axis=-1)
self.logger.info(f"dataset size: {xs.shape}, batch_size: {self.batch_size}, num_batch: {self.num_batch}")
batches = [] # shape=(num_batch, batch_size, seq_len)
for i in range(self.num_batch):
start = i * seq_len
x = xs[:, start:start+seq_len]
batches.append([x,x])
batches = np.array(batches)
return jax.device_put(batches)

def compute_loss(self, param, hidden_state, batch):
x, y = batch[0], batch[1]
hidden_state, pred_y = lax.scan(
Expand Down Expand Up @@ -102,10 +125,10 @@ def train_step(self, param, hidden_state, optim_state, batch):

def train(self):
# Initialize model parameter
seed = random.PRNGKey(self.cfg["seed"])
model_seed, self.seed = random.split(self.seed)
dummy_input = jnp.array([0.0])
dummy_hidden_state = self.model.init_hidden_state(dummy_input)
param = self.model.init(seed, dummy_hidden_state, dummy_input)
param = self.model.init(model_seed, dummy_hidden_state, dummy_input)

# Set optimizer state
optim_state = self.optimizer.init(param)
Expand Down
2 changes: 1 addition & 1 deletion components/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __call__(self, h, g):
o1 = self.mlp1(x1)
# Add a small bias so that m_sign=1 initially
m_sign_raw = jnp.tanh(o1[..., 0]+self.bias)
m_sign = lax.stop_gradient((m_sign_raw >= 0.0) - m_sign_raw) + m_sign_raw
m_sign = lax.stop_gradient(2.0*(m_sign_raw >= 0.0) - 1.0 - m_sign_raw) + m_sign_raw
m = g_sign[..., 0] * m_sign * jnp.exp(o1[..., 1])
# Compute v: 2nd pseudo moment estimate
h2, x2 = self.rnn2(h2, 2.0*g_log)
Expand Down
2 changes: 1 addition & 1 deletion configs/sds_meta.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"env": [
{
"name": [["small_dense_long"]],
"name": [["small_dense_sparse"]],
"num_envs": [512],
"train_steps": [3e7]
}
Expand Down

0 comments on commit dfc8e92

Please sign in to comment.