|
| 1 | +import numpy as np |
| 2 | +import torch |
| 3 | +import torch.nn as nn |
| 4 | +from torch.nn import functional as F |
| 5 | +from torch.distributions import Distribution, Normal |
| 6 | +# following SAC authors' and OpenAI implementation |
| 7 | +LOG_SIG_MAX = 2 |
| 8 | +LOG_SIG_MIN = -20 |
| 9 | +ACTION_BOUND_EPSILON = 1E-6 |
| 10 | +# these numbers are from the MBPO paper |
| 11 | +mbpo_target_entropy_dict = {'Hopper-v2':-1, 'HalfCheetah-v2':-3, 'Walker2d-v2':-3, 'Ant-v2':-4, 'Humanoid-v2':-2} |
| 12 | +mbpo_epoches = {'Hopper-v2':125, 'Walker2d-v2':300, 'Ant-v2':300, 'HalfCheetah-v2':400, 'Humanoid-v2':300} |
| 13 | + |
| 14 | +def weights_init_(m): |
| 15 | + # weight init helper function |
| 16 | + if isinstance(m, nn.Linear): |
| 17 | + torch.nn.init.xavier_uniform_(m.weight, gain=1) |
| 18 | + torch.nn.init.constant_(m.bias, 0) |
| 19 | + |
| 20 | +class ReplayBuffer: |
| 21 | + """ |
| 22 | + A simple FIFO experience replay buffer |
| 23 | + """ |
| 24 | + def __init__(self, obs_dim, act_dim, size): |
| 25 | + """ |
| 26 | + :param obs_dim: size of observation |
| 27 | + :param act_dim: size of the action |
| 28 | + :param size: size of the buffer |
| 29 | + """ |
| 30 | + ## init buffers as numpy arrays |
| 31 | + self.obs1_buf = np.zeros([size, obs_dim], dtype=np.float32) |
| 32 | + self.obs2_buf = np.zeros([size, obs_dim], dtype=np.float32) |
| 33 | + self.acts_buf = np.zeros([size, act_dim], dtype=np.float32) |
| 34 | + self.rews_buf = np.zeros(size, dtype=np.float32) |
| 35 | + self.done_buf = np.zeros(size, dtype=np.float32) |
| 36 | + self.ptr, self.size, self.max_size = 0, 0, size |
| 37 | + |
| 38 | + self.count_buf = np.zeros(size) ## |
| 39 | + self.max_count = 0 ## |
| 40 | + self.mean_count = 0 ## |
| 41 | + |
| 42 | + def store(self, obs, act, rew, next_obs, done): |
| 43 | + """ |
| 44 | + data will get stored in the pointer's location |
| 45 | + """ |
| 46 | + self.obs1_buf[self.ptr] = obs |
| 47 | + self.obs2_buf[self.ptr] = next_obs |
| 48 | + self.acts_buf[self.ptr] = act |
| 49 | + self.rews_buf[self.ptr] = rew |
| 50 | + self.done_buf[self.ptr] = done |
| 51 | + self.count_buf[self.ptr] = 0 ## |
| 52 | + ## move the pointer to store in next location in buffer |
| 53 | + self.ptr = (self.ptr+1) % self.max_size |
| 54 | + ## keep track of the current buffer size |
| 55 | + self.size = min(self.size+1, self.max_size) |
| 56 | + |
| 57 | + def sample_batch(self, batch_size=32, idxs=None): |
| 58 | + """ |
| 59 | + :param batch_size: size of minibatch |
| 60 | + :param idxs: specify indexes if you want specific data points |
| 61 | + :return: mini-batch data as a dictionary |
| 62 | + """ |
| 63 | + if idxs is None: |
| 64 | + idxs = np.random.randint(0, self.size, size=batch_size) |
| 65 | + self.count_buf[idxs] += 1 ## |
| 66 | + self.max_count = max(self.max_count, self.count_buf.max()) ## |
| 67 | + self.mean_count = self.count_buf.mean() ## |
| 68 | + return dict(obs1=self.obs1_buf[idxs], |
| 69 | + obs2=self.obs2_buf[idxs], |
| 70 | + acts=self.acts_buf[idxs], |
| 71 | + rews=self.rews_buf[idxs], |
| 72 | + done=self.done_buf[idxs], |
| 73 | + count=self.count_buf[idxs], |
| 74 | + idxs=idxs) |
| 75 | + |
| 76 | + |
| 77 | +class Mlp(nn.Module): |
| 78 | + def __init__( |
| 79 | + self, |
| 80 | + input_size, |
| 81 | + output_size, |
| 82 | + hidden_sizes, |
| 83 | + hidden_activation=F.relu |
| 84 | + ): |
| 85 | + super().__init__() |
| 86 | + |
| 87 | + self.input_size = input_size |
| 88 | + self.output_size = output_size |
| 89 | + self.hidden_activation = hidden_activation |
| 90 | + ## here we use ModuleList so that the layers in it can be |
| 91 | + ## detected by .parameters() call |
| 92 | + self.hidden_layers = nn.ModuleList() |
| 93 | + in_size = input_size |
| 94 | + |
| 95 | + ## initialize each hidden layer |
| 96 | + for i, next_size in enumerate(hidden_sizes): |
| 97 | + fc_layer = nn.Linear(in_size, next_size) |
| 98 | + in_size = next_size |
| 99 | + self.hidden_layers.append(fc_layer) |
| 100 | + |
| 101 | + ## init last fully connected layer with small weight and bias |
| 102 | + self.last_fc_layer = nn.Linear(in_size, output_size) |
| 103 | + self.apply(weights_init_) |
| 104 | + |
| 105 | + def forward(self, input): |
| 106 | + h = input |
| 107 | + for i, fc_layer in enumerate(self.hidden_layers): |
| 108 | + h = fc_layer(h) |
| 109 | + h = self.hidden_activation(h) |
| 110 | + output = self.last_fc_layer(h) |
| 111 | + return output |
| 112 | + |
| 113 | +class TanhNormal(Distribution): |
| 114 | + """ |
| 115 | + Represent distribution of X where |
| 116 | + X ~ tanh(Z) |
| 117 | + Z ~ N(mean, std) |
| 118 | + Note: this is not very numerically stable. |
| 119 | + """ |
| 120 | + def __init__(self, normal_mean, normal_std, epsilon=1e-6): |
| 121 | + """ |
| 122 | + :param normal_mean: Mean of the normal distribution |
| 123 | + :param normal_std: Std of the normal distribution |
| 124 | + :param epsilon: Numerical stability epsilon when computing log-prob. |
| 125 | + """ |
| 126 | + self.normal_mean = normal_mean |
| 127 | + self.normal_std = normal_std |
| 128 | + self.normal = Normal(normal_mean, normal_std) |
| 129 | + self.epsilon = epsilon |
| 130 | + |
| 131 | + def log_prob(self, value, pre_tanh_value=None): |
| 132 | + """ |
| 133 | + return the log probability of a value |
| 134 | + :param value: some value, x |
| 135 | + :param pre_tanh_value: arctanh(x) |
| 136 | + :return: |
| 137 | + """ |
| 138 | + # use arctanh formula to compute arctanh(value) |
| 139 | + if pre_tanh_value is None: |
| 140 | + pre_tanh_value = torch.log( |
| 141 | + (1+value) / (1-value) |
| 142 | + ) / 2 |
| 143 | + return self.normal.log_prob(pre_tanh_value) - \ |
| 144 | + torch.log(1 - value * value + self.epsilon) |
| 145 | + |
| 146 | + def sample(self, return_pretanh_value=False): |
| 147 | + """ |
| 148 | + Gradients will and should *not* pass through this operation. |
| 149 | + See https://github.com/pytorch/pytorch/issues/4620 for discussion. |
| 150 | + """ |
| 151 | + z = self.normal.sample().detach() |
| 152 | + |
| 153 | + if return_pretanh_value: |
| 154 | + return torch.tanh(z), z |
| 155 | + else: |
| 156 | + return torch.tanh(z) |
| 157 | + |
| 158 | + def rsample(self, return_pretanh_value=False): |
| 159 | + """ |
| 160 | + Sampling in the reparameterization case. |
| 161 | + Implement: tanh(mu + sigma * eksee) |
| 162 | + with eksee~N(0,1) |
| 163 | + z here is mu+sigma+eksee |
| 164 | + """ |
| 165 | + z = ( |
| 166 | + self.normal_mean + |
| 167 | + self.normal_std * |
| 168 | + Normal( ## this part is eksee~N(0,1) |
| 169 | + torch.zeros(self.normal_mean.size()), |
| 170 | + torch.ones(self.normal_std.size()) |
| 171 | + ).sample() |
| 172 | + ) |
| 173 | + if return_pretanh_value: |
| 174 | + return torch.tanh(z), z |
| 175 | + else: |
| 176 | + return torch.tanh(z) |
| 177 | + |
| 178 | +class TanhGaussianPolicy(Mlp): |
| 179 | + """ |
| 180 | + A Gaussian policy network with Tanh to enforce action limits |
| 181 | + """ |
| 182 | + def __init__( |
| 183 | + self, |
| 184 | + obs_dim, |
| 185 | + action_dim, |
| 186 | + hidden_sizes, |
| 187 | + hidden_activation=F.relu, |
| 188 | + action_limit=1.0 |
| 189 | + ): |
| 190 | + super().__init__( |
| 191 | + input_size=obs_dim, |
| 192 | + output_size=action_dim, |
| 193 | + hidden_sizes=hidden_sizes, |
| 194 | + hidden_activation=hidden_activation, |
| 195 | + ) |
| 196 | + last_hidden_size = obs_dim |
| 197 | + if len(hidden_sizes) > 0: |
| 198 | + last_hidden_size = hidden_sizes[-1] |
| 199 | + ## this is the layer that gives log_std, init this layer with small weight and bias |
| 200 | + self.last_fc_log_std = nn.Linear(last_hidden_size, action_dim) |
| 201 | + ## action limit: for example, humanoid has an action limit of -0.4 to 0.4 |
| 202 | + self.action_limit = action_limit |
| 203 | + self.apply(weights_init_) |
| 204 | + |
| 205 | + def forward( |
| 206 | + self, |
| 207 | + obs, |
| 208 | + deterministic=False, |
| 209 | + return_log_prob=True, |
| 210 | + ): |
| 211 | + """ |
| 212 | + :param obs: Observation |
| 213 | + :param reparameterize: if True, use the reparameterization trick |
| 214 | + :param deterministic: If True, take determinisitc (test) action |
| 215 | + :param return_log_prob: If True, return a sample and its log probability |
| 216 | + """ |
| 217 | + h = obs |
| 218 | + for fc_layer in self.hidden_layers: |
| 219 | + h = self.hidden_activation(fc_layer(h)) |
| 220 | + mean = self.last_fc_layer(h) |
| 221 | + |
| 222 | + log_std = self.last_fc_log_std(h) |
| 223 | + log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) |
| 224 | + std = torch.exp(log_std) |
| 225 | + |
| 226 | + normal = Normal(mean, std) |
| 227 | + |
| 228 | + if deterministic: |
| 229 | + pre_tanh_value = mean |
| 230 | + action = torch.tanh(mean) |
| 231 | + else: |
| 232 | + pre_tanh_value = normal.rsample() |
| 233 | + action = torch.tanh(pre_tanh_value) |
| 234 | + |
| 235 | + if return_log_prob: |
| 236 | + log_prob = normal.log_prob(pre_tanh_value) |
| 237 | + log_prob -= torch.log(1 - action.pow(2) + ACTION_BOUND_EPSILON) |
| 238 | + log_prob = log_prob.sum(1, keepdim=True) |
| 239 | + else: |
| 240 | + log_prob = None |
| 241 | + |
| 242 | + return ( |
| 243 | + action * self.action_limit, mean, log_std, log_prob, std, pre_tanh_value, |
| 244 | + ) |
| 245 | + |
| 246 | +def soft_update_model1_with_model2(model1, model2, rou): |
| 247 | + """ |
| 248 | + used to polyak update a target network |
| 249 | + :param model1: a pytorch model |
| 250 | + :param model2: a pytorch model of the same class |
| 251 | + :param rou: the update is model1 <- rou*model1 + (1-rou)model2 |
| 252 | + """ |
| 253 | + for model1_param, model2_param in zip(model1.parameters(), model2.parameters()): |
| 254 | + model1_param.data.copy_(rou*model1_param.data + (1-rou)*model2_param.data) |
| 255 | + |
| 256 | +def test_agent(agent, test_env, max_ep_len, logger, n_eval=1): |
| 257 | + """ |
| 258 | + This will test the agent's performance by running <n_eval> episodes |
| 259 | + During the runs, the agent should only take deterministic action |
| 260 | + This function assumes the agent has a <get_test_action()> function |
| 261 | + :param agent: agent instance |
| 262 | + :param test_env: the environment used for testing |
| 263 | + :param max_ep_len: max length of an episode |
| 264 | + :param logger: logger to store info in |
| 265 | + :param n_eval: number of episodes to run the agent |
| 266 | + :return: test return for each episode as a numpy array |
| 267 | + """ |
| 268 | + ep_return_list = np.zeros(n_eval) |
| 269 | + for j in range(n_eval): |
| 270 | + o, r, d, ep_ret, ep_len = test_env.reset(), 0, False, 0, 0 |
| 271 | + while not (d or (ep_len == max_ep_len)): |
| 272 | + # Take deterministic actions at test time |
| 273 | + a = agent.get_test_action(o) |
| 274 | + o, r, d, _ = test_env.step(a) |
| 275 | + ep_ret += r |
| 276 | + ep_len += 1 |
| 277 | + ep_return_list[j] = ep_ret |
| 278 | + if logger is not None: |
| 279 | + logger.store(TestEpRet=ep_ret, TestEpLen=ep_len) |
| 280 | + return ep_return_list |
0 commit comments