Skip to content

Commit

Permalink
Merge pull request mlpack#1070 from partobs-mdp/grad-opt
Browse files Browse the repository at this point in the history
Adding optimization features (cross-entropy layer, gradient clipping).
  • Loading branch information
zoq authored Jul 29, 2017
2 parents 00c746e + e310c98 commit 5c68061
Show file tree
Hide file tree
Showing 13 changed files with 453 additions and 13 deletions.
21 changes: 19 additions & 2 deletions src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,17 @@ class MiniBatchSGDType
* @param updatePolicy Instantiated update policy used to adjust the given
* parameters.
* @param decayPolicy Instantiated decay policy used to adjust the step size.
* @param resetPolicy Flag that determines whether update policy parameters
* are reset before every Optimize call.
*/
MiniBatchSGDType(const size_t batchSize = 1000,
const double stepSize = 0.01,
const size_t maxIterations = 100000,
const double tolerance = 1e-5,
const bool shuffle = true,
const UpdatePolicyType& updatePolicy = UpdatePolicyType(),
const DecayPolicyType& decayPolicy = DecayPolicyType());
const DecayPolicyType& decayPolicy = DecayPolicyType(),
const bool resetPolicy = true);

/**
* Optimize the given function using mini-batch SGD. The given starting point
Expand All @@ -122,10 +125,13 @@ class MiniBatchSGDType
* @tparam DecomposableFunctionType Type of the function to be optimized.
* @param function Function to optimize.
* @param iterate Starting point (will be modified).
* @param resetPolicy Flag indicating whether update policy
* should be reset before running optimization.
* @return Objective value of the final point.
*/
template<typename DecomposableFunctionType>
double Optimize(DecomposableFunctionType& function, arma::mat& iterate);
double Optimize(DecomposableFunctionType& function,
arma::mat& iterate);

//! Get the batch size.
size_t BatchSize() const { return batchSize; }
Expand All @@ -152,6 +158,13 @@ class MiniBatchSGDType
//! Modify whether or not the individual functions are shuffled.
bool& Shuffle() { return shuffle; }

//! Get whether or not the update policy parameters
//! are reset before Optimize call.
bool ResetPolicy() const { return resetPolicy; }
//! Modify whether or not the update policy parameters
//! are reset before Optimize call.
bool& ResetPolicy() { return resetPolicy; }

//! Get the update policy.
UpdatePolicyType UpdatePolicy() const { return updatePolicy; }
//! Modify the update policy.
Expand Down Expand Up @@ -184,6 +197,10 @@ class MiniBatchSGDType

//! The decay policy used to update the parameters in each iteration.
DecayPolicyType decayPolicy;

//! Flag that determines whether update policy parameters
//! are reset before every Optimize call.
bool resetPolicy;
};

using MiniBatchSGD = MiniBatchSGDType<VanillaUpdate, NoDecay>;
Expand Down
12 changes: 8 additions & 4 deletions src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,16 @@ MiniBatchSGDType<
const double tolerance,
const bool shuffle,
const UpdatePolicyType& updatePolicy,
const DecayPolicyType& decayPolicy) :
const DecayPolicyType& decayPolicy,
const bool resetPolicy) :
batchSize(batchSize),
stepSize(stepSize),
maxIterations(maxIterations),
tolerance(tolerance),
shuffle(shuffle),
updatePolicy(updatePolicy),
decayPolicy(decayPolicy)
decayPolicy(decayPolicy),
resetPolicy(resetPolicy)
{ /* Nothing to do. */ }

//! Optimize the function (minimize).
Expand All @@ -50,7 +52,8 @@ template<typename DecomposableFunctionType>
double MiniBatchSGDType<
UpdatePolicyType,
DecayPolicyType
>::Optimize(DecomposableFunctionType& function, arma::mat& iterate)
>::Optimize(DecomposableFunctionType& function,
arma::mat& iterate)
{
// Find the number of functions.
const size_t numFunctions = function.NumFunctions();
Expand All @@ -75,7 +78,8 @@ double MiniBatchSGDType<
overallObjective += function.Evaluate(iterate, i);

// Initialize the update policy.
updatePolicy.Initialize(iterate.n_rows, iterate.n_cols);
if (resetPolicy)
updatePolicy.Initialize(iterate.n_rows, iterate.n_cols);

// Now iterate!
arma::mat gradient(iterate.n_rows, iterate.n_cols);
Expand Down
21 changes: 19 additions & 2 deletions src/mlpack/core/optimizers/sgd/sgd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,17 @@ class SGD
* @param tolerance Maximum absolute tolerance to terminate algorithm.
* @param shuffle If true, the function order is shuffled; otherwise, each
* function is visited in linear order.
* @param updatePolicy Instantiated update policy used to adjust the given
* parameters.
* @param resetPolicy Flag that determines whether update policy parameters
* are reset before every Optimize call.
*/
SGD(const double stepSize = 0.01,
const size_t maxIterations = 100000,
const double tolerance = 1e-5,
const bool shuffle = true,
const UpdatePolicyType updatePolicy = UpdatePolicyType());
const UpdatePolicyType updatePolicy = UpdatePolicyType(),
const bool resetPolicy = true);

/**
* Optimize the given function using stochastic gradient descent. The given
Expand All @@ -110,7 +115,8 @@ class SGD
* @return Objective value of the final point.
*/
template<typename DecomposableFunctionType>
double Optimize(DecomposableFunctionType& function, arma::mat& iterate);
double Optimize(DecomposableFunctionType& function,
arma::mat& iterate);

