Skip to content

Commit

Permalink
Support for tensor logging to tensorboard (rll#88)
Browse files Browse the repository at this point in the history
Add a customized tensor scalar to tensorboard by using the
custom_scalar plugin in tensorboard. Each line in the scalar
corresponds to an element in the tensor.

Wrap the tensorboard logging module into a new class `Summary`
in file rllab/misc/tensor_summary.py. It supports both the
simple value and tensor logging. It also saves the
computation graph created by rllab.

To record the tensor into tensorboard, use the
`record_tensor` function in file rllab/misc/logger.py.

Refer to: rll#39, rll#38
  • Loading branch information
CatherineSue authored May 26, 2018
1 parent 5c42053 commit 4d2417e
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 38 deletions.
66 changes: 28 additions & 38 deletions rllab/misc/logger.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
from enum import Enum

from rllab.misc import tabulate
from rllab.misc import mkdir_p, colorize
from rllab.misc import get_all_parameters
from contextlib import contextmanager
import numpy as np
import base64
import csv
import datetime
import json
import os
import os.path as osp
import pickle
import sys
import datetime
from contextlib import contextmanager
from enum import Enum

import dateutil.tz
import csv
import joblib
import json
import pickle
import base64
import tensorflow as tf
import numpy as np

from rllab.misc.autoargs import get_all_parameters
from rllab.misc.console import mkdir_p, colorize
from rllab.misc.tabulate import tabulate
from rllab.misc.tensorboard_output import TensorBoardOutput

_prefixes = []
_prefix_str = ''
Expand All @@ -32,7 +33,6 @@
_tabular_fds = {}
_tabular_header_written = set()

_tensorboard_writer = None
_snapshot_dir = None
_snapshot_mode = 'all'
_snapshot_gap = 1
Expand All @@ -43,6 +43,8 @@
_tensorboard_default_step = 0
_tensorboard_step_key = None

_tensorboard = TensorBoardOutput()


def _add_output(file_name, arr, fds, mode='a'):
if file_name not in arr:
Expand Down Expand Up @@ -83,17 +85,7 @@ def remove_tabular_output(file_name):


def set_tensorboard_dir(dir_name):
global _tensorboard_writer
if not dir_name:
if _tensorboard_writer:
_tensorboard_writer.close()
_tensorboard_writer = None
else:
mkdir_p(os.path.dirname(dir_name))
_tensorboard_writer = tf.summary.FileWriter(dir_name)
_tensorboard_default_step = 0
assert _tensorboard_writer is not None
print("tensorboard data will be logged into:", dir_name)
_tensorboard.set_dir(dir_name)


def set_snapshot_dir(dir_name):
Expand Down Expand Up @@ -157,9 +149,15 @@ def log(s, with_prefix=True, with_timestamp=True, color=None):


def record_tabular(key, val):
_tensorboard.record_scalar(str(key), val)
_tabular.append((_tabular_prefix_str + str(key), str(val)))


def record_tensor(key, val):
"""Record tf.Tensor into tensorboard with Tensor.name and its value."""
_tensorboard.record_tensor(key, val)


def push_tabular_prefix(key):
_tabular_prefixes.append(key)
global _tabular_prefix_str
Expand Down Expand Up @@ -214,20 +212,12 @@ def refresh(self):


def dump_tensorboard(*args, **kwargs):
if len(_tabular) > 0 and _tensorboard_writer:
if len(_tabular) > 0:
tabular_dict = dict(_tabular)
if _tensorboard_step_key and _tensorboard_step_key in tabular_dict:
step = tabular_dict[_tensorboard_step_key]
else:
global _tensorboard_default_step
step = _tensorboard_default_step
_tensorboard_default_step += 1

summary = tf.Summary()
for k, v in tabular_dict.items():
summary.value.add(tag=k, simple_value=float(v))
_tensorboard_writer.add_summary(summary, int(step))
_tensorboard_writer.flush()
step = None
if _tensorboard_step_key and _tensorboard_step_key in tabular_dict:
step = tabular_dict[_tensorboard_step_key]
_tensorboard.dump_tensorboard(step)


def dump_tabular(*args, **kwargs):
Expand Down
88 changes: 88 additions & 0 deletions rllab/misc/tensorboard_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import os

import numpy as np
import tensorflow as tf
from tensorboard import summary as summary_lib
from tensorboard.plugins.custom_scalar import layout_pb2

import rllab.misc.logger
from rllab.misc.console import mkdir_p


class TensorBoardOutput:
def __init__(self):
self._scalars = tf.Summary()
self._scope_tensor = {}

self._default_step = 0
self._writer = None

def set_dir(self, dir_name):
if not dir_name:
if self._writer:
self._writer.close()
self._writer = None
else:
mkdir_p(os.path.dirname(dir_name))
self._writer = tf.summary.FileWriter(dir_name)
self._default_step = 0
assert self._writer is not None
rllab.misc.logger.log("tensorboard data will be logged into:" +
dir_name)

def dump_tensorboard(self, step=None):
if not self._writer:
return
run_step = self._default_step
if step:
run_step = step
else:
self._default_step += 1

self._dump_graph()
self._dump_scalars(run_step)
self._dump_tensors()

def record_scalar(self, key, val):
self._scalars.value.add(tag=str(key), simple_value=float(val))

def record_tensor(self, key, val):
scope = str(key).split('/', 1)[0]
if scope not in self._scope_tensor:
self._scope_tensor[scope] = [key]
else:
if key not in self._scope_tensor[scope]:
self._scope_tensor[scope].append(key)

for idx, v in np.ndenumerate(np.array(val)):
self._scalars.value.add(
tag=key + '/' + str(idx).strip('()'), simple_value=float(v))

def _dump_graph(self):
self._writer.add_graph(tf.get_default_graph())
self._writer.flush()

def _dump_scalars(self, step):
self._writer.add_summary(self._scalars, int(step))
self._writer.flush()
del self._scalars.value[:]

def _dump_tensors(self):
layout_categories = []

for scope in self._scope_tensor:
chart = []
for name in self._scope_tensor[scope]:
chart.append(
layout_pb2.Chart(
title=name,
multiline=layout_pb2.MultilineChartContent(
tag=[r'name(?!.*margin.*)'.replace('name', name)
])))
category = layout_pb2.Category(title=scope, chart=chart)
layout_categories.append(category)

if layout_categories:
layout_summary = summary_lib.custom_scalar_pb(
layout_pb2.Layout(category=layout_categories))
self._writer.add_summary(layout_summary)

0 comments on commit 4d2417e

Please sign in to comment.