forked from mlpack/mlpack
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Classical control tasks, Mountain Car and Cart Pole
- Loading branch information
1 parent
a41cda1
commit e981050
Showing
4 changed files
with
354 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,5 @@ xcode* | |
.DS_Store | ||
src/mlpack/core/util/gitversion.hpp | ||
src/mlpack/core/util/arma_config.hpp | ||
.idea | ||
cmake-build-* |
15 changes: 15 additions & 0 deletions
15
src/mlpack/methods/reinforcement_learning/environment/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Define the files we need to compile | ||
# Anything not in this list will not be compiled into mlpack. | ||
set(SOURCES | ||
mountain_car.hpp | ||
cart_pole.hpp | ||
) | ||
|
||
# Add directory name to sources. | ||
set(DIR_SRCS) | ||
foreach(file ${SOURCES}) | ||
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) | ||
endforeach() | ||
# Append sources (with directory name) to list of all mlpack sources (used at | ||
# the parent scope). | ||
set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE) |
182 changes: 182 additions & 0 deletions
182
src/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
/** | ||
* @file cart_pole.hpp | ||
* @author Shangtong Zhang | ||
* | ||
* This file is an implementation of Cart Pole task | ||
* https://gym.openai.com/envs/CartPole-v0 | ||
* | ||
* TODO: refactor to OpenAI interface | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
|
||
#ifndef MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP | ||
#define MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP | ||
|
||
#include <mlpack/prereqs.hpp> | ||
|
||
namespace mlpack { | ||
namespace rl { | ||
|
||
namespace cart_pole_details { | ||
// Some constants of Cart Pole task | ||
constexpr double gravity = 9.8; | ||
constexpr double massCart = 1.0; | ||
constexpr double massPole = 0.1; | ||
constexpr double totalMass = massCart + massPole; | ||
constexpr double length = 0.5; | ||
constexpr double poleMassLength = massPole * length; | ||
constexpr double forceMag = 10.0; | ||
constexpr double tau = 0.02; | ||
constexpr double thetaThresholdRadians = 12 * 2 * 3.1416 / 360; | ||
constexpr double xThreshold = 2.4; | ||
} | ||
|
||
/** | ||
* Implementation of Cart Pole task | ||
*/ | ||
class CartPole | ||
{ | ||
public: | ||
|
||
/** | ||
* Implementation of state of Cart Pole | ||
* Each state is a tuple of (position, velocity, angle, angular velocity) | ||
*/ | ||
class State | ||
{ | ||
public: | ||
//! Construct a state instance | ||
State() : data(4) { } | ||
|
||
//! Construct a state instance from given data | ||
State(arma::colvec data) : data(data) { } | ||
|
||
//! Get position | ||
double X() const | ||
{ | ||
return data[0]; | ||
} | ||
|
||
//! Modify position | ||
double& X() | ||
{ | ||
return data[0]; | ||
} | ||
|
||
//! Get velocity | ||
double XDot() const | ||
{ | ||
return data[1]; | ||
} | ||
|
||
//! Modify velocity | ||
double& XDot() | ||
{ | ||
return data[1]; | ||
} | ||
|
||
//! Get angle | ||
double Theta() const | ||
{ | ||
return data[2]; | ||
} | ||
|
||
//! Modify angle | ||
double& Theta() | ||
{ | ||
return data[2]; | ||
} | ||
|
||
//! Get angular velocity | ||
double ThetaDot() const | ||
{ | ||
return data[3]; | ||
} | ||
|
||
//! Modify angular velocity | ||
double& ThetaDot() | ||
{ | ||
return data[3]; | ||
} | ||
|
||
//! Encode the state to a column vector | ||
const arma::colvec& Encode() const | ||
{ | ||
return data; | ||
} | ||
|
||
//! Whether current state is terminal state | ||
bool IsTerminal() const | ||
{ | ||
using namespace cart_pole_details; | ||
return std::abs(X()) > xThreshold || | ||
std::abs(Theta()) > thetaThresholdRadians; | ||
} | ||
|
||
private: | ||
//! Locally-stored (position, velocity, angle, angular velocity) | ||
arma::colvec data; | ||
}; | ||
|
||
/** | ||
* Implementation of action of Cart Pole | ||
*/ | ||
class Action | ||
{ | ||
public: | ||
enum Actions | ||
{ | ||
backward, | ||
forward | ||
}; | ||
|
||
//! # of actions | ||
static constexpr size_t count = 2; | ||
}; | ||
|
||
/** | ||
* Dynamics of Cart Pole | ||
* Get next state and next action based on current state and current action | ||
* @param state Current state | ||
* @param action Current action | ||
* @param nextState Next state | ||
* @param reward Reward is always 1 | ||
*/ | ||
void Sample(const State& state, const Action::Actions& action, | ||
State& nextState, double& reward) | ||
{ | ||
using namespace cart_pole_details; | ||
double force = action ? forceMag : -forceMag; | ||
double cosTheta = std::cos(state.Theta()); | ||
double sinTheta = std::sin(state.Theta()); | ||
double temp = (force + poleMassLength * state.ThetaDot() * state.ThetaDot() * sinTheta) / totalMass; | ||
double thetaAcc = (gravity * sinTheta - cosTheta * temp) / | ||
(length * (4.0 / 3.0 - massPole * cosTheta * cosTheta / totalMass)); | ||
double xAcc = temp - poleMassLength * thetaAcc * cosTheta / totalMass; | ||
nextState.X() = state.X() + tau * state.XDot(); | ||
nextState.XDot() = state.XDot() + tau * xAcc; | ||
nextState.Theta() = state.Theta() + tau * state.ThetaDot(); | ||
nextState.ThetaDot() = state.ThetaDot() + tau * thetaAcc; | ||
|
||
reward = 1.0; | ||
} | ||
|
||
/** | ||
* Initial state representation is randomly generated within [-0.05, 0.05] | ||
* @return Initial state for each episode | ||
*/ | ||
State InitialSample() | ||
{ | ||
return State((arma::randu<arma::colvec>(4) - 0.5) / 10.0); | ||
} | ||
|
||
}; | ||
|
||
} // namespace rl | ||
} // namespace mlpack | ||
|
||
#endif |
155 changes: 155 additions & 0 deletions
155
src/mlpack/methods/reinforcement_learning/environment/mountain_car.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
/** | ||
* @file mountain_car.hpp | ||
* @author Shangtong Zhang | ||
* | ||
* This file is an implementation of Mountain Car task | ||
* https://gym.openai.com/envs/MountainCar-v0 | ||
* | ||
* TODO: refactor to OpenAI interface | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
|
||
#ifndef MLPACK_METHODS_RL_ENVIRONMENT_MOUNTAIN_CAR_HPP | ||
#define MLPACK_METHODS_RL_ENVIRONMENT_MOUNTAIN_CAR_HPP | ||
|
||
#include <mlpack/prereqs.hpp> | ||
|
||
namespace mlpack { | ||
namespace rl { | ||
|
||
namespace mountain_car_details { | ||
constexpr double positionMin = -1.2; | ||
constexpr double positionMax = 0.5; | ||
constexpr double velocityMin = -0.07; | ||
constexpr double velocityMax = 0.07; | ||
} | ||
|
||
/** | ||
* Implementation of Mountain Car task | ||
*/ | ||
class MountainCar | ||
{ | ||
public: | ||
|
||
/** | ||
* Implementation of state of Mountain Car | ||
* Each state is a (velocity, position) pair | ||
*/ | ||
class State | ||
{ | ||
public: | ||
//! Construct a state instance | ||
State(double velocity = 0, double position = 0) : data(2) | ||
{ | ||
this->Velocity() = velocity; | ||
this->Position() = position; | ||
} | ||
|
||
//! Encode the state to a column vector | ||
const arma::colvec& Encode() const | ||
{ | ||
return data; | ||
} | ||
|
||
//! Get velocity | ||
double Velocity() const | ||
{ | ||
return data[0]; | ||
} | ||
|
||
//! Modify velocity | ||
double& Velocity() | ||
{ | ||
return data[0]; | ||
} | ||
|
||
//! Get position | ||
double Position() const | ||
{ | ||
return data[1]; | ||
} | ||
|
||
//! Modify position | ||
double& Position() | ||
{ | ||
return data[1]; | ||
} | ||
|
||
//! Whether current state is terminal state | ||
bool IsTerminal() const | ||
{ | ||
using namespace mountain_car_details; | ||
return std::abs(Position() - positionMax) <= 1e-5; | ||
} | ||
|
||
private: | ||
//! Locally-stored velocity and position | ||
arma::colvec data; | ||
}; | ||
|
||
/** | ||
* Implementation of action of Mountain Car | ||
*/ | ||
class Action | ||
{ | ||
public: | ||
enum Actions | ||
{ | ||
backward, | ||
stop, | ||
forward | ||
}; | ||
|
||
//! # of actions | ||
static constexpr size_t count = 3; | ||
}; | ||
|
||
/** | ||
* Dynamics of Mountain Car | ||
* Get next state and next action based on current state and current action | ||
* @param state Current state | ||
* @param action Current action | ||
* @param nextState Next state | ||
* @param reward Reward is always -1 | ||
*/ | ||
void Sample(const State& state, const Action::Actions& action, | ||
State& nextState, double& reward) | ||
{ | ||
using namespace mountain_car_details; | ||
int direction = action - 1; | ||
nextState.Velocity() = state.Velocity() + 0.001 * direction - 0.0025 * std::cos(3 * state.Position()); | ||
nextState.Velocity() = std::min(std::max(nextState.Velocity(), velocityMin), velocityMax); | ||
|
||
nextState.Position() = state.Position() + nextState.Velocity(); | ||
nextState.Position() = std::min(std::max(nextState.Position(), positionMin), positionMax); | ||
|
||
reward = -1.0; | ||
if (std::abs(nextState.Position() - positionMin) <= 1e-5) | ||
{ | ||
nextState.Velocity() = 0.0; | ||
} | ||
} | ||
|
||
/** | ||
* Initial position is randomly generated within [-0.6, -0.4] | ||
* Initial velocity is 0 | ||
* @return Initial state for each episode | ||
*/ | ||
State InitialSample() | ||
{ | ||
State state; | ||
state.Velocity() = 0.0; | ||
state.Position() = arma::as_scalar(arma::randu(1)) * 0.2 - 0.6; | ||
return state; | ||
} | ||
|
||
}; | ||
|
||
} // namespace rl | ||
} // namespace mlpack | ||
|
||
#endif |