Skip to content

Commit

Permalink
Classical control tasks, Mountain Car and Cart Pole
Browse files Browse the repository at this point in the history
  • Loading branch information
ShangtongZhang committed Apr 25, 2017
1 parent a41cda1 commit e981050
Show file tree
Hide file tree
Showing 4 changed files with 354 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ xcode*
.DS_Store
src/mlpack/core/util/gitversion.hpp
src/mlpack/core/util/arma_config.hpp
.idea
cmake-build-*
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Define the files we need to compile
# Anything not in this list will not be compiled into mlpack.
set(SOURCES
mountain_car.hpp
cart_pole.hpp
)

# Add directory name to sources.
set(DIR_SRCS)
foreach(file ${SOURCES})
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
endforeach()
# Append sources (with directory name) to list of all mlpack sources (used at
# the parent scope).
set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
182 changes: 182 additions & 0 deletions src/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
/**
* @file cart_pole.hpp
* @author Shangtong Zhang
*
* This file is an implementation of Cart Pole task
* https://gym.openai.com/envs/CartPole-v0
*
* TODO: refactor to OpenAI interface
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/

#ifndef MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP
#define MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP

#include <mlpack/prereqs.hpp>

namespace mlpack {
namespace rl {

namespace cart_pole_details {
// Some constants of Cart Pole task
constexpr double gravity = 9.8;
constexpr double massCart = 1.0;
constexpr double massPole = 0.1;
constexpr double totalMass = massCart + massPole;
constexpr double length = 0.5;
constexpr double poleMassLength = massPole * length;
constexpr double forceMag = 10.0;
constexpr double tau = 0.02;
constexpr double thetaThresholdRadians = 12 * 2 * 3.1416 / 360;
constexpr double xThreshold = 2.4;
}

/**
* Implementation of Cart Pole task
*/
class CartPole
{
public:

/**
* Implementation of state of Cart Pole
* Each state is a tuple of (position, velocity, angle, angular velocity)
*/
class State
{
public:
//! Construct a state instance
State() : data(4) { }

//! Construct a state instance from given data
State(arma::colvec data) : data(data) { }

//! Get position
double X() const
{
return data[0];
}

//! Modify position
double& X()
{
return data[0];
}

//! Get velocity
double XDot() const
{
return data[1];
}

//! Modify velocity
double& XDot()
{
return data[1];
}

//! Get angle
double Theta() const
{
return data[2];
}

//! Modify angle
double& Theta()
{
return data[2];
}

//! Get angular velocity
double ThetaDot() const
{
return data[3];
}

//! Modify angular velocity
double& ThetaDot()
{
return data[3];
}

//! Encode the state to a column vector
const arma::colvec& Encode() const
{
return data;
}

//! Whether current state is terminal state
bool IsTerminal() const
{
using namespace cart_pole_details;
return std::abs(X()) > xThreshold ||
std::abs(Theta()) > thetaThresholdRadians;
}

private:
//! Locally-stored (position, velocity, angle, angular velocity)
arma::colvec data;
};

/**
* Implementation of action of Cart Pole
*/
class Action
{
public:
enum Actions
{
backward,
forward
};

//! # of actions
static constexpr size_t count = 2;
};

/**
* Dynamics of Cart Pole
* Get next state and next action based on current state and current action
* @param state Current state
* @param action Current action
* @param nextState Next state
* @param reward Reward is always 1
*/
void Sample(const State& state, const Action::Actions& action,
State& nextState, double& reward)
{
using namespace cart_pole_details;
double force = action ? forceMag : -forceMag;
double cosTheta = std::cos(state.Theta());
double sinTheta = std::sin(state.Theta());
double temp = (force + poleMassLength * state.ThetaDot() * state.ThetaDot() * sinTheta) / totalMass;
double thetaAcc = (gravity * sinTheta - cosTheta * temp) /
(length * (4.0 / 3.0 - massPole * cosTheta * cosTheta / totalMass));
double xAcc = temp - poleMassLength * thetaAcc * cosTheta / totalMass;
nextState.X() = state.X() + tau * state.XDot();
nextState.XDot() = state.XDot() + tau * xAcc;
nextState.Theta() = state.Theta() + tau * state.ThetaDot();
nextState.ThetaDot() = state.ThetaDot() + tau * thetaAcc;

reward = 1.0;
}

/**
* Initial state representation is randomly generated within [-0.05, 0.05]
* @return Initial state for each episode
*/
State InitialSample()
{
return State((arma::randu<arma::colvec>(4) - 0.5) / 10.0);
}

};

} // namespace rl
} // namespace mlpack

