Skip to content

Commit cd4b095

Browse files
committed
change logger.py for evaluation
1 parent ca5c390 commit cd4b095

File tree

4 files changed

+59
-79
lines changed

4 files changed

+59
-79
lines changed

README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,16 @@ python setup.py install
2020

2121
## Usage
2222

23-
Add our logger (compatible with [OpenAI-baseline](https://github.com/openai/baselines)) in your code or just use [OpenAI-baseline](https://github.com/openai/baselines) bench.Monitor (recommended):
23+
Add our logger in your code of evaluation
24+
25+
```python
26+
from rl_plotter.logger import Logger
27+
logger = Logger(exp_name="your_exp_name", log_dir, env_name)
28+
····
29+
logger.update(score=evaluation_score_list, total_steps=current_training_steps)
30+
```
31+
32+
or just use [OpenAI-baseline](https://github.com/openai/baselines) bench.Monitor (recommended):
2433

2534
```python
2635
from baselines import bench

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,4 @@ pandas
22
numpy
33
statsmodels
44
matplotlib
5-
tensorboardX
65
glob

rl_plotter/logger.py

Lines changed: 48 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -5,88 +5,60 @@
55

66
import csv
77
import os
8-
import json
9-
import time
10-
import logging
8+
import json, time
119
import numpy as np
1210

13-
class Logger():
14-
def __init__(self, exp_name, save=True, log_dir="./logs", env_name=None):
15-
if save:
16-
self.log_dir = log_dir + "/" + exp_name + "/"
17-
if not os.path.exists(self.log_dir):
18-
os.makedirs(self.log_dir)
19-
self.csv_file = open(self.log_dir + 'monitor.csv', 'w')
20-
header={"t_start": time.time(), 'env_id' : env_name}
21-
header = '# {} \n'.format(json.dumps(header))
22-
self.csv_file.write(header)
23-
self.logger = csv.DictWriter(self.csv_file, fieldnames=('r', 'l', 't'))
24-
self.logger.writeheader()
25-
self.csv_file.flush()
11+
color2num = dict(
12+
gray=30,
13+
red=31,
14+
green=32,
15+
yellow=33,
16+
blue=34,
17+
magenta=35,
18+
cyan=36,
19+
white=37,
20+
crimson=38
21+
)
2622

27-
self.step_counter = 0
28-
self.episode_counter = 0
29-
self.steps = []
30-
self.rewards = []
31-
self.losses = []
23+
def colorize(string, color, bold=False, highlight=False):
24+
"""
25+
Colorize a string.
3226
33-
self.save = save
34-
self.exp_name = exp_name
35-
self.is_learning_start = False
36-
self.start_time = time.time()
37-
38-
logging.basicConfig(level=logging.INFO, format='[' + exp_name + '] %(asctime)s: %(levelname)s %(message)s')
39-
logging.info(self.exp_name + " start !")
27+
This function was originally written by John Schulman.
28+
"""
29+
attr = []
30+
num = color2num[color]
31+
if highlight: num += 10
32+
attr.append(str(num))
33+
if bold: attr.append('1')
34+
return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string)
4035

41-
def add_step(self):
42-
self.step_counter += 1
43-
return np.sum(self.steps)
44-
45-
def add_episode(self):
46-
self.steps.append(self.step_counter)
47-
self.step_counter = 0
48-
self.episode_counter += 1
49-
return self.episode_counter
5036

51-
def add_reward(self, reward, freq=10):
52-
self.rewards.append(reward)
53-
total_step = np.sum(self.steps)
37+
class Logger():
38+
def __init__(self, exp_name, log_dir="./logs", env_name=None):
39+
self.log_dir = log_dir + "/" + exp_name + "/"
40+
if not os.path.exists(self.log_dir):
41+
os.makedirs(self.log_dir)
42+
self.csv_file = open(self.log_dir + 'evaluator.csv', 'w', encoding='utf8')
43+
header={"t_start": time.time(), 'env_id' : env_name}
44+
header = '# {} \n'.format(json.dumps(header))
45+
self.csv_file.write(header)
46+
self.logger = csv.DictWriter(self.csv_file, fieldnames=('mean_score', 'total_steps', 'std_score', 'max_score', 'min_score'))
47+
self.logger.writeheader()
48+
self.csv_file.flush()
5449

55-
if self.use_tensorboard:
56-
self.tf_board_writer.add_scalar('Train/reward', reward, total_step)
57-
58-
if self.episode_counter % freq == 0:
59-
if len(self.losses) == 0:
60-
logging.info("episodes: %d, mean reward: %.2f, steps: %d, mean loss: nan" % \
61-
(self.episode_counter, np.mean(self.rewards[-freq:]), total_step))
62-
else:
63-
logging.info("episodes: %d, mean reward: %.2f, steps: %d, mean loss: %f" % \
64-
(self.episode_counter, np.mean(self.rewards[-freq:]), total_step, np.mean(self.losses[-freq:])))
65-
66-
if self.save:
67-
epinfo = {"r": reward, "l": self.steps[-1], "t": time.time() - self.start_time}
68-
self.logger.writerow(epinfo)
69-
self.csv_file.flush()
50+
def update(self, score, total_steps):
51+
'''
52+
Score is a list
53+
'''
54+
avg_score = np.mean(score)
55+
std_score = np.std(score)
56+
max_score = np.max(score)
57+
min_score = np.min(score)
7058

71-
def add_loss(self, loss):
72-
self.losses.append(loss)
73-
total_step = np.sum(self.steps)
59+
print(colorize(f"\nEvaluation over {len(score)} episodes after {total_steps}:", 'yellow', bold=True))
60+
print(colorize(f"Avg: {avg_score:.3f} Std: {std_score:.3f} Max: {max_score:.3f} Min: {min_score:.3f}\n", 'yellow', bold=True))
7461

75-
if not self.is_learning_start:
76-
logging.warn("start learning, loss data received.")
77-
self.is_learning_start = True
78-
79-
#self.csv_file.write(str(total_step) +','+ str(loss)+'\n')
80-
#self.csv_file.flush()
81-
82-
def reset(self):
83-
self.episode_counter = 0
84-
self.step_counter = 0
85-
self.rewards = []
86-
self.losses = []
87-
88-
def finish(self):
89-
self.reset()
90-
if self.save:
91-
self.csv_file.close()
92-
logging.info(self.exp_name + " finished !")
62+
epinfo = {"mean_score": avg_score, "total_steps": total_steps, "std_score": std_score, "max_score": max_score, "min_score": max_score}
63+
self.logger.writerow(epinfo)
64+
self.csv_file.flush()

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setuptools.setup(
77
name="rl_plotter",
8-
version="2.1.0",
8+
version="2.2.0",
99
author="Gong Xiaoyu",
1010
author_email="[email protected]",
1111
description="A plotter for reinforcement learning (RL)",

0 commit comments

Comments
 (0)