Skip to content

Commit 9e2e491

Browse files
committed
update to 2.4.0 & add spinningup plotter
1 parent 0a66735 commit 9e2e491

File tree

6 files changed

+185
-15
lines changed

6 files changed

+185
-15
lines changed

README_zh.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ custom_logger.update(fieldvalues=变量值(list), total_steps=当前训练步数
6262
rl_plotter --save --show
6363
```
6464

65-
## Example
65+
## 例子
6666

6767
**1. 常用命令**
6868

rl_plotter/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__all__ = ["logger", "log_utils", "plotter", "plot_utils"]
1+
__all__ = ["logger", "log_utils", "plotter", "plotter_spinup", "plot_utils"]

rl_plotter/plot_utils.py

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
__author__ = 'MICROYU'
55

6+
from datetime import date
67
import re
78
import matplotlib.pyplot as plt
89
import os.path as osp
@@ -162,7 +163,7 @@ def load_csv_results(dir, filename="monitor.csv"):
162163
#df.headers = headers # HACK to preserve backwards compatibility
163164
return df
164165

165-
def load_results(root_dir_or_dirs="./", filename="monitor.csv", filters=[]):
166+
def load_results(root_dir_or_dirs="./", filename="monitor.csv", filters=['']):
166167

167168
if isinstance(root_dir_or_dirs, str):
168169
rootdirs = [osp.expanduser(root_dir_or_dirs)]
@@ -368,21 +369,78 @@ def allequal(qs):
368369
plt.legend(
369370
[groups_results[key]['legend'] for key in groups_results.keys()],
370371
['%s (%i)'%(key.replace('without', 'w/o').replace('_', '-'), groups_results[key]['num']) for key in groups_results.keys()] if average_group else groups_results.keys(),
371-
loc=9 if legend_outside else legend_loc, bbox_to_anchor = (0.5,-0.1) if legend_outside else (1,1) if legend_outside else None, borderpad=legend_borderpad, labelspacing=legend_labelspacing,ncol= len(groups_results.keys()) if legend_outside else 1)
372+
loc=9 if legend_outside else legend_loc, bbox_to_anchor = (0.5,-0.1) if legend_outside else (1,1) if legend_outside else None, borderpad=legend_borderpad, labelspacing=legend_labelspacing, ncol=len(groups_results.keys()) if legend_outside else 1)
372373
else:
373374
plt.legend(
374375
[groups_results[key]['legend'] for key in groups_results.keys()],
375376
['%s'%(key.replace('without', 'w/o').replace('_', '-')) for key in groups_results.keys()] if average_group else groups_results.keys(),
376-
loc=9 if legend_outside else legend_loc, bbox_to_anchor = (0.5,-0.1) if legend_outside else (1,1) if legend_outside else None, borderpad=legend_borderpad, labelspacing=legend_labelspacing,ncol= len(groups_results.keys()) if legend_outside else 1)
377+
loc=9 if legend_outside else legend_loc, bbox_to_anchor = (0.5,-0.1) if legend_outside else (1,1) if legend_outside else None, borderpad=legend_borderpad, labelspacing=legend_labelspacing, ncol=len(groups_results.keys()) if legend_outside else 1)
377378
# add title
378379
plt.title(title)
379380
# add xlabels
380381
if xlabel is not None: plt.xlabel(xlabel)
381382

382383

384+
def plot_data(data, xaxis='total_steps', value="mean_score", condition="Condition1", smooth=1,
385+
legend_outside=False,
386+
legend_loc=0,
387+
legend_borderpad=1.0,
388+
legend_labelspacing=1.0,
389+
font_scale=1.5,
390+
**kwargs):
391+
import seaborn as sns
392+
if smooth > 1:
393+
"""
394+
smooth data with moving window average.
395+
that is,
396+
smoothed_y[t] = average(y[t-k], y[t-k+1], ..., y[t+k-1], y[t+k])
397+
where the "smooth" param is width of that window (2k+1)
398+
"""
399+
y = np.ones(smooth)
400+
for datum in data:
401+
x = np.asarray(datum[value])
402+
z = np.ones(len(x))
403+
smoothed_x = np.convolve(x,y,'same') / np.convolve(z,y,'same')
404+
datum[value] = smoothed_x
405+
if isinstance(data, list):
406+
data = pandas.concat(data, ignore_index=True)
407+
408+
data.sort_values(by='Condition1', axis=0)
409+
410+
sns.set(style="darkgrid", font_scale=font_scale)
411+
sns.lineplot(data=data, x=xaxis, y=value, hue=condition, ci='sd', **kwargs)
412+
handles, labels = plt.gca().get_legend_handles_labels()
413+
414+
plt.legend(
415+
handles[1:],
416+
['%s'%(key.replace('without', 'w/o').replace('_', '-')) for key in labels[1:]],
417+
loc=9 if legend_outside else legend_loc, bbox_to_anchor = (0.5,-0.1) if legend_outside else (1,1) if legend_outside else None, borderpad=legend_borderpad, labelspacing=legend_labelspacing, ncol=len(labels)-1 if legend_outside else 1)
418+
419+
xscale = np.max(np.asarray(data[xaxis])) > 5e3
420+
if xscale:
421+
# Just some formatting niceness: x-axis scale in scientific notation if max x is large
422+
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
423+
424+
plt.tight_layout(pad=0.5)
425+
426+
383427
if __name__ == "__main__":
384-
allresults = load_results("./logs", filename="evaluator.csv")
385-
for result in allresults:
386-
print(result["dirname"])
387-
plot_results(allresults, average_group=False, smooth_radius=0)
388-
#plt.savefig('figure', dpi=400)
428+
# allresults = load_results("./logs", filename="evaluator.csv")
429+
# for result in allresults:
430+
# print(result["dirname"])
431+
# plot_results(allresults, average_group=False, smooth_radius=0)
432+
# plt.show()
433+
434+
# allresults = load_results("./", filename="evaluator.csv")
435+
# datas = []
436+
# for result in allresults:
437+
# result['data'].insert(len(result['data'].columns),'Condition1', default_split_fn(result))
438+
# datas.append(result['data'])
439+
# plt.figure()
440+
# fig = plt.gcf()
441+
# fig.set_size_inches((16, 9), forward=False)
442+
443+
# plot_data(data=datas)
444+
# plt.show()
445+
pass
446+

rl_plotter/plotter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
def main():
1313
parser = argparse.ArgumentParser(description='plotter')
14-
parser.add_argument('--fig_length', type=int, default=8,
15-
help='matplotlib figure length (default: 8)')
14+
parser.add_argument('--fig_length', type=int, default=6,
15+
help='matplotlib figure length (default: 6)')
1616
parser.add_argument('--fig_width', type=int, default=6,
1717
help='matplotlib figure width (default: 6)')
1818
parser.add_argument('--style', default='seaborn',
@@ -68,7 +68,7 @@ def main():
6868
help='log dir (default: ./)')
6969
parser.add_argument('--filters', default=[''], nargs='+',
7070
help='filters of dirname')
71-
parser.add_argument('--filename', default='evaluator',
71+
parser.add_argument('--filename', default='evaluator.csv',
7272
help='csv filename')
7373
parser.add_argument('--show', action='store_true',
7474
help='show figure')

rl_plotter/plotter_spinup.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
4+
__author__ = 'MICROYU'
5+
6+
import argparse
7+
import matplotlib.pyplot as plt
8+
import matplotlib.ticker as mticker
9+
from rl_plotter import plot_utils as pu
10+
11+
12+
def main():
13+
parser = argparse.ArgumentParser(description='rl-plotter')
14+
parser.add_argument('--log_dir', default='./',
15+
help='log dir (default: ./)')
16+
parser.add_argument('--filters', default=[''], nargs='+',
17+
help='filters of dirname')
18+
parser.add_argument('--filename', default='evaluator.csv',
19+
help='csv filename')
20+
parser.add_argument('--show', action='store_true',
21+
help='show figure')
22+
parser.add_argument('--save', action='store_true',
23+
help='save figure')
24+
parser.add_argument('--dpi', type=int, default=400,
25+
help='figure dpi (default: 400)')
26+
parser.add_argument('--fig_length', type=int, default=6,
27+
help='matplotlib figure length (default: 6)')
28+
parser.add_argument('--fig_width', type=int, default=6,
29+
help='matplotlib figure width (default: 6)')
30+
31+
parser.add_argument('--title', default=None,
32+
help='matplotlib figure title (default: None)')
33+
parser.add_argument('--xlabel', default=None,
34+
help='matplotlib figure xlabel')
35+
parser.add_argument('--ylabel', default=None,
36+
help='matplotlib figure ylabel')
37+
parser.add_argument('--xkey', default='total_steps',
38+
help='x-axis key in csv file (default: l)')
39+
parser.add_argument('--ykey', default=['mean_score'], nargs='+',
40+
help='y-axis key in csv file (support multi) (default: r)')
41+
parser.add_argument('--smooth', type=int, default=1,
42+
help='smooth radius of y axis (default: 1)')
43+
parser.add_argument('--xlim', type=int, default=None,
44+
help='x-axis limitation (default: None)')
45+
46+
parser.add_argument('--legend_loc', type=int, default=0,
47+
help='location of legend')
48+
parser.add_argument('--legend_outside', action='store_true',
49+
help='place the legend outside of the figure')
50+
parser.add_argument('--borderpad', type=float, default=0.5,
51+
help='borderpad of legend (default: 0.5)')
52+
parser.add_argument('--labelspacing', type=float, default=0.5,
53+
help='labelspacing of legend (default: 0.5)')
54+
parser.add_argument('--font_scale', type=float, default=1,
55+
help='font_scale of seaborn (default: 1)')
56+
args = parser.parse_args()
57+
58+
if args.xlabel is None:
59+
args.xlabel = 'Timesteps'
60+
61+
if args.ylabel is None:
62+
args.ylabel = 'Episode Reward'
63+
64+
if '.' not in args.filename:
65+
args.filename = args.filename + '.csv'
66+
67+
# OpenAI baseline's monitor
68+
if args.filename == 'monitor.csv':
69+
args.xkey = 'l'
70+
args.ykey = ['r']
71+
72+
# OpenAI spinup's progress
73+
if args.filename == 'progress.txt' or args.filename == 'progress.csv':
74+
args.xkey = 'TotalEnvInteracts'
75+
args.ykey = ['AverageTestEpRet']
76+
77+
# rl-plotter's evaluator
78+
if args.filename == 'evaluator.csv':
79+
args.xkey = 'total_steps'
80+
args.ykey = ['mean_score']
81+
82+
if args.save is False:
83+
args.show = True
84+
85+
allresults = pu.load_results(args.log_dir, filename=args.filename, filters=args.filters)
86+
datas = []
87+
for result in allresults:
88+
result['data'].insert(len(result['data'].columns),'Condition1', pu.default_split_fn(result))
89+
datas.append(result['data'])
90+
pu.plot_data(data=datas, xaxis=args.xkey, value=args.ykey[0], smooth=args.smooth,
91+
legend_outside=args.legend_outside,
92+
legend_loc=args.legend_loc,
93+
legend_borderpad=args.borderpad,
94+
legend_labelspacing=args.labelspacing,
95+
font_scale=args.font_scale)
96+
plt.title(args.title)
97+
plt.xlabel(args.xlabel)
98+
plt.ylabel(args.ylabel)
99+
fig = plt.gcf()
100+
fig.set_size_inches((args.fig_length, args.fig_width), forward=False)
101+
102+
if args.xlim is not None:
103+
plt.xlim((0, args.xlim))
104+
105+
if args.save:
106+
plt.savefig(args.log_dir + 'figure', dpi=args.dpi, bbox_inches='tight')
107+
if args.show:
108+
plt.show()
109+
110+
111+
if __name__ == "__main__":
112+
main()

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setuptools.setup(
77
name="rl_plotter",
8-
version="2.3.6",
8+
version="2.4.0",
99
author="Xiaoyu Gong",
1010
author_email="[email protected]",
1111
description="A plotter for reinforcement learning (RL)",
@@ -19,7 +19,7 @@
1919
"Operating System :: OS Independent",
2020
],
2121
entry_points = {
22-
'console_scripts': ['rl_plotter=rl_plotter.plotter:main'],
22+
'console_scripts': ['rl_plotter=rl_plotter.plotter:main','rl_plotter_spinup=rl_plotter.plotter_spinup:main'],
2323
},
2424
python_requires='>=3.0',
2525
)

0 commit comments

Comments
 (0)