-
Notifications
You must be signed in to change notification settings - Fork 1
/
play_dqn.py
76 lines (56 loc) · 2.24 KB
/
play_dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#!/usr/bin/env python3
import gym
import ptan
import argparse
import numpy as np
import time
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
from tensorboardX import SummaryWriter
from lib import dqn_model, common, rainbow_model
dump_images = True
dump_directory = "screenshots/"
if __name__ == "__main__":
saves_filename = "pacman_19000000.dat"
step_count = int(saves_filename.split("_")[1].split(".")[0])
params = common.HYPERPARAMS['pacman']
parser = argparse.ArgumentParser()
parser.add_argument("--cuda", default=False, action="store_true", help="Enable cuda")
args = parser.parse_args()
device = torch.device("cuda" if args.cuda else "cpu")
wrapper = params["env_wrapper_test"]
env = gym.make(params['env_name'])
env = wrapper(env)
#net = dqn_model.DQN(env.observation_space.shape, env.action_space.n).to(device)
net = rainbow_model.RainbowDQN(env.observation_space.shape, env.action_space.n).to(device)
frame_idx = 0
#Loads saved net
net.load_state_dict(torch.load(params["save_dir"] + saves_filename))
net.eval()
game_scores = []
current_game_score = 0
obs = env.reset()
done = False
with torch.no_grad():
while True:
if np.random.randint(0, 10000000000000000, 1) == -1:
plt.imshow(np.asarray(obs)[0,:,:], cmap="gray")
plt.savefig("{}{}_{}.png".format(dump_directory, params["run_name"], frame_idx))
frame_idx += 1
tensor = torch.tensor(np.expand_dims(np.asarray(obs), axis=0)).to(device)
q_values = net.qvals(tensor).cpu().data.numpy()[0]
#print(np.max(q_values))
#print(sorted(q_values))
action = np.argmax(q_values)
obs, reward, done, info = env.step(action)
current_game_score += reward
if done:
game_scores.append(current_game_score)
current_game_score = 0
score_std = np.std(game_scores)
score_mean = np.mean(game_scores)
print("Games played, mean, std:\t{}\t{}\t{}".format(len(game_scores), score_mean, score_std))
obs = env.reset()
env.render()
time.sleep(1/120)