//! Get the step size.
double StepSize() const { return stepSize; }
Expand All @@ -132,6 +138,13 @@ class SGD
//! Modify whether or not the individual functions are shuffled.
bool& Shuffle() { return shuffle; }

//! Get whether or not the update policy parameters
//! are reset before Optimize call.
bool ResetPolicy() const { return resetPolicy; }
//! Modify whether or not the update policy parameters
//! are reset before Optimize call.
bool& ResetPolicy() { return resetPolicy; }

//! Get the update policy.
const UpdatePolicyType& UpdatePolicy() const { return updatePolicy; }
//! Modify the update policy.
Expand All @@ -153,6 +166,10 @@ class SGD

//! The update policy used to update the parameters in each iteration.
UpdatePolicyType updatePolicy;

//! Flag indicating whether update policy
//! should be reset before running optimization.
bool resetPolicy;
};

using StandardSGD = SGD<VanillaUpdate>;
Expand Down
9 changes: 6 additions & 3 deletions src/mlpack/core/optimizers/sgd/sgd_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ SGD<UpdatePolicyType>::SGD(
const size_t maxIterations,
const double tolerance,
const bool shuffle,
const UpdatePolicyType updatePolicy) :
const UpdatePolicyType updatePolicy,
const bool resetPolicy) :
stepSize(stepSize),
maxIterations(maxIterations),
tolerance(tolerance),
shuffle(shuffle),
updatePolicy(updatePolicy)
updatePolicy(updatePolicy),
resetPolicy(resetPolicy)
{ /* Nothing to do. */ }

//! Optimize the function (minimize).
Expand Down Expand Up @@ -65,7 +67,8 @@ double SGD<UpdatePolicyType>::Optimize(
overallObjective += function.Evaluate(iterate, i);

// Initialize the update policy.
updatePolicy.Initialize(iterate.n_rows, iterate.n_cols);
if (resetPolicy)
updatePolicy.Initialize(iterate.n_rows, iterate.n_cols);

// Now iterate!
arma::mat gradient(iterate.n_rows, iterate.n_cols);
Expand Down
109 changes: 109 additions & 0 deletions src/mlpack/core/optimizers/sgd/update_policies/gradient_clipping.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/**
* @file gradient_clipping.hpp
* @author Konstantin Sidorov
*
* Gradient clipping update wrapper.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_CORE_OPTIMIZERS_SGD_GRADIENT_CLIPPING_HPP
#define MLPACK_CORE_OPTIMIZERS_SGD_GRADIENT_CLIPPING_HPP

#include <mlpack/prereqs.hpp>

namespace mlpack {
namespace optimization {

/**
* Interface for wrapping around update policies (e.g., VanillaUpdate)
* and feeding a clipped gradient to them instead of the normal one.
* (Clipping here is implemented as
* \f$ g_{\text{clipped}} = \max(g_{\text{min}}, \min(g_{\text{min}}, g))) \f$.)
*
* @tparam UpdatePolicy A type of UpdatePolicy that sould be wrapped around.
*/
template<typename UpdatePolicyType>
class GradientClipping
{
public:
/**
* Constructor for creating a GradientClipping instance.
*
* @param minGradient Minimum possible value of gradient element.
* @param maxGradient Maximum possible value of gradient element.
* @param updatePolicy An instance of the UpdatePolicyType
* used for actual optimization.
*/
GradientClipping(const double minGradient,
const double maxGradient,
UpdatePolicyType& updatePolicy) :
minGradient(minGradient),
maxGradient(maxGradient),
updatePolicy(updatePolicy)
{
// Nothing to do here
}

/**
* The Initialize method is called by SGD Optimizer method before the start of
* the iteration update process. Here we just do whatever initialization
* is needed for the actual update policy.
*
* @param rows Number of rows in the gradient matrix.
* @param cols Number of columns in the gradient matrix.
*/
void Initialize(const size_t rows, const size_t cols)
{
updatePolicy.Initialize(rows, cols);
}

/**
* Update step. First, the gradient is clipped, and then the actual update
* policy does whatever update it needs to do.
*
* @param iterate Parameters that minimize the function.
* @param stepSize Step size to be used for the given iteration.
* @param gradient The gradient matrix.
*/
void Update(arma::mat& iterate,
const double stepSize,
const arma::mat& gradient)
{
// First, clip the gradient.
arma::mat clippedGradient = arma::clamp(gradient, minGradient, maxGradient);
// And only then do the update.
updatePolicy.Update(iterate, stepSize, clippedGradient);
}

//! Get the update policy.
UpdatePolicyType& UpdatePolicy() const { return updatePolicy; }
//! Modify the update policy.
UpdatePolicyType& UpdatePolicy() { return updatePolicy; }

//! Get the minimum gradient value.
double MinGradient() const { return minGradient; }
//! Modify the minimum gradient value.
double& MinGradient() { return minGradient; }

//! Get the maximum gradient value.
double MaxGradient() const { return maxGradient; }
//! Modify the maximum gradient value.
double& MaxGradient() { return maxGradient; }
private:
//! Minimum possible value of gradient element.
double minGradient;

//! Maximum possible value of gradient element.
double maxGradient;

//! An instance of the UpdatePolicy used for actual optimization.
UpdatePolicyType updatePolicy;
};

} // namespace optimization
} // namespace mlpack

#endif
2 changes: 2 additions & 0 deletions src/mlpack/methods/ann/layer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ set(SOURCES
constant_impl.hpp
convolution.hpp
convolution_impl.hpp
cross_entropy_error.hpp
cross_entropy_error_impl.hpp
dropconnect.hpp
dropconnect_impl.hpp
dropout.hpp
Expand Down
Loading

0 comments on commit 5c68061

Please sign in to comment.