#endif
155 changes: 155 additions & 0 deletions src/mlpack/methods/reinforcement_learning/environment/mountain_car.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/**
* @file mountain_car.hpp
* @author Shangtong Zhang
*
* This file is an implementation of Mountain Car task
* https://gym.openai.com/envs/MountainCar-v0
*
* TODO: refactor to OpenAI interface
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/

#ifndef MLPACK_METHODS_RL_ENVIRONMENT_MOUNTAIN_CAR_HPP
#define MLPACK_METHODS_RL_ENVIRONMENT_MOUNTAIN_CAR_HPP

#include <mlpack/prereqs.hpp>

namespace mlpack {
namespace rl {

namespace mountain_car_details {
constexpr double positionMin = -1.2;
constexpr double positionMax = 0.5;
constexpr double velocityMin = -0.07;
constexpr double velocityMax = 0.07;
}

/**
* Implementation of Mountain Car task
*/
class MountainCar
{
public:

/**
* Implementation of state of Mountain Car
* Each state is a (velocity, position) pair
*/
class State
{
public:
//! Construct a state instance
State(double velocity = 0, double position = 0) : data(2)
{
this->Velocity() = velocity;
this->Position() = position;
}

//! Encode the state to a column vector
const arma::colvec& Encode() const
{
return data;
}

//! Get velocity
double Velocity() const
{
return data[0];
}

//! Modify velocity
double& Velocity()
{
return data[0];
}

//! Get position
double Position() const
{
return data[1];
}

//! Modify position
double& Position()
{
return data[1];
}

//! Whether current state is terminal state
bool IsTerminal() const
{
using namespace mountain_car_details;
return std::abs(Position() - positionMax) <= 1e-5;
}

private:
//! Locally-stored velocity and position
arma::colvec data;
};

/**
* Implementation of action of Mountain Car
*/
class Action
{
public:
enum Actions
{
backward,
stop,
forward
};

//! # of actions
static constexpr size_t count = 3;
};

/**
* Dynamics of Mountain Car
* Get next state and next action based on current state and current action
* @param state Current state
* @param action Current action
* @param nextState Next state
* @param reward Reward is always -1
*/
void Sample(const State& state, const Action::Actions& action,
State& nextState, double& reward)
{
using namespace mountain_car_details;
int direction = action - 1;
nextState.Velocity() = state.Velocity() + 0.001 * direction - 0.0025 * std::cos(3 * state.Position());
nextState.Velocity() = std::min(std::max(nextState.Velocity(), velocityMin), velocityMax);

nextState.Position() = state.Position() + nextState.Velocity();
nextState.Position() = std::min(std::max(nextState.Position(), positionMin), positionMax);

reward = -1.0;
if (std::abs(nextState.Position() - positionMin) <= 1e-5)
{
nextState.Velocity() = 0.0;
}
}

/**
* Initial position is randomly generated within [-0.6, -0.4]
* Initial velocity is 0
* @return Initial state for each episode
*/
State InitialSample()
{
State state;
state.Velocity() = 0.0;
state.Position() = arma::as_scalar(arma::randu(1)) * 0.2 - 0.6;
return state;
}

};

} // namespace rl
} // namespace mlpack

#endif

0 comments on commit e981050

Please sign in to comment.