Skip to content

Commit

Permalink
Add PPO evaluation
Browse files Browse the repository at this point in the history
Closes: #75
  • Loading branch information
wil3 committed Jun 14, 2020
1 parent 677995c commit 2104806
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 6 deletions.
3 changes: 2 additions & 1 deletion examples/gymfc_nf/policies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

from gymfc_nf.policies.pidpolicy import PidPolicy
__all__ = ['PidPolicy']
from gymfc_nf.policies.baselinespolicy import PpoBaselinesPolicy
__all__ = ['PidPolicy', 'PpoBaselinesPolicy']
14 changes: 14 additions & 0 deletions examples/gymfc_nf/policies/baselinespolicy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import numpy as np
import tensorflow as tf
from .policy import Policy
class PpoBaselinesPolicy(Policy):
def __init__(self, sess):
graph = tf.get_default_graph()
self.x = graph.get_tensor_by_name('pi/ob:0')
self.y = graph.get_tensor_by_name('pi/pol/final/BiasAdd:0')
self.sess = sess

def action(self, state, sim_time=0, desired=np.zeros(3), actual=np.zeros(3) ):

y_out = self.sess.run(self.y, feed_dict={self.x:[state] })
return y_out[0]
63 changes: 63 additions & 0 deletions examples/gymfc_nf/utils/monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import tensorflow as tf
import os.path
import time


class CheckpointMonitor:
"""Helper class to monitor the Tensorflow checkpoints and call a callback
when a new checkpoint has been created."""

def __init__(self, checkpoint_dir, callback):
"""
Args:
checkpoint_dir: Directory to monitor where new checkpoint
directories will be created
callback: A callback for when a new checkpoint is created.
"""
self.checkpoint_dir = checkpoint_dir
self.callback = callback
# Track which checkpoints have already been called.
self.processed = []

self.watching = True

def _check_new_checkpoint(self):
"""Update the queue with newly found checkpoints.
When a checkpoint directory is created a 'checkpoint' file is created
containing a list of all the checkpoints. We can monitor this file to
determine when new checkpoints have been created.
"""
# TODO (wfk) check if there is a way to get a callback when a file has
# changed.

ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir)
for path in ckpt.all_model_checkpoint_paths:
checkpoint_filename = os.path.split(path)[-1]
if tf.train.checkpoint_exists(path):
# Make sure there is a checkpoint meta file before allowing it
# to be processed
meta_file = path + ".meta"
if os.path.isfile(meta_file):
if (checkpoint_filename not in self.processed):
self.callback(checkpoint_filename)
self.processed.append(checkpoint_filename)
else:
print ("Meta file {} doesn't exist.".format(meta_file))

def start(self):

# Sit and wait until the checkpoint directory is created, otherwise we
# can't monitor it. If it never gets created this could be an indicator
# something is wrong with the trainer.
c=0
while not os.path.isdir(self.checkpoint_dir):
print("[WARN {}] Directory {} doesn't exist yet, waiting until "
"created...".format(c, self.checkpoint_dir))
time.sleep(30)
c+=1

while self.watching:
self._check_new_checkpoint()
time.sleep(10)

120 changes: 120 additions & 0 deletions examples/ppo_baselines_evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import argparse
from pathlib import Path
import os.path
import numpy as np
import tensorflow as tf
import gym
from gymfc_nf.envs import *
from gymfc_nf.utils.monitor import CheckpointMonitor
from gymfc_nf.policies import PpoBaselinesPolicy


if __name__ == "__main__":
parser = argparse.ArgumentParser("Evaluate OpenAI Baseline PPO1 checkpoints.")
parser.add_argument('ckpt_dir', help="Directory where checkpoints are saved. ")
parser.add_argument('--twin', default="./gymfc_nf/twins/nf1/model.sdf",
help="File path of the aircraft digitial twin/model SDF.")
parser.add_argument('--eval-dir',
help="Directory where evaluation logs are saved.")
parser.add_argument('--gym-id', default="gymfc_nf-step-v1")
parser.add_argument('--num-trials', type=int, default=1)
# Provide a seed so the same setpoint will be created. Useful for debugging
parser.add_argument('--seed', help='RNG seed', type=int, default=-1)

args = parser.parse_args()

seed = np.random.randint(0, 1e6) if args.seed < 0 else args.seed
gym_id = args.gym_id
ckpt_dir = args.ckpt_dir
model_dir = Path(ckpt_dir).parent
eval_dir = args.eval_dir if args.eval_dir else os.path.join(model_dir,
"evaluations")
num_trials = args.num_trials
print ("Saving evaluations to {}".format(eval_dir))

