-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathplot.py
160 lines (141 loc) · 7.3 KB
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import argparse
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from saferl_plotter import plot_utils as pu
def main():
parser = argparse.ArgumentParser(description='plotter')
parser.add_argument('--fig_length', type=int, default=6,
help='matplotlib figure length (default: 6)')
parser.add_argument('--fig_width', type=int, default=4.5,
help='matplotlib figure width (default: 6)')
parser.add_argument('--style', default='seaborn-white', #seaborn-white
help='matplotlib figure style (default: seaborn)')
parser.add_argument('--title', default=None,
help='matplotlib figure title (default: None)')
parser.add_argument('--xlabel', default=None,
help='matplotlib figure xlabel')
parser.add_argument('--xkey', default='total_steps',
help='x-axis key in csv file (default: l)')
parser.add_argument('--ykey', default=['EpRet'], nargs='+',
help='y-axis key in csv file (support multi) (default: r)')
parser.add_argument('--yduel', action='store_true',
help='=duel y axis (use if has two ykeys)')
parser.add_argument('--ylabel', default=None,
help='matplotlib figure ylabel')
parser.add_argument('--smooth', type=int, default=10,
help='smooth radius of y axis (default: 5)')
parser.add_argument('--resample', type=int, default=512,
help='if not zero, size of the uniform grid in x direction to resample onto. Resampling is performed via symmetric EMA smoothing (see the docstring for symmetric_ema). Default is zero (no resampling). Note that if average_group is True, resampling is necessary; in that case, default value is 512. (default: 512)')
parser.add_argument('--smooth_step', type=float, default=1.0,
help='when resampling (i.e. when resample > 0 or average_group is True), use this EMA decay parameter (in units of the new grid step). See docstrings for decay_steps in symmetric_ema or one_sided_ema functions.')
parser.add_argument('--avg_group', action='store_true',default=True,
help='average the curves in the same group and plot the mean.')
parser.add_argument('--shaded_std', action='store_true',default=True,
help='shaded region corresponding to standard deviation of the group')
parser.add_argument('--shaded_err', action='store_true',default=True,
help='shaded region corresponding to error in mean estimate')
parser.add_argument('--legend_loc', type=int, default=0,
help='location of legend')
parser.add_argument('--legend_outside', action='store_true',
help='place the legend outside of the figure')
parser.add_argument('--borderpad', type=float, default=0.5,
help='borderpad of legend (default: 0.5)')
parser.add_argument('--labelspacing', type=float, default=0.5,
help='labelspacing of legend (default: 0.5)')
parser.add_argument('--no_legend_group_num', action='store_true',
help="don't show num of group in legend")
parser.add_argument('--time', action='store_true',
help='enable this will activate parameters about time')
parser.add_argument('--time_unit', default='h',
help='parameters about time, x axis time unit (default: h)')
parser.add_argument('--time_interval', type=float, default=1,
help='parameters about time, x axis time interval (default: 1)')
parser.add_argument('--xformat', default='',
help='x-axis format')
parser.add_argument('--xlim', type=int, default=None,
help='x-axis limitation (default: None)')
parser.add_argument('--ylim', type=float, default=None,
help='y-axis limitation (default: None)')
parser.add_argument('--log_dir', default='./',
help='log dir (default: ./)')
parser.add_argument('--filters', default=[''], nargs='+',
help='filters of dirname')
parser.add_argument('--filename', default='logger.csv',
help='csv filename')
parser.add_argument('--show', action='store_true',
help='show figure')
parser.add_argument('--save', action='store_true',
help='save figure')
parser.add_argument('--dpi', type=int, default=400,
help='figure dpi (default: 400)')
args = parser.parse_args()
xscale = 1
if args.time:
if args.xlabel is None:
args.xlabel = 'Training time'
if args.time_unit == 'h':
xscale = 60 * 60
args.time_interval = 2
elif args.time_unit == 'min':
xscale = 60
args.time_interval = 20
else:
if args.xlabel is None:
args.xlabel = 'Total Interactions'
# if args.ylabel is None:
# args.ylabel = 'Episode Reward'
if '.' not in args.filename:
args.filename = args.filename + '.csv'
if args.save is False:
args.show = True
allresults = pu.load_results(args.log_dir, filename=args.filename, filters=args.filters)
pu.plot_results(allresults,
fig_length=args.fig_length,
fig_width=args.fig_width,
style=args.style,
title=args.title,
xlabel=args.xlabel,
ylabel=args.ylabel,
xkey=args.xkey,
ykey=args.ykey,
yduel=args.yduel,
xscale=xscale,
smooth_radius=args.smooth,
resample=args.resample,
smooth_step=args.smooth_step,
average_group=args.avg_group,
shaded_std=args.shaded_std,
shaded_err=args.shaded_err,
legend_outside=args.legend_outside,
legend_loc=args.legend_loc,
legend_group_num=not args.no_legend_group_num,
legend_borderpad=args.borderpad,
legend_labelspacing=args.labelspacing,
filename=args.filename)
ax = plt.gca() # get current axis
if args.time:
if args.time_unit == 'h' or args.time_unit == 'min':
ax.xaxis.set_major_locator(mticker.MultipleLocator(args.time_interval))
ax.xaxis.set_major_formatter(mticker.FormatStrFormatter("%d" + args.time_unit))
else:
if args.xformat == 'eng':
ax.xaxis.set_major_formatter(mticker.EngFormatter())
elif args.xformat == 'log':
ax.xaxis.set_major_formatter(mticker.LogFormatter())
elif args.xformat == 'sci':
#ax.xaxis.set_major_formatter(mticker.LogFormatterSciNotation())
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0), useMathText=True)
else:
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0), useMathText=False)
if args.xlim is not None:
plt.xlim((0, args.xlim))
if args.ylim is not None:
plt.ylim((0, args.ylim))
if args.ykey == ['EpCost']:
plt.hlines(0, 0, args.xlim, colors='red', linestyles='dashed')
if args.save:
plt.savefig(args.log_dir + 'figure', dpi=args.dpi, bbox_inches='tight')
if args.show:
plt.show()
if __name__ == "__main__":
main()