diff --git a/COPYRIGHT.txt b/COPYRIGHT.txt index d39c4a2b5cb..bf9841da5f7 100644 --- a/COPYRIGHT.txt +++ b/COPYRIGHT.txt @@ -78,6 +78,7 @@ Copyright: Copyright 2017, Sagar B Hathwar Copyright 2017, Nishanth Hegde Copyright 2017, Parminder Singh + Copyright 2017, CodeAi License: BSD-3-clause All rights reserved. diff --git a/HISTORY.md b/HISTORY.md index e5883ec98b5..5541f6326df 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,6 +1,11 @@ ### mlpack ?.?.? ###### ????-??-?? +### mlpack 2.2.1 +###### 2017-04-13 + * Compilation fix for mlpack_nca and mlpack_test on older Armadillo versions + (#984). + ### mlpack 2.2.0 ###### 2017-03-21 * Bugfix for mlpack_knn program (#816). diff --git a/doc/guide/build.hpp b/doc/guide/build.hpp index b9f9913d03e..6b999913ecd 100644 --- a/doc/guide/build.hpp +++ b/doc/guide/build.hpp @@ -23,14 +23,14 @@ href="https://keon.io/mlpack/mlpack-on-windows/">Keon's excellent tutorial. @section Download latest mlpack build Download latest mlpack build from here: -mlpack-2.2.0 +mlpack-2.2.1 @section builddir Creating Build Directory Once the mlpack source is unpacked, you should create a build directory. @code -$ cd mlpack-2.2.0 +$ cd mlpack-2.2.1 $ mkdir build @endcode diff --git a/src/mlpack/core.hpp b/src/mlpack/core.hpp index 0880064c5eb..94a2ab84e75 100644 --- a/src/mlpack/core.hpp +++ b/src/mlpack/core.hpp @@ -219,6 +219,7 @@ * - Sagar B Hathwar * - Nishanth Hegde * - Parminder Singh + * - CodeAi (deep learning bug detector) */ // First, include all of the prerequisites. diff --git a/src/mlpack/core/optimizers/CMakeLists.txt b/src/mlpack/core/optimizers/CMakeLists.txt index c0b5f563f58..072fd27b32e 100644 --- a/src/mlpack/core/optimizers/CMakeLists.txt +++ b/src/mlpack/core/optimizers/CMakeLists.txt @@ -10,6 +10,7 @@ set(DIRS sa sdp sgd + smorms3 ) foreach(dir ${DIRS}) diff --git a/src/mlpack/core/optimizers/ada_delta/ada_delta.hpp b/src/mlpack/core/optimizers/ada_delta/ada_delta.hpp index 8589374460c..9a209c63e7b 100644 --- a/src/mlpack/core/optimizers/ada_delta/ada_delta.hpp +++ b/src/mlpack/core/optimizers/ada_delta/ada_delta.hpp @@ -34,10 +34,10 @@ namespace optimization { * * @code * @article{Zeiler2012, - * author = {Matthew D. Zeiler}, - * title = {{ADADELTA:} An Adaptive Learning Rate Method}, - * journal = {CoRR}, - * year = {2012} + * author = {Matthew D. Zeiler}, + * title = {{ADADELTA:} An Adaptive Learning Rate Method}, + * journal = {CoRR}, + * year = {2012} * } * @endcode * diff --git a/src/mlpack/core/optimizers/ada_delta/ada_delta_update.hpp b/src/mlpack/core/optimizers/ada_delta/ada_delta_update.hpp index 391e2a88bd3..fc3d0bcde02 100644 --- a/src/mlpack/core/optimizers/ada_delta/ada_delta_update.hpp +++ b/src/mlpack/core/optimizers/ada_delta/ada_delta_update.hpp @@ -29,10 +29,10 @@ namespace optimization { * * @code * @article{Zeiler2012, - * author = {Matthew D. Zeiler}, - * title = {{ADADELTA:} An Adaptive Learning Rate Method}, - * journal = {CoRR}, - * year = {2012} + * author = {Matthew D. Zeiler}, + * title = {{ADADELTA:} An Adaptive Learning Rate Method}, + * journal = {CoRR}, + * year = {2012} * } * @endcode * diff --git a/src/mlpack/core/optimizers/ada_grad/ada_grad.hpp b/src/mlpack/core/optimizers/ada_grad/ada_grad.hpp index e96b619e858..6c622c8e3e7 100644 --- a/src/mlpack/core/optimizers/ada_grad/ada_grad.hpp +++ b/src/mlpack/core/optimizers/ada_grad/ada_grad.hpp @@ -30,13 +30,13 @@ namespace optimization { * * @code * @article{duchi2011adaptive, - * author = {Duchi, John and Hazan, Elad and Singer, Yoram}, - * title = {Adaptive subgradient methods for online learning and stochastic optimization}, - * journal = {Journal of Machine Learning Research}, - * volume = {12}, - * number = {Jul}, - * pages = {2121--2159}, - * year = {2011} + * author = {Duchi, John and Hazan, Elad and Singer, Yoram}, + * title = {Adaptive subgradient methods for online learning and stochastic optimization}, + * journal = {Journal of Machine Learning Research}, + * volume = {12}, + * number = {Jul}, + * pages = {2121--2159}, + * year = {2011} * } * @endcode * diff --git a/src/mlpack/core/optimizers/ada_grad/ada_grad_update.hpp b/src/mlpack/core/optimizers/ada_grad/ada_grad_update.hpp index 2fb52573b1b..b086926388b 100644 --- a/src/mlpack/core/optimizers/ada_grad/ada_grad_update.hpp +++ b/src/mlpack/core/optimizers/ada_grad/ada_grad_update.hpp @@ -58,8 +58,8 @@ class AdaGradUpdate * gradient matrix is initialized to the zeros matrix with the same size as * gradient matrix (see mlpack::optimization::SGD::Optimizer). * - * @param rows number of rows in the gradient matrix. - * @param cols number of columns in the gradient matrix. + * @param rows Number of rows in the gradient matrix. + * @param cols Number of columns in the gradient matrix. */ void Initialize(const size_t rows, const size_t cols) diff --git a/src/mlpack/core/optimizers/adam/CMakeLists.txt b/src/mlpack/core/optimizers/adam/CMakeLists.txt index 3cbcfd84e79..1377bbb0e15 100644 --- a/src/mlpack/core/optimizers/adam/CMakeLists.txt +++ b/src/mlpack/core/optimizers/adam/CMakeLists.txt @@ -1,6 +1,8 @@ set(SOURCES adam.hpp adam_impl.hpp + adam_update.hpp + adamax_update.hpp ) set(DIR_SRCS) diff --git a/src/mlpack/core/optimizers/adam/adam.hpp b/src/mlpack/core/optimizers/adam/adam.hpp index 62af4c6ac82..23615fdd9e6 100644 --- a/src/mlpack/core/optimizers/adam/adam.hpp +++ b/src/mlpack/core/optimizers/adam/adam.hpp @@ -20,6 +20,10 @@ #include +#include +#include "adam_update.hpp" +#include "adamax_update.hpp" + namespace mlpack { namespace optimization { @@ -33,10 +37,11 @@ namespace optimization { * * @code * @article{Kingma2014, - * author = {Diederik P. Kingma and Jimmy Ba}, - * title = {Adam: {A} Method for Stochastic Optimization}, - * journal = {CoRR}, - * year = {2014} + * author = {Diederik P. Kingma and Jimmy Ba}, + * title = {Adam: {A} Method for Stochastic Optimization}, + * journal = {CoRR}, + * year = {2014}, + * url = {http://arxiv.org/abs/1412.6980} * } * @endcode * @@ -60,9 +65,13 @@ namespace optimization { * * @tparam DecomposableFunctionType Decomposable objective function type to be * minimized. + * @tparam UpdateRule Adam optimizer update rule to be used. */ -template -class Adam +template< + typename DecomposableFunctionType, + typename UpdateRule = AdamUpdate +> +class AdamType { public: /** @@ -84,18 +93,15 @@ class Adam * @param tolerance Maximum absolute tolerance to terminate algorithm. * @param shuffle If true, the function order is shuffled; otherwise, each * function is visited in linear order. - * @param adaMax If true, then the AdaMax optimizer is used; otherwise, by - * default the Adam optimizer is used. */ - Adam(DecomposableFunctionType& function, - const double stepSize = 0.001, - const double beta1 = 0.9, - const double beta2 = 0.999, - const double eps = 1e-8, - const size_t maxIterations = 100000, - const double tolerance = 1e-5, - const bool shuffle = true, - const bool adaMax = false); + AdamType(DecomposableFunctionType& function, + const double stepSize = 0.001, + const double beta1 = 0.9, + const double beta2 = 0.999, + const double eps = 1e-8, + const size_t maxIterations = 100000, + const double tolerance = 1e-5, + const bool shuffle = true); /** * Optimize the given function using Adam. The given starting point will be @@ -105,82 +111,61 @@ class Adam * @param iterate Starting point (will be modified). * @return Objective value of the final point. */ - double Optimize(arma::mat& iterate); + double Optimize(arma::mat& iterate){ return optimizer.Optimize(iterate); } //! Get the instantiated function to be optimized. - const DecomposableFunctionType& Function() const { return function; } + const DecomposableFunctionType& Function() const + { + return optimizer.Function(); + } //! Modify the instantiated function. - DecomposableFunctionType& Function() { return function; } + DecomposableFunctionType& Function() { return optimizer.Function(); } //! Get the step size. - double StepSize() const { return stepSize; } + double StepSize() const { return optimizer.StepSize(); } //! Modify the step size. - double& StepSize() { return stepSize; } + double& StepSize() { return optimizer.StepSize(); } //! Get the smoothing parameter. - double Beta1() const { return beta1; } + double Beta1() const { return optimizer.UpdatePolicy().Beta1(); } //! Modify the smoothing parameter. - double& Beta1() { return beta1; } + double& Beta1() { return optimizer.UpdatePolicy().Beta1(); } //! Get the second moment coefficient. - double Beta2() const { return beta2; } + double Beta2() const { return optimizer.UpdatePolicy().Beta2(); } //! Modify the second moment coefficient. - double& Beta2() { return beta2; } + double& Beta2() { return optimizer.UpdatePolicy().Beta2(); } //! Get the value used to initialise the mean squared gradient parameter. - double Epsilon() const { return eps; } + double Epsilon() const { return optimizer.UpdatePolicy().Epsilon(); } //! Modify the value used to initialise the mean squared gradient parameter. - double& Epsilon() { return eps; } + double& Epsilon() { return optimizer.UpdatePolicy().Epsilon(); } //! Get the maximum number of iterations (0 indicates no limit). - size_t MaxIterations() const { return maxIterations; } + size_t MaxIterations() const { return optimizer.MaxIterations(); } //! Modify the maximum number of iterations (0 indicates no limit). - size_t& MaxIterations() { return maxIterations; } + size_t& MaxIterations() { return optimizer.MaxIterations(); } //! Get the tolerance for termination. - double Tolerance() const { return tolerance; } + double Tolerance() const { return optimizer.Tolerance(); } //! Modify the tolerance for termination. - double& Tolerance() { return tolerance; } + double& Tolerance() { return optimizer.Tolerance(); } //! Get whether or not the individual functions are shuffled. - bool Shuffle() const { return shuffle; } + bool Shuffle() const { return optimizer.Shuffle(); } //! Modify whether or not the individual functions are shuffled. - bool& Shuffle() { return shuffle; } - - //! Get whether or not the AdaMax optimizer is specified. - bool AdaMax() const { return adaMax; } - //! Modify wehther or not the AdaMax optimizer is to be used. - bool& AdaMax() { return adaMax; } + bool& Shuffle() { return optimizer.Shuffle(); } private: - //! The instantiated function. - DecomposableFunctionType& function; - - //! The step size for each example. - double stepSize; - - //! Exponential decay rate for the first moment estimates. - double beta1; - - //! Exponential decay rate for the weighted infinity norm estimates. - double beta2; - - //! The value used to initialise the mean squared gradient parameter. - double eps; - - //! The maximum number of allowed iterations. - size_t maxIterations; - - //! The tolerance for termination. - double tolerance; + //! The Stochastic Gradient Descent object with Adam policy. + SGD optimizer; +}; - //! Controls whether or not the individual functions are shuffled when - //! iterating. - bool shuffle; +template +using Adam = AdamType; - //! Specifies whether or not the AdaMax optimizer is to be used. - bool adaMax; -}; +template +using AdaMax = AdamType; } // namespace optimization } // namespace mlpack diff --git a/src/mlpack/core/optimizers/adam/adam_impl.hpp b/src/mlpack/core/optimizers/adam/adam_impl.hpp index 8a6be66c079..82408de71e0 100644 --- a/src/mlpack/core/optimizers/adam/adam_impl.hpp +++ b/src/mlpack/core/optimizers/adam/adam_impl.hpp @@ -21,161 +21,26 @@ namespace mlpack { namespace optimization { -template -Adam::Adam(DecomposableFunctionType& function, - const double stepSize, - const double beta1, - const double beta2, - const double eps, - const size_t maxIterations, - const double tolerance, - const bool shuffle, - const bool adaMax) : - function(function), - stepSize(stepSize), - beta1(beta1), - beta2(beta2), - eps(eps), - maxIterations(maxIterations), - tolerance(tolerance), - shuffle(shuffle), - adaMax(adaMax) +template +AdamType::AdamType( + DecomposableFunctionType& function, + const double stepSize, + const double beta1, + const double beta2, + const double epsilon, + const size_t maxIterations, + const double tolerance, + const bool shuffle) : + optimizer(function, + stepSize, + maxIterations, + tolerance, + shuffle, + UpdateRule(epsilon, + beta1, + beta2)) { /* Nothing to do. */ } -//! Optimize the function (minimize). -template -double Adam::Optimize(arma::mat& iterate) -{ - // Find the number of functions to use. - const size_t numFunctions = function.NumFunctions(); - - // This is used only if shuffle is true. - arma::Col visitationOrder; - if (shuffle) - visitationOrder = arma::shuffle(arma::linspace>(0, - (numFunctions - 1), numFunctions)); - - // To keep track of where we are and how things are going. - size_t currentFunction = 0; - double overallObjective = 0; - double lastObjective = DBL_MAX; - - // Calculate the first objective function. - for (size_t i = 0; i < numFunctions; ++i) - overallObjective += function.Evaluate(iterate, i); - - // Now iterate! - arma::mat gradient(iterate.n_rows, iterate.n_cols); - - // Exponential moving average of gradient values. - arma::mat m = arma::zeros(iterate.n_rows, iterate.n_cols); - - /** - * Initialize either the exponentially weighted infinity norm for AdaMax - * optimizer (u) or exponential moving average of squared gradient values - * for Adam optimizer (v). - */ - arma::mat u, v; - if (adaMax) - { - u = arma::zeros(iterate.n_rows, iterate.n_cols); - } - else - { - v = arma::zeros(iterate.n_rows, iterate.n_cols); - } - - for (size_t i = 1; i != maxIterations; ++i, ++currentFunction) - { - // Is this iteration the start of a sequence? - if ((currentFunction % numFunctions) == 0) - { - // Output current objective function. - Log::Info << "Adam: iteration " << i << ", objective " << overallObjective - << "." << std::endl; - - if (std::isnan(overallObjective) || std::isinf(overallObjective)) - { - Log::Warn << "Adam: converged to " << overallObjective - << "; terminating with failure. Try a smaller step size?" - << std::endl; - return overallObjective; - } - - if (std::abs(lastObjective - overallObjective) < tolerance) - { - Log::Info << "Adam: minimized within tolerance " << tolerance << "; " - << "terminating optimization." << std::endl; - return overallObjective; - } - - // Reset the counter variables. - lastObjective = overallObjective; - overallObjective = 0; - currentFunction = 0; - - if (shuffle) // Determine order of visitation. - visitationOrder = arma::shuffle(visitationOrder); - } - - // Evaluate the gradient for this iteration. - if (shuffle) - function.Gradient(iterate, visitationOrder[currentFunction], gradient); - else - function.Gradient(iterate, currentFunction, gradient); - - // And update the iterate. - m *= beta1; - m += (1 - beta1) * gradient; - - if (adaMax) - { - // Update the exponentially weighted infinity norm. - u *= beta2; - u = arma::max(u, arma::abs(gradient)); - } - else - { - v *= beta2; - v += (1 - beta2) * (gradient % gradient); - } - - const double biasCorrection1 = 1.0 - std::pow(beta1, (double) i); - const double biasCorrection2 = 1.0 - std::pow(beta2, (double) i); - - if (adaMax) - { - if (biasCorrection1 != 0.0) - iterate -= (stepSize / biasCorrection1 * m / (u + eps)); - } - else - { - /** - * It should be noted that the term, m / (arma::sqrt(v) + eps), in the - * following expression is an approximation of the following actual term; - * m / (arma::sqrt(v) + (arma::sqrt(biasCorrection2) * eps). - */ - iterate -= (stepSize * std::sqrt(biasCorrection2) / biasCorrection1) * - m / (arma::sqrt(v) + eps); - } - - // Now add that to the overall objective function. - if (shuffle) - overallObjective += function.Evaluate(iterate, - visitationOrder[currentFunction]); - else - overallObjective += function.Evaluate(iterate, currentFunction); - } - - Log::Info << "Adam: maximum iterations (" << maxIterations << ") reached; " - << "terminating optimization." << std::endl; - // Calculate final objective. - overallObjective = 0; - for (size_t i = 0; i < numFunctions; ++i) - overallObjective += function.Evaluate(iterate, i); - return overallObjective; -} - } // namespace optimization } // namespace mlpack diff --git a/src/mlpack/core/optimizers/adam/adam_update.hpp b/src/mlpack/core/optimizers/adam/adam_update.hpp new file mode 100644 index 00000000000..0f64c82327c --- /dev/null +++ b/src/mlpack/core/optimizers/adam/adam_update.hpp @@ -0,0 +1,149 @@ +/** + * @file adam_update.hpp + * @author Ryan Curtin + * @author Vasanth Kalingeri + * @author Marcus Edel + * @author Vivek Pal + * + * Adam optimizer. Adam is an an algorithm for first-order gradient-based + * optimization of stochastic objective functions, based on adaptive estimates + * of lower-order moments. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#ifndef MLPACK_CORE_OPTIMIZERS_ADAM_ADAM_UPDATE_HPP +#define MLPACK_CORE_OPTIMIZERS_ADAM_ADAM_UPDATE_HPP + +#include + +namespace mlpack { +namespace optimization { + +/** + * Adam is an optimizer that computes individual adaptive learning rates for + * different parameters from estimates of first and second moments of the + * gradients as given in the section 7 of the following paper. + * + * For more information, see the following. + * + * @code + * @article{Kingma2014, + * author = {Diederik P. Kingma and Jimmy Ba}, + * title = {Adam: {A} Method for Stochastic Optimization}, + * journal = {CoRR}, + * year = {2014}, + * url = {http://arxiv.org/abs/1412.6980} + * } + * @endcode + */ +class AdamUpdate +{ + public: + /** + * Construct the Adam update policy with the given parameters. + * + * @param epsilon The epsilon value used to initialise the squared gradient + * parameter. + * @param beta1 The smoothing parameter. + * @param beta2 The second moment coefficient. + */ + AdamUpdate(const double epsilon = 1e-8, + const double beta1 = 0.9, + const double beta2 = 0.999) : + epsilon(epsilon), + beta1(beta1), + beta2(beta2), + iteration(0) + { + // Nothing to do. + } + + /** + * The Initialize method is called by SGD Optimizer method before the start of + * the iteration update process. + * + * @param rows Number of rows in the gradient matrix. + * @param cols Number of columns in the gradient matrix. + */ + void Initialize(const size_t rows, + const size_t cols) + { + m = arma::zeros(rows, cols); + v = arma::zeros(rows, cols); + } + + /** + * Update step for Adam. + * + * @param iterate Parameters that minimize the function. + * @param stepSize Step size to be used for the given iteration. + * @param gradient The gradient matrix. + */ + void Update(arma::mat& iterate, + const double stepSize, + const arma::mat& gradient) + { + // Increment the iteration counter variable. + ++iteration; + + // And update the iterate. + m *= beta1; + m += (1 - beta1) * gradient; + + v *= beta2; + v += (1 - beta2) * (gradient % gradient); + + const double biasCorrection1 = 1.0 - std::pow(beta1, iteration); + const double biasCorrection2 = 1.0 - std::pow(beta2, iteration); + + /** + * It should be noted that the term, m / (arma::sqrt(v) + eps), in the + * following expression is an approximation of the following actual term; + * m / (arma::sqrt(v) + (arma::sqrt(biasCorrection2) * eps). + */ + iterate -= (stepSize * std::sqrt(biasCorrection2) / biasCorrection1) * + m / (arma::sqrt(v) + epsilon); + } + + //! Get the value used to initialise the squared gradient parameter. + double Epsilon() const { return epsilon; } + //! Modify the value used to initialise the squared gradient parameter. + double& Epsilon() { return epsilon; } + + //! Get the smoothing parameter. + double Beta1() const { return beta1; } + //! Modify the smoothing parameter. + double& Beta1() { return beta1; } + + //! Get the second moment coefficient. + double Beta2() const { return beta2; } + //! Modify the second moment coefficient. + double& Beta2() { return beta2; } + + private: + // The epsilon value used to initialise the squared gradient parameter. + double epsilon; + + // The smoothing parameter. + double beta1; + + // The second moment coefficient. + double beta2; + + // The exponential moving average of gradient values. + arma::mat m; + + // The exponential moving average of squared gradient values. + arma::mat v; + + // The number of iterations. + double iteration; +}; + +} // namespace optimization +} // namespace mlpack + +#endif diff --git a/src/mlpack/core/optimizers/adam/adamax_update.hpp b/src/mlpack/core/optimizers/adam/adamax_update.hpp new file mode 100644 index 00000000000..9ac281cca8e --- /dev/null +++ b/src/mlpack/core/optimizers/adam/adamax_update.hpp @@ -0,0 +1,146 @@ +/** + * @file adamax_update.hpp + * @author Ryan Curtin + * @author Vasanth Kalingeri + * @author Marcus Edel + * @author Vivek Pal + * + * AdaMax update rule. Adam is an an algorithm for first-order gradient- + * -based optimization of stochastic objective functions, based on adaptive + * estimates of lower-order moments. AdaMax is simply a variant of Adam based + * on the infinity norm. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#ifndef MLPACK_CORE_OPTIMIZERS_ADAM_ADAMAX_UPDATE_HPP +#define MLPACK_CORE_OPTIMIZERS_ADAM_ADAMAX_UPDATE_HPP + +#include + +namespace mlpack { +namespace optimization { + +/** + * AdaMax is a variant of Adam, an optimizer that computes individual adaptive + * learning rates for different parameters from estimates of first and second + * moments of the gradients.based on the infinity norm as given in the section + * 7 of the following paper. + * + * For more information, see the following. + * + * @code + * @article{Kingma2014, + * author = {Diederik P. Kingma and Jimmy Ba}, + * title = {Adam: {A} Method for Stochastic Optimization}, + * journal = {CoRR}, + * year = {2014}, + * url = {http://arxiv.org/abs/1412.6980} + * } + * @endcode + */ +class AdaMaxUpdate +{ + public: + /** + * Construct the AdaMax update policy with the given parameters. + * + * @param epsilon The epsilon value used to initialise the squared gradient + * parameter. + * @param beta1 The smoothing parameter. + * @param beta2 The second moment coefficient. + */ + AdaMaxUpdate(const double epsilon = 1e-8, + const double beta1 = 0.9, + const double beta2 = 0.999) : + epsilon(epsilon), + beta1(beta1), + beta2(beta2), + iteration(0) + { + // Nothing to do. + } + + /** + * The Initialize method is called by SGD Optimizer method before the start of + * the iteration update process. + * + * @param rows Number of rows in the gradient matrix. + * @param cols Number of columns in the gradient matrix. + */ + void Initialize(const size_t rows, + const size_t cols) + { + m = arma::zeros(rows, cols); + u = arma::zeros(rows, cols); + } + + /** + * Update step for Adam. + * + * @param iterate Parameters that minimize the function. + * @param stepSize Step size to be used for the given iteration. + * @param gradient The gradient matrix. + */ + void Update(arma::mat& iterate, + const double stepSize, + const arma::mat& gradient) + { + // Increment the iteration counter variable. + ++iteration; + + // And update the iterate. + m *= beta1; + m += (1 - beta1) * gradient; + + // Update the exponentially weighted infinity norm. + u *= beta2; + u = arma::max(u, arma::abs(gradient)); + + const double biasCorrection1 = 1.0 - std::pow(beta1, iteration); + + if (biasCorrection1 != 0) + iterate -= (stepSize / biasCorrection1 * m / (u + epsilon)); + } + + //! Get the value used to initialise the squared gradient parameter. + double Epsilon() const { return epsilon; } + //! Modify the value used to initialise the squared gradient parameter. + double& Epsilon() { return epsilon; } + + //! Get the smoothing parameter. + double Beta1() const { return beta1; } + //! Modify the smoothing parameter. + double& Beta1() { return beta1; } + + //! Get the second moment coefficient. + double Beta2() const { return beta2; } + //! Modify the second moment coefficient. + double& Beta2() { return beta2; } + + private: + // The epsilon value used to initialise the squared gradient parameter. + double epsilon; + + // The smoothing parameter. + double beta1; + + // The second moment coefficient. + double beta2; + + // The exponential moving average of gradient values. + arma::mat m; + + // The exponentially weighted infinity norm. + arma::mat u; + + // The number of iterations. + double iteration; +}; + +} // namespace optimization +} // namespace mlpack + +#endif diff --git a/src/mlpack/core/optimizers/gradient_descent/gradient_descent.hpp b/src/mlpack/core/optimizers/gradient_descent/gradient_descent.hpp index 78ee1038130..79fb8dded78 100644 --- a/src/mlpack/core/optimizers/gradient_descent/gradient_descent.hpp +++ b/src/mlpack/core/optimizers/gradient_descent/gradient_descent.hpp @@ -66,9 +66,9 @@ class GradientDescent * @param tolerance Maximum absolute tolerance to terminate algorithm. */ GradientDescent(FunctionType& function, - const double stepSize = 0.01, - const size_t maxIterations = 100000, - const double tolerance = 1e-5); + const double stepSize = 0.01, + const size_t maxIterations = 100000, + const double tolerance = 1e-5); /** * Optimize the given function using gradient descent. The given starting diff --git a/src/mlpack/core/optimizers/lbfgs/lbfgs.hpp b/src/mlpack/core/optimizers/lbfgs/lbfgs.hpp index 422e6c7cc2f..8c6f64168fb 100644 --- a/src/mlpack/core/optimizers/lbfgs/lbfgs.hpp +++ b/src/mlpack/core/optimizers/lbfgs/lbfgs.hpp @@ -94,11 +94,15 @@ class L_BFGS * finishing point of the algorithm, and the final objective value is * returned. * + * This overload will be removed in mlpack 3.0.0---you should set + * maxIterations in the constructor instead. + * * @param iterate Starting point (will be modified). * @param maxIterations Maximum number of iterations (0 specifies no limit). * @return Objective value of the final point. */ - double Optimize(arma::mat& iterate, const size_t maxIterations); + mlpack_deprecated double Optimize(arma::mat& iterate, + const size_t maxIterations); //! Return the function that is being optimized. const FunctionType& Function() const { return function; } diff --git a/src/mlpack/core/optimizers/lbfgs/lbfgs_impl.hpp b/src/mlpack/core/optimizers/lbfgs/lbfgs_impl.hpp index f8e44a383b7..38fa92bd5b1 100644 --- a/src/mlpack/core/optimizers/lbfgs/lbfgs_impl.hpp +++ b/src/mlpack/core/optimizers/lbfgs/lbfgs_impl.hpp @@ -326,9 +326,11 @@ L_BFGS::MinPointIterate() const } template -inline double L_BFGS::Optimize(arma::mat& iterate) +inline double L_BFGS::Optimize(arma::mat& iterate, + const size_t maxIterations) { - return Optimize(iterate, maxIterations); + this->maxIterations = maxIterations; + return Optimize(iterate); } /** @@ -341,8 +343,7 @@ inline double L_BFGS::Optimize(arma::mat& iterate) * @param iterate Starting point (will be modified) */ template -double L_BFGS::Optimize(arma::mat& iterate, - const size_t maxIterations) +double L_BFGS::Optimize(arma::mat& iterate) { // Ensure that the cubes holding past iterations' information are the right // size. Also set the current best point value to the maximum. diff --git a/src/mlpack/core/optimizers/rmsprop/CMakeLists.txt b/src/mlpack/core/optimizers/rmsprop/CMakeLists.txt index 75c30c67bb9..da3c7681030 100644 --- a/src/mlpack/core/optimizers/rmsprop/CMakeLists.txt +++ b/src/mlpack/core/optimizers/rmsprop/CMakeLists.txt @@ -1,6 +1,7 @@ set(SOURCES rmsprop.hpp rmsprop_impl.hpp + rmsprop_update.hpp ) set(DIR_SRCS) diff --git a/src/mlpack/core/optimizers/rmsprop/rmsprop.hpp b/src/mlpack/core/optimizers/rmsprop/rmsprop.hpp index 58969e5fcb5..c3da802d701 100644 --- a/src/mlpack/core/optimizers/rmsprop/rmsprop.hpp +++ b/src/mlpack/core/optimizers/rmsprop/rmsprop.hpp @@ -2,8 +2,9 @@ * @file rmsprop.hpp * @author Ryan Curtin * @author Marcus Edel + * @author Vivek Pal * - * RMSprop optimizer. RmsProp is an optimizer that utilizes the magnitude of + * RMSProp optimizer. RMSProp is an optimizer that utilizes the magnitude of * recent gradients to normalize the gradients. * * mlpack is free software; you may redistribute it and/or modify it under the @@ -16,11 +17,14 @@ #include +#include +#include "rmsprop_update.hpp" + namespace mlpack { namespace optimization { /** - * RMSprop is an optimizer that utilizes the magnitude of recent gradients to + * RMSProp is an optimizer that utilizes the magnitude of recent gradients to * normalize the gradients. In its basic form, given a step rate \f$ \gamma \f$ * and a decay term \f$ \alpha \f$ we perform the following updates: * @@ -34,13 +38,13 @@ namespace optimization { * * @code * @misc{tieleman2012, - * title={Lecture 6.5 - rmsprop, COURSERA: Neural Networks for Machine - * Learning}, - * year={2012} + * title = {Lecture 6.5 - rmsprop, COURSERA: Neural Networks for Machine + * Learning}, + * year = {2012} * } * @endcode * - * For RMSprop to work, a DecomposableFunctionType template parameter is + * For RMSProp to work, a DecomposableFunctionType template parameter is * required. This class must implement the following function: * * size_t NumFunctions(); @@ -61,11 +65,11 @@ namespace optimization { * minimized. */ template -class RMSprop +class RMSProp { public: /** - * Construct the RMSprop optimizer with the given function and parameters. The + * Construct the RMSProp optimizer with the given function and parameters. The * defaults here are not necessarily good for the given problem, so it is * suggested that the values used be tailored to the task at hand. The * maximum number of iterations refers to the maximum number of points that @@ -76,88 +80,72 @@ class RMSprop * @param stepSize Step size for each iteration. * @param alpha Smoothing constant, similar to that used in AdaDelta and * momentum methods. - * @param eps Value used to initialise the mean squared gradient parameter. + * @param epsilon Value used to initialise the mean squared gradient parameter. * @param maxIterations Maximum number of iterations allowed (0 means no * limit). * @param tolerance Maximum absolute tolerance to terminate algorithm. * @param shuffle If true, the function order is shuffled; otherwise, each * function is visited in linear order. */ - RMSprop(DecomposableFunctionType& function, - const double stepSize = 0.01, - const double alpha = 0.99, - const double eps = 1e-8, - const size_t maxIterations = 100000, - const double tolerance = 1e-5, - const bool shuffle = true); + RMSProp(DecomposableFunctionType& function, + const double stepSize = 0.01, + const double alpha = 0.99, + const double epsilon = 1e-8, + const size_t maxIterations = 100000, + const double tolerance = 1e-5, + const bool shuffle = true); /** - * Optimize the given function using RMSprop. The given starting point will be + * Optimize the given function using RMSProp. The given starting point will be * modified to store the finishing point of the algorithm, and the final * objective value is returned. * * @param iterate Starting point (will be modified). * @return Objective value of the final point. */ - double Optimize(arma::mat& iterate); + double Optimize(arma::mat& iterate) { return optimizer.Optimize(iterate); } //! Get the instantiated function to be optimized. - const DecomposableFunctionType& Function() const { return function; } + const DecomposableFunctionType& Function() const + { + return optimizer.Function(); + } //! Modify the instantiated function. - DecomposableFunctionType& Function() { return function; } + DecomposableFunctionType& Function() { return optimizer.Function(); } //! Get the step size. - double StepSize() const { return stepSize; } + double StepSize() const { return optimizer.StepSize(); } //! Modify the step size. - double& StepSize() { return stepSize; } + double& StepSize() { return optimizer.StepSize(); } //! Get the smoothing parameter. - double Alpha() const { return alpha; } + double Alpha() const { return optimizer.UpdatePolicy().Alpha(); } //! Modify the smoothing parameter. - double& Alpha() { return alpha; } + double& Alpha() { return optimizer.UpdatePolicy().Alpha(); } //! Get the value used to initialise the mean squared gradient parameter. - double Epsilon() const { return eps; } + double Epsilon() const { return optimizer.UpdatePolicy().Epsilon(); } //! Modify the value used to initialise the mean squared gradient parameter. - double& Epsilon() { return eps; } + double& Epsilon() { return optimizer.UpdatePolicy().Epsilon(); } //! Get the maximum number of iterations (0 indicates no limit). - size_t MaxIterations() const { return maxIterations; } + size_t MaxIterations() const { return optimizer.MaxIterations(); } //! Modify the maximum number of iterations (0 indicates no limit). - size_t& MaxIterations() { return maxIterations; } + size_t& MaxIterations() { return optimizer.MaxIterations(); } //! Get the tolerance for termination. - double Tolerance() const { return tolerance; } + double Tolerance() const { return optimizer.Tolerance(); } //! Modify the tolerance for termination. - double& Tolerance() { return tolerance; } + double& Tolerance() { return optimizer.Tolerance(); } //! Get whether or not the individual functions are shuffled. - bool Shuffle() const { return shuffle; } + bool Shuffle() const { return optimizer.Shuffle(); } //! Modify whether or not the individual functions are shuffled. - bool& Shuffle() { return shuffle; } + bool& Shuffle() { return optimizer.Shuffle(); } private: - //! The instantiated function. - DecomposableFunctionType& function; - - //! The step size for each example. - double stepSize; - - //! The smoothing parameter. - double alpha; - - //! The value used to initialise the mean squared gradient parameter. - double eps; - - //! The maximum number of allowed iterations. - size_t maxIterations; - - //! The tolerance for termination. - double tolerance; - - //! Controls whether or not the individual functions are shuffled when - //! iterating. - bool shuffle; + //! The Stochastic Gradient Descent object with RMSPropUpdate policy. + SGD optimizer; }; } // namespace optimization diff --git a/src/mlpack/core/optimizers/rmsprop/rmsprop_impl.hpp b/src/mlpack/core/optimizers/rmsprop/rmsprop_impl.hpp index a06814b08f9..3dcd8606954 100644 --- a/src/mlpack/core/optimizers/rmsprop/rmsprop_impl.hpp +++ b/src/mlpack/core/optimizers/rmsprop/rmsprop_impl.hpp @@ -2,8 +2,9 @@ * @file rmsprop_impl.hpp * @author Ryan Curtin * @author Marcus Edel + * @author Vivek Pal * - * Implementation of the RMSprop optimizer. + * Implementation of the RMSProp constructor. * * mlpack is free software; you may redistribute it and/or modify it under the * terms of the 3-clause BSD license. You should have received a copy of the @@ -20,112 +21,22 @@ namespace mlpack { namespace optimization { template -RMSprop::RMSprop(DecomposableFunctionType& function, +RMSProp::RMSProp(DecomposableFunctionType& function, const double stepSize, const double alpha, - const double eps, + const double epsilon, const size_t maxIterations, const double tolerance, const bool shuffle) : - function(function), - stepSize(stepSize), - alpha(alpha), - eps(eps), - maxIterations(maxIterations), - tolerance(tolerance), - shuffle(shuffle) + optimizer(function, + stepSize, + maxIterations, + tolerance, + shuffle, + RMSPropUpdate(epsilon, + alpha)) { /* Nothing to do. */ } -//! Optimize the function (minimize). -template -double RMSprop::Optimize(arma::mat& iterate) -{ - // Find the number of functions to use. - const size_t numFunctions = function.NumFunctions(); - - // This is used only if shuffle is true. - arma::Col visitationOrder; - if (shuffle) - visitationOrder = arma::shuffle(arma::linspace>(0, - (numFunctions - 1), numFunctions)); - - // To keep track of where we are and how things are going. - size_t currentFunction = 0; - double overallObjective = 0; - double lastObjective = DBL_MAX; - - // Calculate the first objective function. - for (size_t i = 0; i < numFunctions; ++i) - overallObjective += function.Evaluate(iterate, i); - - // Now iterate! - arma::mat gradient(iterate.n_rows, iterate.n_cols); - - // Leaky sum of squares of parameter gradient. - arma::mat meanSquaredGradient = arma::zeros(iterate.n_rows, - iterate.n_cols); - - for (size_t i = 1; i != maxIterations; ++i, ++currentFunction) - { - // Is this iteration the start of a sequence? - if ((currentFunction % numFunctions) == 0) - { - // Output current objective function. - Log::Info << "RMSprop: iteration " << i << ", objective " - << overallObjective << "." << std::endl; - - if (std::isnan(overallObjective) || std::isinf(overallObjective)) - { - Log::Warn << "RMSprop: converged to " << overallObjective - << "; terminating with failure. Try a smaller step size?" - << std::endl; - return overallObjective; - } - - if (std::abs(lastObjective - overallObjective) < tolerance) - { - Log::Info << "RMSprop: minimized within tolerance " << tolerance << "; " - << "terminating optimization." << std::endl; - return overallObjective; - } - - // Reset the counter variables. - lastObjective = overallObjective; - overallObjective = 0; - currentFunction = 0; - - if (shuffle) // Determine order of visitation. - visitationOrder = arma::shuffle(visitationOrder); - } - - // Evaluate the gradient for this iteration. - if (shuffle) - function.Gradient(iterate, visitationOrder[currentFunction], gradient); - else - function.Gradient(iterate, currentFunction, gradient); - - // And update the iterate. - meanSquaredGradient *= alpha; - meanSquaredGradient += (1 - alpha) * (gradient % gradient); - iterate -= stepSize * gradient / (arma::sqrt(meanSquaredGradient) + eps); - - // Now add that to the overall objective function. - if (shuffle) - overallObjective += function.Evaluate(iterate, - visitationOrder[currentFunction]); - else - overallObjective += function.Evaluate(iterate, currentFunction); - } - - Log::Info << "RMSprop: maximum iterations (" << maxIterations << ") reached; " - << "terminating optimization." << std::endl; - // Calculate final objective. - overallObjective = 0; - for (size_t i = 0; i < numFunctions; ++i) - overallObjective += function.Evaluate(iterate, i); - return overallObjective; -} - } // namespace optimization } // namespace mlpack diff --git a/src/mlpack/core/optimizers/rmsprop/rmsprop_update.hpp b/src/mlpack/core/optimizers/rmsprop/rmsprop_update.hpp new file mode 100644 index 00000000000..c86f09f0c3c --- /dev/null +++ b/src/mlpack/core/optimizers/rmsprop/rmsprop_update.hpp @@ -0,0 +1,117 @@ +/** + * @file rmsprop_update.hpp + * @author Ryan Curtin + * @author Marcus Edel + * @author Vivek Pal + * + * RMSProp optimizer. RMSProp is an optimizer that utilizes the magnitude of + * recent gradients to normalize the gradients. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#ifndef MLPACK_CORE_OPTIMIZERS_RMSPROP_RMSPROP_UPDATE_HPP +#define MLPACK_CORE_OPTIMIZERS_RMSPROP_RMSPROP_UPDATE_HPP + +#include + +namespace mlpack { +namespace optimization { + +/** + * RMSProp is an optimizer that utilizes the magnitude of recent gradients to + * normalize the gradients. In its basic form, given a step rate \f$ \gamma \f$ + * and a decay term \f$ \alpha \f$ we perform the following updates: + * + * \f{eqnarray*}{ + * r_t &=& (1 - \gamma) f'(\Delta_t)^2 + \gamma r_{t - 1} \\ + * v_{t + 1} &=& \frac{\alpha}{\sqrt{r_t}}f'(\Delta_t) \\ + * \Delta_{t + 1} &=& \Delta_t - v_{t + 1} + * \f} + * + * For more information, see the following. + * + * @code + * @misc{tieleman2012, + * title = {Lecture 6.5 - rmsprop, COURSERA: Neural Networks for Machine + * Learning}, + * year = {2012} + * } + * @endcode + */ +class RMSPropUpdate +{ + public: + /** + * Construct the RMSProp update policy with the given parameters. + * + * @param epsilon The epsilon value used to initialise the squared gradient + * parameter. + * @param alpha The smoothing parameter. + */ + RMSPropUpdate(const double epsilon = 1e-8, + const double alpha = 0.99) : + epsilon(epsilon), + alpha(alpha) + { + // Nothing to do. + } + + /** + * The Initialize method is called by SGD Optimizer method before the start of + * the iteration update process. + * + * @param rows Number of rows in the gradient matrix. + * @param cols Number of columns in the gradient matrix. + */ + void Initialize(const size_t rows, + const size_t cols) + { + // Leaky sum of squares of parameter gradient. + meanSquaredGradient = arma::zeros(rows, cols); + } + + /** + * Update step for RMSProp. + * + * @param iterate Parameters that minimize the function. + * @param stepSize Step size to be used for the given iteration. + * @param gradient The gradient matrix. + */ + void Update(arma::mat& iterate, + const double stepSize, + const arma::mat& gradient) + { + meanSquaredGradient *= alpha; + meanSquaredGradient += (1 - alpha) * (gradient % gradient); + iterate -= stepSize * gradient / (arma::sqrt(meanSquaredGradient) + + epsilon); + } + + //! Get the value used to initialise the squared gradient parameter. + double Epsilon() const { return epsilon; } + //! Modify the value used to initialise the squared gradient parameter. + double& Epsilon() { return epsilon; } + + //! Get the smoothing parameter. + double Alpha() const { return alpha; } + //! Modify the smoothing parameter. + double& Alpha() { return alpha; } + + private: + // The epsilon value used to initialise the squared gradient parameter. + double epsilon; + + // The smoothing parameter. + double alpha; + + // Leaky sum of squares of parameter gradient. + arma::mat meanSquaredGradient; +}; + +} // namespace optimization +} // namespace mlpack + +#endif \ No newline at end of file diff --git a/src/mlpack/core/optimizers/sgd/update_policies/momentum_update.hpp b/src/mlpack/core/optimizers/sgd/update_policies/momentum_update.hpp index 2947c355f90..1ea08560349 100644 --- a/src/mlpack/core/optimizers/sgd/update_policies/momentum_update.hpp +++ b/src/mlpack/core/optimizers/sgd/update_policies/momentum_update.hpp @@ -42,23 +42,23 @@ namespace optimization { * * @code * @article{rumelhart1988learning, - * title={Learning representations by back-propagating errors}, - * author={Rumelhart, David E. and Hinton, Geoffrey E. and - * Williams, Ronald J.}, - * journal={Cognitive Modeling}, - * volume={5}, - * number={3}, - * pages={1}, - * year={1988} + * title = {Learning representations by back-propagating errors}, + * author = {Rumelhart, David E. and Hinton, Geoffrey E. and + * Williams, Ronald J.}, + * journal = {Cognitive Modeling}, + * volume = {5}, + * number = {3}, + * pages = {1}, + * year = {1988} * } * * @code * @book{Goodfellow-et-al-2016, - * title={Deep Learning}, - * author={Ian Goodfellow and Yoshua Bengio and Aaron Courville}, - * publisher={MIT Press}, - * note={\url{http://www.deeplearningbook.org}}, - * year={2016} + * title = {Deep Learning}, + * author = {Ian Goodfellow and Yoshua Bengio and Aaron Courville}, + * publisher = {MIT Press}, + * note = {\url{http://www.deeplearningbook.org}}, + * year = {2016} * } */ class MomentumUpdate @@ -78,14 +78,14 @@ class MomentumUpdate * matrix is initialized to the zeros matrix with the same size as the * gradient matrix (see mlpack::optimization::SGD::Optimizer ) * - * @param n_rows number of rows in the gradient matrix. - * @param n_cols number of columns in the gradient matrix. + * @param rows Number of rows in the gradient matrix. + * @param cols Number of columns in the gradient matrix. */ - void Initialize(const size_t n_rows, - const size_t n_cols) + void Initialize(const size_t rows, + const size_t cols) { //Initialize am empty velocity matrix. - velocity = arma::zeros(n_rows, n_cols); + velocity = arma::zeros(rows, cols); } /** diff --git a/src/mlpack/core/optimizers/sgd/update_policies/vanilla_update.hpp b/src/mlpack/core/optimizers/sgd/update_policies/vanilla_update.hpp index 1bd85bf1c37..d76264864b7 100644 --- a/src/mlpack/core/optimizers/sgd/update_policies/vanilla_update.hpp +++ b/src/mlpack/core/optimizers/sgd/update_policies/vanilla_update.hpp @@ -36,10 +36,10 @@ class VanillaUpdate * the iteration update process. The vanilla update doesn't initialize * anything. * - * @param n_rows number of rows in the gradient matrix. - * @param n_cols number of columns in the gradient matrix. + * @param rows Number of rows in the gradient matrix. + * @param cols Number of columns in the gradient matrix. */ - void Initialize(const size_t /* n_rows */, const size_t /* n_cols */) + void Initialize(const size_t /* rows */, const size_t /* cols */) { /* Do nothing. */ } /** diff --git a/src/mlpack/core/optimizers/smorms3/CMakeLists.txt b/src/mlpack/core/optimizers/smorms3/CMakeLists.txt new file mode 100644 index 00000000000..e77beb68b1d --- /dev/null +++ b/src/mlpack/core/optimizers/smorms3/CMakeLists.txt @@ -0,0 +1,12 @@ +set(SOURCES + smorms3.hpp + smorms3_impl.hpp + smorms3_update.hpp +) + +set(DIR_SRCS) +foreach(file ${SOURCES}) + set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) +endforeach() + +set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE) diff --git a/src/mlpack/core/optimizers/smorms3/smorms3.hpp b/src/mlpack/core/optimizers/smorms3/smorms3.hpp new file mode 100644 index 00000000000..d8b67497079 --- /dev/null +++ b/src/mlpack/core/optimizers/smorms3/smorms3.hpp @@ -0,0 +1,145 @@ +/** + * @file smorms3.hpp + * @author Vivek Pal + * + * SMORMS3 i.e. squared mean over root mean squared cubed optimizer. It is a + * hybrid of RMSprop, which estimates a safe and optimal distance based on + * curvature and Yann LeCun’s method in "No more pesky learning rates". + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#ifndef MLPACK_CORE_OPTIMIZERS_SMORMS3_SMORMS3_HPP +#define MLPACK_CORE_OPTIMIZERS_SMORMS3_SMORMS3_HPP + +#include +#include + +#include "smorms3_update.hpp" + +namespace mlpack { +namespace optimization { + +/** + * SMORMS3 is an optimizer that estimates a safe and optimal distance based on + * curvature and normalizing the stepsize in the parameter space. It is a hybrid + * of RMSprop and Yann LeCun’s method in "No more pesky learning rates". + * + * For more information, see the following. + * + * @code + * @misc{Funk2015, + * author = {Simon Funk}, + * title = {RMSprop loses to SMORMS3 - Beware the Epsilon!}, + * year = {2015} + * url = {http://sifter.org/~simon/journal/20150420.html} + * } + * @endcode + * + * + * For SMORMS3 to work, a DecomposableFunctionType template parameter is + * required. This class must implement the following function: + * + * size_t NumFunctions(); + * double Evaluate(const arma::mat& coordinates, const size_t i); + * void Gradient(const arma::mat& coordinates, + * const size_t i, + * arma::mat& gradient); + * + * NumFunctions() should return the number of functions (\f$n\f$), and in the + * other two functions, the parameter i refers to which individual function (or + * gradient) is being evaluated. So, for the case of a data-dependent function, + * such as NCA (see mlpack::nca::NCA), NumFunctions() should return the number + * of points in the dataset, and Evaluate(coordinates, 0) will evaluate the + * objective function on the first point in the dataset (presumably, the dataset + * is held internally in the DecomposableFunctionType). + * + * @tparam DecomposableFunctionType Decomposable objective function type to be + * minimized. + */ +template +class SMORMS3 +{ + public: + /** + * Construct the SMORMS3 optimizer with the given function and parameters. The + * defaults here are not necessarily good for the given problem, so it is + * suggested that the values used be tailored to the task at hand. The + * maximum number of iterations refers to the maximum number of points that + * are processed (i.e., one iteration equals one point; one iteration does not + * equal one pass over the dataset). + * + * @param function Function to be optimized (minimized). + * @param stepSize Step size for each iteration. + * @param epsilon Value used to initialise the mean squared gradient + * parameter. + * @param maxIterations Maximum number of iterations allowed (0 means no + * limit). + * @param tolerance Maximum absolute tolerance to terminate algorithm. + * @param shuffle If true, the function order is shuffled; otherwise, each + * function is visited in linear order. + */ + SMORMS3(DecomposableFunctionType& function, + const double stepSize = 0.001, + const double epsilon = 1e-16, + const size_t maxIterations = 100000, + const double tolerance = 1e-5, + const bool shuffle = true); + + /** + * Optimize the given function using SMORMS3. The given starting point will + * be modified to store the finishing point of the algorithm, and the final + * objective value is returned. + * + * @param iterate Starting point (will be modified). + * @return Objective value of the final point. + */ + double Optimize(arma::mat& iterate) { return optimizer.Optimize(iterate); } + + //! Get the instantiated function to be optimized. + const DecomposableFunctionType& Function() const + { + return optimizer.Function(); + } + //! Modify the instantiated function. + DecomposableFunctionType& Function() { return optimizer.Function(); } + + //! Get the step size. + double StepSize() const { return optimizer.StepSize(); } + //! Modify the step size. + double& StepSize() { return optimizer.StepSize(); } + + //! Get the value used to initialise the mean squared gradient parameter. + double Epsilon() const { return optimizer.UpdatePolicy().Epsilon(); } + //! Modify the value used to initialise the mean squared gradient parameter. + double& Epsilon() { return optimizer.UpdatePolicy().Epsilon(); } + + //! Get the maximum number of iterations (0 indicates no limit). + size_t MaxIterations() const { return optimizer.MaxIterations(); } + //! Modify the maximum number of iterations (0 indicates no limit). + size_t& MaxIterations() { return optimizer.MaxIterations(); } + + //! Get the tolerance for termination. + double Tolerance() const { return optimizer.Tolerance(); } + //! Modify the tolerance for termination. + double& Tolerance() { return optimizer.Tolerance(); } + + //! Get whether or not the individual functions are shuffled. + bool Shuffle() const { return optimizer.Shuffle(); } + //! Modify whether or not the individual functions are shuffled. + bool& Shuffle() { return optimizer.Shuffle(); } + + private: + //! The Stochastic Gradient Descent object with SMORMS3Update update policy. + SGD optimizer; +}; + +} // namespace optimization +} // namespace mlpack + +// Include implementation. +#include "smorms3_impl.hpp" + +#endif diff --git a/src/mlpack/core/optimizers/smorms3/smorms3_impl.hpp b/src/mlpack/core/optimizers/smorms3/smorms3_impl.hpp new file mode 100644 index 00000000000..68ced5913aa --- /dev/null +++ b/src/mlpack/core/optimizers/smorms3/smorms3_impl.hpp @@ -0,0 +1,39 @@ +/** + * @file smorms3_impl.hpp + * @author Vivek Pal + * + * Implementation of the SMORMS3 constructor. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#ifndef MLPACK_CORE_OPTIMIZERS_SMORMS3_SMORMS3_IMPL_HPP +#define MLPACK_CORE_OPTIMIZERS_SMORMS3_SMORMS3_IMPL_HPP + +// In case it hasn't been included yet. +#include "smorms3.hpp" + +namespace mlpack { +namespace optimization { + +template +SMORMS3::SMORMS3(DecomposableFunctionType& function, + const double stepSize, + const double epsilon, + const size_t maxIterations, + const double tolerance, + const bool shuffle) : + optimizer(function, + stepSize, + maxIterations, + tolerance, + shuffle, + SMORMS3Update(epsilon)) +{ /* Nothing to do. */ } + +} // namespace optimization +} // namespace mlpack + +#endif diff --git a/src/mlpack/core/optimizers/smorms3/smorms3_update.hpp b/src/mlpack/core/optimizers/smorms3/smorms3_update.hpp new file mode 100644 index 00000000000..3cf35eb7df2 --- /dev/null +++ b/src/mlpack/core/optimizers/smorms3/smorms3_update.hpp @@ -0,0 +1,111 @@ +/** + * @file smorms3_update.hpp + * @author Vivek Pal + * + * SMORMS3 update for Stochastic Gradient Descent. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#ifndef MLPACK_CORE_OPTIMIZERS_SMORMS3_SMORMS3_UPDATE_HPP +#define MLPACK_CORE_OPTIMIZERS_SMORMS3_SMORMS3_UPDATE_HPP + +#include + +namespace mlpack { +namespace optimization { + +/** + * SMORMS3 is an optimizer that estimates a safe and optimal distance based on + * curvature and normalizing the stepsize in the parameter space. It is a hybrid + * of RMSprop and Yann LeCun’s method in "No more pesky learning rates". + * + * For more information, see the following. + * + * @code + * @misc{Funk2015, + * author = {Simon Funk}, + * title = {RMSprop loses to SMORMS3 - Beware the Epsilon!}, + * year = {2015} + * url = {http://sifter.org/~simon/journal/20150420.html} + * } + * @endcode + */ + +class SMORMS3Update +{ + public: + /** + * Construct the SMORMS3 update policy with given epsilon parameter. + * + * @param epsilon Value used to initialise the mean squared gradient + * parameter. + */ + SMORMS3Update(const double epsilon = 1e-16) : epsilon(epsilon) + { /* Do nothing. */ } + + /** + * The Initialize method is called by SGD::Optimize method with UpdatePolicy + * SMORMS3Update before the start of the iteration update process. + * + * @param rows Number of rows in the gradient matrix. + * @param cols Number of columns in the gradient matrix. + */ + void Initialize(const size_t rows, + const size_t cols) + { + // Initialise the parameters mem, g and g2. + mem = arma::ones(rows, cols); + g = arma::zeros(rows, cols); + g2 = arma::zeros(rows, cols); + } + + /** + * Update step for SMORMS3. + * + * @param iterate Parameter that minimizes the function. + * @param stepSize Step size to be used for the given iteration. + * @param gradient The gradient matrix. + */ + void Update(arma::mat& iterate, + const double stepSize, + const arma::mat& gradient) + { + // Update the iterate. + arma::mat r = 1 / (mem + 1); + + g = (1 - r) % g; + g += r % gradient; + + g2 = (1 - r) % g2; + g2 += r % (gradient % gradient); + + arma::mat x = (g % g) / (g2 + epsilon); + + x.transform( [stepSize](double &v) { return std::min(v, stepSize); } ); + + iterate -= gradient % x / (arma::sqrt(g2) + epsilon); + + mem %= (1 - x); + mem += 1; + } + + //! Get the value used to initialise the mean squared gradient parameter. + double Epsilon() const { return epsilon; } + //! Modify the value used to initialise the mean squared gradient parameter. + double& Epsilon() { return epsilon; } + + private: + //! The value used to initialise the mean squared gradient parameter. + double epsilon; + + // The parameters mem, g and g2. + arma::mat mem, g, g2; +}; + +} // namespace optimization +} // namespace mlpack + +#endif \ No newline at end of file diff --git a/src/mlpack/core/util/version.hpp b/src/mlpack/core/util/version.hpp index b3dea0c3a52..a921444257e 100644 --- a/src/mlpack/core/util/version.hpp +++ b/src/mlpack/core/util/version.hpp @@ -17,13 +17,13 @@ // The version of mlpack. If this is a git repository, this will be a version // with higher number than the most recent release. #define MLPACK_VERSION_MAJOR 2 -#define MLPACK_VERSION_MINOR 0 +#define MLPACK_VERSION_MINOR 2 #define MLPACK_VERSION_PATCH "x" // Reverse compatibility; these macros will be removed in future versions of // mlpack (3.0.0 and newer)! #define __MLPACK_VERSION_MAJOR 2 -#define __MLPACK_VERSION_MINOR 0 +#define __MLPACK_VERSION_MINOR 2 #define __MLPACK_VERSION_PATCH "x" // The name of the version (for use by --version). diff --git a/src/mlpack/methods/CMakeLists.txt b/src/mlpack/methods/CMakeLists.txt index b7bedde853d..e0ceb91c56d 100644 --- a/src/mlpack/methods/CMakeLists.txt +++ b/src/mlpack/methods/CMakeLists.txt @@ -21,6 +21,7 @@ set(DIRS approx_kfn amf ann + block_krylov_svd cf dbscan decision_stump diff --git a/src/mlpack/methods/ann/ffn.hpp b/src/mlpack/methods/ann/ffn.hpp index 2bd20cfa2fb..466900bbca0 100644 --- a/src/mlpack/methods/ann/ffn.hpp +++ b/src/mlpack/methods/ann/ffn.hpp @@ -93,7 +93,7 @@ class FFN */ template< template class OptimizerType = - mlpack::optimization::RMSprop, + mlpack::optimization::RMSProp, typename... OptimizerTypeArgs > void Train(const arma::mat& predictors, @@ -102,7 +102,7 @@ class FFN /** * Train the feedforward network on the given input data. By default, the - * RMSprop optimization algorithm is used, but others can be specified + * RMSProp optimization algorithm is used, but others can be specified * (such as mlpack::optimization::SGD). * * This will use the existing model parameters as a starting point for the @@ -114,7 +114,7 @@ class FFN * @param responses Outputs results from input training variables. */ template< - template class OptimizerType = mlpack::optimization::RMSprop + template class OptimizerType = mlpack::optimization::RMSProp > void Train(const arma::mat& predictors, const arma::mat& responses); @@ -204,7 +204,7 @@ class FFN void Gradient(); /** - * Reset the module infomration (weights/parameters). + * Reset the module information (weights/parameters). */ void ResetParameters(); diff --git a/src/mlpack/methods/ann/layer/add_merge_impl.hpp b/src/mlpack/methods/ann/layer/add_merge_impl.hpp index 0ef62c2d0ac..583a4c1bee6 100644 --- a/src/mlpack/methods/ann/layer/add_merge_impl.hpp +++ b/src/mlpack/methods/ann/layer/add_merge_impl.hpp @@ -14,7 +14,7 @@ #define MLPACK_METHODS_ANN_LAYER_ADD_MERGE_IMPL_HPP // In case it hasn't yet been included. -#include "add_merge_impl.hpp" +#include "add_merge.hpp" namespace mlpack { namespace ann /** Artificial Neural Network. */ { diff --git a/src/mlpack/methods/block_krylov_svd/CMakeLists.txt b/src/mlpack/methods/block_krylov_svd/CMakeLists.txt new file mode 100644 index 00000000000..6380befb28a --- /dev/null +++ b/src/mlpack/methods/block_krylov_svd/CMakeLists.txt @@ -0,0 +1,15 @@ +# Define the files we need to compile. +# Anything not in this list will not be compiled into mlpack. +set(SOURCES + randomized_block_krylov_svd.hpp + randomized_block_krylov_svd.cpp +) + +# Add directory name to sources. +set(DIR_SRCS) +foreach(file ${SOURCES}) + set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) +endforeach() +# Append sources (with directory name) to list of all mlpack sources (used at +# the parent scope). +set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE) diff --git a/src/mlpack/methods/block_krylov_svd/randomized_block_krylov_svd.cpp b/src/mlpack/methods/block_krylov_svd/randomized_block_krylov_svd.cpp new file mode 100644 index 00000000000..57e9a5d201b --- /dev/null +++ b/src/mlpack/methods/block_krylov_svd/randomized_block_krylov_svd.cpp @@ -0,0 +1,96 @@ +/** + * @file randomized_block_krylov_svd.cpp + * @author Marcus Edel + * + * Implementation of the randomized block krylov SVD method. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ + +#include "randomized_block_krylov_svd.hpp" + +namespace mlpack { +namespace svd { + +RandomizedBlockKrylovSVD::RandomizedBlockKrylovSVD(const arma::mat& data, + arma::mat& u, + arma::vec& s, + arma::mat& v, + const size_t maxIterations, + const size_t rank, + const size_t blockSize) : + maxIterations(maxIterations), + blockSize(blockSize) +{ + if (rank == 0) + { + Apply(data, u, s, v, data.n_rows); + } + else + { + Apply(data, u, s, v, rank); + } +} + +RandomizedBlockKrylovSVD::RandomizedBlockKrylovSVD(const size_t maxIterations, + const size_t blockSize) : + maxIterations(maxIterations), + blockSize(blockSize) +{ + /* Nothing to do here */ +} + +void RandomizedBlockKrylovSVD::Apply(const arma::mat& data, + arma::mat& u, + arma::vec& s, + arma::mat& v, + const size_t rank) +{ + arma::mat Q, R, block, blockIteration; + + if (blockSize == 0) + { + blockSize = rank + 10; + } + + // Random block initialization. + arma::mat G = arma::randn(data.n_cols, blockSize); + + // Construct and orthonormalize Krylov subspace. + arma::mat K(data.n_rows, blockSize * (maxIterations + 1)); + + // Create a working matrix using data from writable auxiliary memory + // (K matrix). Doing so avoids an uncessary copy in upcoming step. + block = arma::mat(K.memptr(), data.n_rows, blockSize, false, false); + arma::qr_econ(block, R, data * G); + + for (size_t blockOffset = block.n_elem; blockOffset < K.n_elem; + blockOffset += block.n_elem) + { + // Temporary working matrix to store the result in the correct place. + blockIteration = arma::mat(K.memptr() + blockOffset, block.n_rows, + block.n_cols, false, false); + + arma::qr_econ(blockIteration, R, data * (data.t() * block)); + + // Update working matrix for the next iteration. + block = arma::mat(K.memptr() + blockOffset, block.n_rows, block.n_cols, + false, false); + } + + arma::qr_econ(Q, R, K); + + // Approximate eigenvalues and eigenvectors using Rayleigh–Ritz method. + arma::svd_econ(u, s, v, Q.t() * data); + + // Do economical singular value decomposition and compute only the + // approximations of the left singular vectors by using the centered data + // applied to Q. + u = Q * u; +} + +} // namespace svd +} // namespace mlpack diff --git a/src/mlpack/methods/block_krylov_svd/randomized_block_krylov_svd.hpp b/src/mlpack/methods/block_krylov_svd/randomized_block_krylov_svd.hpp new file mode 100644 index 00000000000..06ef8b4c2b6 --- /dev/null +++ b/src/mlpack/methods/block_krylov_svd/randomized_block_krylov_svd.hpp @@ -0,0 +1,128 @@ +/** + * @file randomized_block_krylov_svd.hpp + * @author Marcus Edel + * + * An implementation of the randomized block krylov SVD method. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ + +#ifndef MLPACK_METHODS_BLOCK_KRYLOV_SVD_RANDOMIZED_BLOCK_KRYLOV_SVD_HPP +#define MLPACK_METHODS_BLOCK_KRYLOV_SVD_RANDOMIZED_BLOCK_KRYLOV_SVD_HPP + +#include + +namespace mlpack { +namespace svd { + +/** + * Randomized block krylov SVD is a matrix factorization that is based on + * randomized matrix approximation techniques, developed in in + * "Randomized Block Krylov Methods for Stronger and Faster Approximate + * Singular Value Decomposition". + * + * For more information, see the following. + * + * @code + * @inproceedings{Musco2015, + * author = {Cameron Musco and Christopher Musco}, + * title = {Randomized Block Krylov Methods for Stronger and Faster + * Approximate Singular Value Decomposition}, + * booktitle = {Advances in Neural Information Processing Systems 28: Annual + * Conference on Neural Information Processing Systems 2015, + * December 7-12, 2015, Montreal, Quebec, Canada}, + * pages = {1396--1404}, + * year = {2015}, + * } + * @endcode + * + * An example of how to use the interface is shown below: + * + * @code + * arma::mat data; // Rating data in the form of coordinate list. + * + * const size_t rank = 20; // Rank used for the decomposition. + * + * // Make a RandomizedBlockKrylovSVD object. + * RandomizedBlockKrylovSVD bSVD(); + * + * arma::mat u, s, v; + * + * // Use the Apply() method to get a factorization. + * bSVD.Apply(data, u, s, v, rank); + * @endcode + */ +class RandomizedBlockKrylovSVD +{ + public: + /** + * Create object for the randomized block krylov SVD method. + * + * @param data Data matrix. + * @param u First unitary matrix. + * @param v Second unitary matrix. + * @param s Diagonal matrix of singular values. + * @param maxIterations Number of iterations for the power method + * (Default: 2). + * @param rank Rank of the approximation (Default: number of rows.) + * @param blockSize The block size, must be >= rank (Default: rank + 10). + */ + RandomizedBlockKrylovSVD(const arma::mat& data, + arma::mat& u, + arma::vec& s, + arma::mat& v, + const size_t maxIterations = 2, + const size_t rank = 0, + const size_t blockSize = 0); + + /** + * Create object for the randomized block krylov SVD method. + * + * @param maxIterations Number of iterations for the power method + * (Default: 2). + * @param blockSize The block size, must be >= rank (Default: rank + 10). + */ + RandomizedBlockKrylovSVD(const size_t maxIterations = 2, + const size_t blockSize = 0); + + /** + * Apply Principal Component Analysis to the provided data set using the + * randomized block krylov SVD. + * + * @param data Data matrix. + * @param u First unitary matrix. + * @param v Second unitary matrix. + * @param s Diagonal matrix of singular values. + * @param rank Rank of the approximation. + */ + void Apply(const arma::mat& data, + arma::mat& u, + arma::vec& s, + arma::mat& v, + const size_t rank); + + //! Get the number of iterations for the power method. + size_t MaxIterations() const { return maxIterations; } + //! Modify the number of iterations for the power method. + size_t& MaxIterations() { return maxIterations; } + + //! Get the block size. + size_t BlockSize() const { return blockSize; } + //! Modify the block size. + size_t& BlockSize() { return blockSize; } + + private: + //! Locally stored number of iterations for the power method. + size_t maxIterations; + + //! The block size value. + size_t blockSize; +}; + +} // namespace svd +} // namespace mlpack + +#endif diff --git a/src/mlpack/methods/emst/dtb_impl.hpp b/src/mlpack/methods/emst/dtb_impl.hpp index 1d87ea73f19..60320757413 100644 --- a/src/mlpack/methods/emst/dtb_impl.hpp +++ b/src/mlpack/methods/emst/dtb_impl.hpp @@ -18,27 +18,25 @@ namespace mlpack { namespace emst { //! Call the tree constructor that does mapping. -template +template TreeType* BuildTree( - MatType& dataset, + MatType&& dataset, std::vector& oldFromNew, - const typename std::enable_if_t< - tree::TreeTraits::RearrangesDataset, TreeType - >* = 0) + const typename std::enable_if< + tree::TreeTraits::RearrangesDataset>::type* = 0) { - return new TreeType(dataset, oldFromNew); + return new TreeType(std::forward(dataset), oldFromNew); } //! Call the tree constructor that does not do mapping. -template +template TreeType* BuildTree( - const MatType& dataset, + MatType&& dataset, const std::vector& /* oldFromNew */, - const typename std::enable_if_t< - tree::TreeTraits::RearrangesDataset == false, TreeType - >* = 0) + const typename std::enable_if< + !tree::TreeTraits::RearrangesDataset>::type* = 0) { - return new TreeType(dataset); + return new TreeType(std::forward(dataset)); } /** @@ -55,8 +53,7 @@ DualTreeBoruvka::DualTreeBoruvka( const MatType& dataset, const bool naive, const MetricType metric) : - tree(naive ? NULL : BuildTree(const_cast(dataset), - oldFromNew)), + tree(naive ? NULL : BuildTree(dataset, oldFromNew)), data(naive ? dataset : tree->Dataset()), ownTree(!naive), naive(naive), diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_model.cpp b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_model.cpp index 0599a4ce08c..9b35b531529 100644 --- a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_model.cpp +++ b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_model.cpp @@ -68,13 +68,13 @@ HoeffdingTreeModel& HoeffdingTreeModel::operator=( // Create the right tree. type = other.type; - if (type == GINI_HOEFFDING) + if (other.giniHoeffdingTree && (type == GINI_HOEFFDING)) giniHoeffdingTree = new GiniHoeffdingTreeType(*other.giniHoeffdingTree); - else if (type == GINI_BINARY) + else if (other.giniBinaryTree && (type == GINI_BINARY)) giniBinaryTree = new GiniBinaryTreeType(*other.giniBinaryTree); - else if (type == INFO_HOEFFDING) + else if (other.infoHoeffdingTree && (type == INFO_HOEFFDING)) infoHoeffdingTree = new InfoHoeffdingTreeType(*other.infoHoeffdingTree); - else if (type == INFO_BINARY) + else if (other.infoBinaryTree && (type == INFO_BINARY)) infoBinaryTree = new InfoBinaryTreeType(*other.infoBinaryTree); return *this; diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp index 27c54d53a27..66b0e6519d1 100644 --- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp +++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp @@ -24,29 +24,27 @@ namespace mlpack { namespace kmeans { //! Call the tree constructor that does mapping. -template +template TreeType* BuildTree( - const typename TreeType::Mat& dataset, + MatType&& dataset, std::vector& oldFromNew, - const typename std::enable_if_t< - tree::TreeTraits::RearrangesDataset, TreeType - >* = 0) + const typename std::enable_if< + tree::TreeTraits::RearrangesDataset>::type* = 0) { // This is a hack. I know this will be BinarySpaceTree, so force a leaf size // of two. - return new TreeType(dataset, oldFromNew, 1); + return new TreeType(std::forward(dataset), oldFromNew, 1); } //! Call the tree constructor that does not do mapping. -template +template TreeType* BuildTree( - const typename TreeType::Mat& dataset, + MatType&& dataset, const std::vector& /* oldFromNew */, - const typename std::enable_if_t< - !tree::TreeTraits::RearrangesDataset, TreeType - >* = 0) + const typename std::enable_if< + !tree::TreeTraits::RearrangesDataset>::type* = 0) { - return new TreeType(dataset); + return new TreeType(std::forward(dataset)); } template class OptimizerType = mlpack::optimization::L_BFGS + template class OptimizerType = mlpack::optimization::L_BFGS > void Train(const MatType& predictors, const arma::Row& responses); diff --git a/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp b/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp index 4974c975a59..94945258474 100644 --- a/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp +++ b/src/mlpack/methods/logistic_regression/logistic_regression_impl.hpp @@ -66,7 +66,7 @@ LogisticRegression::LogisticRegression( } template -template class OptimizerType> +template class OptimizerType> void LogisticRegression::Train(const MatType& predictors, const arma::Row& responses) { diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp index 3218659142b..ec814581b54 100644 --- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp +++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp @@ -22,31 +22,7 @@ namespace mlpack { namespace neighbor { //! Call the tree constructor that does mapping. -template -TreeType* BuildTree( - const MatType& dataset, - std::vector& oldFromNew, - typename std::enable_if_t< - tree::TreeTraits::RearrangesDataset, TreeType - >* = 0) -{ - return new TreeType(dataset, oldFromNew); -} - -//! Call the tree constructor that does not do mapping. -template -TreeType* BuildTree( - const MatType& dataset, - const std::vector& /* oldFromNew */, - const typename std::enable_if_t< - !tree::TreeTraits::RearrangesDataset, TreeType - >* = 0) -{ - return new TreeType(dataset); -} - -//! Call the tree construct that does mapping. -template +template TreeType* BuildTree( MatType&& dataset, std::vector& oldFromNew, @@ -54,19 +30,19 @@ TreeType* BuildTree( tree::TreeTraits::RearrangesDataset, TreeType >* = 0) { - return new TreeType(std::move(dataset), oldFromNew); + return new TreeType(std::forward(dataset), oldFromNew); } //! Call the tree constructor that does not do mapping. -template +template TreeType* BuildTree( MatType&& dataset, - std::vector& /* oldFromNew */, - typename std::enable_if_t< + const std::vector& /* oldFromNew */, + const typename std::enable_if_t< !tree::TreeTraits::RearrangesDataset, TreeType >* = 0) { - return new TreeType(std::move(dataset)); + return new TreeType(std::forward(dataset)); } // Construct the object. @@ -84,7 +60,7 @@ SingleTreeTraversalType>::NeighborSearch(const MatType& referenceSetIn, const double epsilon, const MetricType metric) : referenceTree(mode == NAIVE_MODE ? NULL : - BuildTree(referenceSetIn, oldFromNewReferences)), + BuildTree(referenceSetIn, oldFromNewReferences)), referenceSet(mode == NAIVE_MODE ? &referenceSetIn : &referenceTree->Dataset()), treeOwner(mode != NAIVE_MODE), @@ -115,9 +91,8 @@ SingleTreeTraversalType>::NeighborSearch(MatType&& referenceSetIn, const double epsilon, const MetricType metric) : referenceTree(mode == NAIVE_MODE ? NULL : - BuildTree(std::move(referenceSetIn), - oldFromNewReferences)), - referenceSet(mode == NAIVE_MODE ? new MatType(std::move(referenceSetIn)) : + BuildTree(std::move(referenceSetIn), oldFromNewReferences)), + referenceSet(mode == NAIVE_MODE ? new MatType(std::move(referenceSetIn)) : &referenceTree->Dataset()), treeOwner(mode != NAIVE_MODE), setOwner(mode == NAIVE_MODE), @@ -220,8 +195,7 @@ SingleTreeTraversalType>::NeighborSearch(const NeighborSearchMode mode, // Build the tree on the empty dataset, if necessary. if (mode != NAIVE_MODE) { - referenceTree = BuildTree(*referenceSet, - oldFromNewReferences); + referenceTree = BuildTree(*referenceSet, oldFromNewReferences); treeOwner = true; } } @@ -278,7 +252,7 @@ SingleTreeTraversalType>::NeighborSearch(NeighborSearch&& other) : { // Clear the other model. other.referenceSet = new MatType(); - other.referenceTree = BuildTree(*other.referenceSet, + other.referenceTree = BuildTree(*other.referenceSet, other.oldFromNewReferences); other.treeOwner = true; other.setOwner = true; @@ -373,7 +347,7 @@ NeighborSearch(*other.referenceSet, + other.referenceTree = BuildTree(*other.referenceSet, other.oldFromNewReferences); other.treeOwner = true; other.setOwner = true; @@ -424,8 +398,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Train( // We may need to rebuild the tree. if (searchMode != NAIVE_MODE) { - referenceTree = BuildTree(referenceSet, - oldFromNewReferences); + referenceTree = BuildTree(referenceSet, oldFromNewReferences); treeOwner = true; } else @@ -465,7 +438,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Train(MatType&& referenceSetIn) // We may need to rebuild the tree. if (searchMode != NAIVE_MODE) { - referenceTree = BuildTree(std::move(referenceSetIn), + referenceTree = BuildTree(std::move(referenceSetIn), oldFromNewReferences); treeOwner = true; } @@ -656,7 +629,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search( // Build the query tree. Timer::Stop("computing_neighbors"); Timer::Start("tree_building"); - Tree* queryTree = BuildTree(querySet, oldFromNewQueries); + Tree* queryTree = BuildTree(querySet, oldFromNewQueries); Timer::Stop("tree_building"); Timer::Start("computing_neighbors"); diff --git a/src/mlpack/methods/pca/decomposition_policies/CMakeLists.txt b/src/mlpack/methods/pca/decomposition_policies/CMakeLists.txt index 968c7cc4bb9..85bbb7c305b 100644 --- a/src/mlpack/methods/pca/decomposition_policies/CMakeLists.txt +++ b/src/mlpack/methods/pca/decomposition_policies/CMakeLists.txt @@ -2,6 +2,7 @@ # Anything not in this list will not be compiled into mlpack. set(SOURCES exact_svd_method.hpp + randomized_block_krylov_method.hpp randomized_svd_method.hpp quic_svd_method.hpp ) diff --git a/src/mlpack/methods/pca/decomposition_policies/exact_svd_method.hpp b/src/mlpack/methods/pca/decomposition_policies/exact_svd_method.hpp index fe0fb0c9ffe..dd5fa1e987e 100644 --- a/src/mlpack/methods/pca/decomposition_policies/exact_svd_method.hpp +++ b/src/mlpack/methods/pca/decomposition_policies/exact_svd_method.hpp @@ -26,7 +26,7 @@ namespace pca { */ class ExactSVDPolicy { - public: + public: /** * Apply Principal Component Analysis to the provided data set using the * exact SVD method. diff --git a/src/mlpack/methods/pca/decomposition_policies/quic_svd_method.hpp b/src/mlpack/methods/pca/decomposition_policies/quic_svd_method.hpp index df18f0ba882..f3ecc2103d8 100644 --- a/src/mlpack/methods/pca/decomposition_policies/quic_svd_method.hpp +++ b/src/mlpack/methods/pca/decomposition_policies/quic_svd_method.hpp @@ -25,8 +25,7 @@ namespace pca { */ class QUICSVDPolicy { - public: - + public: /** * Use QUIC-SVD method to perform the principal components analysis (PCA). * @@ -83,12 +82,12 @@ class QUICSVDPolicy //! Modify the cumulative probability for Monte Carlo error lower bound. double& Delta() { return delta; } - private: - //! Error tolerance fraction for calculated subspace. - double epsilon; + private: + //! Error tolerance fraction for calculated subspace. + double epsilon; - //! Cumulative probability for Monte Carlo error lower bound. - double delta; + //! Cumulative probability for Monte Carlo error lower bound. + double delta; }; } // namespace pca diff --git a/src/mlpack/methods/pca/decomposition_policies/randomized_block_krylov_method.hpp b/src/mlpack/methods/pca/decomposition_policies/randomized_block_krylov_method.hpp new file mode 100644 index 00000000000..957cd3bbd61 --- /dev/null +++ b/src/mlpack/methods/pca/decomposition_policies/randomized_block_krylov_method.hpp @@ -0,0 +1,101 @@ +/** + * @file randomized_block_krylov_method.hpp + * @author Marcus Edel + * + * Implementation of the randomized block krylov SVD method for use in the + * Principal Components Analysis method. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ + +#ifndef MLPACK_METHODS_PCA_DECOMPOSITION_POLICIES_RANDOMIZED_BLOCK_KRYLOV_HPP +#define MLPACK_METHODS_PCA_DECOMPOSITION_POLICIES_RANDOMIZED_BLOCK_KRYLOV_HPP + +#include +#include + +namespace mlpack { +namespace pca { + +/** + * Implementation of the randomized block krylov SVD policy. + */ +class RandomizedBlockKrylovSVDPolicy +{ + public: + /** + * Use randomized block krylov SVD method to perform the principal components + * analysis (PCA). + * + * @param maxIterations Number of iterations for the power method + * (Default: 2). + * @param blockSize The block size, must be >= rank (Default: rank + 10). + */ + RandomizedBlockKrylovSVDPolicy(const size_t maxIterations = 2, + const size_t blockSize = 0) : + maxIterations(maxIterations), + blockSize(blockSize) + { + /* Nothing to do here */ + } + + /** + * Apply Principal Component Analysis to the provided data set using the + * randomized block krylov SVD method. + * + * @param data Data matrix. + * @param centeredData Centered data matrix. + * @param transformedData Matrix to put results of PCA into. + * @param eigVal Vector to put eigenvalues into. + * @param eigvec Matrix to put eigenvectors (loadings) into. + * @param rank Rank of the decomposition. + */ + void Apply(const arma::mat& data, + const arma::mat& centeredData, + arma::mat& transformedData, + arma::vec& eigVal, + arma::mat& eigvec, + const size_t rank) + { + // This matrix will store the right singular values; we do not need them. + arma::mat v; + + // Do singular value decomposition using the randomized block krylov SVD + // algorithm. + svd::RandomizedBlockKrylovSVD rsvd(maxIterations, blockSize); + rsvd.Apply(centeredData, eigvec, eigVal, v, rank); + + // Now we must square the singular values to get the eigenvalues. + // In addition we must divide by the number of points, because the + // covariance matrix is X * X' / (N - 1). + eigVal %= eigVal / (data.n_cols - 1); + + // Project the samples to the principals. + transformedData = arma::trans(eigvec) * centeredData; + } + + //! Get the number of iterations for the power method. + size_t MaxIterations() const { return maxIterations; } + //! Modify the number of iterations for the power method. + size_t& MaxIterations() { return maxIterations; } + + //! Get the block size. + size_t BlockSize() const { return blockSize; } + //! Modify the block size. + size_t& BlockSize() { return blockSize; } + + private: + //! Locally stored number of iterations for the power method. + size_t maxIterations; + + //! Locally stored block size value. + size_t blockSize; +}; + +} // namespace pca +} // namespace mlpack + +#endif diff --git a/src/mlpack/methods/pca/decomposition_policies/randomized_svd_method.hpp b/src/mlpack/methods/pca/decomposition_policies/randomized_svd_method.hpp index f7f9089b4c6..148dfb9e02a 100644 --- a/src/mlpack/methods/pca/decomposition_policies/randomized_svd_method.hpp +++ b/src/mlpack/methods/pca/decomposition_policies/randomized_svd_method.hpp @@ -16,7 +16,6 @@ #include #include -#include namespace mlpack { namespace pca { @@ -26,7 +25,7 @@ namespace pca { */ class RandomizedSVDPolicy { - public: + public: /** * Use randomized SVD method to perform the principal components analysis * (PCA). @@ -88,12 +87,12 @@ class RandomizedSVDPolicy //! Modify the number of iterations for the power method. size_t& MaxIterations() { return maxIterations; } - private: - //! Locally stored size of the normalized power iterations. - size_t iteratedPower; + private: + //! Locally stored size of the normalized power iterations. + size_t iteratedPower; - //! Locally stored number of iterations for the power method. - size_t maxIterations; + //! Locally stored number of iterations for the power method. + size_t maxIterations; }; } // namespace pca diff --git a/src/mlpack/methods/range_search/range_search_impl.hpp b/src/mlpack/methods/range_search/range_search_impl.hpp index 5c6571a7997..f1a813c1668 100644 --- a/src/mlpack/methods/range_search/range_search_impl.hpp +++ b/src/mlpack/methods/range_search/range_search_impl.hpp @@ -21,49 +21,25 @@ namespace mlpack { namespace range { -template +template TreeType* BuildTree( - typename TreeType::Mat& dataset, + MatType&& dataset, std::vector& oldFromNew, - typename std::enable_if_t< - tree::TreeTraits::RearrangesDataset, TreeType - >* = 0) + const typename std::enable_if< + tree::TreeTraits::RearrangesDataset>::type* = 0) { - return new TreeType(dataset, oldFromNew); + return new TreeType(std::forward(dataset), oldFromNew); } //! Call the tree constructor that does not do mapping. -template +template TreeType* BuildTree( - const typename TreeType::Mat& dataset, + MatType&& dataset, const std::vector& /* oldFromNew */, - const typename std::enable_if_t< - !tree::TreeTraits::RearrangesDataset, TreeType - >* = 0) + const typename std::enable_if< + !tree::TreeTraits::RearrangesDataset>::type* = 0) { - return new TreeType(dataset); -} - -template -TreeType* BuildTree( - typename TreeType::Mat&& dataset, - std::vector& oldFromNew, - const typename std::enable_if_t< - tree::TreeTraits::RearrangesDataset, TreeType - >* = 0) -{ - return new TreeType(std::move(dataset), oldFromNew); -} - -template -TreeType* BuildTree( - typename TreeType::Mat&& dataset, - const std::vector& /* oldFromNew */, - const typename std::enable_if_t< - !tree::TreeTraits::RearrangesDataset, TreeType - >* = 0) -{ - return new TreeType(std::move(dataset)); + return new TreeType(std::forward(dataset)); } template::RangeSearch( const bool naive, const bool singleMode, const MetricType metric) : - referenceTree(naive ? NULL : BuildTree( - const_cast(referenceSetIn), oldFromNewReferences)), + referenceTree(naive ? NULL : BuildTree(referenceSetIn, + oldFromNewReferences)), referenceSet(naive ? &referenceSetIn : &referenceTree->Dataset()), treeOwner(!naive), // If in naive mode, we are not building any trees. setOwner(false), @@ -497,8 +473,7 @@ void RangeSearch::Search( // Build the query tree. Timer::Stop("range_search/computing_neighbors"); Timer::Start("range_search/tree_building"); - Tree* queryTree = BuildTree(const_cast(querySet), - oldFromNewQueries); + Tree* queryTree = BuildTree(querySet, oldFromNewQueries); Timer::Stop("range_search/tree_building"); Timer::Start("range_search/computing_neighbors"); diff --git a/src/mlpack/methods/rann/ra_search_impl.hpp b/src/mlpack/methods/rann/ra_search_impl.hpp index fca31657ad3..23984132fd9 100644 --- a/src/mlpack/methods/rann/ra_search_impl.hpp +++ b/src/mlpack/methods/rann/ra_search_impl.hpp @@ -23,51 +23,25 @@ namespace neighbor { namespace aux { //! Call the tree constructor that does mapping. -template +template TreeType* BuildTree( - const typename TreeType::Mat& dataset, + MatType&& dataset, std::vector& oldFromNew, - typename std::enable_if_t< - tree::TreeTraits::RearrangesDataset, TreeType - >* = 0) + typename std::enable_if< + tree::TreeTraits::RearrangesDataset>::type* = 0) { - return new TreeType(dataset, oldFromNew); + return new TreeType(std::forward(dataset), oldFromNew); } //! Call the tree constructor that does not do mapping. -template +template TreeType* BuildTree( - const typename TreeType::Mat& dataset, + MatType&& dataset, const std::vector& /* oldFromNew */, - const typename std::enable_if_t< - !tree::TreeTraits::RearrangesDataset, TreeType - >* = 0) + const typename std::enable_if< + !tree::TreeTraits::RearrangesDataset>::type* = 0) { - return new TreeType(dataset); -} - -//! Call the tree constructor that does mapping. -template -TreeType* BuildTree( - typename TreeType::Mat&& dataset, - std::vector& oldFromNew, - typename std::enable_if_t< - tree::TreeTraits::RearrangesDataset, TreeType - >* = 0) -{ - return new TreeType(std::move(dataset), oldFromNew); -} - -//! Call the tree constructor that does not do mapping. -template -TreeType* BuildTree( - typename TreeType::Mat&& dataset, - const std::vector& /* oldFromNew */, - const typename std::enable_if_t< - !tree::TreeTraits::RearrangesDataset, TreeType - >* = 0) -{ - return new TreeType(std::move(dataset)); + return new TreeType(std::forward(dataset)); } } // namespace aux diff --git a/src/mlpack/prereqs.hpp b/src/mlpack/prereqs.hpp index e46fa859ab6..42a49d637cb 100644 --- a/src/mlpack/prereqs.hpp +++ b/src/mlpack/prereqs.hpp @@ -34,6 +34,7 @@ #include #include #include +#include // But if it's not defined, we'll do it. #ifndef M_PI diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt index 8f18594a9dc..475764580ef 100644 --- a/src/mlpack/tests/CMakeLists.txt +++ b/src/mlpack/tests/CMakeLists.txt @@ -12,6 +12,7 @@ add_executable(mlpack_test armadillo_svd_test.cpp aug_lagrangian_test.cpp binarize_test.cpp + block_krylov_svd_test.cpp cf_test.cpp cli_test.cpp convolution_test.cpp @@ -82,6 +83,7 @@ add_executable(mlpack_test serialization.hpp serialization.cpp serialization_test.cpp + smorms3_test.cpp softmax_regression_test.cpp sort_policy_test.cpp sparse_autoencoder_test.cpp diff --git a/src/mlpack/tests/adam_test.cpp b/src/mlpack/tests/adam_test.cpp index 88d839fd5cf..a88e3618f8a 100644 --- a/src/mlpack/tests/adam_test.cpp +++ b/src/mlpack/tests/adam_test.cpp @@ -36,7 +36,8 @@ BOOST_AUTO_TEST_SUITE(AdamTest); BOOST_AUTO_TEST_CASE(SimpleAdamTestFunction) { SGDTestFunction f; - Adam optimizer(f, 1e-3, 0.9, 0.999, 1e-8, 5000000, 1e-9, true); + Adam optimizer(f, 1e-3, 0.9, 0.999, 1e-8, 5000000, 1e-9, + true); arma::mat coordinates = f.GetInitialPoint(); optimizer.Optimize(coordinates); @@ -52,8 +53,8 @@ BOOST_AUTO_TEST_CASE(SimpleAdamTestFunction) BOOST_AUTO_TEST_CASE(SimpleAdaMaxTestFunction) { SGDTestFunction f; - Adam optimizer(f, 2e-3, 0.9, 0.999, 1e-8, 5000000, 1e-9, true - ,true); + AdaMax optimizer(f, 2e-3, 0.9, 0.999, 1e-8, 5000000, 1e-9, + true); arma::mat coordinates = f.GetInitialPoint(); optimizer.Optimize(coordinates); @@ -174,8 +175,9 @@ BOOST_AUTO_TEST_CASE(AdaMaxLogisticRegressionTest) LogisticRegression<> lr(shuffledData.n_rows, 0.5); LogisticRegressionFunction<> lrf(shuffledData, shuffledResponses, 0.5); - Adam > adamax(lrf, 1e-3, 0.9, 0.999, 1e-8, 5000000, - 1e-9, true, true); + AdaMax > adamax(lrf, 1e-3, 0.9, 0.999, 1e-8, + 5000000, 1e-9, true); + lr.Train(adamax); // Ensure that the error is close to zero. diff --git a/src/mlpack/tests/block_krylov_svd_test.cpp b/src/mlpack/tests/block_krylov_svd_test.cpp new file mode 100644 index 00000000000..fa608596fff --- /dev/null +++ b/src/mlpack/tests/block_krylov_svd_test.cpp @@ -0,0 +1,112 @@ +/** + * @file block_krylov_svd_test.cpp + * @author Marcus Edel + * + * Test file for the Randomized Block Krylov SVD class. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ + +#include +#include + +#include +#include "test_tools.hpp" + +BOOST_AUTO_TEST_SUITE(BlockKrylovSVDTest); + +using namespace mlpack; + +// Generate a low rank matrix with bell-shaped singular values. +void CreateNoisyLowRankMatrix(arma::mat& data, + const size_t rows, + const size_t cols, + const size_t rank, + const double strength) +{ + arma::mat R, U, V; + const size_t n = std::min(rows, cols); + + arma::qr_econ(U, R, arma::randn(rows, n)); + arma::qr_econ(V, R, arma::randn(cols, n)); + + arma::vec ids = arma::linspace(0, n - 1, n); + + arma::vec lowRank = ((1 - strength) * + arma::exp(-1.0 * arma::pow((ids / rank), 2))); + arma::vec tail = strength * arma::exp(-0.1 * ids / rank); + + arma::mat s = arma::zeros(n, n); + s.diag() = lowRank + tail; + data = (U * s) * V.t(); +} + +/** + * The reconstruction and sigular value error of the obtained SVD should be + * small. + */ +BOOST_AUTO_TEST_CASE(RandomizedBlockKrylovSVDReconstructionError) +{ + arma::mat U = arma::randn(3, 20); + arma::mat V = arma::randn(10, 3); + + arma::mat R; + arma::qr_econ(U, R, U); + arma::qr_econ(V, R, V); + + arma::mat s = arma::diagmat(arma::vec("1 0.1 0.01")); + + arma::mat data = arma::trans(U * arma::diagmat(s) * V.t()); + + // Center the data into a temporary matrix. + arma::mat centeredData; + math::Center(data, centeredData); + + arma::mat U1, U2, V1, V2; + arma::vec s1, s2, s3; + + arma::svd_econ(U1, s1, V1, centeredData); + + svd::RandomizedBlockKrylovSVD rSVD(20, 10); + rSVD.Apply(centeredData, U2, s2, V2, 3); + + // Use the same amount of data for the compariosn (matrix rank). + s3 = s1.subvec(0, s2.n_elem - 1); + + // The sigular value error should be small. + double error = arma::norm(s2 - s3, "frob") / arma::norm(s2, "frob"); + BOOST_REQUIRE_SMALL(error, 1e-5); + + arma::mat reconstruct = U2 * arma::diagmat(s2) * V2.t(); + + // The relative reconstruction error should be small. + error = arma::norm(centeredData - reconstruct, "frob") / + arma::norm(centeredData, "frob"); + BOOST_REQUIRE_SMALL(error, 1e-5); +} + +/* + * Check if the method can handle noisy matrices. + */ +BOOST_AUTO_TEST_CASE(RandomizedBlockKrylovSVDNoisyLowRankTest) +{ + arma::mat data; + CreateNoisyLowRankMatrix(data, 200, 1000, 5, 0.5); + + const size_t rank = 5; + + arma::mat U1, U2, V1, V2; + arma::vec s1, s2, s3; + + arma::svd_econ(U1, s1, V1, data); + + svd::RandomizedBlockKrylovSVD rSVDB(data, U2, s2, V2, 10, rank, 20); + + double error = arma::max(arma::abs(s1.subvec(0, rank) - s2.subvec(0, rank))); + BOOST_REQUIRE_SMALL(error, 1e-2); +} + +BOOST_AUTO_TEST_SUITE_END(); diff --git a/src/mlpack/tests/convolutional_network_test.cpp b/src/mlpack/tests/convolutional_network_test.cpp index efafcc9c24c..ea54a4ab95b 100644 --- a/src/mlpack/tests/convolutional_network_test.cpp +++ b/src/mlpack/tests/convolutional_network_test.cpp @@ -1,6 +1,7 @@ /** * @file convolutional_network_test.cpp * @author Marcus Edel + * @author Abhinav Moudgil * * Tests the convolutional neural network. * @@ -23,7 +24,6 @@ using namespace mlpack; using namespace mlpack::ann; using namespace mlpack::optimization; - BOOST_AUTO_TEST_SUITE(ConvolutionalNetworkTest); /** @@ -47,11 +47,13 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest) { if (i < nPoints / 2) { - Y(i) = 4; + // Assign label "1" to all samples with digit = 4 + Y(i) = 1; } else { - Y(i) = 9; + // Assign label "2" to all samples with digit = 9 + Y(i) = 2; } } @@ -73,60 +75,46 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest) * +---+ +---+ +---+ +---+ +---+ +---+ */ - // It isn't guaranteed that the network will converge in the specified number - // of iterations using random weights. If this works 1 of 5 times, I'm fine - // with that. All I want to know is that the network is able to escape from - // local minima and to solve the task. - size_t successes = 0; - for (size_t trial = 0; trial < 5; ++trial) + FFN, RandomInitialization> model; + + model.Add >(1, 8, 5, 5, 1, 1, 0, 0, 28, 28); + model.Add >(); + model.Add >(8, 8, 2, 2); + model.Add >(8, 12, 2, 2); + model.Add >(); + model.Add >(2, 2, 2, 2); + model.Add >(192, 20); + model.Add >(); + model.Add >(20, 10); + model.Add >(); + model.Add >(10, 2); + model.Add >(); + + RMSProp opt(model, 0.001, 0.88, 1e-8, 5000, -1); + + model.Train(std::move(X), std::move(Y), opt); + + arma::mat predictionTemp; + model.Predict(X, predictionTemp); + arma::mat prediction = arma::zeros(1, predictionTemp.n_cols); + + for (size_t i = 0; i < predictionTemp.n_cols; ++i) { - FFN, GaussianInitialization> model; - - model.Add >(1, 8, 5, 5, 1, 1, 0, 0, 28, 28); - model.Add >(); - model.Add >(8, 8, 2, 2); - model.Add >(8, 12, 2, 2); - model.Add >(); - model.Add >(2, 2, 2, 2); - model.Add >(192, 20); - model.Add >(); - model.Add >(20, 30); - model.Add >(); - model.Add >(30, 10); - model.Add >(); - - RMSprop opt(model, 0.01, 0.88, 1e-8, 5000, -1); - - model.Train(std::move(X), std::move(Y), opt); - - arma::mat predictionTemp; - model.Predict(X, predictionTemp); - arma::mat prediction = arma::zeros(1, predictionTemp.n_cols); - - for (size_t i = 0; i < predictionTemp.n_cols; ++i) - { - prediction(i) = arma::as_scalar(arma::find( + prediction(i) = arma::as_scalar(arma::find( arma::max(predictionTemp.col(i)) == predictionTemp.col(i), 1)) + 1; - } - - size_t error = 0; - for (size_t i = 0; i < X.n_cols; i++) - { - if (prediction(i) == Y(i)) - { - error++; - } - } + } - double classificationError = 1 - double(error) / X.n_cols; - if (classificationError <= 0.2) + size_t correct = 0; + for (size_t i = 0; i < X.n_cols; i++) + { + if (prediction(i) == Y(i)) { - ++successes; - break; + correct++; } } - BOOST_REQUIRE_GE(successes, 1); + double classificationError = 1 - double(correct) / X.n_cols; + BOOST_REQUIRE_LE(classificationError, 0.2); } BOOST_AUTO_TEST_SUITE_END(); diff --git a/src/mlpack/tests/feedforward_network_test.cpp b/src/mlpack/tests/feedforward_network_test.cpp index 883fe9e6b00..e76e83bfe21 100644 --- a/src/mlpack/tests/feedforward_network_test.cpp +++ b/src/mlpack/tests/feedforward_network_test.cpp @@ -66,7 +66,7 @@ void BuildVanillaNetwork(MatType& trainData, model.Add >(hiddenLayerSize, outputSize); model.Add >(); - RMSprop opt(model, 0.01, 0.88, 1e-8, + RMSProp opt(model, 0.01, 0.88, 1e-8, maxEpochs * trainData.n_cols, -1); model.Train(std::move(trainData), std::move(trainLabels), opt); @@ -194,7 +194,7 @@ void BuildDropoutNetwork(MatType& trainData, model.Add >(hiddenLayerSize, outputSize); model.Add >(); - RMSprop opt(model, 0.01, 0.88, 1e-8, + RMSProp opt(model, 0.01, 0.88, 1e-8, maxEpochs * trainData.n_cols, -1); model.Train(std::move(trainData), std::move(trainLabels), opt); @@ -324,7 +324,7 @@ void BuildDropConnectNetwork(MatType& trainData, model.Add >(hiddenLayerSize, outputSize); model.Add >(); - RMSprop opt(model, 0.01, 0.88, 1e-8, + RMSProp opt(model, 0.01, 0.88, 1e-8, maxEpochs * trainData.n_cols, -1); model.Train(std::move(trainData), std::move(trainLabels), opt); diff --git a/src/mlpack/tests/ksinit_test.cpp b/src/mlpack/tests/ksinit_test.cpp index 1b42d8bfa0a..16ec7457e42 100644 --- a/src/mlpack/tests/ksinit_test.cpp +++ b/src/mlpack/tests/ksinit_test.cpp @@ -85,7 +85,7 @@ void BuildVanillaNetwork(MatType& trainData, model.Add >(); model.Add >(hiddenLayerSize, outputSize); - RMSprop opt(model, 0.01, 0.88, 1e-8, + RMSProp opt(model, 0.01, 0.88, 1e-8, maxEpochs * trainData.n_cols, 1e-18); model.Train(std::move(trainData), std::move(trainLabels), opt); diff --git a/src/mlpack/tests/pca_test.cpp b/src/mlpack/tests/pca_test.cpp index 5ec70b15bb9..926fff93118 100644 --- a/src/mlpack/tests/pca_test.cpp +++ b/src/mlpack/tests/pca_test.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include "test_tools.hpp" @@ -31,15 +32,17 @@ using namespace mlpack::distribution; * specified decomposition policy. */ template -void ArmaComparisonPCA() +void ArmaComparisonPCA( + const bool scaleData = false, + const DecompositionPolicy& decomposition = DecompositionPolicy()) { arma::mat coeff, coeff1, score, score1; arma::vec eigVal, eigVal1; arma::mat data = arma::randu(3, 1000); - PCAType exactPCA; - exactPCA.Apply(data, score1, eigVal1, coeff1); + PCAType pcaType(scaleData, decomposition); + pcaType.Apply(data, score1, eigVal1, coeff1); princomp(coeff, score, eigVal, trans(data)); @@ -58,7 +61,9 @@ void ArmaComparisonPCA() * (which should be correct!) using the specified decomposition policy. */ template -void PCADimensionalityReduction() +void PCADimensionalityReduction( + const bool scaleData = false, + const DecompositionPolicy& decomposition = DecompositionPolicy()) { // Fake, simple dataset. The results we will compare against are from MATLAB. mat data("1 0 2 3 9;" @@ -66,7 +71,7 @@ void PCADimensionalityReduction() "6 7 3 1 8"); // Now run PCA to reduce the dimensionality. - PCAType p; + PCAType p(scaleData, decomposition); const double varRetained = p.Apply(data, 2); // Reduce to 2 dimensions. // Compare with correct results. @@ -168,6 +173,16 @@ BOOST_AUTO_TEST_CASE(ArmaComparisonExactPCATest) ArmaComparisonPCA(); } +/** + * Compare the output of our randomized block krylov PCA implementation with + * Armadillo's. + */ +BOOST_AUTO_TEST_CASE(ArmaComparisonRandomizedBlockKrylovPCATest) +{ + RandomizedBlockKrylovSVDPolicy decomposition(5); + ArmaComparisonPCA(false, decomposition); +} + /** * Compare the output of our randomized-SVD PCA implementation with Armadillo's. */ @@ -185,6 +200,17 @@ BOOST_AUTO_TEST_CASE(ExactPCADimensionalityReductionTest) PCADimensionalityReduction(); } +/** + * Test that dimensionality reduction with randomized block krylov PCA works the + * same way MATLAB does (which should be correct!). + */ +BOOST_AUTO_TEST_CASE(RandomizedBlockKrylovPCADimensionalityReductionTest) +{ + RandomizedBlockKrylovSVDPolicy decomposition(5); + PCADimensionalityReduction(false, + decomposition); +} + /** * Test that dimensionality reduction with randomized-svd PCA works the same way * MATLAB does (which should be correct!). diff --git a/src/mlpack/tests/recurrent_network_test.cpp b/src/mlpack/tests/recurrent_network_test.cpp index b726fc84ce9..73df32a209d 100644 --- a/src/mlpack/tests/recurrent_network_test.cpp +++ b/src/mlpack/tests/recurrent_network_test.cpp @@ -359,76 +359,90 @@ void ReberGrammarTestNetwork(bool embedded = false) * . . * ....... */ - const size_t outputSize = 7; - const size_t inputSize = 7; - const size_t rho = trainInput.at(0, 0).n_elem / inputSize; + // It isn't guaranteed that the recurrent network will converge in the + // specified number of iterations using random weights. If this works 1 of 5 + // times, I'm fine with that. All I want to know is that the network is able + // to escape from local minima and to solve the task. + size_t successes = 0; + for (size_t trial = 0; trial < 5; ++trial) + { + const size_t outputSize = 7; + const size_t inputSize = 7; + const size_t rho = trainInput.at(0, 0).n_elem / inputSize; - RNN > model(rho); + RNN > model(rho); - model.Add >(); - model.Add >(inputSize, 20); - model.Add >(20, 7, rho); - model.Add >(7, outputSize); - model.Add >(); + model.Add >(); + model.Add >(inputSize, 20); + model.Add >(20, 7, rho); + model.Add >(7, outputSize); + model.Add >(); - StandardSGD opt(model, 0.1, 2, -50000); + StandardSGD opt(model, 0.1, 2, -50000); - arma::mat inputTemp, labelsTemp; - for (size_t i = 0; i < 40; i++) - { - for (size_t j = 0; j < trainReberGrammarCount; j++) + arma::mat inputTemp, labelsTemp; + for (size_t i = 0; i < 20; i++) { - inputTemp = trainInput.at(0, j); - labelsTemp = trainLabels.at(0, j); + for (size_t j = 0; j < trainReberGrammarCount; j++) + { + inputTemp = trainInput.at(0, j); + labelsTemp = trainLabels.at(0, j); - model.Train(inputTemp, labelsTemp, opt); + model.Train(inputTemp, labelsTemp, opt); + } } - } - double error = 0; + double error = 0; - // Ask the network to predict the next Reber grammar in the given sequence. - for (size_t i = 0; i < testReberGrammarCount; i++) - { - arma::mat output, prediction; - arma::mat input = testInput.at(0, i); + // Ask the network to predict the next Reber grammar in the given sequence. + for (size_t i = 0; i < testReberGrammarCount; i++) + { + arma::mat output, prediction; + arma::mat input = testInput.at(0, i); - model.Predict(input, prediction); - data::Binarize(prediction, output, 0.5); + model.Predict(input, prediction); + data::Binarize(prediction, output, 0.5); - const size_t reberGrammerSize = 7; - std::string inputReber = ""; + const size_t reberGrammerSize = 7; + std::string inputReber = ""; - size_t reberError = 0; - for (size_t j = 0; j < (output.n_elem / reberGrammerSize); j++) - { - if (arma::sum(arma::sum(output.submat(j * reberGrammerSize, 0, (j + 1) * - reberGrammerSize - 1, 0))) != 1) break; + size_t reberError = 0; + for (size_t j = 0; j < (output.n_elem / reberGrammerSize); j++) + { + if (arma::sum(arma::sum(output.submat(j * reberGrammerSize, 0, (j + 1) * + reberGrammerSize - 1, 0))) != 1) break; - char predictedSymbol, inputSymbol; - std::string reberChoices; + char predictedSymbol, inputSymbol; + std::string reberChoices; - ReberReverseTranslation(output.submat(j * reberGrammerSize, 0, (j + 1) * - reberGrammerSize - 1, 0), predictedSymbol); - ReberReverseTranslation(input.submat(j * reberGrammerSize, 0, (j + 1) * - reberGrammerSize - 1, 0), inputSymbol); - inputReber += inputSymbol; + ReberReverseTranslation(output.submat(j * reberGrammerSize, 0, (j + 1) * + reberGrammerSize - 1, 0), predictedSymbol); + ReberReverseTranslation(input.submat(j * reberGrammerSize, 0, (j + 1) * + reberGrammerSize - 1, 0), inputSymbol); + inputReber += inputSymbol; - if (embedded) - GenerateNextEmbeddedReber(transitions, inputReber, reberChoices); - else - GenerateNextReber(transitions, inputReber, reberChoices); + if (embedded) + GenerateNextEmbeddedReber(transitions, inputReber, reberChoices); + else + GenerateNextReber(transitions, inputReber, reberChoices); - if (reberChoices.find(predictedSymbol) != std::string::npos) - reberError++; + if (reberChoices.find(predictedSymbol) != std::string::npos) + reberError++; + } + + if (reberError != (output.n_elem / reberGrammerSize)) + error += 1; } - if (reberError != (output.n_elem / reberGrammerSize)) - error += 1; + error /= testReberGrammarCount; + if (error <= 0.2) + { + ++successes; + break; + } } - error /= testReberGrammarCount; - BOOST_REQUIRE_LE(error, 0.2); + BOOST_REQUIRE_GE(successes, 1); } /** diff --git a/src/mlpack/tests/rmsprop_test.cpp b/src/mlpack/tests/rmsprop_test.cpp index 831df74302d..19275588dff 100644 --- a/src/mlpack/tests/rmsprop_test.cpp +++ b/src/mlpack/tests/rmsprop_test.cpp @@ -27,15 +27,15 @@ using namespace mlpack::optimization::test; using namespace mlpack::distribution; using namespace mlpack::regression; -BOOST_AUTO_TEST_SUITE(RMSpropTest); +BOOST_AUTO_TEST_SUITE(RMSPropTest); /** - * Tests the RMSprop optimizer using a simple test function. + * Tests the RMSProp optimizer using a simple test function. */ -BOOST_AUTO_TEST_CASE(SimpleRMSpropTestFunction) +BOOST_AUTO_TEST_CASE(SimpleRMSPropTestFunction) { SGDTestFunction f; - RMSprop optimizer(f, 1e-3, 0.99, 1e-8, 5000000, 1e-9, true); + RMSProp optimizer(f, 1e-3, 0.99, 1e-8, 5000000, 1e-9, true); arma::mat coordinates = f.GetInitialPoint(); optimizer.Optimize(coordinates); @@ -46,7 +46,7 @@ BOOST_AUTO_TEST_CASE(SimpleRMSpropTestFunction) } /** - * Run RMSprop on logistic regression and make sure the results are acceptable. + * Run RMSProp on logistic regression and make sure the results are acceptable. */ BOOST_AUTO_TEST_CASE(LogisticRegressionTest) { @@ -95,7 +95,7 @@ BOOST_AUTO_TEST_CASE(LogisticRegressionTest) LogisticRegression<> lr(shuffledData.n_rows, 0.5); LogisticRegressionFunction<> lrf(shuffledData, shuffledResponses, 0.5); - RMSprop > rmsprop(lrf); + RMSProp > rmsprop(lrf); lr.Train(rmsprop); // Ensure that the error is close to zero. diff --git a/src/mlpack/tests/smorms3_test.cpp b/src/mlpack/tests/smorms3_test.cpp new file mode 100644 index 00000000000..1d62556312d --- /dev/null +++ b/src/mlpack/tests/smorms3_test.cpp @@ -0,0 +1,109 @@ +/** + * @file smorms3_test.cpp + * @author Vivek Pal + * + * Tests the SMORMS3 optimizer. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#include + +#include +#include +#include + +#include +#include "test_tools.hpp" + +using namespace arma; +using namespace mlpack::optimization; +using namespace mlpack::optimization::test; + +using namespace mlpack::distribution; +using namespace mlpack::regression; + +using namespace mlpack; + +BOOST_AUTO_TEST_SUITE(SMORMS3Test); + +/** + * Tests the SMORMS3 optimizer using a simple test function. + */ +BOOST_AUTO_TEST_CASE(SimpleSMORMS3TestFunction) +{ + SGDTestFunction f; + SMORMS3 s(f, 0.001, 1e-16, 5000000, 1e-9, true); + + arma::mat coordinates = f.GetInitialPoint(); + s.Optimize(coordinates); + + BOOST_REQUIRE_SMALL(coordinates[0], 0.1); + BOOST_REQUIRE_SMALL(coordinates[1], 0.1); + BOOST_REQUIRE_SMALL(coordinates[2], 0.1); +} + +/** + * Run SMORMS3 on logistic regression and make sure the results are acceptable. + */ +BOOST_AUTO_TEST_CASE(SMORMS3LogisticRegressionTest) +{ + // Generate a two-Gaussian dataset. + GaussianDistribution g1(arma::vec("1.0 1.0 1.0"), arma::eye(3, 3)); + GaussianDistribution g2(arma::vec("9.0 9.0 9.0"), arma::eye(3, 3)); + + arma::mat data(3, 1000); + arma::Row responses(1000); + for (size_t i = 0; i < 500; ++i) + { + data.col(i) = g1.Random(); + responses[i] = 0; + } + for (size_t i = 500; i < 1000; ++i) + { + data.col(i) = g2.Random(); + responses[i] = 1; + } + + // Shuffle the dataset. + arma::uvec indices = arma::shuffle(arma::linspace(0, + data.n_cols - 1, data.n_cols)); + arma::mat shuffledData(3, 1000); + arma::Row shuffledResponses(1000); + for (size_t i = 0; i < data.n_cols; ++i) + { + shuffledData.col(i) = data.col(indices[i]); + shuffledResponses[i] = responses[indices[i]]; + } + + // Create a test set. + arma::mat testData(3, 1000); + arma::Row testResponses(1000); + for (size_t i = 0; i < 500; ++i) + { + testData.col(i) = g1.Random(); + testResponses[i] = 0; + } + for (size_t i = 500; i < 1000; ++i) + { + testData.col(i) = g2.Random(); + testResponses[i] = 1; + } + + LogisticRegression<> lr(shuffledData.n_rows, 0.5); + + LogisticRegressionFunction<> lrf(shuffledData, shuffledResponses, 0.5); + SMORMS3 > smorms3(lrf); + lr.Train(smorms3); + + // Ensure that the error is close to zero. + const double acc = lr.ComputeAccuracy(data, responses); + BOOST_REQUIRE_CLOSE(acc, 100.0, 0.3); // 0.3% error tolerance. + + const double testAcc = lr.ComputeAccuracy(testData, testResponses); + BOOST_REQUIRE_CLOSE(testAcc, 100.0, 0.6); // 0.6% error tolerance. +} + +BOOST_AUTO_TEST_SUITE_END(); \ No newline at end of file