Skip to content

Commit

Permalink
neural_net: simulate net with random outcomes
Browse files Browse the repository at this point in the history
  • Loading branch information
Shen-En Shih committed Jan 24, 2018
1 parent 16871c4 commit c0dcf26
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 26 deletions.
8 changes: 7 additions & 1 deletion agents/include/MCTS/Config.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,16 @@ namespace mcts
public:
Config() : neural_net_path_() {}

void SetNeuralNetPath(std::string const& filename) { neural_net_path_ = filename; }
void SetNeuralNetPath(std::string const& filename, bool is_random = false) {
neural_net_path_ = filename;
neural_net_is_random_ = is_random;
}
std::string const& GetNeuralNetPath() const { return neural_net_path_; }
bool IsNeuralNetRandom() const { return neural_net_is_random_; }


private:
std::string neural_net_path_;
bool neural_net_is_random_;
};
}
4 changes: 2 additions & 2 deletions agents/include/MCTS/inspector/InteractiveShell.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ namespace mcts
controller_ = controller;
}

void SetConfig(agents::MCTSAgentConfig const& config) {
void SetConfig(agents::MCTSAgentConfig const& config, std::mt19937 & random) {
try {
state_value_func_.reset(new mcts::policy::simulation::NeuralNetworkStateValueFunction(config.mcts));
state_value_func_.reset(new mcts::policy::simulation::NeuralNetworkStateValueFunction(config.mcts, random));
}
catch (std::exception ex) {
state_value_func_.reset(nullptr);
Expand Down
13 changes: 7 additions & 6 deletions agents/include/MCTS/policy/Simulation.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@ namespace mcts
class NeuralNetworkStateValueFunction
{
public:
NeuralNetworkStateValueFunction(Config const& config)
: net_(), current_player_viewer_()
NeuralNetworkStateValueFunction(Config const& config, std::mt19937 & random)
: net_(), current_player_viewer_(), random_(random)
{
net_.Load(config.GetNeuralNetPath());
net_.Load(config.GetNeuralNetPath(), config.IsNeuralNetRandom());
}

// State value is in range [-1, 1]
Expand All @@ -153,7 +153,7 @@ namespace mcts
double GetStateValue(state::State const& state) {
current_player_viewer_.Reset(state);

double score = net_.Predict(&current_player_viewer_);
double score = net_.Predict(&current_player_viewer_, random_);

if (!state.GetCurrentPlayerId().IsFirst()) {
score = -score;
Expand Down Expand Up @@ -358,6 +358,7 @@ namespace mcts
private:
neural_net::NeuralNetwork net_;
StateDataBridge current_player_viewer_;
std::mt19937 & random_;
};

class RandomPlayoutWithHeuristicEarlyCutoffPolicy
Expand Down Expand Up @@ -389,7 +390,7 @@ namespace mcts
public:
RandomPlayoutWithHeuristicEarlyCutoffPolicy(state::PlayerSide side, std::mt19937 & rand, Config & config) :
rand_(rand),
state_value_func_(config)
state_value_func_(config, rand_)
{
}

Expand Down Expand Up @@ -446,7 +447,7 @@ namespace mcts
HeuristicPlayoutWithHeuristicEarlyCutoffPolicy(std::mt19937 & rand, Config const& config) :
rand_(rand),
decision_(), decision_idx_(0),
state_value_func_(config)
state_value_func_(config, rand_)
{
}

Expand Down
5 changes: 4 additions & 1 deletion agents/include/alphazero/trainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace alphazero
TrainerConfigs() :
threads_(2),
best_net_path_(),
best_net_is_random_(false),
competitor_net_path_(),
kTrainingDataCapacityPowerOfTwo(10),
kMinimumTraningData(0),
Expand All @@ -28,6 +29,8 @@ namespace alphazero
int threads_;

std::string best_net_path_;
bool best_net_is_random_;

std::string competitor_net_path_;

size_t kTrainingDataCapacityPowerOfTwo;
Expand Down Expand Up @@ -79,7 +82,7 @@ namespace alphazero
threads_.Initialize(configs_.threads_);
training_data_.Initialize(configs.kTrainingDataCapacityPowerOfTwo);

best_neural_net_.Load(configs.best_net_path_);
best_neural_net_.Load(configs.best_net_path_, configs.best_net_is_random_);
neural_net_.Load(configs.best_net_path_);

optimizer_.Initialize();
Expand Down
6 changes: 3 additions & 3 deletions agents/include/neural_net/NeuralNetwork.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ namespace neural_net {
static void CreateWithRandomWeights(std::string const& path);

void Save(std::string const& path) const;
void Load(std::string const& path);
void Load(std::string const& path, bool is_random = false);

void CopyFrom(NeuralNetwork const& rhs);

Expand All @@ -111,8 +111,8 @@ namespace neural_net {
NeuralNetworkInput const& input,
NeuralNetworkOutput const& output);

double Predict(IInputGetter * input);
void Predict(impl::NeuralNetworkInputImpl const& input, std::vector<double> & results);
double Predict(IInputGetter * input, std::mt19937 & random);
void Predict(impl::NeuralNetworkInputImpl const& input, std::vector<double> & results, std::mt19937 & random);

private:
impl::NeuralNetworkImpl * impl_;
Expand Down
29 changes: 19 additions & 10 deletions agents/src/neural_net/NeuralNetwork.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,9 @@ namespace neural_net {
net.save(filename);
}

void Load(std::string const& filename) {
void Load(std::string const& filename, bool is_random) {
net_.load(filename);
random_net_ = is_random;
}

void CopyFrom(NeuralNetworkImpl const& rhs) {
Expand Down Expand Up @@ -317,23 +318,31 @@ namespace neural_net {
return { correct, total };
}

void Predict(impl::NeuralNetworkInputImpl const& input, std::vector<double> & results) {
void Predict(impl::NeuralNetworkInputImpl const& input, std::vector<double> & results, std::mt19937 & random) {
auto const& input_data = input.GetData();
results.clear();
results.reserve(input_data.size());
for (size_t idx = 0; idx < input_data.size(); ++idx) {
results.push_back(net_.predict(input_data[0])[0][0]);
results.push_back(Predict(input_data[0], random));
}
}

double Predict(IInputGetter * input) {
double Predict(IInputGetter * input, std::mt19937 & random) {
tiny_dnn::tensor_t data;
impl::InputDataConverter().Convert(input, data);
return Predict(data, random);
}

double Predict(tiny_dnn::tensor_t const& data, std::mt19937 & random) {
if (random_net_) {
return std::uniform_real_distribution<double>(-1.0, 1.0)(random);
}
return net_.predict(data)[0][0];
}

private:
tiny_dnn::network<tiny_dnn::graph> net_;
bool random_net_;
};
}

Expand Down Expand Up @@ -393,12 +402,12 @@ namespace neural_net {
void NeuralNetwork::CreateWithRandomWeights(std::string const& path) {
return impl::NeuralNetworkImpl::CreateWithRandomWeights(path);
}
void NeuralNetwork::Load(std::string const& path) {
void NeuralNetwork::Load(std::string const& path, bool is_random) {
// reload neural net
delete impl_;
impl_ = new impl::NeuralNetworkImpl();

impl_->Load(path);
impl_->Load(path, is_random);
}

void NeuralNetwork::CopyFrom(NeuralNetwork const& rhs) {
Expand All @@ -423,13 +432,13 @@ namespace neural_net {
return impl_->Verify(*input.impl_, *output.impl_);
}

void NeuralNetwork::Predict(impl::NeuralNetworkInputImpl const& input, std::vector<double> & results)
void NeuralNetwork::Predict(impl::NeuralNetworkInputImpl const& input, std::vector<double> & results, std::mt19937 & random)
{
return impl_->Predict(input, results);
return impl_->Predict(input, results, random);
}

double NeuralNetwork::Predict(IInputGetter * input)
double NeuralNetwork::Predict(IInputGetter * input, std::mt19937 & random)
{
return impl_->Predict(input);
return impl_->Predict(input, random);
}
}
5 changes: 4 additions & 1 deletion agents/test/alphazero_e2e_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,15 @@ int main(void)
trainer_config.evaluation.agent_config.callback_interval_ms = 100;

trainer_config.kMinimumTraningData = trainer_config.optimizer.batch_size * 1;

neural_net::NeuralNetwork::CreateWithRandomWeights(trainer_config.best_net_path_);
trainer_config.best_net_path_ = "best_net";
trainer_config.best_net_is_random_ = true;

trainer_config.competitor_net_path_ = "competitor_net";

// create a random model
// TODO: only create a random model if best_net does not exist
neural_net::NeuralNetwork::CreateWithRandomWeights(trainer_config.best_net_path_);

std::string model_path = "";
trainer.Initialize(trainer_config, random);
Expand Down
8 changes: 6 additions & 2 deletions ui/src/UI/GameEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ namespace ui
public:
GameEngineImpl() :
logger_(),
random_(),
running_(false),
controller_(),
shell_(),
board_getter_(logger_),
config_()
{}
{
random_.seed(std::random_device()());
}

int Initialize(int root_sample_count) {
try {
Expand Down Expand Up @@ -49,7 +52,7 @@ namespace ui
config_.tree_samples = root_sample_count;
config_.mcts.SetNeuralNetPath("neural_net");

shell_.SetConfig(config_);
shell_.SetConfig(config_, random_);

return 0;
}
Expand Down Expand Up @@ -174,6 +177,7 @@ namespace ui

private:
GameEngineLogger logger_;
std::mt19937 random_;
std::atomic<bool> running_;
std::unique_ptr<agents::MCTSRunner> controller_;
mcts::inspector::InteractiveShell shell_;
Expand Down

0 comments on commit c0dcf26

Please sign in to comment.