forked from mlpack/mlpack
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request mlpack#1056 from ShangtongZhang/rl-policy
Add aggregated policy for async rl methods
- Loading branch information
Showing
2 changed files
with
81 additions
and
0 deletions.
There are no files selected for viewing
1 change: 1 addition & 0 deletions
1
src/mlpack/methods/reinforcement_learning/policy/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
80 changes: 80 additions & 0 deletions
80
src/mlpack/methods/reinforcement_learning/policy/aggregated_policy.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
/** | ||
* @file aggregated_policy.hpp | ||
* @author Shangtong Zhang | ||
* | ||
* This file is the implementation of AggregatedPolicy class. | ||
* An aggregated policy will randomly select a child policy under a given | ||
* distribution at each time step. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
#ifndef MLPACK_METHODS_RL_POLICY_AGGREGATED_POLICY_HPP | ||
#define MLPACK_METHODS_RL_POLICY_AGGREGATED_POLICY_HPP | ||
|
||
#include <mlpack/prereqs.hpp> | ||
#include <mlpack/core/dists/discrete_distribution.hpp> | ||
|
||
namespace mlpack { | ||
namespace rl { | ||
|
||
/** | ||
* @tparam PolicyType The type of the child policy. | ||
*/ | ||
template <typename PolicyType> | ||
class AggregatedPolicy | ||
{ | ||
public: | ||
//! Convenient typedef for action. | ||
using ActionType = typename PolicyType::ActionType; | ||
|
||
/** | ||
* @param policies Child policies. | ||
* @param distribution Probability distribution for each child policy. | ||
* User should make sure its size is same as the number of policies | ||
* and the sum of its element is equal to 1. | ||
*/ | ||
AggregatedPolicy(std::vector<PolicyType> policies, | ||
const arma::colvec& distribution) : | ||
policies(std::move(policies)), | ||
sampler({distribution}) | ||
{ /* Nothing to do here. */ }; | ||
|
||
/** | ||
* Sample an action based on given action values. | ||
* | ||
* @param actionValue Values for each action. | ||
* @param deterministic Always select the action greedily. | ||
* @return Sampled action. | ||
*/ | ||
ActionType Sample(const arma::colvec& actionValue, bool deterministic = false) | ||
{ | ||
if (deterministic) | ||
return policies.front().Sample(actionValue, true); | ||
size_t selected = arma::as_scalar(sampler.Random()); | ||
return policies[selected].Sample(actionValue, false); | ||
} | ||
|
||
/** | ||
* Exploration probability will anneal at each step. | ||
*/ | ||
void Anneal() | ||
{ | ||
for (PolicyType& policy : policies) | ||
policy.Anneal(); | ||
} | ||
|
||
private: | ||
//! Locally-stored child policies. | ||
std::vector<PolicyType> policies; | ||
|
||
//! Locally-stored sampler under the given distribution. | ||
distribution::DiscreteDistribution sampler; | ||
}; | ||
|
||
} // namespace rl | ||
} // namespace mlpack | ||
|
||
#endif |