-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval_td3.py
97 lines (77 loc) · 2.15 KB
/
eval_td3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import argparse
import os
import pandas as pd
from core.envs import make_envs
from core.td3_trainer import TD3Trainer
from vis import evaluate_in_batch
parser = argparse.ArgumentParser()
parser.add_argument(
"--log-dir",
default="data/",
type=str,
help="The directory where you want to store the data. "
"Default: ./data/"
)
parser.add_argument(
"--num-envs",
default=10,
type=int,
help="The number of parallel environments for evaluation. Default: 10"
)
parser.add_argument(
"--seed",
default=0,
type=int,
help="The random seed. Default: 0"
)
parser.add_argument(
"--num-episodes",
default=50,
type=int,
)
parser.add_argument(
"--env-id",
# === See here! We will use test environment for eval! ===
default="MetaDrive-Tut-Test-v0",
type=str,
)
if __name__ == '__main__':
args = parser.parse_args()
log_dir = args.log_dir
num_episodes = args.num_episodes
num_envs = args.num_envs
env_id = args.env_id
if "MetaDrive" in env_id:
from core.utils import register_metadrive
register_metadrive()
envs = make_envs(
env_id=env_id,
log_dir=log_dir,
num_envs=num_envs,
asynchronous=True,
)
# env = envs.envs[0]
state_dim = envs.observation_space.shape[0]
action_dim = envs.action_space.shape[0]
max_action = float(envs.action_space.high[0])
kwargs = {
"state_dim": state_dim,
"action_dim": action_dim,
"max_action": max_action,
}
trainer = TD3Trainer(**kwargs)
trainer.load(os.path.join(log_dir))
def _policy(obs):
return trainer.select_action_in_batch(obs)
eval_reward, eval_info = evaluate_in_batch(
policy=_policy,
envs=envs,
num_episodes=num_episodes,
)
df = pd.DataFrame({"rewards": eval_info["rewards"], "successes": eval_info["successes"]})
path = "{}/eval_results.csv".format(log_dir)
df.to_csv(path)
print("The average return after running {} agent for {} episodes in {} environment: {}.\n" \
"Result is saved at: {}".format(
"TD3", num_episodes, env_id, eval_reward, path
))