-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcustom_policy.py
66 lines (57 loc) · 3.84 KB
/
custom_policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import numpy as np
import tensorflow as tf
from stable_baselines.common.tf_layers import conv, linear, conv_to_fc
from stable_baselines.common.policies import LstmPolicy, FeedForwardPolicy
def modified_cnn(scaled_images, **kwargs):
activ = tf.nn.relu
layer_1 = activ(conv(scaled_images, 'c1', n_filters=32, filter_size=4, stride=2, init_scale=np.sqrt(2), **kwargs))
layer_2 = activ(conv(layer_1, 'c2', n_filters=64, filter_size=4, stride=1, init_scale=np.sqrt(2), **kwargs))
layer_2 = conv_to_fc(layer_2)
return activ(linear(layer_2, 'fc1', n_hidden=512, init_scale=np.sqrt(2)))
class CustomCnnLnLstmPolicy(LstmPolicy):
"""
Policy object that implements actor critic, using a layer normalized LSTMs with a CNN feature extraction
:param sess: (TensorFlow session) The current TensorFlow session
:param ob_space: (Gym Space) The observation space of the environment
:param ac_space: (Gym Space) The action space of the environment
:param n_env: (int) The number of environments to run
:param n_steps: (int) The number of steps to run for each environment
:param n_batch: (int) The number of batch to run (n_envs * n_steps)
:param n_lstm: (int) The number of LSTM cells (for recurrent policies)
:param reuse: (bool) If the policy is reusable or not
:param kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
"""
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256, reuse=False, **_kwargs):
super(CustomCnnLnLstmPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm, reuse,
layer_norm=True, cnn_extractor=modified_cnn, feature_extraction="cnn", **_kwargs)
class CustomCnnLstmPolicy(LstmPolicy):
"""
Policy object that implements actor critic, using a layer normalized LSTMs with a CNN feature extraction
:param sess: (TensorFlow session) The current TensorFlow session
:param ob_space: (Gym Space) The observation space of the environment
:param ac_space: (Gym Space) The action space of the environment
:param n_env: (int) The number of environments to run
:param n_steps: (int) The number of steps to run for each environment
:param n_batch: (int) The number of batch to run (n_envs * n_steps)
:param n_lstm: (int) The number of LSTM cells (for recurrent policies)
:param reuse: (bool) If the policy is reusable or not
:param kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
"""
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256, reuse=False, **_kwargs):
super(CustomCnnLstmPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm, reuse,
layer_norm=False, cnn_extractor=modified_cnn, feature_extraction="cnn", **_kwargs)
class CustomCnnPolicy(FeedForwardPolicy):
"""
Policy object that implements actor critic, using a CNN (the nature CNN)
:param sess: (TensorFlow session) The current TensorFlow session
:param ob_space: (Gym Space) The observation space of the environment
:param ac_space: (Gym Space) The action space of the environment
:param n_env: (int) The number of environments to run
:param n_steps: (int) The number of steps to run for each environment
:param n_batch: (int) The number of batch to run (n_envs * n_steps)
:param reuse: (bool) If the policy is reusable or not
:param _kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
"""
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, **_kwargs):
super(CustomCnnPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse,
cnn_extractor=modified_cnn, feature_extraction="cnn", **_kwargs)