env = gym.make(gym_id)
env.seed(seed)
env.set_aircraft_model(args.twin)

log_header = ""
def make_header(ob_size):
"""Make the log header.
This needs to be done dynamically because the observations which are
used as input to the NN may differ.
"""
entries = []
entries.append("t")
for i in range(ob_size):
entries.append("ob{}".format(i))
for i in range(4):
entries.append("ac{}".format(i))
for i in range(4):
entries.append("y{}".format(i))
entries.append("p") # roll rate
entries.append("q") # pitch rate
entries.append("r") # yaw rate
entries.append("p-sp") # roll rate setpoint
entries.append("q-sp") # pitch rate setpoint
entries.append("r-sp") # yaw rate setpoint
for i in range(4):
entries.append("w{}".format(i)) # ESC rpms
entries.append("reward")

log_header = ",".join(entries)

def callback(checkpoint):
print ("Callback ", checkpoint)

ckpt_eval_dir = os.path.join(eval_dir, checkpoint)
Path(ckpt_eval_dir).mkdir(parents=True, exist_ok=True)

# TODO (wfk) I'm pretty sure this just takes the last checkpoint
# written defined by 'model_checkpoint_path' in the checkpoint file
# should look at how to specify the exact one.
checkpoint = tf.train.get_checkpoint_state(ckpt_dir)
input_checkpoint = checkpoint.model_checkpoint_path
print ("Using checkpoint=", input_checkpoint)
with tf.Session() as sess:
saver = tf.train.import_meta_graph(input_checkpoint + '.meta',
clear_devices=True)
saver.restore(sess, input_checkpoint)
pi = PpoBaselinesPolicy(sess)


for i in range(num_trials):

pi.reset()
ob = env.reset()
if len(log_header) == 0:
make_header(len(ob))

log_file = os.path.join(ckpt_eval_dir, "trial-{}.csv".format(i))

sim_time = 0
actual = np.zeros(3)

logs = []
while True:
ac = pi.action(ob, env.sim_time, env.angular_rate_sp,
env.imu_angular_velocity_rpy)
ob, reward, done, _ = env.step(ac)

log = ([env.sim_time] +
ob.tolist() + # The observations are the NN input
ac.tolist() + # The actions are the NN output
env.y.tolist() + # Y is the output sent to the ESC

env.imu_angular_velocity_rpy.tolist() + # Angular velocites
env.angular_rate_sp.tolist() + #
env.esc_motor_angular_velocity.tolist() +
[reward])# The reward that would have been given for the action, can be helpful for debugging

logs.append(log)

if done:
break
np.savetxt(log_file, logs, delimiter=",", header=log_header)

monitor = CheckpointMonitor(args.ckpt_dir, callback)
monitor.start()
11 changes: 6 additions & 5 deletions examples/ppo_example.py → examples/ppo_baselines_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_commit_hash():

def get_training_name():
now = datetime.datetime.now()
timestamp = now.strftime('%y%m%d-%H%M%S')
timestamp = now.strftime('%Y%m%d-%H%M%S')
return "baselines_" + get_commit_hash() + "_" + timestamp


Expand Down Expand Up @@ -80,19 +80,20 @@ def policy_fn(name, ob_space, ac_space):

if __name__ == '__main__':
parser = argparse.ArgumentParser("Synthesize a neuro-flight controller.")
parser.add_argument('--checkpoint_dir', default="../../checkpoints")
parser.add_argument('--model_dir', default="../../models",
help="Directory where models are saved to.")
parser.add_argument('--twin', default="./gymfc_nf/twins/nf1/model.sdf",
help="File path of the aircraft digitial twin/model SDF.")
parser.add_argument('--gym-id', default="gymfc_nf-step-v1")
parser.add_argument('--timesteps', type=int, default=10e6)
parser.add_argument('--ckpt-freq', type=int, default=100e3)
args = parser.parse_args()

training_dir = os.path.join(args.checkpoint_dir, get_training_name())
print ("Storing results to ", training_dir)
training_dir = os.path.join(args.model_dir, get_training_name())

seed = np.random.randint(0, 1e6)
ckpt_dir = os.path.join(training_dir, "checkpoints")
ckpt_dir = os.path.abspath(os.path.join(training_dir, "checkpoints"))
print ("Saving checkpoints to {}.".format(ckpt_dir))
render = False
# How many timesteps until a checkpoint is saved
ckpt_freq = args.ckpt_freq
Expand Down

0 comments on commit 2104806

Please sign in to comment.