From e981050fd9a7199a4b4cc3226982fafdead50dbc Mon Sep 17 00:00:00 2001 From: Shangtong Zhang Date: Mon, 24 Apr 2017 20:04:18 -0600 Subject: [PATCH] Classical control tasks, Mountain Car and Cart Pole --- .gitignore | 2 + .../environment/CMakeLists.txt | 15 ++ .../environment/cart_pole.hpp | 182 ++++++++++++++++++ .../environment/mountain_car.hpp | 155 +++++++++++++++ 4 files changed, 354 insertions(+) create mode 100644 src/mlpack/methods/reinforcement_learning/environment/CMakeLists.txt create mode 100644 src/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp create mode 100644 src/mlpack/methods/reinforcement_learning/environment/mountain_car.hpp diff --git a/.gitignore b/.gitignore index 6a6e0b36ecd..826abc20d12 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ xcode* .DS_Store src/mlpack/core/util/gitversion.hpp src/mlpack/core/util/arma_config.hpp +.idea +cmake-build-* \ No newline at end of file diff --git a/src/mlpack/methods/reinforcement_learning/environment/CMakeLists.txt b/src/mlpack/methods/reinforcement_learning/environment/CMakeLists.txt new file mode 100644 index 00000000000..58baa7e14d3 --- /dev/null +++ b/src/mlpack/methods/reinforcement_learning/environment/CMakeLists.txt @@ -0,0 +1,15 @@ +# Define the files we need to compile +# Anything not in this list will not be compiled into mlpack. +set(SOURCES + mountain_car.hpp + cart_pole.hpp +) + +# Add directory name to sources. +set(DIR_SRCS) +foreach(file ${SOURCES}) + set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) +endforeach() +# Append sources (with directory name) to list of all mlpack sources (used at +# the parent scope). +set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE) diff --git a/src/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp b/src/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp new file mode 100644 index 00000000000..355b9f08b91 --- /dev/null +++ b/src/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp @@ -0,0 +1,182 @@ +/** + * @file cart_pole.hpp + * @author Shangtong Zhang + * + * This file is an implementation of Cart Pole task + * https://gym.openai.com/envs/CartPole-v0 + * + * TODO: refactor to OpenAI interface + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ + +#ifndef MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP +#define MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP + +#include + +namespace mlpack { +namespace rl { + +namespace cart_pole_details { +// Some constants of Cart Pole task +constexpr double gravity = 9.8; +constexpr double massCart = 1.0; +constexpr double massPole = 0.1; +constexpr double totalMass = massCart + massPole; +constexpr double length = 0.5; +constexpr double poleMassLength = massPole * length; +constexpr double forceMag = 10.0; +constexpr double tau = 0.02; +constexpr double thetaThresholdRadians = 12 * 2 * 3.1416 / 360; +constexpr double xThreshold = 2.4; +} + +/** + * Implementation of Cart Pole task + */ +class CartPole +{ + public: + + /** + * Implementation of state of Cart Pole + * Each state is a tuple of (position, velocity, angle, angular velocity) + */ + class State + { + public: + //! Construct a state instance + State() : data(4) { } + + //! Construct a state instance from given data + State(arma::colvec data) : data(data) { } + + //! Get position + double X() const + { + return data[0]; + } + + //! Modify position + double& X() + { + return data[0]; + } + + //! Get velocity + double XDot() const + { + return data[1]; + } + + //! Modify velocity + double& XDot() + { + return data[1]; + } + + //! Get angle + double Theta() const + { + return data[2]; + } + + //! Modify angle + double& Theta() + { + return data[2]; + } + + //! Get angular velocity + double ThetaDot() const + { + return data[3]; + } + + //! Modify angular velocity + double& ThetaDot() + { + return data[3]; + } + + //! Encode the state to a column vector + const arma::colvec& Encode() const + { + return data; + } + + //! Whether current state is terminal state + bool IsTerminal() const + { + using namespace cart_pole_details; + return std::abs(X()) > xThreshold || + std::abs(Theta()) > thetaThresholdRadians; + } + + private: + //! Locally-stored (position, velocity, angle, angular velocity) + arma::colvec data; + }; + + /** + * Implementation of action of Cart Pole + */ + class Action + { + public: + enum Actions + { + backward, + forward + }; + + //! # of actions + static constexpr size_t count = 2; + }; + + /** + * Dynamics of Cart Pole + * Get next state and next action based on current state and current action + * @param state Current state + * @param action Current action + * @param nextState Next state + * @param reward Reward is always 1 + */ + void Sample(const State& state, const Action::Actions& action, + State& nextState, double& reward) + { + using namespace cart_pole_details; + double force = action ? forceMag : -forceMag; + double cosTheta = std::cos(state.Theta()); + double sinTheta = std::sin(state.Theta()); + double temp = (force + poleMassLength * state.ThetaDot() * state.ThetaDot() * sinTheta) / totalMass; + double thetaAcc = (gravity * sinTheta - cosTheta * temp) / + (length * (4.0 / 3.0 - massPole * cosTheta * cosTheta / totalMass)); + double xAcc = temp - poleMassLength * thetaAcc * cosTheta / totalMass; + nextState.X() = state.X() + tau * state.XDot(); + nextState.XDot() = state.XDot() + tau * xAcc; + nextState.Theta() = state.Theta() + tau * state.ThetaDot(); + nextState.ThetaDot() = state.ThetaDot() + tau * thetaAcc; + + reward = 1.0; + } + + /** + * Initial state representation is randomly generated within [-0.05, 0.05] + * @return Initial state for each episode + */ + State InitialSample() + { + return State((arma::randu(4) - 0.5) / 10.0); + } + +}; + +} // namespace rl +} // namespace mlpack + +#endif \ No newline at end of file diff --git a/src/mlpack/methods/reinforcement_learning/environment/mountain_car.hpp b/src/mlpack/methods/reinforcement_learning/environment/mountain_car.hpp new file mode 100644 index 00000000000..8575ea3bbd6 --- /dev/null +++ b/src/mlpack/methods/reinforcement_learning/environment/mountain_car.hpp @@ -0,0 +1,155 @@ +/** + * @file mountain_car.hpp + * @author Shangtong Zhang + * + * This file is an implementation of Mountain Car task + * https://gym.openai.com/envs/MountainCar-v0 + * + * TODO: refactor to OpenAI interface + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ + +#ifndef MLPACK_METHODS_RL_ENVIRONMENT_MOUNTAIN_CAR_HPP +#define MLPACK_METHODS_RL_ENVIRONMENT_MOUNTAIN_CAR_HPP + +#include + +namespace mlpack { +namespace rl { + +namespace mountain_car_details { +constexpr double positionMin = -1.2; +constexpr double positionMax = 0.5; +constexpr double velocityMin = -0.07; +constexpr double velocityMax = 0.07; +} + +/** + * Implementation of Mountain Car task + */ +class MountainCar +{ + public: + + /** + * Implementation of state of Mountain Car + * Each state is a (velocity, position) pair + */ + class State + { + public: + //! Construct a state instance + State(double velocity = 0, double position = 0) : data(2) + { + this->Velocity() = velocity; + this->Position() = position; + } + + //! Encode the state to a column vector + const arma::colvec& Encode() const + { + return data; + } + + //! Get velocity + double Velocity() const + { + return data[0]; + } + + //! Modify velocity + double& Velocity() + { + return data[0]; + } + + //! Get position + double Position() const + { + return data[1]; + } + + //! Modify position + double& Position() + { + return data[1]; + } + + //! Whether current state is terminal state + bool IsTerminal() const + { + using namespace mountain_car_details; + return std::abs(Position() - positionMax) <= 1e-5; + } + + private: + //! Locally-stored velocity and position + arma::colvec data; + }; + + /** + * Implementation of action of Mountain Car + */ + class Action + { + public: + enum Actions + { + backward, + stop, + forward + }; + + //! # of actions + static constexpr size_t count = 3; + }; + + /** + * Dynamics of Mountain Car + * Get next state and next action based on current state and current action + * @param state Current state + * @param action Current action + * @param nextState Next state + * @param reward Reward is always -1 + */ + void Sample(const State& state, const Action::Actions& action, + State& nextState, double& reward) + { + using namespace mountain_car_details; + int direction = action - 1; + nextState.Velocity() = state.Velocity() + 0.001 * direction - 0.0025 * std::cos(3 * state.Position()); + nextState.Velocity() = std::min(std::max(nextState.Velocity(), velocityMin), velocityMax); + + nextState.Position() = state.Position() + nextState.Velocity(); + nextState.Position() = std::min(std::max(nextState.Position(), positionMin), positionMax); + + reward = -1.0; + if (std::abs(nextState.Position() - positionMin) <= 1e-5) + { + nextState.Velocity() = 0.0; + } + } + + /** + * Initial position is randomly generated within [-0.6, -0.4] + * Initial velocity is 0 + * @return Initial state for each episode + */ + State InitialSample() + { + State state; + state.Velocity() = 0.0; + state.Position() = arma::as_scalar(arma::randu(1)) * 0.2 - 0.6; + return state; + } + +}; + +} // namespace rl +} // namespace mlpack + +#endif