From 256d8f1924b79194c271b94da40a3cce60047b2b Mon Sep 17 00:00:00 2001 From: Tomoki Ohtsuki Date: Mon, 11 Jan 2021 23:09:04 +0900 Subject: [PATCH] C++-parallelized implementation of SLIM. (#29) * The first version for C++ slim * L2 only version. * It basically works. * optimized performance * Adding a doc * Doc fixed * added a test for cpp slim --- CMakeLists.txt | 10 +- cpp_source/util.cpp | 8 ++ cpp_source/util.hpp | 166 +++++++++++++++++++++++ docs/source/api_reference.rst | 2 + examples/movielens/movielens_20m_cold.py | 6 +- irspack/optimizers/_optimizers.py | 4 +- irspack/recommenders/ials.py | 3 +- irspack/recommenders/slim.py | 99 +++++++++----- irspack/utils/_util_cpp.pyi | 82 +++++------ tests/recommenders/test_slim.py | 47 +++++++ 10 files changed, 330 insertions(+), 97 deletions(-) create mode 100644 tests/recommenders/test_slim.py diff --git a/CMakeLists.txt b/CMakeLists.txt index d9aa4d4..e4efdab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,8 +12,8 @@ set(CPACK_PROJECT_VERSION ${PROJECT_VERSION}) include(CPack) add_subdirectory(pybind11) -pybind11_add_module(irspack.recommenders._ials cpp_source/als/wrapper.cpp) -pybind11_add_module(irspack.recommenders._knn cpp_source/knn/wrapper.cpp) -pybind11_add_module(irspack.utils._util_cpp cpp_source/util.cpp) -pybind11_add_module(irspack._evapuator cpp_source/evaluator.cpp) -pybind11_add_module(irspack._rwr cpp_source/rws.cpp) +#pybind11_add_module(irspack.recommenders._ials cpp_source/als/wrapper.cpp) +#pybind11_add_module(irspack.recommenders._knn cpp_source/knn/wrapper.cpp) +pybind11_add_module(_util_cpp cpp_source/util.cpp) +#pybind11_add_module(irspack._evapuator cpp_source/evaluator.cpp) +#pybind11_add_module(irspack._rwr cpp_source/rws.cpp) diff --git a/cpp_source/util.cpp b/cpp_source/util.cpp index 787fc7f..932b5ba 100644 --- a/cpp_source/util.cpp +++ b/cpp_source/util.cpp @@ -23,4 +23,12 @@ PYBIND11_MODULE(_util_cpp, m) { py::arg("X"), py::arg("k1") = 1.2, py::arg("b") = 0.75); m.def("tf_idf_weight", &sparse_util::tf_idf_weight, py::arg("X"), py::arg("smooth") = true); + + m.def("slim_weight_allow_negative", &sparse_util::SLIM, + py::arg("X"), py::arg("n_threads"), py::arg("n_iter"), + py::arg("l2_coeff"), py::arg("l1_coeff")); + + m.def("slim_weight_positive_only", &sparse_util::SLIM, + py::arg("X"), py::arg("n_threads"), py::arg("n_iter"), + py::arg("l2_coeff"), py::arg("l1_coeff")); } diff --git a/cpp_source/util.hpp b/cpp_source/util.hpp index 899e45b..25d9985 100644 --- a/cpp_source/util.hpp +++ b/cpp_source/util.hpp @@ -1,12 +1,17 @@ #pragma once #include #include +#include +#include +#include #include +#include #include #include #include #include #include +#include #include "argcheck.hpp" @@ -16,9 +21,20 @@ namespace sparse_util { template using CSRMatrix = Eigen::SparseMatrix; +template +using RowMajorMatrix = + Eigen::Matrix; + +template +using ColMajorMatrix = + Eigen::Matrix; + template using DenseVector = Eigen::Matrix; +template +using DenseColVector = Eigen::Matrix; + template using CSCMatrix = Eigen::SparseMatrix; @@ -170,5 +186,155 @@ CSRMatrix remove_diagonal(const CSRMatrix &X) { return result; } +template ::size> +inline CSCMatrix SLIM(const CSRMatrix &X, size_t n_threads, + size_t n_iter, Real l2_coeff, Real l1_coeff) { + check_arg(n_threads > 0, "n_threads must be > 0."); + check_arg(n_iter > 0, "n_iter must be > 0."); + check_arg(l2_coeff >= 0, "l2_coeff must be > 0."); + check_arg(l1_coeff >= 0, "l1_coeff must be > 0."); + using MatrixType = + Eigen::Matrix; + using VectorType = Eigen::Matrix; + + // CSRMatrix X_csr(X); + CSCMatrix X_csc(X); + X_csc.makeCompressed(); + using TripletType = Eigen::Triplet; + using CSCIter = typename CSCMatrix::InnerIterator; + std::vector>> workers; + std::atomic cursor(0); + for (size_t th = 0; th < n_threads; th++) { + workers.emplace_back(std::async(std::launch::async, [th, &cursor, &X_csc, + l2_coeff, l1_coeff, + n_iter] { + const int64_t F = X_csc.cols(); + MatrixType remnants(X_csc.rows(), block_size); + MatrixType coeffs(F, block_size); + VectorType linear(block_size); + VectorType linear_plus(block_size); + VectorType linear_minus(block_size); + + std::vector local_resuts; + while (true) { + int64_t current_cursor = cursor.fetch_add(block_size); + if (current_cursor >= F) { + break; + } + int64_t block_begin = current_cursor; + int64_t block_end = std::min(block_begin + block_size, F); + int64_t valid_block_size = block_end - block_begin; + remnants.array() = 0; + coeffs.array() = 0; + + for (int64_t f_cursor = block_begin; f_cursor < block_end; f_cursor++) { + const int64_t internal_col_position = f_cursor - block_begin; + for (CSCIter iter(X_csc, f_cursor); iter; ++iter) { + remnants(iter.row(), internal_col_position) = -iter.value(); + } + } + + for (size_t cd_iteration = 0; cd_iteration < n_iter; cd_iteration++) { + for (int64_t feature_index = 0; feature_index < F; feature_index++) { + linear.array() = static_cast(0.0); + Real x2_sum = static_cast(0.0); + for (CSCIter nnz_iter(X_csc, feature_index); nnz_iter; ++nnz_iter) { + Real x = nnz_iter.value(); + + const int64_t row = nnz_iter.row(); + x2_sum += x * x; + /* + loss = \sum_u (remnant_u - w^old _f X_uf + w^new _f X_uf ) ^2 + z_new = + CONST + + \sum_u X_{uf} ^2 w^new_f ^2 + + 2 * w^new_f \sum_u X_{uf} ( remnant_u - X_{uf} w^{old}_f ) + + LINEAR_COEFF = + \sum_u X_{uf} ( remnant_u ) - + - \sum _u ( X_{uf} ^2) w^{old}_f + + */ + remnants.row(row).noalias() -= x * coeffs.row(feature_index); + linear.noalias() += x * remnants.row(row); + } + + Real quadratic = x2_sum + l2_coeff; + linear_plus.array() = (-linear.array() - l1_coeff) / quadratic; + linear_minus.array() = (-linear.array() + l1_coeff) / quadratic; + // linear_plus /= quadratic; + + Real *ptr_location = coeffs.data() + feature_index * block_size; + Real *lp_ptr = linear_plus.data(); + Real *lm_ptr = linear_minus.data(); + + for (int64_t inner_cursor_position = 0; + inner_cursor_position < block_size; inner_cursor_position++) { + Real lplus = *(lp_ptr++); + Real lminus = *(lm_ptr++); + int64_t original_cursor_position = + inner_cursor_position + block_begin; + if (original_cursor_position == feature_index) { + *(ptr_location++) = 0.0; + continue; + } + if (positive_only) { + if (lplus > 0) { + *(ptr_location++) = lplus; + } else { + *(ptr_location++) = static_cast(0.0); + } + + } else { + if (lplus > 0) { + *(ptr_location++) = lplus; + } else { + if (lminus < 0) { + *(ptr_location++) = lminus; + } else { + *(ptr_location++) = static_cast(0.0); + } + } + } // allow nagative block + } + + for (CSCIter nnz_iter(X_csc, feature_index); nnz_iter; ++nnz_iter) { + Real x = nnz_iter.value(); + const int64_t row = nnz_iter.row(); + remnants.row(row).noalias() += x * coeffs.row(feature_index); + } + } + } + + for (int64_t f = 0; f < F; f++) { + for (int64_t inner_cursor_position = 0; + inner_cursor_position < valid_block_size; + inner_cursor_position++) { + int64_t original_location = inner_cursor_position + block_begin; + Real c = coeffs(f, inner_cursor_position); + if (c != 0.0) { + local_resuts.emplace_back(f, original_location, c); + } + } + } + } + return local_resuts; + })); + } + std::vector nnzs; + for (auto &fres : workers) { + auto result = fres.get(); + for (const auto &w : result) { + nnzs.emplace_back(w); + } + } + + CSCMatrix result(X.cols(), X.cols()); + result.setFromTriplets(nnzs.begin(), nnzs.end()); + result.makeCompressed(); + return result; +} + } // namespace sparse_util } // namespace irspack diff --git a/docs/source/api_reference.rst b/docs/source/api_reference.rst index 47ae8e1..1805234 100644 --- a/docs/source/api_reference.rst +++ b/docs/source/api_reference.rst @@ -35,6 +35,7 @@ Recommenders AsymmetricCosineKNNRecommender JaccardKNNRecommender TverskyIndexKNNRecommender + SLIMRecommender A LightFM wrapper for BPR matrix factorization (requires a separate installation of `lightFM `_). @@ -73,6 +74,7 @@ Optimizers AsymmetricCosineKNNOptimizer JaccardKNNOptimizer TverskyIndexKNNOptimizer + SLIMOptimizer MultVAEOptimizer .. currentmodule:: irspack.split diff --git a/examples/movielens/movielens_20m_cold.py b/examples/movielens/movielens_20m_cold.py index c76b397..4c42e52 100644 --- a/examples/movielens/movielens_20m_cold.py +++ b/examples/movielens/movielens_20m_cold.py @@ -88,7 +88,7 @@ dim_z=200, enc_hidden_dims=600, kl_anneal_goal=0.2 ), # nothing to tune, use the parameters used in the paper. ), - # (SLIMOptimizer, 40), + # (SLIMOptimizer, 40, dict()), # Note: this is a heavy one. ] for optimizer_class, n_trials, config in test_configs: recommender_name = optimizer_class.recommender_class.__name__ @@ -98,9 +98,7 @@ metric="ndcg", fixed_params=config, ) - (best_param, validation_result_df) = optimizer.optimize( - n_trials=n_trials, timeout=14400 - ) + (best_param, validation_result_df) = optimizer.optimize(n_trials=n_trials) validation_result_df["recommender_name"] = recommender_name validation_results.append(validation_result_df) pd.concat(validation_results).to_csv(f"validation_scores.csv") diff --git a/irspack/optimizers/_optimizers.py b/irspack/optimizers/_optimizers.py index 790309f..8cf2137 100644 --- a/irspack/optimizers/_optimizers.py +++ b/irspack/optimizers/_optimizers.py @@ -185,8 +185,8 @@ class RandomWalkWithRestartOptimizer(BaseOptimizer): class SLIMOptimizer(BaseOptimizer): default_tune_range = [ - UniformSuggestion("alpha", 0, 1), - LogUniformSuggestion("l1_ratio", 1e-6, 1), + LogUniformSuggestion("alpha", 1e-5, 1), + UniformSuggestion("l1_ratio", 0, 1), ] recommender_class = SLIMRecommender diff --git a/irspack/recommenders/ials.py b/irspack/recommenders/ials.py index 4d004fc..0fd06b0 100644 --- a/irspack/recommenders/ials.py +++ b/irspack/recommenders/ials.py @@ -99,7 +99,8 @@ class IALSRecommender( Frequency of validation score measurement (if any). Defaults to 5. score_degradation_max (int, optional): Maximal number of allowed score degradation. Defaults to 5. - n_threads (Optional[int], optional): Specifies the number of threads to use for the computation. + n_threads (Optional[int], optional): + Specifies the number of threads to use for the computation. If ``None``, the environment variable ``"IRSPACK_NUM_THREADS_DEFAULT"`` will be looked up, and if there is no such an environment variable, it will be set to 1. Defaults to None. max_epoch (int, optional): diff --git a/irspack/recommenders/slim.py b/irspack/recommenders/slim.py index caaf1ea..1cd0f0a 100644 --- a/irspack/recommenders/slim.py +++ b/irspack/recommenders/slim.py @@ -1,50 +1,81 @@ +from typing import Optional + from scipy import sparse as sps from sklearn.linear_model import ElasticNet -from ..definitions import InteractionMatrix -from .base import BaseSimilarityRecommender - - -def slim_weight(X: InteractionMatrix, alpha: float, l1_ratio: float) -> sps.csr_matrix: - model = ElasticNet( - fit_intercept=False, - positive=True, - copy_X=False, - precompute=True, - selection="random", - max_iter=100, - tol=1e-4, - alpha=alpha, - l1_ratio=l1_ratio, - ) - coeff_all = [] - A: sps.csc_matrix = X.tocsc() - for i in range(X.shape[1]): - if i % 1000 == 0: - print(f"Slim Iteration: {i}") - start_pos = int(A.indptr[i]) - end_pos = int(A.indptr[i + 1]) - current_item_data_backup = A.data[start_pos:end_pos].copy() - target = A[:, i].toarray().ravel() - A.data[start_pos:end_pos] = 0.0 - model.fit(A, target) - coeff_all.append(model.sparse_coef_) - A.data[start_pos:end_pos] = current_item_data_backup - return sps.vstack(coeff_all, format="csr") +from irspack.definitions import InteractionMatrix +from irspack.recommenders.base import BaseSimilarityRecommender +from irspack.utils import get_n_threads +from irspack.utils._util_cpp import ( + slim_weight_allow_negative, + slim_weight_positive_only, +) class SLIMRecommender(BaseSimilarityRecommender): + """`SLIM `_ with ElasticNet-type loss function: + + .. math :: + + \mathrm{loss} = \\frac{1}{2} ||X - XB|| ^2 _F + \\frac{\\alpha (1 - l_1) U}{2} ||B|| ^2 _FF + \\alpha l_1 U |B| + + The implementation relies on a simple (parallelized) cyclic-coordinate descent method. + + Currently, this does not support: + + - shuffling of item indices + - elaborate convergence check + + Args: + X_train_all: + Input interaction matrix. + alpha: + Determines the strength of L1/L2 regularization (see above). Defaults to 0.05. + l1_ratio: + Determines the strength of L1 regularization relative to alpha. Defaults to 0.01. + positive_only: + Whether we constrain the weight matrix to be non-negative. Defaults to True. + n_iter: + The number of coordinate-descent iterations. Defaults to 10. + n_threads: + Specifies the number of threads to use for the computation. + If ``None``, the environment variable ``"IRSPACK_NUM_THREADS_DEFAULT"`` will be looked up, + and if there is no such an environment variable, it will be set to 1. Defaults to None. + """ + def __init__( self, X_train_all: InteractionMatrix, alpha: float = 0.05, l1_ratio: float = 0.01, + positive_only: bool = True, + n_iter: int = 10, + n_threads: Optional[int] = None, ): - super(SLIMRecommender, self).__init__(X_train_all) + super().__init__(X_train_all) self.alpha = alpha self.l1_ratio = l1_ratio + self.positive_only = positive_only + self.n_threads = get_n_threads(n_threads) + self.n_iter = n_iter def _learn(self) -> None: - self.W_ = slim_weight( - self.X_train_all, alpha=self.alpha, l1_ratio=self.l1_ratio - ) + l2_coeff = self.n_users * self.alpha * (1 - self.l1_ratio) + l1_coeff = self.n_users * self.alpha * self.l1_ratio + + if self.positive_only: + self.W_ = slim_weight_positive_only( + self.X_train_all, + n_threads=self.n_threads, + n_iter=self.n_iter, + l2_coeff=l2_coeff, + l1_coeff=l1_coeff, + ) + else: + self.W_ = slim_weight_allow_negative( + self.X_train_all, + n_threads=self.n_threads, + n_iter=self.n_iter, + l2_coeff=l2_coeff, + l1_coeff=l1_coeff, + ) diff --git a/irspack/utils/_util_cpp.pyi b/irspack/utils/_util_cpp.pyi index 1053164..0f2fc31 100644 --- a/irspack/utils/_util_cpp.pyi +++ b/irspack/utils/_util_cpp.pyi @@ -1,57 +1,37 @@ +from numpy import float32 +import irspack.utils._util_cpp +from typing import * from typing import Iterable as iterable from typing import Iterator as iterator -from typing import * - -from numpy import float32, float64 - -import irspack.utils._util_cpp - +from numpy import float64 _Shape = Tuple[int, ...] import scipy.sparse - -__all__ = [ - "okapi_BM_25_weight", - "remove_diagonal", - "rowwise_train_test_split_d", - "rowwise_train_test_split_f", - "rowwise_train_test_split_i", - "sparse_mm_threaded", - "tf_idf_weight", +__all__ = [ +"okapi_BM_25_weight", +"remove_diagonal", +"rowwise_train_test_split_d", +"rowwise_train_test_split_f", +"rowwise_train_test_split_i", +"slim_weight_allow_negative", +"slim_weight_positive_only", +"sparse_mm_threaded", +"tf_idf_weight" ] - -def okapi_BM_25_weight( - X: scipy.sparse.csr_matrix[float64], k1: float = 1.2, b: float = 0.75 -) -> scipy.sparse.csr_matrix[float64]: - pass - -def remove_diagonal( - arg0: scipy.sparse.csr_matrix[float64], -) -> scipy.sparse.csr_matrix[float64]: - pass - -def rowwise_train_test_split_d( - arg0: scipy.sparse.csr_matrix[float64], arg1: float, arg2: int -) -> Tuple[scipy.sparse.csr_matrix[float64], scipy.sparse.csr_matrix[float64]]: - pass - -def rowwise_train_test_split_f( - arg0: scipy.sparse.csr_matrix[float32], arg1: float, arg2: int -) -> Tuple[scipy.sparse.csr_matrix[float32], scipy.sparse.csr_matrix[float32]]: - pass - -def rowwise_train_test_split_i( - arg0: scipy.sparse.csr_matrix[float32], arg1: float, arg2: int -) -> Tuple[scipy.sparse.csr_matrix[float32], scipy.sparse.csr_matrix[float32]]: - pass - -def sparse_mm_threaded( - arg0: scipy.sparse.csr_matrix[float64], - arg1: scipy.sparse.csc_matrix[float64], - arg2: int, -) -> scipy.sparse.csr_matrix[float64]: - pass - -def tf_idf_weight( - X: scipy.sparse.csr_matrix[float64], smooth: bool = True -) -> scipy.sparse.csr_matrix[float64]: +def okapi_BM_25_weight(X: scipy.sparse.csr_matrix[float64], k1: float = 1.2, b: float = 0.75) -> scipy.sparse.csr_matrix[float64]: + pass +def remove_diagonal(arg0: scipy.sparse.csr_matrix[float64]) -> scipy.sparse.csr_matrix[float64]: + pass +def rowwise_train_test_split_d(arg0: scipy.sparse.csr_matrix[float64], arg1: float, arg2: int) -> Tuple[scipy.sparse.csr_matrix[float64], scipy.sparse.csr_matrix[float64]]: + pass +def rowwise_train_test_split_f(arg0: scipy.sparse.csr_matrix[float32], arg1: float, arg2: int) -> Tuple[scipy.sparse.csr_matrix[float32], scipy.sparse.csr_matrix[float32]]: + pass +def rowwise_train_test_split_i(arg0: scipy.sparse.csr_matrix[float32], arg1: float, arg2: int) -> Tuple[scipy.sparse.csr_matrix[float32], scipy.sparse.csr_matrix[float32]]: + pass +def slim_weight_allow_negative(X: scipy.sparse.csr_matrix[float32], n_threads: int, n_iter: int, l2_coeff: float, l1_coeff: float) -> scipy.sparse.csc_matrix[float32]: + pass +def slim_weight_positive_only(X: scipy.sparse.csr_matrix[float32], n_threads: int, n_iter: int, l2_coeff: float, l1_coeff: float) -> scipy.sparse.csc_matrix[float32]: + pass +def sparse_mm_threaded(arg0: scipy.sparse.csr_matrix[float64], arg1: scipy.sparse.csc_matrix[float64], arg2: int) -> scipy.sparse.csr_matrix[float64]: + pass +def tf_idf_weight(X: scipy.sparse.csr_matrix[float64], smooth: bool = True) -> scipy.sparse.csr_matrix[float64]: pass diff --git a/tests/recommenders/test_slim.py b/tests/recommenders/test_slim.py new file mode 100644 index 0000000..fd63476 --- /dev/null +++ b/tests/recommenders/test_slim.py @@ -0,0 +1,47 @@ +from typing import Dict + +import numpy as np +import scipy.sparse as sps +from sklearn.linear_model import ElasticNet + +from irspack.recommenders import SLIMRecommender + + +def test_slim_positive(test_interaction_data: Dict[str, sps.csr_matrix]) -> None: + alpha = 0.1 + l1_ratio = 0.5 + X = test_interaction_data["X_small"] + rec = SLIMRecommender( + X, alpha=alpha, l1_ratio=l1_ratio, positive_only=True, n_iter=10, n_threads=8 + ) + rec.learn() + + enet = ElasticNet( + alpha=alpha, l1_ratio=l1_ratio, fit_intercept=False, positive=True, max_iter=10 + ) + for iind in range(rec.W.shape[1]): + m = rec.W[:, iind].toarray().ravel() + Xcp = X.toarray() + y = X[:, iind].toarray().ravel() + Xcp[:, iind] = 0.0 + enet.fit(Xcp, y) + np.testing.assert_allclose(enet.coef_, m, rtol=1e-2) + + +def test_slim_allow_negative(test_interaction_data: Dict[str, sps.csr_matrix]) -> None: + alpha = 0.1 + l1_ratio = 0.5 + X = test_interaction_data["X_small"] + rec = SLIMRecommender( + X, alpha=alpha, l1_ratio=l1_ratio, positive_only=False, n_iter=10, n_threads=8 + ) + rec.learn() + + enet = ElasticNet(alpha=alpha, l1_ratio=l1_ratio, fit_intercept=False, max_iter=10) + for iind in range(rec.W.shape[1]): + m = rec.W[:, iind].toarray().ravel() + Xcp = X.toarray() + y = X[:, iind].toarray().ravel() + Xcp[:, iind] = 0.0 + enet.fit(Xcp, y) + np.testing.assert_allclose(enet.coef_, m, rtol=1e-2)