-
Notifications
You must be signed in to change notification settings - Fork 817
/
obs_filter.py
212 lines (177 loc) · 6.09 KB
/
obs_filter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
# Third party code
#
# The following code are copied or modified from:
# https://github.com/ray-project/ray/blob/master/python/ray/rllib/utils/filter.py
import numpy as np
class Filter(object):
"""Processes input, possibly statefully."""
def apply_changes(self, other, *args, **kwargs):
"""Updates self with "new state" from other filter."""
raise NotImplementedError
def copy(self):
"""Creates a new object with same state as self.
Returns:
A copy of self.
"""
raise NotImplementedError
def sync(self, other):
"""Copies all state from other filter to self."""
raise NotImplementedError
def clear_buffer(self):
"""Creates copy of current state and clears accumulated state"""
raise NotImplementedError
def as_serializable(self):
raise NotImplementedError
# http://www.johndcook.com/blog/standard_deviation/
class RunningStat(object):
def __init__(self, shape=None):
self._n = 0
self._M = np.zeros(shape)
self._S = np.zeros(shape)
def copy(self):
other = RunningStat()
other._n = self._n
other._M = np.copy(self._M)
other._S = np.copy(self._S)
return other
def push(self, x):
x = np.asarray(x)
# Unvectorized update of the running statistics.
if x.shape != self._M.shape:
raise ValueError(
"Unexpected input shape {}, expected {}, value = {}".format(
x.shape, self._M.shape, x))
n1 = self._n
self._n += 1
if self._n == 1:
self._M[...] = x
else:
delta = x - self._M
self._M[...] += delta / self._n
self._S[...] += delta * delta * n1 / self._n
def update(self, other):
n1 = self._n
n2 = other._n
n = n1 + n2
if n == 0:
# Avoid divide by zero, which creates nans
return
delta = self._M - other._M
delta2 = delta * delta
M = (n1 * self._M + n2 * other._M) / n
S = self._S + other._S + delta2 * n1 * n2 / n
self._n = n
self._M = M
self._S = S
def __repr__(self):
return '(n={}, mean_mean={}, mean_std={})'.format(
self.n, np.mean(self.mean), np.mean(self.std))
@property
def n(self):
return self._n
@property
def mean(self):
return self._M
@property
def var(self):
return self._S / (self._n - 1) if self._n > 1 else np.square(self._M)
@property
def std(self):
return np.sqrt(self.var)
@property
def shape(self):
return self._M.shape
class MeanStdFilter(Filter):
"""Keeps track of a running mean for seen states.
The filter will be used to normalize observations and will be
online updated according to the seen observations of all actors.
"""
is_concurrent = False
def __init__(self, shape, demean=True, destd=True, clip=10.0):
self.shape = shape
self.demean = demean
self.destd = destd
self.clip = clip
self.rs = RunningStat(shape)
# In distributed rollouts, each worker sees different states.
# The buffer is used to keep track of deltas amongst all the
# observation filters.
self.buffer = RunningStat(shape)
def clear_buffer(self):
self.buffer = RunningStat(self.shape)
def apply_changes(self, other, with_buffer=False):
"""Applies updates from the buffer of another filter.
Params:
other (MeanStdFilter): Other filter to apply info from
with_buffer (bool): Flag for specifying if the buffer should be
copied from other.
Examples:
>>> a = MeanStdFilter(())
>>> a(1)
>>> a(2)
>>> print([a.rs.n, a.rs.mean, a.buffer.n])
[2, 1.5, 2]
>>> b = MeanStdFilter(())
>>> b(10)
>>> a.apply_changes(b, with_buffer=False)
>>> print([a.rs.n, a.rs.mean, a.buffer.n])
[3, 4.333333333333333, 2]
>>> a.apply_changes(b, with_buffer=True)
>>> print([a.rs.n, a.rs.mean, a.buffer.n])
[4, 5.75, 1]
"""
self.rs.update(other.buffer)
if with_buffer:
self.buffer = other.buffer.copy()
def copy(self):
"""Returns a copy of Filter."""
other = MeanStdFilter(self.shape)
other.sync(self)
return other
def as_serializable(self):
return self.copy()
def sync(self, other):
"""Syncs all fields together from other filter.
Examples:
>>> a = MeanStdFilter(())
>>> a(1)
>>> a(2)
>>> print([a.rs.n, a.rs.mean, a.buffer.n])
[2, array(1.5), 2]
>>> b = MeanStdFilter(())
>>> b(10)
>>> print([b.rs.n, b.rs.mean, b.buffer.n])
[1, array(10.0), 1]
>>> a.sync(b)
>>> print([a.rs.n, a.rs.mean, a.buffer.n])
[1, array(10.0), 1]
"""
assert other.shape == self.shape, "Shapes don't match!"
self.demean = other.demean
self.destd = other.destd
self.clip = other.clip
self.rs = other.rs.copy()
self.buffer = other.buffer.copy()
def __call__(self, x, update=True):
x = np.asarray(x)
if update:
if len(x.shape) == len(self.rs.shape) + 1:
# The vectorized case.
for i in range(x.shape[0]):
self.rs.push(x[i])
self.buffer.push(x[i])
else:
# The unvectorized case.
self.rs.push(x)
self.buffer.push(x)
if self.demean:
x = x - self.rs.mean
if self.destd:
x = x / (self.rs.std + 1e-8)
if self.clip:
x = np.clip(x, -self.clip, self.clip)
return x
def __repr__(self):
return 'MeanStdFilter({}, {}, {}, {}, {}, {})'.format(
self.shape, self.demean, self.destd, self.clip, self.rs,
self.buffer)