Skip to content

Commit cc9c3d2

Browse files
committed
brand new version
1 parent 131f9cf commit cc9c3d2

File tree

11 files changed

+165
-663
lines changed

11 files changed

+165
-663
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,4 +130,5 @@ dmypy.json
130130

131131
## mine
132132
backup/
133-
rl_plotter-preview/
133+
rl_plotter-preview/
134+
rl_plotter-history/

README.md

Lines changed: 44 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
![PyPI](https://img.shields.io/pypi/v/rl_plotter?style=flat-square) ![GitHub](https://img.shields.io/github/license/gxywy/rl-plotter?style=flat-square) ![GitHub last commit](https://img.shields.io/github/last-commit/gxywy/rl-plotter?style=flat-square)
44

5-
This is a simple tool which can plot learning curves easily for reinforcement learning.
5+
This is a simple tool which can plot learning curves easily for reinforcement learning (RL).
66

77
## Installation
88

@@ -15,65 +15,66 @@ pip install rl_plotter
1515
from source
1616

1717
```
18-
python3 setup.py install
18+
python setup.py install
1919
```
2020

2121
## Examples
2222

23-
First, add a logger in your code (for example: DQN):
23+
First, add our logger (compatible with [OpenAI-baseline](https://github.com/openai/baselines)) in your code
24+
25+
or just [OpenAI-baseline](https://github.com/openai/baselines) bench.Monitor (recommended)
2426

2527
```python
26-
from rl_plotter.logger import Logger
27-
28-
def train(name):
29-
dqn = DQN()
30-
logger = Logger(name, env_name='PongNoFrameskip-v4', use_tensorboard=False)
31-
32-
while True:
33-
s = env.reset()
34-
while True:
35-
total_step = logger.add_step()
36-
a = dqn.select_action(s, EPSILON)
37-
s_, r, done, info = env.step(a)
38-
39-
dqn.store_transition(s, a, r, s_)
40-
episode_reward += r
41-
42-
if dqn.replay_memory.memory_counter > REPLAY_MEMORY_SIZE:
43-
loss = dqn.learn()
44-
logger.add_loss(loss.cpu().item())
45-
if done:
46-
break
47-
s = s_
48-
logger.add_episode()
49-
logger.add_reward(episode_reward, freq=10)
50-
logger.finish()
28+
from baselines import bench
29+
env = bench.Monitor(env, log_dir)
5130
```
5231

5332
After the training or when you are training your agent, you can plot the learning curves in this way:
5433

5534
```
56-
python -m rl_plotter.plotter
35+
python -m rl_plotter.plotter --save --show
5736
```
5837
for help use:
5938
```
6039
python -m rl_plotter.plotter --help
6140
```
6241

63-
The learning curves looks like this:
42+
and you can find parameters to custom the style of your curves.
43+
44+
```
45+
optional arguments:
46+
-h, --help show this help message and exit
47+
--fig_length matplotlib figure length (default: 6)
48+
--fig_width matplotlib figure width (default: 6)
49+
--style matplotlib figure style (default: seaborn)
50+
--title matplotlib figure title (default: None)
51+
--xlabel matplotlib figure xlabel
52+
--xkey x-axis key in csv file (default: l)
53+
--ykey y-axis key in csv file (default: r)
54+
--smooth smooth radius of y axis (default: 1)
55+
--ylabel matplotlib figure ylabel
56+
--avg_group average the curves in the same group and plot the mean
57+
--shaded_std shaded region corresponding to standard deviation of the group
58+
--shaded_err shaded region corresponding to error in mean estimate of the group
59+
--legend_outside place the legend outside of the figure
60+
--time enable this will set x_key to t, and activate parameters about time
61+
--time_unit parameters about time, x axis time unit (default: h)
62+
--time_interval parameters about time, x axis time interval (default: 1)
63+
--xformat x-axis format
64+
--xlim x-axis limitation (default: None)
65+
--log_dir log dir (default: ./logs/)
66+
--filename csv filename
67+
--show show figure
68+
--save save figure
69+
--dpi DPI figure dpi (default: 400)
70+
```
71+
72+
finally, the learning curves looks like this:
6473
<div align="center"><img width="400" height="400" src="https://github.com/gxywy/rl-plotter/blob/master/imgs/figure_1.png?raw=true"/></div>
65-
<div align="center"><img width="400" height="400" src="https://github.com/gxywy/rl-plotter/blob/master/imgs/figure_2.png?raw=true"/></div>
66-
<div align="center"><img width="400" height="400" src="https://github.com/gxywy/rl-plotter/blob/master/imgs/figure_3.png?raw=true"/></div>
67-
And you can custom the style of your curves by use parameter of `rl_plotter.plotter`or modifying`rl_plotter.plotter`
6874

6975
## Features
70-
- [x] reinforcement learning plot tools
71-
- [x] timestamp x axis features
72-
- [x] history experiment data plot tools
73-
- [x] x axis formatter features
74-
- [x] multiprocessing algorithm x.monitor logger
75-
- [x] compatible with [OpenAI-baseline](https://github.com/openai/baselines) monitor data style
76-
- [ ] compatible with [OpenAI-baseline](https://github.com/openai/baselines) progress data style
77-
- [x] custom scalars logger (can be used to analyze any variable in training)
78-
- [ ] ~~basic data plot tools(including ML-Loss plot)~~
79-
- [ ] ~~dynamic plot tools~~
76+
- [x] custom logger, style, key, label, interval, and so on ...
77+
- [x] multi-experiment plotter
78+
- [x] x-axis formatter features
79+
- [x] x-axis formatter features
80+
- [x] compatible with [OpenAI-baseline](https://github.com/openai/baselines) monitor data style

imgs/figure_1.png

124 KB
Loading

imgs/figure_2.png

-370 KB
Binary file not shown.

imgs/figure_3.png

-280 KB
Binary file not shown.

imgs/screenshot_1.png

-177 KB
Binary file not shown.

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ numpy==1.16.5
33
statsmodels==0.10.1
44
matplotlib==3.1.2
55
tensorboardX==1.9
6+
glob

rl_plotter/logger.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,17 @@
66
import csv
77
import os
88
import json
9-
import random
109
import time
1110
import logging
12-
import matplotlib.pyplot as plt
1311
import numpy as np
1412

1513
class Logger():
16-
def __init__(self, exp_name, save=True, save_dir="./logs", env_name=None, use_tensorboard=False):
14+
def __init__(self, exp_name, save=True, log_dir="./logs", env_name=None, use_tensorboard=False):
1715
if save:
18-
self.save_dir = save_dir + "/" + exp_name + "/"
19-
if not os.path.exists(self.save_dir):
20-
os.makedirs(self.save_dir)
21-
self.csv_file = open(self.save_dir + 'monitor.csv', 'w')
16+
self.log_dir = log_dir + "/" + exp_name + "/"
17+
if not os.path.exists(self.log_dir):
18+
os.makedirs(self.log_dir)
19+
self.csv_file = open(self.log_dir + 'monitor.csv', 'w')
2220
header={"t_start": time.time(), 'env_id' : env_name}
2321
header = '# {} \n'.format(json.dumps(header))
2422
self.csv_file.write(header)

0 commit comments

Comments
 (0)