-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathscpo_core.py
190 lines (135 loc) · 6.2 KB
/
scpo_core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import numpy as np
import scipy.signal
from gym.spaces import Box, Discrete
import torch
import torch.nn as nn
from torch.distributions.normal import Normal
from torch.distributions.categorical import Categorical
EPS = 1e-8
def diagonal_gaussian_kl(mu0, log_std0, mu1, log_std1):
"""
torch symbol for mean KL divergence between two batches of diagonal gaussian distributions,
where distributions are specified by means and log stds.
(https://en.wikipedia.org/wiki/Kullback-Leibler_divergence#Multivariate_normal_distributions)
"""
var0, var1 = torch.exp(2 * log_std0), torch.exp(2 * log_std1)
pre_sum = 0.5*(((mu1- mu0)**2 + var0)/(var1 + EPS) - 1) + log_std1 - log_std0
all_kls = torch.sum(pre_sum, axis=1)
return torch.mean(all_kls)
def combined_shape(length, shape=None):
if shape is None:
return (length,)
return (length, shape) if np.isscalar(shape) else (length, *shape)
def mlp(sizes, activation, output_activation=nn.Identity):
layers = []
for j in range(len(sizes)-1):
act = activation if j < len(sizes)-2 else output_activation
layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
return nn.Sequential(*layers)
def count_vars(module):
return sum([np.prod(p.shape) for p in module.parameters()])
def discount_cumsum(x, discount):
"""
magic from rllab for computing discounted cumulative sums of vectors.
input:
vector x,
[x0,
x1,
x2]
output:
[x0 + discount * x1 + discount^2 * x2,
x1 + discount * x2,
x2]
"""
return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]
def expected_cost(cost_ret, discount):
expected_cost = cost_ret[:-1] - discount*cost_ret[1:]
return expected_cost
def future_max(x):
return [np.max(x[i:]) for i in range(len(x))]
def future_max_norm(x, pnorm):
return [np.linalg.norm(x[i:], ord=pnorm) for i in range(len(x))]
class Actor(nn.Module):
def _distribution(self, obs):
raise NotImplementedError
def _log_prob_from_distribution(self, pi, act):
raise NotImplementedError
def forward(self, obs, act=None):
# Produce action distributions for given observations, and
# optionally compute the log likelihood of given actions under
# those distributions.
pi = self._distribution(obs)
logp_a = None
if act is not None:
logp_a = self._log_prob_from_distribution(pi, act)
return pi, logp_a
def _d_kl(self, obs, old_mu, old_log_std, device):
raise NotImplementedError
class MLPCategoricalActor(Actor):
def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
super().__init__()
self.logits_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)
def _distribution(self, obs):
logits = self.logits_net(obs)
return Categorical(logits=logits)
def _log_prob_from_distribution(self, pi, act):
return pi.log_prob(act)
def _d_kl(self, obs, old_mu, old_log_std, device):
raise NotImplementedError
class MLPGaussianActor(Actor):
def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
super().__init__()
log_std = -0.5 * np.ones(act_dim, dtype=np.float32)
self.log_std = torch.nn.Parameter(torch.as_tensor(log_std))
self.mu_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)
def _distribution(self, obs):
mu = self.mu_net(obs)
# std = torch.clamp(0.01 + 0.99 * torch.exp(self.log_std), max=10)
std = torch.exp(self.log_std)
return Normal(mu, std)
def _log_prob_from_distribution(self, pi, act):
return pi.log_prob(act).sum(axis=-1) # Last axis sum needed for Torch Normal distribution
def _d_kl(self, obs, old_mu, old_log_std, device):
# kl divergence computation
mu = self.mu_net(obs.to(device))
log_std = self.log_std
d_kl = diagonal_gaussian_kl(old_mu.to(device), old_log_std.to(device), mu, log_std)
return d_kl
class MLPCritic(nn.Module):
def __init__(self, obs_dim, hidden_sizes, activation):
super().__init__()
self.v_net = mlp([obs_dim] + list(hidden_sizes) + [1], activation)
def forward(self, obs):
return torch.squeeze(self.v_net(obs), -1) # Critical to ensure v has right shape.
class MLPMaxCostCritic(nn.Module):
def __init__(self, obs_dim, hidden_sizes, activation):
super().__init__()
self.v_net = mlp([obs_dim] + list(hidden_sizes) + [1], activation, output_activation=nn.Softplus)
def forward(self, obs):
return torch.squeeze(self.v_net(obs), -1) # Critical to ensure v has right shape.
class MLPActorCritic(nn.Module):
def __init__(self, observation_space, action_space,
hidden_sizes=(64,64), activation=nn.Tanh):
super().__init__()
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
obs_dim = observation_space.shape[0] + 1 # this is especially designed for SCPO, since we require an additional M in the observation space
# policy builder depends on action space
if isinstance(action_space, Box):
self.pi = MLPGaussianActor(obs_dim, action_space.shape[0], hidden_sizes, activation).to(self.device)
elif isinstance(action_space, Discrete):
self.pi = MLPCategoricalActor(obs_dim, action_space.n, hidden_sizes, activation).to(self.device)
# build value function
self.v = MLPCritic(obs_dim, hidden_sizes, activation).to(self.device)
# build cost value function
self.vc = MLPMaxCostCritic(obs_dim, hidden_sizes, activation).to(self.device)
def step(self, obs):
with torch.no_grad():
obs = obs.to(self.device)
pi = self.pi._distribution(obs)
a = pi.sample()
logp_a = self.pi._log_prob_from_distribution(pi, a)
v = self.v(obs)
vc = self.vc(obs)
return a.cpu().numpy(), v.cpu().numpy(), vc.cpu().numpy(), logp_a.cpu().numpy(), pi.mean.cpu().numpy(), torch.log(pi.stddev).cpu().numpy()
def act(self, obs):
return self.step(obs)[0]