diff --git a/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd.hpp b/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd.hpp index 248bc90daee..0a7276d5acd 100644 --- a/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd.hpp +++ b/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd.hpp @@ -105,6 +105,8 @@ class MiniBatchSGDType * @param updatePolicy Instantiated update policy used to adjust the given * parameters. * @param decayPolicy Instantiated decay policy used to adjust the step size. + * @param resetPolicy Flag that determines whether update policy parameters + * are reset before every Optimize call. */ MiniBatchSGDType(const size_t batchSize = 1000, const double stepSize = 0.01, @@ -112,7 +114,8 @@ class MiniBatchSGDType const double tolerance = 1e-5, const bool shuffle = true, const UpdatePolicyType& updatePolicy = UpdatePolicyType(), - const DecayPolicyType& decayPolicy = DecayPolicyType()); + const DecayPolicyType& decayPolicy = DecayPolicyType(), + const bool resetPolicy = true); /** * Optimize the given function using mini-batch SGD. The given starting point @@ -122,10 +125,13 @@ class MiniBatchSGDType * @tparam DecomposableFunctionType Type of the function to be optimized. * @param function Function to optimize. * @param iterate Starting point (will be modified). + * @param resetPolicy Flag indicating whether update policy + * should be reset before running optimization. * @return Objective value of the final point. */ template - double Optimize(DecomposableFunctionType& function, arma::mat& iterate); + double Optimize(DecomposableFunctionType& function, + arma::mat& iterate); //! Get the batch size. size_t BatchSize() const { return batchSize; } @@ -152,6 +158,13 @@ class MiniBatchSGDType //! Modify whether or not the individual functions are shuffled. bool& Shuffle() { return shuffle; } + //! Get whether or not the update policy parameters + //! are reset before Optimize call. + bool ResetPolicy() const { return resetPolicy; } + //! Modify whether or not the update policy parameters + //! are reset before Optimize call. + bool& ResetPolicy() { return resetPolicy; } + //! Get the update policy. UpdatePolicyType UpdatePolicy() const { return updatePolicy; } //! Modify the update policy. @@ -184,6 +197,10 @@ class MiniBatchSGDType //! The decay policy used to update the parameters in each iteration. DecayPolicyType decayPolicy; + + //! Flag that determines whether update policy parameters + //! are reset before every Optimize call. + bool resetPolicy; }; using MiniBatchSGD = MiniBatchSGDType; diff --git a/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd_impl.hpp b/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd_impl.hpp index 776182cb87d..16be3ec0352 100644 --- a/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd_impl.hpp +++ b/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd_impl.hpp @@ -31,14 +31,16 @@ MiniBatchSGDType< const double tolerance, const bool shuffle, const UpdatePolicyType& updatePolicy, - const DecayPolicyType& decayPolicy) : + const DecayPolicyType& decayPolicy, + const bool resetPolicy) : batchSize(batchSize), stepSize(stepSize), maxIterations(maxIterations), tolerance(tolerance), shuffle(shuffle), updatePolicy(updatePolicy), - decayPolicy(decayPolicy) + decayPolicy(decayPolicy), + resetPolicy(resetPolicy) { /* Nothing to do. */ } //! Optimize the function (minimize). @@ -50,7 +52,8 @@ template double MiniBatchSGDType< UpdatePolicyType, DecayPolicyType ->::Optimize(DecomposableFunctionType& function, arma::mat& iterate) +>::Optimize(DecomposableFunctionType& function, + arma::mat& iterate) { // Find the number of functions. const size_t numFunctions = function.NumFunctions(); @@ -75,7 +78,8 @@ double MiniBatchSGDType< overallObjective += function.Evaluate(iterate, i); // Initialize the update policy. - updatePolicy.Initialize(iterate.n_rows, iterate.n_cols); + if (resetPolicy) + updatePolicy.Initialize(iterate.n_rows, iterate.n_cols); // Now iterate! arma::mat gradient(iterate.n_rows, iterate.n_cols); diff --git a/src/mlpack/core/optimizers/sgd/sgd.hpp b/src/mlpack/core/optimizers/sgd/sgd.hpp index 4801df158a4..c21517861a9 100644 --- a/src/mlpack/core/optimizers/sgd/sgd.hpp +++ b/src/mlpack/core/optimizers/sgd/sgd.hpp @@ -92,12 +92,17 @@ class SGD * @param tolerance Maximum absolute tolerance to terminate algorithm. * @param shuffle If true, the function order is shuffled; otherwise, each * function is visited in linear order. + * @param updatePolicy Instantiated update policy used to adjust the given + * parameters. + * @param resetPolicy Flag that determines whether update policy parameters + * are reset before every Optimize call. */ SGD(const double stepSize = 0.01, const size_t maxIterations = 100000, const double tolerance = 1e-5, const bool shuffle = true, - const UpdatePolicyType updatePolicy = UpdatePolicyType()); + const UpdatePolicyType updatePolicy = UpdatePolicyType(), + const bool resetPolicy = true); /** * Optimize the given function using stochastic gradient descent. The given @@ -110,7 +115,8 @@ class SGD * @return Objective value of the final point. */ template - double Optimize(DecomposableFunctionType& function, arma::mat& iterate); + double Optimize(DecomposableFunctionType& function, + arma::mat& iterate); //! Get the step size. double StepSize() const { return stepSize; } @@ -132,6 +138,13 @@ class SGD //! Modify whether or not the individual functions are shuffled. bool& Shuffle() { return shuffle; } + //! Get whether or not the update policy parameters + //! are reset before Optimize call. + bool ResetPolicy() const { return resetPolicy; } + //! Modify whether or not the update policy parameters + //! are reset before Optimize call. + bool& ResetPolicy() { return resetPolicy; } + //! Get the update policy. const UpdatePolicyType& UpdatePolicy() const { return updatePolicy; } //! Modify the update policy. @@ -153,6 +166,10 @@ class SGD //! The update policy used to update the parameters in each iteration. UpdatePolicyType updatePolicy; + + //! Flag indicating whether update policy + //! should be reset before running optimization. + bool resetPolicy; }; using StandardSGD = SGD; diff --git a/src/mlpack/core/optimizers/sgd/sgd_impl.hpp b/src/mlpack/core/optimizers/sgd/sgd_impl.hpp index 224b60e445e..4bfa4b64817 100644 --- a/src/mlpack/core/optimizers/sgd/sgd_impl.hpp +++ b/src/mlpack/core/optimizers/sgd/sgd_impl.hpp @@ -29,12 +29,14 @@ SGD::SGD( const size_t maxIterations, const double tolerance, const bool shuffle, - const UpdatePolicyType updatePolicy) : + const UpdatePolicyType updatePolicy, + const bool resetPolicy) : stepSize(stepSize), maxIterations(maxIterations), tolerance(tolerance), shuffle(shuffle), - updatePolicy(updatePolicy) + updatePolicy(updatePolicy), + resetPolicy(resetPolicy) { /* Nothing to do. */ } //! Optimize the function (minimize). @@ -65,7 +67,8 @@ double SGD::Optimize( overallObjective += function.Evaluate(iterate, i); // Initialize the update policy. - updatePolicy.Initialize(iterate.n_rows, iterate.n_cols); + if (resetPolicy) + updatePolicy.Initialize(iterate.n_rows, iterate.n_cols); // Now iterate! arma::mat gradient(iterate.n_rows, iterate.n_cols); diff --git a/src/mlpack/core/optimizers/sgd/update_policies/gradient_clipping.hpp b/src/mlpack/core/optimizers/sgd/update_policies/gradient_clipping.hpp new file mode 100644 index 00000000000..68e58a99fed --- /dev/null +++ b/src/mlpack/core/optimizers/sgd/update_policies/gradient_clipping.hpp @@ -0,0 +1,109 @@ +/** + * @file gradient_clipping.hpp + * @author Konstantin Sidorov + * + * Gradient clipping update wrapper. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#ifndef MLPACK_CORE_OPTIMIZERS_SGD_GRADIENT_CLIPPING_HPP +#define MLPACK_CORE_OPTIMIZERS_SGD_GRADIENT_CLIPPING_HPP + +#include + +namespace mlpack { +namespace optimization { + +/** + * Interface for wrapping around update policies (e.g., VanillaUpdate) + * and feeding a clipped gradient to them instead of the normal one. + * (Clipping here is implemented as + * \f$ g_{\text{clipped}} = \max(g_{\text{min}}, \min(g_{\text{min}}, g))) \f$.) + * + * @tparam UpdatePolicy A type of UpdatePolicy that sould be wrapped around. + */ +template +class GradientClipping +{ + public: + /** + * Constructor for creating a GradientClipping instance. + * + * @param minGradient Minimum possible value of gradient element. + * @param maxGradient Maximum possible value of gradient element. + * @param updatePolicy An instance of the UpdatePolicyType + * used for actual optimization. + */ + GradientClipping(const double minGradient, + const double maxGradient, + UpdatePolicyType& updatePolicy) : + minGradient(minGradient), + maxGradient(maxGradient), + updatePolicy(updatePolicy) + { + // Nothing to do here + } + + /** + * The Initialize method is called by SGD Optimizer method before the start of + * the iteration update process. Here we just do whatever initialization + * is needed for the actual update policy. + * + * @param rows Number of rows in the gradient matrix. + * @param cols Number of columns in the gradient matrix. + */ + void Initialize(const size_t rows, const size_t cols) + { + updatePolicy.Initialize(rows, cols); + } + + /** + * Update step. First, the gradient is clipped, and then the actual update + * policy does whatever update it needs to do. + * + * @param iterate Parameters that minimize the function. + * @param stepSize Step size to be used for the given iteration. + * @param gradient The gradient matrix. + */ + void Update(arma::mat& iterate, + const double stepSize, + const arma::mat& gradient) + { + // First, clip the gradient. + arma::mat clippedGradient = arma::clamp(gradient, minGradient, maxGradient); + // And only then do the update. + updatePolicy.Update(iterate, stepSize, clippedGradient); + } + + //! Get the update policy. + UpdatePolicyType& UpdatePolicy() const { return updatePolicy; } + //! Modify the update policy. + UpdatePolicyType& UpdatePolicy() { return updatePolicy; } + + //! Get the minimum gradient value. + double MinGradient() const { return minGradient; } + //! Modify the minimum gradient value. + double& MinGradient() { return minGradient; } + + //! Get the maximum gradient value. + double MaxGradient() const { return maxGradient; } + //! Modify the maximum gradient value. + double& MaxGradient() { return maxGradient; } + private: + //! Minimum possible value of gradient element. + double minGradient; + + //! Maximum possible value of gradient element. + double maxGradient; + + //! An instance of the UpdatePolicy used for actual optimization. + UpdatePolicyType updatePolicy; +}; + +} // namespace optimization +} // namespace mlpack + +#endif diff --git a/src/mlpack/methods/ann/layer/CMakeLists.txt b/src/mlpack/methods/ann/layer/CMakeLists.txt index 17f54d69915..caa1cf112cd 100644 --- a/src/mlpack/methods/ann/layer/CMakeLists.txt +++ b/src/mlpack/methods/ann/layer/CMakeLists.txt @@ -14,6 +14,8 @@ set(SOURCES constant_impl.hpp convolution.hpp convolution_impl.hpp + cross_entropy_error.hpp + cross_entropy_error_impl.hpp dropconnect.hpp dropconnect_impl.hpp dropout.hpp diff --git a/src/mlpack/methods/ann/layer/cross_entropy_error.hpp b/src/mlpack/methods/ann/layer/cross_entropy_error.hpp new file mode 100644 index 00000000000..4301e0c4a91 --- /dev/null +++ b/src/mlpack/methods/ann/layer/cross_entropy_error.hpp @@ -0,0 +1,111 @@ +/** + * @file cross_entropy_error.hpp + * @author Konstantin Sidorov + * + * Definition of the cross-entropy performance function. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#ifndef MLPACK_METHODS_ANN_LAYER_CROSS_ENTROPY_ERROR_HPP +#define MLPACK_METHODS_ANN_LAYER_CROSS_ENTROPY_ERROR_HPP + +#include + +namespace mlpack { +namespace ann /** Artificial Neural Network. */ { + +/** + * The cross-entropy performance function measures the network's + * performance according to the cross-entropy between the input + * and target distributions. + * + * @tparam InputDataType Type of the input data (arma::colvec, arma::mat, + * arma::sp_mat or arma::cube). + * @tparam OutputDataType Type of the output data (arma::colvec, arma::mat, + * arma::sp_mat or arma::cube). + */ +template < + typename InputDataType = arma::mat, + typename OutputDataType = arma::mat +> +class CrossEntropyError +{ + public: + /** + * Create the CrossEntropyError object. + * + * @param eps The minimum value used for computing logarithms + * and denominators in a numerically stable way. + */ + CrossEntropyError(double eps = 1e-10); + + /* + * Computes the cross-entropy function. + * + * @param input Input data used for evaluating the specified function. + * @param output Resulting output activation. + */ + template + double Forward(const arma::Mat&& input, const arma::Mat&& target); + /** + * Ordinary feed backward pass of a neural network. + * + * @param input The propagated input activation. + * @param target The target vector. + * @param output The calculated error. + */ + template + void Backward(const arma::Mat&& input, + const arma::Mat&& target, + arma::Mat&& output); + + //! Get the input parameter. + InputDataType& InputParameter() const { return inputParameter; } + //! Modify the input parameter. + InputDataType& InputParameter() { return inputParameter; } + + //! Get the output parameter. + OutputDataType& OutputParameter() const { return outputParameter; } + //! Modify the output parameter. + OutputDataType& OutputParameter() { return outputParameter; } + + //! Get the delta. + OutputDataType& Delta() const { return delta; } + //! Modify the delta. + OutputDataType& Delta() { return delta; } + + //! Get the epsilon. + double Eps() const { return eps; } + //! Modify the epsilon. + double& Eps() { return eps; } + + /** + * Serialize the layer. + */ + template + void Serialize(Archive& ar, const unsigned int /* version */); + + private: + //! Locally-stored delta object. + OutputDataType delta; + + //! Locally-stored input parameter object. + InputDataType inputParameter; + + //! Locally-stored output parameter object. + OutputDataType outputParameter; + + //! The minimum value used for computing logarithms and denominators + double eps; +}; // class CrossEntropyError + +} // namespace ann +} // namespace mlpack + +// Include implementation. +#include "cross_entropy_error_impl.hpp" + +#endif diff --git a/src/mlpack/methods/ann/layer/cross_entropy_error_impl.hpp b/src/mlpack/methods/ann/layer/cross_entropy_error_impl.hpp new file mode 100644 index 00000000000..1aa4b754e6d --- /dev/null +++ b/src/mlpack/methods/ann/layer/cross_entropy_error_impl.hpp @@ -0,0 +1,59 @@ +/** + * @file cross_entropy_error_impl.hpp + * @author Konstantin Sidorov + * + * Implementation of the cross-entropy performance function. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#ifndef MLPACK_METHODS_ANN_LAYER_CROSS_ENTROPY_ERROR_IMPL_HPP +#define MLPACK_METHODS_ANN_LAYER_CROSS_ENTROPY_ERROR_IMPL_HPP + +// In case it hasn't yet been included. +#include "cross_entropy_error.hpp" + +namespace mlpack { +namespace ann /** Artificial Neural Network. */ { + +template +CrossEntropyError::CrossEntropyError(double eps) + : eps(eps) +{ + // Nothing to do here. +} + +template +template +double CrossEntropyError::Forward( + const arma::Mat&& input, const arma::Mat&& target) +{ + return -arma::accu(target % arma::log(input + eps) + + (1. - target) % arma::log(1. - input + eps)); +} + +template +template +void CrossEntropyError::Backward( + const arma::Mat&& input, + const arma::Mat&& target, + arma::Mat&& output) +{ + output = (1. - target) / (1. - input + eps) - target / (input + eps); +} + +template +template +void CrossEntropyError::Serialize( + Archive& ar, + const unsigned int /* version */) +{ + ar & data::CreateNVP(eps, "eps"); +} + +} // namespace ann +} // namespace mlpack + +#endif diff --git a/src/mlpack/methods/ann/layer/layer_types.hpp b/src/mlpack/methods/ann/layer/layer_types.hpp index 708838bc01d..82a1662e57c 100644 --- a/src/mlpack/methods/ann/layer/layer_types.hpp +++ b/src/mlpack/methods/ann/layer/layer_types.hpp @@ -33,6 +33,7 @@ #include #include #include +#include // Convolution modules. #include @@ -89,6 +90,7 @@ using LayerTypes = boost::variant< Convolution, NaiveConvolution, NaiveConvolution, arma::mat, arma::mat>*, + CrossEntropyError*, DropConnect*, Dropout*, ELU*, diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt index c8b32722573..1c681369c02 100644 --- a/src/mlpack/tests/CMakeLists.txt +++ b/src/mlpack/tests/CMakeLists.txt @@ -30,6 +30,7 @@ add_executable(mlpack_test feedforward_network_test.cpp frankwolfe_test.cpp gmm_test.cpp + gradient_clipping_test.cpp gradient_descent_test.cpp hmm_test.cpp hoeffding_tree_test.cpp diff --git a/src/mlpack/tests/ann_layer_test.cpp b/src/mlpack/tests/ann_layer_test.cpp index 3a0e599b949..71990826012 100644 --- a/src/mlpack/tests/ann_layer_test.cpp +++ b/src/mlpack/tests/ann_layer_test.cpp @@ -26,7 +26,7 @@ using namespace mlpack::ann; BOOST_AUTO_TEST_SUITE(ANNLayerTest); -// Helper function whcih calls the Reset function of the given module. +// Helper function which calls the Reset function of the given module. template void ResetFunction( T& layer, @@ -531,7 +531,7 @@ BOOST_AUTO_TEST_CASE(JacobianLinearNoBiasLayerTest) /** * LinearNoBias layer numerically gradient test. */ -BOOST_AUTO_TEST_CASE(GradientLinearNoBiadLayerTest) +BOOST_AUTO_TEST_CASE(GradientLinearNoBiasLayerTest) { // LinearNoBias function gradient instantiation. struct GradientFunction @@ -926,4 +926,45 @@ BOOST_AUTO_TEST_CASE(SimpleLogSoftmaxLayerTest) arma::mat("1.6487; 0.6487") - delta)), 1e-3); } +/* + * Simple test for the cross-entropy error performance function. + */ +BOOST_AUTO_TEST_CASE(SimpleCrossEntropyErrorLayerTest) +{ + arma::mat input1, input2, output, target1, target2; + CrossEntropyError<> module(1e-6); + + // Test the Forward function on a user generator input and compare it against + // the manually calculated result. + input1 = arma::mat("0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5"); + target1 = arma::zeros(1, 8); + double error1 = module.Forward(std::move(input1), std::move(target1)); + BOOST_REQUIRE_SMALL(error1 - 8 * std::log(2), 2e-5); + + input2 = arma::mat("0 1 1 0 1 0 0 1"); + target2 = arma::mat("0 1 1 0 1 0 0 1"); + double error2 = module.Forward(std::move(input2), std::move(target2)); + BOOST_REQUIRE_SMALL(error2, 1e-5); + + // Test the Backward function. + module.Backward(std::move(input1), std::move(target1), std::move(output)); + for (double el : output) { + // For the 0.5 constant vector we should get 1 / (1 - 0.5) = 2 everywhere. + BOOST_REQUIRE_SMALL(el - 2, 5e-6); + } + BOOST_REQUIRE_EQUAL(output.n_rows, input1.n_rows); + BOOST_REQUIRE_EQUAL(output.n_cols, input1.n_cols); + + module.Backward(std::move(input2), std::move(target2), std::move(output)); + for (size_t i = 0; i < 8; ++i) { + double el = output.at(0, i); + if (input2.at(i) == 0) + BOOST_REQUIRE_SMALL(el - 1, 2e-6); + else + BOOST_REQUIRE_SMALL(el + 1, 2e-6); + } + BOOST_REQUIRE_EQUAL(output.n_rows, input2.n_rows); + BOOST_REQUIRE_EQUAL(output.n_cols, input2.n_cols); +} + BOOST_AUTO_TEST_SUITE_END(); diff --git a/src/mlpack/tests/gradient_clipping_test.cpp b/src/mlpack/tests/gradient_clipping_test.cpp new file mode 100644 index 00000000000..dc8882a7cc0 --- /dev/null +++ b/src/mlpack/tests/gradient_clipping_test.cpp @@ -0,0 +1,73 @@ +/** + * @file gradient_clipping_test.cpp + * @author Konstantin Sidorov + * + * Test file for gradient clipping. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#include +#include +#include +#include +#include +#include + +#include +#include "test_tools.hpp" + +using namespace std; +using namespace arma; +using namespace mlpack; +using namespace mlpack::optimization; +using namespace mlpack::optimization::test; + +BOOST_AUTO_TEST_SUITE(GradientClippingTest); + +// Test checking that gradient clipping works with vanilla update. +BOOST_AUTO_TEST_CASE(ClippedVanillaUpdateTest) +{ + VanillaUpdate vanillaUpdate; + GradientClipping update(-3.0, +3.0, vanillaUpdate); + update.Initialize(3, 3); + + arma::mat coordinates = arma::zeros(3, 3); + // Setting step = 1 to make math easy. + double stepSize = 1.0; + arma::mat dummyGradient("-6 6 0; 1 2 3; -3 0 4;"); + update.Update(coordinates, stepSize, dummyGradient); + // After clipping, we should get the following coordinates: + arma::mat targetCoordinates("3 -3 0; -1 -2 -3; 3 0 -3;"); + BOOST_REQUIRE_SMALL(arma::abs(coordinates - targetCoordinates).max(), 1e-7); +} + +// Test checking that gradient clipping works with momentum update. +BOOST_AUTO_TEST_CASE(ClippedMomentumUpdateTest) +{ + // Once again, setting momentum = 1 for easy math + // (now momentum = -stepSize * [sum of gradients]) + MomentumUpdate momentumUpdate(1); + GradientClipping update(-3.0, +3.0, momentumUpdate); + update.Initialize(3, 3); + + arma::mat coordinates = arma::zeros(3, 3); + double stepSize = 1.0; + arma::mat dummyGradient("-6 6 0; 1 2 3; -3 0 4;"); + update.Update(coordinates, stepSize, dummyGradient); + arma::mat targetCoordinates("3 -3 0; -1 -2 -3; 3 0 -3;"); + // On the first Update() call the parameters + // should just be equal to (-gradient). + BOOST_REQUIRE_SMALL(arma::abs(coordinates - targetCoordinates).max(), 1e-7); + update.Update(coordinates, stepSize, dummyGradient); + // On the second Update() call the Momentum update will subtract + // the gradient from the momentum, which gives 2 * gradient value + // for the momentum on that step. Adding that to the gradient which + // was subtracted earlier yiels the 3 * gradient in the following check. + BOOST_REQUIRE_SMALL( + arma::abs(coordinates - 3 * targetCoordinates).max(), 1e-7); +} + +BOOST_AUTO_TEST_SUITE_END(); diff --git a/src/mlpack/tests/momentum_sgd_test.cpp b/src/mlpack/tests/momentum_sgd_test.cpp index de99e652c6f..a7c2d675618 100644 --- a/src/mlpack/tests/momentum_sgd_test.cpp +++ b/src/mlpack/tests/momentum_sgd_test.cpp @@ -11,6 +11,7 @@ */ #include #include +#include #include #include #include