Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update dqn_opt #25

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions examples/breakout_dqn_opt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
alg_para:
alg_name: DQNOpt
alg_config: {
'train_per_checkpoint': 50,
'prepare_times_per_train': 4,
'learning_starts': 10000,
'BUFFER_SIZE': 400000,
}

env_para:
env_name: AtariEnv
env_info: { 'name': BreakoutNoFrameskip-v4, 'vision': False}

agent_para:
agent_name: AtariDqn
agent_num : 1
agent_config: {
'max_steps': 2000,
'complete_step': 10000000,
'episode_count': 200000
}

model_para:
actor:
model_name: DqnCnnOpt
state_dim: [84,84,4]
action_dim: 4
model_config: {
'LR': 0.00015,
}

env_num: 4
benchmark:
archive_root: ./logs
162 changes: 162 additions & 0 deletions xt/algorithm/dqn/dqn_opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
"""Build DQN algorithm."""

import os
import numpy as np

from xt.algorithm import Algorithm
from xt.algorithm.dqn.default_config import BUFFER_SIZE, GAMMA, TARGET_UPDATE_FREQ, BATCH_SIZE
from xt.algorithm.replay_buffer import ReplayBuffer
from zeus.common.util.register import Registers
from xt.model import model_builder
from zeus.common.util.common import import_config

os.environ["KERAS_BACKEND"] = "tensorflow"


@Registers.algorithm
class DQNOpt(Algorithm):
"""Build Deep Q learning algorithm."""

def __init__(self, model_info, alg_config, **kwargs):
"""
Initialize DQN algorithm.

It contains four steps:
1. override the default config, with user's configuration;
2. create the default actor with Algorithm.__init__;
3. create once more actor, named by target_actor;
4. create the replay buffer for training.
:param model_info:
:param alg_config:
"""
import_config(globals(), alg_config)
model_info = model_info["actor"]
super(DQNOpt, self).__init__(
alg_name="dqn", model_info=model_info, alg_config=alg_config
)

self.target_actor = model_builder(model_info)
self.buff = ReplayBuffer(BUFFER_SIZE)
self.double_dqn = alg_config.get('double_dqn', False)

def train(self, **kwargs):
"""
Train process for DQN algorithm.

1. predict the newest state with actor & target actor;
2. calculate TD error;
3. train operation;
4. update target actor if need.
:return: loss of this train step.
"""
batch_size = BATCH_SIZE

batch = self.buff.get_batch(batch_size)
states = np.asarray([e[0] for e in batch])
actions = np.asarray([e[1] for e in batch])
rewards = np.asarray([e[2] for e in batch])
new_states = np.asarray([e[3] for e in batch])
dones = np.asarray([e[4] for e in batch])
if self.double_dqn:
y_t = self.actor.predict(states)
q_values, target_q_values = self.actor.predict(new_states, new_states)
best_action = np.argmax(q_values, 1)
# target_q_values = self.target_actor.predict(new_states)
max_q_val = target_q_values[range(len(batch)), best_action]
else:
y_t, target_q_values = self.actor.predict(states, new_states)
# target_q_values = self.target_actor.predict(new_states)
max_q_val = np.max(target_q_values, 1)

for k in range(len(batch)):
if dones[k]:
q_value = rewards[k]
else:
q_value = rewards[k] + GAMMA * max_q_val[k]
y_t[k][actions[k]] = q_value

loss = self.actor.train(states, y_t)

self.train_count += 1
if self.train_count % TARGET_UPDATE_FREQ == 0:
self.update_target()

return loss

def restore(self, model_name=None, model_weights=None):
"""
Restore model weights.

DQN will restore two model weights, actor & target.
:param model_name:
:param model_weights:
:return:
"""
if model_weights is not None:
self.actor.set_weights(model_weights)
self.target_actor.set_weights(model_weights)
else:
self.actor.load_model(model_name)
self.target_actor.load_model(model_name)

def prepare_data(self, train_data, **kwargs):
"""
Prepare the train data for DQN.

here, just add once new data into replay buffer.
:param train_data:
:return:
"""
buff = self.buff
data_len = len(train_data["done"])
for index in range(data_len):
data = (
train_data["cur_state"][index],
train_data["action"][index],
train_data["reward"][index],
train_data["next_state"][index],
train_data["done"][index],
)
buff.add(data) # Add replay buffer

def update_target(self):
"""
Synchronize the actor's weight to target.

:return:
"""
# weights = self.actor.get_weights()
# self.target_actor.set_weights(weights)
weights = self.actor.update_target()

def predict(self, state):
"""
Predict action.

The api will call the keras.model.predict as default,
if the inputs is different from the normal state,
You need overwrite this function.
"""
inputs = state.reshape((1, ) + state.shape)
out, _ = self.actor.predict(inputs, inputs)

