Skip to content

Commit e88f07e

Browse files
committed
ceer
1 parent 2749bd4 commit e88f07e

15 files changed

+2017
-0
lines changed

README.md

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
## Replay Memory as An Empirical MDP: Combining Conservative Estimation with Experience Replay
2+
3+
![overview](D:/wakaka/2022/structured_memory/ICLR2023/github_version/ceer/pic/overview.svg)
4+
5+
### Overview
6+
7+
- PyTorch implementation of Conservative Estimation with Experience Replay ([CEER](https://openreview.net/forum?id=SjzFVSJUt8S)).
8+
9+
- Method is tested on [Sokoban](https://github.com/mpSchrader/gym-sokoban), [Minigrid](https://github.com/Farama-Foundation/Minigrid) and [MinAtar](https://github.com/kenjyoung/MinAtar) environments.
10+
11+
### Installation
12+
```
13+
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
14+
pip install -r requirements.txt
15+
```
16+
- My Python version is 3.7.11. CUDA version is 11.4.
17+
18+
### Running Experiments
19+
20+
```
21+
python ceer/main.py
22+
```
23+
- Modify `atari_name_list` in `ceer/arguments.py` for different environments.
24+
25+
- For example, `'atari_name_list': ['Sokoban-Push_5x5_1_120']`.
26+
27+
- Other parameters like `sample_method_para # alpha`,`policy_loss_para # lambda` are also in `ceer/arguments.py`.
28+
29+
### Bibtex
30+
```
31+
@inproceedings{
32+
zhang2023replay,
33+
title={Replay Memory as An Empirical {MDP}: Combining Conservative Estimation with Experience Replay},
34+
author={Hongming Zhang and Chenjun Xiao and Han Wang and Jun Jin and bo xu and Martin M{\"u}ller},
35+
booktitle={The Eleventh International Conference on Learning Representations },
36+
year={2023},
37+
url={https://openreview.net/forum?id=SjzFVSJUt8S}
38+
}
39+
```
40+
41+
### Acknowledgments
42+
43+
- Awesome Environments used for testing:
44+
45+
Sokoban: https://github.com/mpSchrader/gym-sokoban
46+
47+
Minigrid: https://github.com/Farama-Foundation/Minigrid
48+
49+
MinAtar: https://github.com/kenjyoung/MinAtar
50+
51+
52+
- Some baselines can be found in following works:
53+
54+
TER: https://openreview.net/forum?id=OXRZeMmOI7a
55+
56+
Dreamerv2: https://github.com/RajGhugare19/dreamerv2
57+
58+
Tianshou: https://github.com/thu-ml/tianshou

agents.py

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import copy
2+
import time
3+
4+
import utils
5+
from env_wrappers import *
6+
from collections import deque
7+
import numpy as np
8+
import torch
9+
import torch.nn.functional as F
10+
import hashlib
11+
import pickle
12+
from rl_algorithms import TD
13+
from schedules import LinearSchedule
14+
from buffers import BatchBuffer,Graph_buffer
15+
16+
class DQN_Agent():
17+
def __init__(self,env,net,args_dict):
18+
self.game_env = env
19+
self.args_dict = args_dict
20+
self.action_space = self.game_env.action_space
21+
self.action_space_set = set(range(self.action_space))
22+
self.state_space = self.game_env.observation_space
23+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24+
# self.device = "cpu"
25+
self.net = net(self.action_space,self.state_space,args_dict['atari_name']).to(self.device)
26+
self.target_net = net(self.action_space,self.state_space,args_dict['atari_name']).to(self.device)
27+
self.target_net.load_state_dict(self.net.state_dict())
28+
29+
self.exploration_decay = LinearSchedule(schedule_timesteps=args_dict['decay_step'],final_p=args_dict['exploration_final_eps'],initial_p=args_dict['exploration_initial_eps'])
30+
self.lr_decay = LinearSchedule(schedule_timesteps=args_dict['final_step'], final_p=0.)
31+
self.update = TD(self.net,self.target_net, self.lr_decay,self.device,args_dict)
32+
33+
self.graph_buffer = Graph_buffer(args_dict,action_space=self.action_space) # CEER
34+
self.batch_buffer = BatchBuffer(args_dict) # DQN
35+
self.current_episode = [[] for _ in range(args_dict['number_env'])]
36+
37+
self.max_q_mean = 0
38+
self.all_q_mean = 0
39+
self.density = 0
40+
41+
def save_model(self,path):
42+
torch.save(self.net.state_dict(), path)
43+
# torch.save(self.net.node_dict(),path,_use_new_zipfile_serialization=False)
44+
45+
def load_model(self,path):
46+
self.net.load_state_dict(torch.load(path))
47+
48+
def act(self,states,rewards,dones,infos,train,current_step):
49+
# print([(s.dtype,s.shape) for s in states])
50+
states_tensor = torch.from_numpy(np.array(states)).to(self.device).float()
51+
# print(states)
52+
# print(states_tensor.shape)
53+
with torch.no_grad():
54+
q_values = self.net(states_tensor)
55+
q_values = q_values.detach().cpu().numpy()
56+
# print('q_values :',q_values.shape)
57+
actions = []
58+
if train:
59+
epsilon = self.exploration_decay.value(current_step)
60+
exploration_list = np.random.random(self.args_dict['number_env']) < epsilon
61+
for i in range(self.args_dict['number_env']):
62+
# print('number :', i)
63+
# print(args.number_env, q_values.shape, q_values[i], states_tensor.shape)
64+
if exploration_list[i]:
65+
actions.append(np.random.randint(self.action_space))
66+
else:
67+
actions.append(np.argmax(q_values[i]))
68+
69+
self.train(states,actions,rewards,dones,infos,current_step)
70+
else:
71+
exploration_list = np.random.random(self.args_dict.number_env) < 0.01 # 0.05
72+
for i in range(self.args_dict['number_env']):
73+
if exploration_list[i]:
74+
actions.append(np.random.randint(self.action_space))
75+
else:
76+
actions.append(np.argmax(q_values[i]))
77+
# actions = np.argmax(q_values,axis=1)
78+
# print(q_values)
79+
return actions
80+
81+
def train(self,states,actions,rewards,dones,infos,current_step):
82+
if self.args_dict.sample_method != 'uniform':
83+
if rewards is None:
84+
self.s_t = states
85+
self.a_t = actions
86+
else:
87+
s_t_key_list = []
88+
for i in range(self.args_dict['number_env']):
89+
if dones[i]:
90+
s_t_key = hashlib.md5(pickle.dumps(self.s_t[i])).hexdigest() + str(False)
91+
s_t1_key = hashlib.md5(pickle.dumps(infos[i]['terminal_state'])).hexdigest()+str(True)
92+
self.graph_buffer.add_data(self.s_t[i], self.a_t[i], rewards[i],dones[i],
93+
infos[i]['terminal_state'],s_t_key,s_t1_key)
94+
self.current_episode[i].reverse()
95+
self.graph_buffer.update_node(self.args_dict.batch_size,self.current_episode[i])
96+
self.current_episode[i] = []
97+
else:
98+
s_t_key = hashlib.md5(pickle.dumps(self.s_t[i])).hexdigest()+str(False)
99+
s_t1_key = hashlib.md5(pickle.dumps(states[i])).hexdigest()+str(False)
100+
self.graph_buffer.add_data(self.s_t[i], self.a_t[i], rewards[i],dones[i],states[i],
101+
s_t_key,s_t1_key)
102+
self.current_episode[i].append(s_t_key)
103+
s_t_key_list.append(s_t_key)
104+
105+
self.s_t = states
106+
self.a_t = actions
107+
else:
108+
if rewards is None and dones is None:
109+
for i in range(self.batch_buffer.buffer_num):
110+
self.batch_buffer.buffer_list[i].add_data(state_t=states[i],action_t=actions[i])
111+
else:
112+
for i in range(self.batch_buffer.buffer_num):
113+
self.batch_buffer.buffer_list[i].add_data(
114+
state_t=states[i],
115+
action_t=actions[i],
116+
reward_t=rewards[i],
117+
terminal_t=dones[i])
118+
119+
if current_step % self.args_dict['target_update_interval'] == 0:
120+
self.target_net.load_state_dict(self.net.state_dict())
121+
122+
if current_step >= self.args_dict['learning_starts']:
123+
# print(np.shape(self.batch_buffer.buffer_list))
124+
if self.args_dict['sample_method'] != 'uniform':
125+
for _ in range(self.args_dict['batch_num']):
126+
s_t, a_t, r_t, t_t, s_t1, target_q_t, updated_t1,\
127+
all_target_q_t,not_exist_action_value = self.graph_buffer.sample_batch(self.args_dict.batch_size)
128+
129+
s_t, one_hot_a_t, index, r_t, t_t, s_t1 = self.update.np2torch(
130+
self.args_dict.batch_size, self.action_space, s_t, a_t, r_t, t_t, s_t1)
131+
132+
max_q_mean,all_q_mean,density = self.update.learn(self.args_dict.sample_method,
133+
self.graph_buffer,self.args_dict.batch_size,self.action_space,
134+
s_t, one_hot_a_t, r_t, t_t, s_t1,target_q_t,updated_t1,
135+
all_target_q_t,not_exist_action_value,self.args_dict.policy_loss_para)
136+
137+
self.max_q_mean = max_q_mean
138+
self.all_q_mean = all_q_mean
139+
self.density = density
140+
else:
141+
for _ in range(self.args_dict.batch_num):
142+
n = int(self.args_dict.batch_size / self.args_dict.number_env)
143+
s_t, a_t, r_t, t_t, s_t1 = self.batch_buffer.sample_batch(current_step,n)
144+
# print('state:',s_t)
145+
s_t,one_hot_a_t,index,r_t,t_t,s_t1 = self.update.np2torch(
146+
self.args_dict.batch_size,self.action_space,s_t, a_t, r_t, t_t, s_t1)
147+
self.update.learn(self.args_dict.sample_method,None,self.args_dict.batch_size,self.action_space, s_t,one_hot_a_t, r_t, t_t, s_t1)
148+
# print('data shape:', s_t.shape, a_t.shape, ret.shape, v.shape,logp.shape, adv.shape)
149+
# print('data type:', s.dtype, a.dtype, ret.dtype, v.dtype,logp.dtype, adv.dtype)

arguments.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# import pynvml
2+
# pynvml.nvmlInit()
3+
4+
# gpu_num = pynvml.nvmlDeviceGetCount()
5+
# if gpu_num:
6+
# CUDA_VISIBLE_DEVICES = gpu_num - 1 # use the last gpu
7+
# print('CUDA_VISIBLE_DEVICES:', CUDA_VISIBLE_DEVICES)
8+
CUDA_VISIBLE_DEVICES=0
9+
number_env = 8
10+
11+
para_list_dict = {
12+
'atari_name_list': ['Sokoban-Push_5x5_1_120'],
13+
# 'atari_name_list': ['MinAtar/Asterix-v0','MinAtar/Breakout-v0','MinAtar/Freeway-v0','MinAtar/Seaquest-v0','MinAtar/SpaceInvaders-v0',
14+
# 'MiniGrid-DoorKey-6x6-v0','MiniGrid-Unlock-v0','MiniGrid-RedBlueDoors-6x6-v0','MiniGrid-SimpleCrossingS9N1-v0',
15+
# 'MiniGrid-SimpleCrossingS9N2-v0','MiniGrid-LavaCrossingS9N1-v0','MiniGrid-LavaCrossingS9N2-v0',
16+
# 'Sokoban-Push_5x5_1_120','Sokoban-Push_6x6_1_120','Sokoban-Push_7x7_1_120','Sokoban-Push_6x6_3_120',
17+
# 'Sokoban-Push_5x5_2_120','Sokoban-Push_6x6_2_120','Sokoban-Push_7x7_2_120'],
18+
'network_type_list': ['large'], # 'larger','large','medium','small','mlp'
19+
'seed_list': [0], # list(range(21))
20+
'exploration_final_eps_list': [0.01], # 0.1
21+
'batch_size_list': [32], # 32, 64, 128,256,512
22+
'batch_num_list': [2], # replay ratio, 0.25, int(number_env*0.25)
23+
'double_dqn_list': [False], # True, False
24+
'update_time_list': [1], # number of updates for each batch
25+
'sample_method_list': ['kl'], # uniform, kl
26+
'sample_method_para_list':[0.], # [0.,0.2,0.5,0.8,1.]
27+
'policy_loss_list':[True], # True, False
28+
'policy_loss_para_list':[0.], # [0.,0.01,0.1,1.,2.,5.]
29+
'tau_list': [1.],# temperature for softmax 1.,0.1,0.01
30+
}
31+
32+
final_step = 2e6 # 2e6
33+
learning_rate = 1e-4 # 3e-3,1e-3,3e-4,1e-4,3e-5,1e-5
34+
buffer_size = int(final_step/20) # int(5e4) # int(final_step/20) # int(1e5) # 1_000_000,1e6,1e5 80 int(1e5/number_env)
35+
learning_starts = final_step*0.005 # 100 10000
36+
gamma = 0.99
37+
target_update_interval= 1000
38+
decay_step = final_step/2 # fraction of entire training period over which the exploration rate is reduced
39+
exploration_initial_eps = 1.0
40+
max_grad_norm = 10.
41+
test_num = 100
42+
43+
FullyObs_minigrid = True
44+
deterministic = False
45+
fix_difficulty = False

0 commit comments

Comments
 (0)