Skip to content

Commit

Permalink
Merge pull request mlpack#1056 from ShangtongZhang/rl-policy
Browse files Browse the repository at this point in the history
Add aggregated policy for async rl methods
  • Loading branch information
rcurtin authored Jul 19, 2017
2 parents 251a6a9 + 658492f commit 923df37
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Define the files we need to compile
# Anything not in this list will not be compiled into mlpack.
set(SOURCES
aggregated_policy.hpp
greedy_policy.hpp
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/**
* @file aggregated_policy.hpp
* @author Shangtong Zhang
*
* This file is the implementation of AggregatedPolicy class.
* An aggregated policy will randomly select a child policy under a given
* distribution at each time step.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_METHODS_RL_POLICY_AGGREGATED_POLICY_HPP
#define MLPACK_METHODS_RL_POLICY_AGGREGATED_POLICY_HPP

#include <mlpack/prereqs.hpp>
#include <mlpack/core/dists/discrete_distribution.hpp>

namespace mlpack {
namespace rl {

/**
* @tparam PolicyType The type of the child policy.
*/
template <typename PolicyType>
class AggregatedPolicy
{
public:
//! Convenient typedef for action.
using ActionType = typename PolicyType::ActionType;

/**
* @param policies Child policies.
* @param distribution Probability distribution for each child policy.
* User should make sure its size is same as the number of policies
* and the sum of its element is equal to 1.
*/
AggregatedPolicy(std::vector<PolicyType> policies,
const arma::colvec& distribution) :
policies(std::move(policies)),
sampler({distribution})
{ /* Nothing to do here. */ };

/**
* Sample an action based on given action values.
*
* @param actionValue Values for each action.
* @param deterministic Always select the action greedily.
* @return Sampled action.
*/
ActionType Sample(const arma::colvec& actionValue, bool deterministic = false)
{
if (deterministic)
return policies.front().Sample(actionValue, true);
size_t selected = arma::as_scalar(sampler.Random());
return policies[selected].Sample(actionValue, false);
}

/**
* Exploration probability will anneal at each step.
*/
void Anneal()
{
for (PolicyType& policy : policies)
policy.Anneal();
}

private:
//! Locally-stored child policies.
std::vector<PolicyType> policies;

//! Locally-stored sampler under the given distribution.
distribution::DiscreteDistribution sampler;
};

} // namespace rl
} // namespace mlpack

#endif

0 comments on commit 923df37

Please sign in to comment.