return np.argmax(out)
160 changes: 160 additions & 0 deletions xt/model/dqn/dqn_cnn_opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
from xt.model.tf_compat import tf
from xt.model.tf_compat import Conv2D, Dense, Flatten, Input, Model, Adam, Lambda, K, MSE
from xt.model.dqn.default_config import LR
from xt.model.dqn.dqn_mlp import layer_normalize, layer_add
from xt.model import XTModel
from xt.model.tf_utils import TFVariables
from zeus.common.util.common import import_config

from zeus.common.util.register import Registers
tf.disable_eager_execution()

@Registers.model
class DqnCnnOpt(XTModel):
"""Docstring for DqnCnn."""

def __init__(self, model_info):
model_config = model_info.get('model_config', None)
import_config(globals(), model_config)

self.state_dim = model_info['state_dim']
self.action_dim = model_info['action_dim']
self.learning_rate = LR
self.dueling = model_config.get('dueling', False)

super().__init__(model_info)

def create_model(self, model_info):
self.network = self.create_actor_model(model_info)
self.target_network = self.create_actor_model(model_info)
self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)

self.v = self.network.outputs[0]
self.target_v = self.target_network.outputs[0]

self.obs = self.network.inputs[0]
self.tn_obs = self.target_network.inputs[0]

self.full_model = DQNBase(self.network, self.target_network)

self.build_graph()
# self.sess.run(tf.initialize_all_variables())
return self.full_model

def create_actor_model(self, model_info):
"""Create Deep-Q CNN network."""
state = Input(shape=self.state_dim, dtype="uint8")
state1 = Lambda(lambda x: K.cast(x, dtype='float32') / 255.)(state)
convlayer = Conv2D(32, (8, 8), strides=(4, 4), activation='relu', padding='valid')(state1)
convlayer = Conv2D(64, (4, 4), strides=(2, 2), activation='relu', padding='valid')(convlayer)
convlayer = Conv2D(64, (3, 3), strides=(1, 1), activation='relu', padding='valid')(convlayer)
flattenlayer = Flatten()(convlayer)
denselayer = Dense(256, activation='relu')(flattenlayer)
value = Dense(self.action_dim, activation='linear')(denselayer)
if self.dueling:
adv = Dense(1, activation='linear')(denselayer)
mean = Lambda(layer_normalize)(value)
value = Lambda(layer_add)([adv, mean])
model = Model(inputs=state, outputs=value)
return model

def build_train_graph(self):
self.obs = tf.placeholder(tf.uint8, name="infer_input",
shape=(None,) + tuple(self.state_dim))

self.v = self.network(self.obs)
self.true_v = tf.placeholder(tf.float32, name="true_v",
shape=self.v.shape)

loss = tf.keras.losses.mean_squared_error(self.true_v, self.v)
self.loss = loss
self.train_op = self.optimizer.minimize(loss)

def build_infer_graph(self):
self.tn_obs = tf.placeholder(tf.uint8, name="tn_input",
shape=(None,) + tuple(self.state_dim))
self.target_v = self.target_network(self.tn_obs)
self.v = self.network(self.obs)

def build_graph(self):
self.build_train_graph()
self.build_infer_graph()
self.sess.run(tf.initialize_all_variables())
self.explore_paras = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES,
scope="explore_agent")

def train(self, state, label):
"""Train the model."""
with self.graph.as_default():
K.set_session(self.sess)
feed_dict = {
self.obs: state,
self.true_v: label
}
_, loss = self.sess.run([self.train_op, self.loss], feed_dict)
# loss = self.model.train_on_batch(state, label)
return loss

def predict(self, state, tn_state):
"""
Do predict use the newest model.

:param state:
:return:
"""
with self.graph.as_default():
K.set_session(self.sess)
feed_dict = {self.obs: state, self.tn_obs: tn_state}
return self.sess.run([self.v, self.target_v], feed_dict)

def update_target(self):
try:
with self.graph.as_default():
weights = self.network.get_weights()
self.target_network.set_weights(weights)
except ValueError:
for t in self.explore_paras:
print(t.name)
raise RuntimeError("DEBUG")

def get_weights(self):
with self.graph.as_default():
return self.model.get_weights()

def set_weights(self, weights):
try:
with self.graph.as_default():
self.model.set_weights(weights)
except ValueError as e:
for t in self.explore_paras:
print(t.name)
raise RuntimeError("DEBUG")


class DQNBase(Model):
"""Model that combine the representation and prediction (value+policy) network."""

def __init__(self, actor_network: Model, target_actor_network: Model):
super().__init__()
self.actor_network = actor_network
self.target_actor_network = target_actor_network