forked from mlpack/mlpack
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request mlpack#1037 from shikharbhardwaj/parallel_sgd
Implementation of parallel SGD
- Loading branch information
Showing
24 changed files
with
1,127 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
set(SOURCES | ||
parallel_sgd.hpp | ||
parallel_sgd_impl.hpp | ||
sparse_test_function.hpp | ||
sparse_test_function_impl.hpp | ||
) | ||
|
||
set(DIR_SRCS) | ||
foreach(file ${SOURCES}) | ||
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) | ||
endforeach() | ||
|
||
set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE) |
11 changes: 11 additions & 0 deletions
11
src/mlpack/core/optimizers/parallel_sgd/decay_policies/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
set(SOURCES | ||
constant_step.hpp | ||
exponential_backoff.hpp | ||
) | ||
|
||
set(DIR_SRCS) | ||
foreach(file ${SOURCES}) | ||
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) | ||
endforeach() | ||
|
||
set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE) |
55 changes: 55 additions & 0 deletions
55
src/mlpack/core/optimizers/parallel_sgd/decay_policies/constant_step.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/** | ||
* @file constant_step.hpp | ||
* @author Shikhar Bhardwaj | ||
* | ||
* Constant step size policy for parallel Stochastic Gradient Descent. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
#ifndef MLPACK_CORE_OPTIMIZERS_PARALLEL_SGD_CONSTANT_STEP_HPP | ||
#define MLPACK_CORE_OPTIMIZERS_PARALLEL_SGD_CONSTANT_STEP_HPP | ||
|
||
#include <mlpack/prereqs.hpp> | ||
|
||
namespace mlpack { | ||
namespace optimization { | ||
|
||
/** | ||
* Implementation of the ConstantStep stepsize decay policy for parallel SGD. | ||
*/ | ||
class ConstantStep | ||
{ | ||
public: | ||
/** | ||
* Member initialization constructor. | ||
* | ||
* The defaults here are not necessarily good for the given problem, so it is | ||
* suggested that the values used be tailored to the task at hand. | ||
* | ||
* @param step The intial stepsize to use. | ||
*/ | ||
ConstantStep(const double step = 0.01) : step(step) { /* Nothing to do */ } | ||
|
||
/** | ||
* This function is called in each iteration before the gradient update. | ||
* | ||
* @param numEpoch The iteration number for which the stepsize is to be | ||
* calculated. | ||
* @return The step size for the current iteration. | ||
*/ | ||
double StepSize(const size_t /* numEpoch */) | ||
{ | ||
return step; | ||
} | ||
private: | ||
//! The initial stepsize, which remains unchanged | ||
double step; | ||
}; | ||
|
||
} // namespace optimization | ||
} // namespace mlpack | ||
|
||
#endif |
92 changes: 92 additions & 0 deletions
92
src/mlpack/core/optimizers/parallel_sgd/decay_policies/exponential_backoff.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
/** | ||
* @file exponential_backoff.hpp | ||
* @author Shikhar Bhardwaj | ||
* | ||
* Exponential backoff step size decay policy for parallel Stochastic Gradient | ||
* Descent. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
#ifndef MLPACK_CORE_OPTIMIZERS_PARALLEL_SGD_EXP_BACKOFF_HPP | ||
#define MLPACK_CORE_OPTIMIZERS_PARALLEL_SGD_EXP_BACKOFF_HPP | ||
|
||
#include <mlpack/prereqs.hpp> | ||
|
||
namespace mlpack { | ||
namespace optimization { | ||
|
||
/** | ||
* Exponential backoff stepsize reduction policy for parallel SGD. | ||
* | ||
* For more information, see the following. | ||
* | ||
* @misc{1106.5730, | ||
* Author = {Feng Niu and Benjamin Recht and Christopher Re and Stephen J. | ||
* Wright}, | ||
* Title = {HOGWILD!: A Lock-Free Approach to Parallelizing Stochastic | ||
* Gradient Descent}, | ||
* Year = {2011}, | ||
* Eprint = {arXiv:1106.5730}, | ||
* } | ||
* | ||
* This stepsize update scheme gives robust 1/k convergence rates to the | ||
* implementation of parallel SGD. | ||
*/ | ||
class ExponentialBackoff | ||
{ | ||
public: | ||
/** | ||
* Member initializer constructor to construct the exponential backoff policy | ||
* with the required parameters. | ||
* | ||
* @param firstBackoffEpoch The number of updates to run before the first | ||
* stepsize backoff. | ||
* @param step The initial stepsize(gamma). | ||
* @param beta The reduction factor. This should be a value in range (0, 1). | ||
*/ | ||
ExponentialBackoff(const size_t firstBackoffEpoch, | ||
const double step, | ||
const double beta) : | ||
firstBackoffEpoch(firstBackoffEpoch), | ||
cutoffEpoch(firstBackoffEpoch), | ||
step(step), | ||
beta(beta) | ||
{ /* Nothing to do. */ } | ||
|
||
/** | ||
* Get the step size for the current gradient update. | ||
* | ||
* @param numEpoch The iteration number of the current update. | ||
* @return The stepsize for the current iteration. | ||
*/ | ||
double StepSize(const size_t numEpoch) | ||
{ | ||
if (numEpoch >= cutoffEpoch) | ||
{ | ||
step *= beta; | ||
cutoffEpoch += firstBackoffEpoch / beta; | ||
} | ||
return step; | ||
} | ||
|
||
private: | ||
//! The first iteration at which the stepsize should be reduced. | ||
size_t firstBackoffEpoch; | ||
|
||
//! The iteration at which the next decay will be performed. | ||
size_t cutoffEpoch; | ||
|
||
//! The initial stepsize. | ||
double step; | ||
|
||
//! The reduction factor, should be in range (0, 1). | ||
double beta; | ||
}; | ||
|
||
} // namespace optimization | ||
} // namespace mlpack | ||
|
||
#endif |
151 changes: 151 additions & 0 deletions
151
src/mlpack/core/optimizers/parallel_sgd/parallel_sgd.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
/** | ||
* @file parallel_sgd.hpp | ||
* @author Shikhar Bhardwaj | ||
* | ||
* Parallel Stochastic Gradient Descent. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
#ifndef MLPACK_CORE_OPTIMIZERS_PARALLEL_SGD_HPP | ||
#define MLPACK_CORE_OPTIMIZERS_PARALLEL_SGD_HPP | ||
|
||
#include <mlpack/prereqs.hpp> | ||
#include <mlpack/core/math/random.hpp> | ||
#include "decay_policies/constant_step.hpp" | ||
|
||
namespace mlpack { | ||
namespace optimization { | ||
|
||
/** | ||
* An implementation of parallel stochastic gradient descent using the lock-free | ||
* HOGWILD! approach. | ||
* | ||
* For more information, see the following. | ||
* @misc{1106.5730, | ||
* Author = {Feng Niu and Benjamin Recht and Christopher Re and Stephen J. | ||
* Wright}, | ||
* Title = {HOGWILD!: A Lock-Free Approach to Parallelizing Stochastic | ||
* Gradient Descent}, | ||
* Year = {2011}, | ||
* Eprint = {arXiv:1106.5730}, | ||
* } | ||
* | ||
* For Parallel SGD to work, a SparseFunctionType template parameter is | ||
* required. This class must implement the following functions: | ||
* | ||
* size_t NumFunctions(); | ||
* double Evaluate(const arma::mat& coordinates, const size_t i); | ||
* void Gradient(const arma::mat& coordinates, | ||
* const size_t i, | ||
* arma::sp_mat& gradient); | ||
* | ||
* In these functions the parameter id refers to which individual function (or | ||
* gradient) is being evaluated. In case of a data-dependent function, the id | ||
* would refer to the index of the datapoint(or training example). | ||
* The data is distributed uniformly among the threads made available to the | ||
* program by the OpenMP runtime. | ||
* | ||
* The Gradient function interface is slightly changed from the | ||
* DecomposableFunctionType interface, it takes in a sparse matrix as the | ||
* out-param for the gradient, as ParallelSGD is only expected to be relevant in | ||
* situations where the computed gradient is sparse. | ||
* | ||
* @tparam DecayPolicyType Step size update policy used by parallel SGD | ||
* to update the stepsize after each iteration. | ||
*/ | ||
template <typename DecayPolicyType = ConstantStep> | ||
class ParallelSGD | ||
{ | ||
public: | ||
/** | ||
* Construct the parallel SGD optimizer to optimize the given function with | ||
* the given parameters. One iteration means one batch of datapoints processed | ||
* by each thread. | ||
* | ||
* The defaults here are not necessarily good for the given problem, so it is | ||
* suggested that the values used be tailored to the task at hand. | ||
* | ||
* @param maxIterations Maximum number of iterations allowed (0 means no | ||
* limit). | ||
* @param threadShareSize Number of datapoints to be processed in one | ||
* iteration by each thread. | ||
* @param tolerance Maximum absolute tolerance to terminate the algorithm. | ||
* @param shuffle If true, the function order is shuffled; otherwise, each | ||
* function is visited in linear order. | ||
* @param decayPolicy The step size update policy to use. | ||
*/ | ||
ParallelSGD(const size_t maxIterations, | ||
const size_t threadShareSize, | ||
const double tolerance = 1e-5, | ||
const bool shuffle = true, | ||
const DecayPolicyType& decayPolicy = DecayPolicyType()); | ||
|
||
/** | ||
* Optimize the given function using the parallel SGD algorithm. The given | ||
* starting point will be modified to store the finishing point of the | ||
* algorithm, and the value of the loss function at the final point is | ||
* returned. | ||
* | ||
* @tparam SparseFunctionType Type of function to be optimized. | ||
* @param function Function to be optimized(minimized). | ||
* @param iterate Starting point(will be modified). | ||
* @return Objective value at the final point. | ||
*/ | ||
template <typename SparseFunctionType> | ||
double Optimize(SparseFunctionType& function, arma::mat& iterate); | ||
|
||
//! Get the maximum number of iterations (0 indicates no limits). | ||
size_t MaxIterations() const { return maxIterations; } | ||
//! Modify the maximum number of iterations (0 indicates no limits). | ||
size_t& MaxIterations() { return maxIterations; } | ||
|
||
//! Get the number of datapoints to be processed in one iteration by each | ||
//! thread. | ||
size_t ThreadShareSize() const { return threadShareSize; } | ||
//! Modify the number of datapoints to be processed in one iteration by each | ||
//! thread. | ||
size_t& ThreadShareSize() { return threadShareSize; } | ||
|
||
//! Get the tolerance for termination. | ||
double Tolerance() const { return tolerance; } | ||
//! Modify the tolerance for termination. | ||
double& Tolerance() { return tolerance; } | ||
|
||
//! Get whether or not the individual functions are shuffled. | ||
bool Shuffle() const { return shuffle; } | ||
//! Modify whether or not the individual functions are shuffled. | ||
bool& Shuffle() { return shuffle; } | ||
|
||
//! Get the step size decay policy. | ||
DecayPolicyType& DecayPolicy() const { return decayPolicy; } | ||
//! Modify the step size decay policy. | ||
DecayPolicyType& DecayPolicy() { return decayPolicy; } | ||
|
||
private: | ||
//! The maximum number of allowed iterations. | ||
size_t maxIterations; | ||
|
||
//! The number of datapoints to be processed in one iteration by each thread. | ||
size_t threadShareSize; | ||
|
||
//! The tolerance for termination. | ||
double tolerance; | ||
|
||
//! Controls whether or not the individual functions are shuffled when | ||
//! iterating. | ||
bool shuffle; | ||
|
||
//! The step size decay policy. | ||
DecayPolicyType decayPolicy; | ||
}; | ||
|
||
} // namespace optimization | ||
} // namespace mlpack | ||
|
||
// Include implementation. | ||
#include "parallel_sgd_impl.hpp" | ||
|
||
#endif |
Oops, something went wrong.