Skip to content

Commit 076eb09

Browse files
authored
Add files via upload
1 parent 5c1129d commit 076eb09

File tree

3 files changed

+23
-10
lines changed

3 files changed

+23
-10
lines changed

params.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,16 @@ class train_params:
4242
DENSE2_SIZE = 300 # Size of second hidden layer in networks
4343
FINAL_LAYER_INIT = 0.003 # Initialise networks' final layer weights in range +/-final_layer_init
4444
NUM_ATOMS = 51 # Number of atoms in output layer of distributional critic
45-
V_MIN = -20.0 # Lower bound of critic value output distribution
46-
V_MAX = 0.0 # Upper bound of critic value output distribution (V_min and V_max should be chosen based on the range of normalised reward values in the chosen env)
45+
V_MIN = -10.0 # Lower bound of critic value output distribution
46+
V_MAX = 10.0 # Upper bound of critic value output distribution (V_min and V_max should be chosen based on the range of normalised reward values in the chosen env)
4747
TAU = 0.001 # Parameter for soft target network updates
4848
USE_BATCH_NORM = False # Whether or not to use batch normalisation in the networks
4949

5050
# Files/Directories
51-
SAVE_CKPT_STEP = 10000 # Save checkpoint every save_ckpt_step training steps
52-
CKPT_DIR = './ckpts' # Directory for saving/loading checkpoints
53-
CKPT_FILE = None # Checkpoint file to load and resume training from (if None, train from scratch)
54-
LOG_DIR = './logs/train' # Directory for saving Tensorboard logs (if None, do not save logs)
51+
SAVE_CKPT_STEP = 10000 # Save checkpoint every save_ckpt_step training steps
52+
CKPT_DIR = './ckpts/' + ENV # Directory for saving/loading checkpoints
53+
CKPT_FILE = None # Checkpoint file to load and resume training from (if None, train from scratch)
54+
LOG_DIR = './logs/train/' + ENV # Directory for saving Tensorboard logs (if None, do not save logs)
5555

5656

5757
class test_params:
@@ -66,10 +66,10 @@ class test_params:
6666
MAX_EP_LENGTH = 1000 # Maximum number of steps per episode
6767

6868
# Files/directories
69-
CKPT_DIR = './ckpts' # Directory for saving/loading checkpoints
69+
CKPT_DIR = './ckpts/' + ENV # Directory for saving/loading checkpoints
7070
CKPT_FILE = None # Checkpoint file to load and test (if None, load latest ckpt)
7171
RESULTS_DIR = './test_results' # Directory for saving txt file of results (if None, do not save results)
72-
LOG_DIR = './logs/test' # Directory for saving Tensorboard logs (if None, do not save logs)
72+
LOG_DIR = './logs/test/' + ENV # Directory for saving Tensorboard logs (if None, do not save logs)
7373

7474

7575
class play_params:
@@ -83,7 +83,7 @@ class play_params:
8383
MAX_EP_LENGTH = 1000 # Maximum number of steps per episode
8484

8585
# Files/directories
86-
CKPT_DIR = './ckpts' # Directory for saving/loading checkpoints
86+
CKPT_DIR = './ckpts/' + ENV # Directory for saving/loading checkpoints
8787
CKPT_FILE = None # Checkpoint file to load and run (if None, load latest ckpt)
8888
RECORD_DIR = './video' # Directory to store recorded gif of gameplay (if None, do not record)
8989

test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
'''
22
## Test ##
3-
# Test a trained D4PG network. This can be run alongside training by running 'run_every_new_ckpt.sh'.
3+
# Test a trained D4PG network. This can be run alongside training by running 'test_every_new_ckpt.py'.
44
@author: Mark Sinton ([email protected])
55
'''
66

test_every_new_ckpt.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
'''
2+
## Test Every New Ckpt ##
3+
# This allows testing to be run alongside training by running 'run_every_new_ckpt.sh', which monitors the ckpt directory and runs test.py every time a new ckpt is added.
4+
@author: Mark Sinton ([email protected])
5+
'''
6+
7+
from subprocess import call
8+
from params import test_params
9+
10+
if __name__ == '__main__':
11+
12+
ckpt_dir = test_params.CKPT_DIR
13+
call(['bash', 'utils/run_every_new_ckpt.sh', ckpt_dir])

0 commit comments

Comments
 (0)