From 149ccbe9140d59c381c2c6527c6a814587ca8aa9 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 3 Mar 2025 15:50:09 +0800 Subject: [PATCH 01/11] Add cats to the host adapter. --- R-package/src/Makevars.in | 1 + include/xgboost/data.h | 39 ++++- include/xgboost/predictor.h | 3 +- ops/script/lint_python.py | 2 + python-package/xgboost/_data_utils.py | 202 ++++++++++++++++++++-- python-package/xgboost/core.py | 97 +++++++++-- python-package/xgboost/data.py | 80 +++++---- python-package/xgboost/testing/ordinal.py | 127 ++++++++++++++ src/c_api/c_api.cc | 52 ++++++ src/data/adapter.cc | 53 +++++- src/data/adapter.h | 119 +++++++++++-- src/data/data.cc | 21 ++- src/data/proxy_dmatrix.h | 6 + src/data/simple_dmatrix.cc | 15 +- src/data/simple_dmatrix.h | 6 +- src/objective/lambdarank_obj.cu | 2 + tests/python/test_ordinal.py | 14 ++ tests/python/test_with_pandas.py | 2 +- 18 files changed, 749 insertions(+), 92 deletions(-) create mode 100644 python-package/xgboost/testing/ordinal.py create mode 100644 tests/python/test_ordinal.py diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index 5bd8f6f9e775..67e7a86b0033 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -66,6 +66,7 @@ OBJECTS= \ $(PKGROOT)/src/gbm/gblinear_model.o \ $(PKGROOT)/src/data/adapter.o \ $(PKGROOT)/src/data/array_interface.o \ + $(PKGROOT)/src/data/cat_container.o \ $(PKGROOT)/src/data/simple_dmatrix.o \ $(PKGROOT)/src/data/data.o \ $(PKGROOT)/src/data/sparse_page_raw_format.o \ diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 7d363893eb25..7f7ffcbddad2 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -1,5 +1,5 @@ /** - * Copyright 2015-2024, XGBoost Contributors + * Copyright 2015-2025, XGBoost Contributors * \file data.h * \brief The input data structure of xgboost. * \author Tianqi Chen @@ -8,8 +8,8 @@ #define XGBOOST_DATA_H_ #include -#include -#include +#include // for Stream +#include // for Handler #include #include #include @@ -42,8 +42,11 @@ enum class FeatureType : uint8_t { kNumerical = 0, kCategorical = 1 }; enum class DataSplitMode : int { kRow = 0, kCol = 1 }; -/*! - * \brief Meta information about dataset, always sit in memory. +// Forward declaration of the container used by the meta info. +struct CatContainer; + +/** + * @brief Meta information about dataset, always sit in memory. */ class MetaInfo { public: @@ -100,9 +103,9 @@ class MetaInfo { */ HostDeviceVector feature_weights; - /*! \brief default constructor */ - MetaInfo() = default; + MetaInfo(); MetaInfo(MetaInfo&& that) = default; + MetaInfo(MetaInfo const& that) = delete; MetaInfo& operator=(MetaInfo&& that) = default; MetaInfo& operator=(MetaInfo const& that) = delete; @@ -205,6 +208,16 @@ class MetaInfo { * @brief Flag for whether the DMatrix has categorical features. */ bool HasCategorical() const { return has_categorical_; } + /** + * @brief Getters for categories. + */ + [[nodiscard]] CatContainer const* Cats() const; + [[nodiscard]] CatContainer* Cats(); + [[nodiscard]] std::shared_ptr CatsShared() const; + /** + * @brief Setter for categories. + */ + void Cats(std::shared_ptr cats); private: void SetInfoFromHost(Context const* ctx, StringView key, Json arr); @@ -213,6 +226,8 @@ class MetaInfo { /*! \brief argsort of labels */ mutable std::vector label_order_cache_; bool has_categorical_{false}; + + std::shared_ptr cats_; }; /*! \brief Element from a sparse vector */ @@ -691,7 +706,15 @@ class DMatrix { * @param slice_id Index of the current slice * @return DMatrix containing the slice of columns */ - virtual DMatrix *SliceCol(int num_slices, int slice_id) = 0; + virtual DMatrix* SliceCol(int num_slices, int slice_id) = 0; + /** + * @brief Accessor for the string representation of the categories. + */ + CatContainer const* Cats() const { return this->CatsShared().get(); } + [[nodiscard]] virtual std::shared_ptr CatsShared() const { + LOG(FATAL) << "Not implemented for the current DMatrix type."; + return nullptr; + } protected: virtual BatchSet GetRowBatches() = 0; diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index ad89e54891c6..020e0a59d1e8 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -5,7 +5,8 @@ * performs predictions for a gradient booster. */ #pragma once -#include +#include // for FunctionRegEntryBase +#include // for bst_tree_t #include // for DMatrixCache #include // for Context #include diff --git a/ops/script/lint_python.py b/ops/script/lint_python.py index 0545c4b55644..c4267accebbb 100644 --- a/ops/script/lint_python.py +++ b/ops/script/lint_python.py @@ -27,6 +27,7 @@ class LintersPaths: "tests/python/test_early_stopping.py", "tests/python/test_multi_target.py", "tests/python/test_objectives.py", + "tests/python/test_ordinal.py", "tests/python/test_predict.py", "tests/python/test_quantile_dmatrix.py", "tests/python/test_tracker.py", @@ -101,6 +102,7 @@ class LintersPaths: "tests/python-gpu/load_pickle.py", "tests/python-gpu/test_gpu_training_continuation.py", "tests/python/test_model_io.py", + "tests/python/test_ordinal.py", "tests/test_distributed/test_federated/", "tests/test_distributed/test_gpu_federated/", "tests/test_distributed/test_with_dask/test_ranking.py", diff --git a/python-package/xgboost/_data_utils.py b/python-package/xgboost/_data_utils.py index 5229287f59e7..c405f82d09e1 100644 --- a/python-package/xgboost/_data_utils.py +++ b/python-package/xgboost/_data_utils.py @@ -2,15 +2,35 @@ import copy import ctypes +import functools import json -from typing import Literal, Optional, Protocol, Tuple, Type, TypedDict, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Literal, + Optional, + Protocol, + Tuple, + Type, + TypedDict, + TypeGuard, + Union, + cast, + overload, +) import numpy as np -from ._typing import CNumericPtr, DataType, NumpyOrCupy -from .compat import import_cupy +from ._typing import CNumericPtr, DataType, NumpyDType, NumpyOrCupy +from .compat import import_cupy, lazy_isinstance + +if TYPE_CHECKING: + import pandas as pd + import pyarrow as pa +# Used for accepting inputs for numpy and cupy arrays class _ArrayLikeArg(Protocol): @property def __array_interface__(self) -> "ArrayInf": ... @@ -44,19 +64,27 @@ def shape(self) -> Tuple[int, int]: }, ) +StringArray = TypedDict("StringArray", {"offsets": ArrayInf, "values": ArrayInf}) + def array_hasobject(data: DataType) -> bool: """Whether the numpy array has object dtype.""" return hasattr(data.dtype, "hasobject") and data.dtype.hasobject -def cuda_array_interface(data: DataType) -> bytes: - """Make cuda array interface str.""" +def cuda_array_interface_dict(data: _CudaArrayLikeArg) -> ArrayInf: + """Returns a dictionary storing the CUDA array interface.""" if array_hasobject(data): raise ValueError("Input data contains `object` dtype. Expecting numeric data.") - interface = data.__cuda_array_interface__ - if "mask" in interface: - interface["mask"] = interface["mask"].__cuda_array_interface__ + ainf = data.__cuda_array_interface__ + if "mask" in ainf: + ainf["mask"] = ainf["mask"].__cuda_array_interface__ # type: ignore + return cast(ArrayInf, ainf) + + +def cuda_array_interface(data: _CudaArrayLikeArg) -> bytes: + """Make cuda array interface str.""" + interface = cuda_array_interface_dict(data) interface_str = bytes(json.dumps(interface), "utf-8") return interface_str @@ -107,6 +135,12 @@ def __cuda_array_interface__(self, interface: ArrayInf) -> None: return out +# Default constant value for CUDA per-thread stream. +STREAM_PER_THREAD = 2 + + +# Typing is not strict as we there are subtle differences between CUDA array interface +# and array interface. We handle them uniformly for now. def make_array_interface( ptr: Union[CNumericPtr, int], shape: Tuple[int, ...], @@ -134,21 +168,157 @@ def make_array_interface( return array array["data"] = (addr, True) - if is_cuda: - array["stream"] = 2 + if is_cuda and "stream" not in array: + array["stream"] = STREAM_PER_THREAD array["shape"] = shape array["strides"] = None return array -def array_interface_dict(data: np.ndarray) -> ArrayInf: - """Convert array interface into a Python dictionary.""" +def is_arrow_dict(data: Any) -> TypeGuard["pa.DictionaryArray"]: + """Is this an arrow dictionary array?""" + return lazy_isinstance(data, "pyarrow.lib", "DictionaryArray") + + +class PdCatAccessor(Protocol): + """Protocol for pandas cat accessor.""" + + @property + def categories( # pylint: disable=missing-function-docstring + self, + ) -> "pd.Index": ... + + @property + def codes(self) -> "pd.Series": ... # pylint: disable=missing-function-docstring + + @property + def dtype(self) -> np.dtype: ... # pylint: disable=missing-function-docstring + + def to_arrow( # pylint: disable=missing-function-docstring + self, + ) -> Union["pa.StringArray", "pa.IntegerArray"]: ... + + @property + def __cuda_array_interface__(self) -> ArrayInf: ... + + +def _is_pd_cat(data: Any) -> TypeGuard[PdCatAccessor]: + # Test pd.Series.cat, not pd.Series + return hasattr(data, "categories") and hasattr(data, "codes") + + +@functools.cache +def _arrow_typestr() -> Dict["pa.DataType", str]: + import pyarrow as pa + + mapping = { + pa.int8(): " Tuple[np.ndarray, str]: + """Convert a numpy string array to an arrow string array.""" + lenarr = np.vectorize(len) + offsets = np.cumsum(np.concatenate([np.array([0], dtype=np.int64), lenarr(strarr)])) + values = strarr.sum() + assert "\0" not in values # arrow string array doesn't need null terminal + return offsets.astype(np.int32), values + + +def _ensure_np_dtype( + data: DataType, dtype: Optional[NumpyDType] +) -> Tuple[np.ndarray, Optional[NumpyDType]]: + """Ensure the np array has correct type and is contiguous.""" + if array_hasobject(data) or data.dtype in [np.float16, np.bool_]: + dtype = np.float32 + data = data.astype(dtype, copy=False) + if not data.flags.aligned: + data = np.require(data, requirements="A") + return data, dtype + + +@overload +def array_interface_dict(data: np.ndarray) -> ArrayInf: ... + + +@overload +def array_interface_dict( + data: PdCatAccessor, +) -> Tuple[StringArray, ArrayInf, Tuple]: ... + + +@overload +def array_interface_dict( + data: "pa.DictionaryArray", +) -> Tuple[StringArray, ArrayInf, Tuple]: ... + + +def array_interface_dict( # pylint: disable=too-many-locals + data: Union[np.ndarray, PdCatAccessor], +) -> Union[ArrayInf, Tuple[StringArray, ArrayInf, Optional[Tuple]]]: + """Returns an array interface from the input.""" + # Handle categorical values + if _is_pd_cat(data): + cats = data.categories + # pandas uses -1 to represent missing values for categorical features + codes = data.codes.replace(-1, np.nan) + + if np.issubdtype(cats.dtype, np.floating) or np.issubdtype( + cats.dtype, np.integer + ): + # Numeric index type + name_values = cats.values + jarr_values = array_interface_dict(name_values) + code_values = codes.values + jarr_codes = array_interface_dict(code_values) + return jarr_values, jarr_codes, (name_values, code_values) + + # String index type + name_offsets, name_values = npstr_to_arrow_strarr(cats.values) + name_offsets, _ = _ensure_np_dtype(name_offsets, np.int32) + joffsets = array_interface_dict(name_offsets) + bvalues = name_values.encode("utf-8") + ptr = ctypes.c_void_p.from_buffer(ctypes.c_char_p(bvalues)).value + assert ptr is not None + + jvalues: ArrayInf = { + "data": (ptr, True), + "typestr": "|i1", + "shape": (len(name_values),), + "strides": None, + "version": 3, + "mask": None, + } + jnames: StringArray = {"offsets": joffsets, "values": jvalues} + + code_values = codes.values + jcodes = array_interface_dict(code_values) + + buf = ( + name_offsets, + name_values, + bvalues, + code_values, + ) # store temporary values + return jnames, jcodes, buf + + # Handle numeric values + assert isinstance(data, np.ndarray) if array_hasobject(data): raise ValueError("Input data contains `object` dtype. Expecting numeric data.") - arrinf = data.__array_interface__ - if "mask" in arrinf: - arrinf["mask"] = arrinf["mask"].__array_interface__ - return cast(ArrayInf, arrinf) + ainf = data.__array_interface__ + if "mask" in ainf: + ainf["mask"] = ainf["mask"].__array_interface__ + return cast(ArrayInf, ainf) def array_interface(data: np.ndarray) -> bytes: diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index b6c3414339d2..d669a5181250 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -16,6 +16,7 @@ from inspect import Parameter, signature from types import EllipsisType from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -65,9 +66,19 @@ TransformedData, c_bst_ulong, ) -from .compat import PANDAS_INSTALLED, DataFrame, import_polars, py_str +from .compat import ( + PANDAS_INSTALLED, + DataFrame, + import_polars, + import_pyarrow, + is_pyarrow_available, + py_str, +) from .libpath import find_lib_path, is_sphinx_build +if TYPE_CHECKING: + import pyarrow as pa + class XGBoostError(ValueError): """Error thrown by xgboost trainer.""" @@ -321,9 +332,9 @@ def _check_distributed_params(kwargs: Dict[str, Any]) -> None: if device and device.find(":") != -1: if device != "sycl:gpu": raise ValueError( - "Distributed training doesn't support selecting device ordinal as GPUs are" - " managed by the distributed frameworks. use `device=cuda` or `device=gpu`" - " instead." + "Distributed training doesn't support selecting device ordinal as GPUs" + " are managed by the distributed frameworks. use `device=cuda` or" + " `device=gpu` instead." ) if kwargs.get("booster", None) == "gblinear": @@ -781,8 +792,8 @@ def __init__( types. Note that, if passing an iterator, it **will cache data on disk**, and note - that fields like ``label`` will be concatenated in-memory from multiple calls - to the iterator. + that fields like ``label`` will be concatenated in-memory from multiple + calls to the iterator. label : Label of the training data. weight : @@ -1240,6 +1251,70 @@ def get_quantile_cut(self) -> Tuple[np.ndarray, np.ndarray]: assert data.dtype == np.float32 return indptr, data + def get_categories(self) -> Optional[Dict[str, "pa.DictionaryArray"]]: + """Get the categories in the dataset. Return `None` if there's no categorical + features. + + .. warning:: + + This function is still working in progress. + + .. versionadded:: 3.1.0 + + """ + if not is_pyarrow_available(): + raise ImportError("`pyarrow` is required for exporting categories.") + + if TYPE_CHECKING: + import pyarrow as pa + else: + pa = import_pyarrow() + + n_features = self.num_col() + fnames = self.feature_names + if fnames is None: + fnames = [str(i) for i in range(n_features)] + + results: Dict[str, "pa.DictionaryArray"] = {} + + ret = ctypes.c_char_p() + _check_call(_LIB.XGBDMatrixGetCategories(self.handle, ctypes.byref(ret))) + if ret.value is None: + return None + + retstr = ret.value.decode() # pylint: disable=no-member + jcats = json.loads(retstr) + assert isinstance(jcats, list) and len(jcats) == n_features + + for fidx in range(n_features): + f_jcats = jcats[fidx] + if f_jcats is None: + # Numeric data + results[fnames[fidx]] = None + continue + + if "offsets" not in f_jcats: + values = from_array_interface(f_jcats) + pa_values = pa.Array.from_pandas(values) + results[fnames[fidx]] = pa_values + continue + + joffsets = f_jcats["offsets"] + jvalues = f_jcats["values"] + offsets = from_array_interface(joffsets, True) + values = from_array_interface(jvalues, True) + pa_offsets = pa.array(offsets).buffers() + pa_values = pa.array(values).buffers() + assert ( + pa_offsets[0] is None and pa_values[0] is None + ), "Should not have null mask." + pa_dict = pa.StringArray.from_buffers( + len(offsets) - 1, pa_offsets[1], pa_values[1] + ) + results[fnames[fidx]] = pa_dict + + return results + def num_row(self) -> int: """Get the number of rows in the DMatrix.""" ret = c_bst_ulong() @@ -1520,7 +1595,8 @@ class QuantileDMatrix(DMatrix, _RefMixIn): X, y = make_regression() X_train, X_test, y_train, y_test = train_test_split(X, y) Xy_train = xgb.QuantileDMatrix(X_train, y_train) - # It's necessary to have the training DMatrix as a reference for valid quantiles. + # It's necessary to have the training DMatrix as a reference for valid + # quantiles. Xy_test = xgb.QuantileDMatrix(X_test, y_test, ref=Xy_train) Parameters @@ -2671,7 +2747,8 @@ def inplace_predict( if validate_features: if not hasattr(data, "shape"): raise TypeError( - "`shape` attribute is required when `validate_features` is True." + "`shape` attribute is required when `validate_features` is True" + f", got: {type(data)}" ) if len(data.shape) != 1 and self.num_features() != data.shape[1]: raise ValueError( @@ -2750,13 +2827,13 @@ def inplace_predict( data, cat_codes, fns, _ = _transform_cudf_df( data, None, None, enable_categorical ) - interfaces_str = _cudf_array_interfaces(data, cat_codes) + array_inf, _ = _cudf_array_interfaces(data, cat_codes) if validate_features: self._validate_features(fns) _check_call( _LIB.XGBoosterPredictFromCudaColumnar( self.handle, - interfaces_str, + array_inf, args, p_handle, ctypes.byref(shape), diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 3ac5a6cc5376..039fa35e88fa 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -25,12 +25,16 @@ from ._data_utils import ( ArrayInf, + PdCatAccessor, TransformedDf, + _ensure_np_dtype, + _is_pd_cat, array_hasobject, array_interface, array_interface_dict, check_cudf_meta, cuda_array_interface, + is_arrow_dict, make_array_interface, ) from ._typing import ( @@ -45,9 +49,13 @@ TransformedData, c_bst_ulong, ) -from .compat import DataFrame -from .compat import Series as PdSeries -from .compat import import_polars, import_pyarrow, is_pyarrow_available, lazy_isinstance +from .compat import ( + DataFrame, + import_polars, + import_pyarrow, + is_pyarrow_available, + lazy_isinstance, +) from .core import ( _LIB, DataIter, @@ -62,6 +70,8 @@ if TYPE_CHECKING: import pyarrow as pa + from pandas import Series as PdSeries + DispatchedDataBackendReturnType = Tuple[ ctypes.c_void_p, Optional[FeatureNames], Optional[FeatureTypes] @@ -231,17 +241,6 @@ def _is_np_array_like(data: DataType) -> TypeGuard[np.ndarray]: return hasattr(data, "__array_interface__") -def _ensure_np_dtype( - data: DataType, dtype: Optional[NumpyDType] -) -> Tuple[np.ndarray, Optional[NumpyDType]]: - if array_hasobject(data) or data.dtype in [np.float16, np.bool_]: - dtype = np.float32 - data = data.astype(dtype, copy=False) - if not data.flags.aligned: - data = np.require(data, requirements="A") - return data, dtype - - def _maybe_np_slice(data: DataType, dtype: Optional[NumpyDType]) -> np.ndarray: """Handle numpy slice. This can be removed if we use __array_interface__.""" try: @@ -536,22 +535,17 @@ def _lazy_load_pd_floats() -> tuple: return Float32Dtype, Float64Dtype -def pandas_transform_data(data: DataFrame) -> List[np.ndarray]: +def pandas_transform_data(data: DataFrame) -> List[Union[np.ndarray, PdCatAccessor]]: """Handle categorical dtype and extension types from pandas.""" Float32Dtype, Float64Dtype = _lazy_load_pd_floats() - result: List[np.ndarray] = [] + result: List[Union[np.ndarray, PdCatAccessor]] = [] np_dtypes = _lazy_has_npdtypes() - def cat_codes(ser: PdSeries) -> np.ndarray: - return _ensure_np_dtype( - ser.cat.codes.astype(np.float32) - .replace(-1.0, np.nan) - .to_numpy(na_value=np.nan), - np.float32, - )[0] + def cat_codes(ser: "PdSeries") -> PdCatAccessor: + return ser.cat - def nu_type(ser: PdSeries) -> np.ndarray: + def nu_type(ser: "PdSeries") -> np.ndarray: # Avoid conversion when possible if isinstance(dtype, Float32Dtype): res_dtype: NumpyDType = np.float32 @@ -563,7 +557,7 @@ def nu_type(ser: PdSeries) -> np.ndarray: ser.to_numpy(dtype=res_dtype, na_value=np.nan), res_dtype )[0] - def oth_type(ser: PdSeries) -> np.ndarray: + def oth_type(ser: "PdSeries") -> np.ndarray: # The dtypes module is added in 1.25. npdtypes = np_dtypes and isinstance( ser.dtype, @@ -612,19 +606,47 @@ def oth_type(ser: PdSeries) -> np.ndarray: class PandasTransformed(TransformedDf): """A storage class for transformed pandas DataFrame.""" - def __init__(self, columns: List[np.ndarray]) -> None: + def __init__( + self, columns: List[Union[np.ndarray, PdCatAccessor, "pa.DictionaryType"]] + ) -> None: self.columns = columns + aitfs = [] + self.temporary_buffers = [] + + # Get the array interface representation for each column. + for col in self.columns: + inf = array_interface_dict(col) + if isinstance(inf, tuple): + # Categorical column + jnames, jcodes, buf = inf + # Store the transformed results to avoid garbage collection. + self.temporary_buffers.append(buf) + aitfs.append([jnames, jcodes]) + else: + # Numeric column + aitfs.append(inf) + + self.aitfs = aitfs + def array_interface(self) -> bytes: """Return a byte string for JSON encoded array interface.""" - aitfs = list(map(array_interface_dict, self.columns)) - sarrays = bytes(json.dumps(aitfs), "utf-8") + sarrays = bytes(json.dumps(self.aitfs), "utf-8") return sarrays @property def shape(self) -> Tuple[int, int]: """Return shape of the transformed DataFrame.""" - return self.columns[0].shape[0], len(self.columns) + if is_arrow_dict(self.columns[0]): + # When input is arrow. (cuDF) + n_samples = len(self.columns[0].indices) + elif _is_pd_cat(self.columns[0]): + # When input is pandas. + n_samples = self.columns[0].codes.shape[0] + else: + # Anything else, TypeGuard is ignored by mypy 1.15.0 for some reason + n_samples = self.columns[0].shape[0] # type: ignore + return n_samples, len(self.columns) def _transform_pandas_df( diff --git a/python-package/xgboost/testing/ordinal.py b/python-package/xgboost/testing/ordinal.py new file mode 100644 index 000000000000..0278de5c8f68 --- /dev/null +++ b/python-package/xgboost/testing/ordinal.py @@ -0,0 +1,127 @@ +# pylint: disable=invalid-name +"""Tests for the ordinal re-coder.""" + +from typing import Any, Tuple, Type + +import numpy as np + +from ..compat import import_cupy +from ..core import DMatrix +from .data import is_pd_cat_dtype, make_categorical + + +def get_df_impl(device: str) -> Tuple[Type, Type]: + """Get data frame implementation based on the ]device.""" + if device == "cpu": + import pandas as pd + + Df = pd.DataFrame + Ser = pd.Series + else: + import cudf + + Df = cudf.DataFrame + Ser = cudf.Series + return Df, Ser + + +def assert_allclose(device: str, a: Any, b: Any) -> None: + """Dispatch the assert_allclose for devices.""" + if device == "cpu": + np.testing.assert_allclose(a, b) + else: + cp = import_cupy() + cp.testing.assert_allclose(a, b) + + +def run_cat_container(device: str) -> None: + """Basic tests for the container class used by the DMatrix.""" + Df, _ = get_df_impl(device) + # Basic test with a single feature + df = Df({"c": ["cdef", "abc"]}, dtype="category") + categories = df.c.cat.categories + + Xy = DMatrix(df, enable_categorical=True) + results = Xy.get_categories() + assert results is not None + assert len(results["c"]) == len(categories) + for i in range(len(results["c"])): + assert str(results["c"][i]) == str(categories[i]), ( + results["c"][i], + categories[i], + ) + + # Test with missing values. + df = Df({"c": ["cdef", None, "abc", "abc"]}, dtype="category") + Xy = DMatrix(df, enable_categorical=True) + + cats = Xy.get_categories() + assert cats is not None + ser = cats["c"].to_pandas() + assert ser.iloc[0] == "abc" + assert ser.iloc[1] == "cdef" + assert ser.size == 2 + + csr = Xy.get_data() + assert csr.data.size == 3 + assert_allclose(device, csr.data, np.array([1.0, 0.0, 0.0])) + assert_allclose(device, csr.indptr, np.array([0, 1, 1, 2, 3])) + assert_allclose(device, csr.indices, np.array([0, 0, 0])) + + # Test with explicit null-terminated strings. + df = Df({"c": ["cdef", None, "abc", "abc\0"]}, dtype="category") + Xy = DMatrix(df, enable_categorical=True) + + +def run_cat_container_mixed() -> None: + """Run checks with mixed types.""" + import pandas as pd + + def check(Xy: DMatrix, X: pd.DataFrame) -> None: + cats = Xy.get_categories() + assert cats is not None + + for fname in X.columns: + if is_pd_cat_dtype(X[fname].dtype): + aw_list = sorted(cats[fname].to_pylist()) + pd_list: list = X[fname].unique().tolist() + if np.nan in pd_list: + pd_list.remove(np.nan) + pd_list = sorted(pd_list) + assert aw_list == pd_list + else: + assert cats[fname] is None + + # full str type + X, y = make_categorical(256, 16, 7, onehot=False, cat_dtype=np.str_) + Xy = DMatrix(X, y, enable_categorical=True) + check(Xy, X) + + # str type, mixed with numerical features + X, y = make_categorical(256, 16, 7, onehot=False, cat_ratio=0.5, cat_dtype=np.str_) + Xy = DMatrix(X, y, enable_categorical=True) + check(Xy, X) + + # str type, mixed with numerical features and missing values + X, y = make_categorical( + 256, 16, 7, onehot=False, cat_ratio=0.5, sparsity=0.5, cat_dtype=np.str_ + ) + Xy = DMatrix(X, y, enable_categorical=True) + check(Xy, X) + + # int type + X, y = make_categorical(256, 16, 7, onehot=False, cat_dtype=np.int64) + Xy = DMatrix(X, y, enable_categorical=True) + check(Xy, X) + + # int type, mixed with numerical features + X, y = make_categorical(256, 16, 7, onehot=False, cat_ratio=0.5, cat_dtype=np.int64) + Xy = DMatrix(X, y, enable_categorical=True) + check(Xy, X) + + # int type, mixed with numerical features and missing values + X, y = make_categorical( + 256, 16, 7, onehot=False, cat_ratio=0.5, sparsity=0.5, cat_dtype=np.int64 + ) + Xy = DMatrix(X, y, enable_categorical=True) + check(Xy, X) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 3fbfbc8b9792..f64bd3324406 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -23,9 +23,11 @@ #include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor #include "../data/adapter.h" // for ArrayAdapter, DenseAdapter, RecordBatchesIte... #include "../data/batch_utils.h" // for MatchingPageBytes, CachePageRatio +#include "../data/cat_container.h" // for CatContainer #include "../data/ellpack_page.h" // for EllpackPage #include "../data/proxy_dmatrix.h" // for DMatrixProxy #include "../data/simple_dmatrix.h" // for SimpleDMatrix +#include "../encoder/types.h" // for Overloaded #include "c_api_error.h" // for xgboost_CHECK_C_ARG_PTR, API_END, API_BEGIN #include "c_api_utils.h" // for RequiredArg, OptionalArg, GetMissing, CastDM... #include "dmlc/base.h" // for BeginPtr @@ -718,6 +720,56 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field, API_END(); } +XGB_DLL int XGBDMatrixGetCategories(DMatrixHandle handle, char const **out) { + // We can directly use the storage in the cat container instead of allocating temporary storage. + API_BEGIN() + CHECK_HANDLE() + auto const p_fmat = *static_cast *>(handle); + auto const cats = p_fmat->Cats()->HostView(); + + auto &ret_str = p_fmat->GetThreadLocal().ret_str; + xgboost_CHECK_C_ARG_PTR(out); + + if (cats.Empty()) { + *out = nullptr; + } else { + Json jout{Array{}}; + auto n_features = p_fmat->Info().num_col_; + for (decltype(n_features) f_idx = 0; f_idx < n_features; ++f_idx) { + auto const &col = cats[f_idx]; + if (std::visit([](auto &&arg) { return arg.empty(); }, col)) { + get(jout).emplace_back(); + continue; + } + std::visit(enc::Overloaded{[&](enc::CatStrArrayView const &str) { + auto const &offsets = str.offsets; + auto ovec = linalg::MakeVec(offsets.data(), offsets.size()); + auto jovec = linalg::ArrayInterface(ovec); + + auto const &values = str.values; + auto dvec = linalg::MakeVec(values.data(), values.size()); + auto jdvec = linalg::ArrayInterface(dvec); + + get(jout).emplace_back(Object{}); + get(jout).back()["offsets"] = std::move(jovec); + get(jout).back()["values"] = std::move(jdvec); + }, + [&](auto &&values) { + auto vec = linalg::MakeVec(values.data(), values.size()); + auto jvec = linalg::ArrayInterface(vec); + get(jout).emplace_back(std::move(jvec)); + }}, + col); + } + auto str = Json::Dump(jout); + ret_str = std::move(str); + + *out = ret_str.c_str(); + } + + API_END() +} + XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void const *data, xgboost::bst_ulong size, int type) { API_BEGIN(); diff --git a/src/data/adapter.cc b/src/data/adapter.cc index 4fa171c9d14a..656d7e82c877 100644 --- a/src/data/adapter.cc +++ b/src/data/adapter.cc @@ -1,19 +1,66 @@ /** - * Copyright 2019-2023, XGBoost Contributors + * Copyright 2019-2025, XGBoost Contributors */ #include "adapter.h" +#include // for move + #include "../c_api/c_api_error.h" // for API_BEGIN, API_END +#include "array_interface.h" // for ArrayInterface #include "xgboost/c_api.h" +#include "xgboost/logging.h" namespace xgboost::data { +ColumnarAdapter::ColumnarAdapter(StringView columns) { + auto jarray = Json::Load(columns); + CHECK(IsA(jarray)); + auto const& array = get(jarray); + bst_idx_t n_samples{0}; + std::vector cat_segments{0}; + for (auto const& jcol : array) { + std::int32_t n_cats{0}; + if (IsA(jcol)) { + // This is a dictionary type (categorical values). + auto const& first = get(jcol[0]); + if (first.find("offsets") == first.cend()) { + // numeric index + n_cats = GetArrowNumericIndex(DeviceOrd::CPU(), jcol, &this->cats_, &this->columns_, + &this->n_bytes_, &n_samples); + } else { + // string index + n_cats = + GetArrowDictionary(jcol, &this->cats_, &this->columns_, &this->n_bytes_, &n_samples); + } + } else { + // Numeric values + columns_.emplace_back(get(jcol)); + this->cats_.emplace_back(); + this->n_bytes_ += columns_.back().ElementSize() * columns_.back().Shape<0>(); + n_samples = std::max(n_samples, columns_.back().Shape<0>()); + } + cat_segments.push_back(n_cats); + } + std::partial_sum(cat_segments.cbegin(), cat_segments.cend(), cat_segments.begin()); + auto no_overflow = std::is_sorted(cat_segments.cbegin(), cat_segments.cend()); + CHECK(no_overflow) << "Maximum number of categories exceeded."; + + // Check consistency. + bool consistent = columns_.empty() || std::all_of(columns_.cbegin(), columns_.cend(), + [&](ArrayInterface<1> const& array) { + return array.Shape<0>() == n_samples; + }); + this->cat_segments_ = std::move(cat_segments); + CHECK(consistent) << "Size of columns should be the same."; + batch_ = ColumnarAdapterBatch{columns_}; +} + template bool IteratorAdapter::Next() { if ((*next_callback_)( data_handle_, - [](void *handle, XGBoostBatchCSR batch) -> int { + [](void* handle, XGBoostBatchCSR batch) -> int { API_BEGIN(); - static_cast(handle)->SetData(batch); + static_cast(handle)->SetData(batch); API_END(); }, this) != 0) { diff --git a/src/data/adapter.h b/src/data/adapter.h index 0888a2f86b4f..b8ab21017b24 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -13,10 +13,13 @@ #include // for numeric_limits #include // for unique_ptr, make_unique #include // for move +#include // for variant #include // for vector #include "../common/math.h" -#include "array_interface.h" +#include "../encoder/ordinal.h" // for CatStrArrayView +#include "../encoder/types.h" // for TupToVarT +#include "array_interface.h" // for CategoricalIndexArgTypes #include "xgboost/base.h" #include "xgboost/data.h" #include "xgboost/logging.h" @@ -568,41 +571,121 @@ class ColumnarAdapterBatch : public detail::NoMetaInfo { static constexpr bool kIsRowMajor = true; }; +/** + * @brief Get string names and codes for categorical features. + * + * @return The number of categories for the current column. + */ +template +[[nodiscard]] std::size_t GetArrowDictionary(Json jcol, + std::vector* p_cat_columns, + std::vector>* p_columns, + bst_idx_t* p_n_bytes, bst_idx_t* p_n_samples) { + auto& cat_columns = *p_cat_columns; + // arrow StringArray for name of categories + auto const& jnames = get(jcol[0]); + // There are 3 buffers for a StringArray, validity mask, offset, and data. Mask + // and data are represented by a single masked array. + auto const& joffset = get(jnames.at("offsets")); + auto offset = ArrayInterface<1>{joffset}; + auto const& jstr = get(jnames.at("values")); + auto strbuf = ArrayInterface<1>(jstr); + CHECK_EQ(strbuf.type, ArrayInterfaceHandler::kI1); + + auto names = enc::CatStrArrayView{ + common::Span{static_cast(offset.data), offset.Shape<0>()}, + common::Span{reinterpret_cast(strbuf.data), strbuf.n}}; + cat_columns.emplace_back(names); + + // arrow Integer array for encoded categories + auto const& jcodes = get(jcol[1]); + auto codes = ArrayInterface<1>{jcodes}; + p_columns->push_back(codes); + + auto& n_bytes = *p_n_bytes; + n_bytes += codes.ElementSize() * codes.Shape<0>(); + n_bytes += names.SizeBytes(); + + *p_n_samples = std::max(*p_n_samples, codes.Shape<0>()); + return names.size(); +} + +/** + * @brief Get numeric names and codes for categorical features. + * + * @return The number of categories for the current column. + */ +template +[[nodiscard]] std::size_t GetArrowNumericIndex( + DeviceOrd device, Json jcol, std::vector* p_cat_columns, + std::vector>* p_columns, bst_idx_t* p_n_bytes, + bst_idx_t* p_n_samples) { + auto const& first = get(jcol[0]); + auto names = ArrayInterface<1>{first}; + auto& n_bytes = *p_n_bytes; + DispatchDType(names, device, [&](auto t) { + using T = typename decltype(t)::value_type; + constexpr bool kKnownType = enc::MemberOf, enc::CatPrimIndexTypes>::value; + CHECK(kKnownType) << "Unsupported categorical index type."; + auto span = common::Span{t.Values().data(), t.Size()}; + if constexpr (kKnownType) { + p_cat_columns->emplace_back(span); + n_bytes += span.size_bytes(); + } + }); + auto const& jcodes = get(jcol[1]); + auto codes = ArrayInterface<1>{jcodes}; + p_columns->push_back(codes); + + n_bytes += codes.ElementSize() * codes.Shape<0>(); + *p_n_samples = std::max(*p_n_samples, codes.Shape<0>()); + + return names.n; +} + +/** + * @brief Adapter for columnar format (arrow). + * + * Supports for both numeric values and categorical values. + */ class ColumnarAdapter : public detail::SingleBatchDataIter { std::vector> columns_; + std::vector cats_; + std::vector cat_segments_; ColumnarAdapterBatch batch_; + std::size_t n_bytes_{0}; public: - explicit ColumnarAdapter(StringView columns) { - auto jarray = Json::Load(columns); - CHECK(IsA(jarray)); - auto const& array = get(jarray); - for (auto col : array) { - columns_.emplace_back(get(col)); - } - bool consistent = - columns_.empty() || - std::all_of(columns_.cbegin(), columns_.cend(), [&](ArrayInterface<1> const& array) { - return array.Shape<0>() == columns_[0].Shape<0>(); - }); - CHECK(consistent) << "Size of columns should be the same."; - batch_ = ColumnarAdapterBatch{columns_}; - } + /** + * @brief JSON-encoded array of columns. + */ + explicit ColumnarAdapter(StringView columns); [[nodiscard]] ColumnarAdapterBatch const& Value() const override { return batch_; } - [[nodiscard]] std::size_t NumRows() const { + [[nodiscard]] bst_idx_t NumRows() const { if (!columns_.empty()) { return columns_.front().shape[0]; } return 0; } - [[nodiscard]] std::size_t NumColumns() const { + [[nodiscard]] bst_idx_t NumColumns() const { if (!columns_.empty()) { return columns_.size(); } return 0; } + [[nodiscard]] bool HasCategorical() const { + return !std::all_of(this->cats_.cbegin(), this->cats_.cend(), [](auto const& cats) { + return std::visit([](auto&& cats) { return cats.empty(); }, cats); + }); + } + [[nodiscard]] std::size_t SizeBytes() const { return n_bytes_; } + + [[nodiscard]] enc::HostColumnsView Cats() const { + return {this->cats_, this->cat_segments_, + static_cast(this->cat_segments_.back())}; + } }; class FileAdapterBatch { diff --git a/src/data/data.cc b/src/data/data.cc index 1cb4c2bc0385..8f240d415655 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -1,5 +1,5 @@ /** - * Copyright 2015-2024, XGBoost Contributors + * Copyright 2015-2025, XGBoost Contributors * \file data.cc */ #include "xgboost/data.h" @@ -36,6 +36,7 @@ #include "./sparse_page_dmatrix.h" // for SparsePageDMatrix #include "array_interface.h" // for ArrayInterfaceHandler, ArrayInterface, Dispa... #include "batch_utils.h" // for MatchingPageBytes +#include "cat_container.h" // for CatContainer #include "dmlc/base.h" // for BeginPtr #include "dmlc/common.h" // for OMPException #include "dmlc/data.h" // for Parser @@ -199,6 +200,8 @@ namespace xgboost { uint64_t constexpr MetaInfo::kNumField; +MetaInfo::MetaInfo() : cats_{std::make_shared()} {} + // implementation of inline functions void MetaInfo::Clear() { num_row_ = num_col_ = num_nonzero_ = 0; @@ -232,6 +235,9 @@ void MetaInfo::Clear() { */ void MetaInfo::SaveBinary(dmlc::Stream *fo) const { + if (!this->Cats()->Empty()) { + LOG(FATAL) << "Cannot save binary when there are category indices."; + } Version::Save(fo); fo->Write(kNumField); int field_cnt = 0; // make sure we are actually writing kNumField fields @@ -852,6 +858,19 @@ bool MetaInfo::ShouldHaveLabels() const { return !IsVerticalFederated() || collective::GetRank() == 0; } +[[nodiscard]] CatContainer const* MetaInfo::Cats() const { return this->cats_.get(); } +[[nodiscard]] CatContainer* MetaInfo::Cats() { return this->cats_.get(); } + +[[nodiscard]] std::shared_ptr MetaInfo::CatsShared() const { + return this->cats_; +} + +void MetaInfo::Cats(std::shared_ptr cats) { + this->cats_ = std::move(cats); + CHECK_LT(cats_->NumFeatures(), + static_castNumFeatures())>(std::numeric_limits::max())); +} + using DMatrixThreadLocal = dmlc::ThreadLocalStore>; diff --git a/src/data/proxy_dmatrix.h b/src/data/proxy_dmatrix.h index 06cfbd9946a1..4f067829db95 100644 --- a/src/data/proxy_dmatrix.h +++ b/src/data/proxy_dmatrix.h @@ -254,6 +254,9 @@ namespace cuda_impl { [[nodiscard]] bst_idx_t BatchColumns(DMatrixProxy const*); } // namespace cuda_impl +/** + * @brief Get the number of samples for the current batch. + */ [[nodiscard]] inline bst_idx_t BatchSamples(DMatrixProxy const* proxy) { bool type_error = false; auto n_samples = @@ -264,6 +267,9 @@ namespace cuda_impl { return n_samples; } +/** + * @brief Get the number of features for the current batch. + */ [[nodiscard]] inline bst_feature_t BatchColumns(DMatrixProxy const* proxy) { bool type_error = false; auto n_features = diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index b80e33ae3585..0cdaccad4109 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -1,5 +1,5 @@ /** - * Copyright 2014-2024, XGBoost Contributors + * Copyright 2014-2025, XGBoost Contributors * \file simple_dmatrix.cc * \brief the input data structure for gradient boosting * \author Tianqi Chen @@ -12,13 +12,14 @@ #include #include -#include "../collective/communicator-inl.h" // for GetWorldSize, GetRank, Allgather #include "../collective/allgather.h" +#include "../collective/communicator-inl.h" // for GetWorldSize, GetRank, Allgather #include "../common/error_msg.h" // for InconsistentMaxBin #include "./simple_batch_iterator.h" #include "adapter.h" -#include "batch_utils.h" // for CheckEmpty, RegenGHist -#include "ellpack_page.h" // for EllpackPage +#include "batch_utils.h" // for CheckEmpty, RegenGHist +#include "cat_container.h" // for CatContainer +#include "ellpack_page.h" // for EllpackPage #include "gradient_index.h" #include "xgboost/c_api.h" #include "xgboost/data.h" @@ -287,6 +288,12 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread, info_.num_col_ = adapter->NumColumns(); } + if constexpr (std::is_same_v) { + if (adapter->HasCategorical()) { + info_.Cats(std::make_shared(adapter->Cats())); + } + } + // Must called before sync column this->ReindexFeatures(&ctx, data_split_mode); this->info_.SynchronizeNumberOfColumns(&ctx, data_split_mode); diff --git a/src/data/simple_dmatrix.h b/src/data/simple_dmatrix.h index cef8e01ecdfe..610fcea584c1 100644 --- a/src/data/simple_dmatrix.h +++ b/src/data/simple_dmatrix.h @@ -1,5 +1,5 @@ /** - * Copyright 2015-2023, XGBoost Contributors + * Copyright 2015-2025, XGBoost Contributors * \file simple_dmatrix.h * \brief In-memory version of DMatrix. * \author Tianqi Chen @@ -33,6 +33,10 @@ class SimpleDMatrix : public DMatrix { const MetaInfo& Info() const override; Context const* Ctx() const override { return &fmat_ctx_; } + std::shared_ptr CatsShared() const override { + return this->Info().CatsShared(); + } + DMatrix* Slice(common::Span ridxs) override; DMatrix* SliceCol(int num_slices, int slice_id) override; diff --git a/src/objective/lambdarank_obj.cu b/src/objective/lambdarank_obj.cu index 3f076eaff045..eae067a56649 100644 --- a/src/objective/lambdarank_obj.cu +++ b/src/objective/lambdarank_obj.cu @@ -3,6 +3,8 @@ * * \brief CUDA implementation of lambdarank. */ +#include // for DMLC_REGISTRY_FILE_TAG + #include // for fill_n #include // for for_each_n #include // for make_counting_iterator diff --git a/tests/python/test_ordinal.py b/tests/python/test_ordinal.py new file mode 100644 index 000000000000..837ec883a72d --- /dev/null +++ b/tests/python/test_ordinal.py @@ -0,0 +1,14 @@ +import pytest + +from xgboost import testing as tm +from xgboost.testing.ordinal import run_cat_container, run_cat_container_mixed + +pytestmark = pytest.mark.skipif(**tm.no_multiple(tm.no_arrow(), tm.no_pandas())) + + +def test_cat_container() -> None: + run_cat_container("cpu") + + +def test_cat_container_mixed() -> None: + run_cat_container_mixed() diff --git a/tests/python/test_with_pandas.py b/tests/python/test_with_pandas.py index 27be831d3f88..8c30bb354b7e 100644 --- a/tests/python/test_with_pandas.py +++ b/tests/python/test_with_pandas.py @@ -245,7 +245,7 @@ def test_pandas_categorical(self, data_split_mode=DataSplitMode.ROW): X, enable_categorical=True ) - assert transformed.columns[0].min() == 0 + assert len(transformed.aitfs[0]) == 2 # test missing value X = pd.DataFrame({"f0": ["a", "b", np.nan]}) From 6d039a46ad116155d3380de6c13c21eff6a9402c Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 3 Mar 2025 19:47:32 +0800 Subject: [PATCH 02/11] macos. --- src/data/adapter.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/data/adapter.h b/src/data/adapter.h index b8ab21017b24..5c5c52ddb467 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -606,7 +606,7 @@ template n_bytes += codes.ElementSize() * codes.Shape<0>(); n_bytes += names.SizeBytes(); - *p_n_samples = std::max(*p_n_samples, codes.Shape<0>()); + *p_n_samples = std::max(*p_n_samples, static_cast(codes.Shape<0>())); return names.size(); } From d624690bae891d7fd6e3cd4dbd8aac21b646a8f1 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 3 Mar 2025 19:49:31 +0800 Subject: [PATCH 03/11] macos. --- src/data/adapter.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/data/adapter.h b/src/data/adapter.h index 5c5c52ddb467..7b19252c557f 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -638,7 +638,7 @@ template p_columns->push_back(codes); n_bytes += codes.ElementSize() * codes.Shape<0>(); - *p_n_samples = std::max(*p_n_samples, codes.Shape<0>()); + *p_n_samples = std::max(*p_n_samples, static_cast(codes.Shape<0>())); return names.n; } From 69d4b6a6bc47177a7ed437c9dc7178866ac8cfc8 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 3 Mar 2025 19:51:06 +0800 Subject: [PATCH 04/11] Cleanup. --- src/data/adapter.h | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/data/adapter.h b/src/data/adapter.h index 7b19252c557f..ee1088102316 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -669,12 +669,7 @@ class ColumnarAdapter : public detail::SingleBatchDataIter } return 0; } - [[nodiscard]] bst_idx_t NumColumns() const { - if (!columns_.empty()) { - return columns_.size(); - } - return 0; - } + [[nodiscard]] bst_idx_t NumColumns() const { return columns_.size(); } [[nodiscard]] bool HasCategorical() const { return !std::all_of(this->cats_.cbegin(), this->cats_.cend(), [](auto const& cats) { return std::visit([](auto&& cats) { return cats.empty(); }, cats); From 1d43ee8a6f715779511317b8b1c2702435abc4c6 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 3 Mar 2025 19:53:52 +0800 Subject: [PATCH 05/11] macos. --- src/data/adapter.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/data/adapter.cc b/src/data/adapter.cc index 656d7e82c877..2df16b91b606 100644 --- a/src/data/adapter.cc +++ b/src/data/adapter.cc @@ -36,7 +36,7 @@ ColumnarAdapter::ColumnarAdapter(StringView columns) { columns_.emplace_back(get(jcol)); this->cats_.emplace_back(); this->n_bytes_ += columns_.back().ElementSize() * columns_.back().Shape<0>(); - n_samples = std::max(n_samples, columns_.back().Shape<0>()); + n_samples = std::max(n_samples, static_cast(columns_.back().Shape<0>())); } cat_segments.push_back(n_cats); } From e9a96602896519d447a4f7beb8a62e22906dc054 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 3 Mar 2025 19:57:58 +0800 Subject: [PATCH 06/11] macos. --- src/data/adapter.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/data/adapter.h b/src/data/adapter.h index ee1088102316..a9e97b3feb1b 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -580,7 +580,7 @@ template [[nodiscard]] std::size_t GetArrowDictionary(Json jcol, std::vector* p_cat_columns, std::vector>* p_columns, - bst_idx_t* p_n_bytes, bst_idx_t* p_n_samples) { + std::size_t* p_n_bytes, bst_idx_t* p_n_samples) { auto& cat_columns = *p_cat_columns; // arrow StringArray for name of categories auto const& jnames = get(jcol[0]); @@ -618,7 +618,7 @@ template template [[nodiscard]] std::size_t GetArrowNumericIndex( DeviceOrd device, Json jcol, std::vector* p_cat_columns, - std::vector>* p_columns, bst_idx_t* p_n_bytes, + std::vector>* p_columns, std::size_t* p_n_bytes, bst_idx_t* p_n_samples) { auto const& first = get(jcol[0]); auto names = ArrayInterface<1>{first}; From 7071ba5819a395a39077c560fc8ac9a55f5bb0c5 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 3 Mar 2025 20:08:36 +0800 Subject: [PATCH 07/11] R. --- R-package/src/Makevars.win.in | 1 + 1 file changed, 1 insertion(+) diff --git a/R-package/src/Makevars.win.in b/R-package/src/Makevars.win.in index 8a86ba97c34d..d84cdb43ce6c 100644 --- a/R-package/src/Makevars.win.in +++ b/R-package/src/Makevars.win.in @@ -65,6 +65,7 @@ OBJECTS= \ $(PKGROOT)/src/gbm/gblinear_model.o \ $(PKGROOT)/src/data/adapter.o \ $(PKGROOT)/src/data/array_interface.o \ + $(PKGROOT)/src/data/cat_container.o \ $(PKGROOT)/src/data/simple_dmatrix.o \ $(PKGROOT)/src/data/data.o \ $(PKGROOT)/src/data/sparse_page_raw_format.o \ From eedb63b5f3b68e00a7bb6676d3514b7ffb316fc6 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 3 Mar 2025 22:01:04 +0800 Subject: [PATCH 08/11] Revert. --- python-package/xgboost/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index d669a5181250..98a00b664bb2 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -2827,13 +2827,13 @@ def inplace_predict( data, cat_codes, fns, _ = _transform_cudf_df( data, None, None, enable_categorical ) - array_inf, _ = _cudf_array_interfaces(data, cat_codes) + interfaces_str = _cudf_array_interfaces(data, cat_codes) if validate_features: self._validate_features(fns) _check_call( _LIB.XGBoosterPredictFromCudaColumnar( self.handle, - array_inf, + interfaces_str, args, p_handle, ctypes.byref(shape), From 2ea98d16e8f3b768d28734269af39b23085254e4 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 3 Mar 2025 22:18:03 +0800 Subject: [PATCH 09/11] Bump version. --- CMakeLists.txt | 2 +- R-package/DESCRIPTION | 4 +-- R-package/configure | 26 +++++++++---------- R-package/configure.ac | 2 +- include/xgboost/data.h | 2 +- include/xgboost/version_config.h | 2 +- jvm-packages/pom.xml | 2 +- jvm-packages/xgboost4j-example/pom.xml | 4 +-- jvm-packages/xgboost4j-flink/pom.xml | 4 +-- jvm-packages/xgboost4j-spark-gpu/pom.xml | 2 +- jvm-packages/xgboost4j-spark/pom.xml | 2 +- jvm-packages/xgboost4j/pom.xml | 4 +-- python-package/pyproject.toml | 2 +- python-package/pyproject.toml.in | 2 +- python-package/xgboost/VERSION | 2 +- src/data/data.cc | 33 +++++++++++++++--------- 16 files changed, 52 insertions(+), 43 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8dfe7dd9e048..845347ea1ad6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,7 +4,7 @@ if(PLUGIN_SYCL) string(REPLACE " -isystem ${CONDA_PREFIX}/include" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") endif() -project(xgboost LANGUAGES CXX C VERSION 3.0.0) +project(xgboost LANGUAGES CXX C VERSION 3.1.0) include(cmake/Utils.cmake) list(APPEND CMAKE_MODULE_PATH "${xgboost_SOURCE_DIR}/cmake/modules") diff --git a/R-package/DESCRIPTION b/R-package/DESCRIPTION index 71ce7acef61a..47be99b20dc9 100644 --- a/R-package/DESCRIPTION +++ b/R-package/DESCRIPTION @@ -1,8 +1,8 @@ Package: xgboost Type: Package Title: Extreme Gradient Boosting -Version: 3.0.0.0 -Date: 2024-11-26 +Version: 3.1.0.0 +Date: 2025-03-03 Authors@R: c( person("Tianqi", "Chen", role = c("aut"), email = "tianqi.tchen@gmail.com"), diff --git a/R-package/configure b/R-package/configure index f791b6d84aa7..fc4594db9458 100755 --- a/R-package/configure +++ b/R-package/configure @@ -1,6 +1,6 @@ #! /bin/sh # Guess values for system-dependent variables and create Makefiles. -# Generated by GNU Autoconf 2.71 for xgboost 3.0.0. +# Generated by GNU Autoconf 2.71 for xgboost 3.1.0. # # # Copyright (C) 1992-1996, 1998-2017, 2020-2021 Free Software Foundation, @@ -607,8 +607,8 @@ MAKEFLAGS= # Identity of this package. PACKAGE_NAME='xgboost' PACKAGE_TARNAME='xgboost' -PACKAGE_VERSION='3.0.0' -PACKAGE_STRING='xgboost 3.0.0' +PACKAGE_VERSION='3.1.0' +PACKAGE_STRING='xgboost 3.1.0' PACKAGE_BUGREPORT='' PACKAGE_URL='' @@ -1262,7 +1262,7 @@ if test "$ac_init_help" = "long"; then # Omit some internal or obsolete options to make the list less imposing. # This message is too long to be a string in the A/UX 3.1 sh. cat <<_ACEOF -\`configure' configures xgboost 3.0.0 to adapt to many kinds of systems. +\`configure' configures xgboost 3.1.0 to adapt to many kinds of systems. Usage: $0 [OPTION]... [VAR=VALUE]... @@ -1324,7 +1324,7 @@ fi if test -n "$ac_init_help"; then case $ac_init_help in - short | recursive ) echo "Configuration of xgboost 3.0.0:";; + short | recursive ) echo "Configuration of xgboost 3.1.0:";; esac cat <<\_ACEOF @@ -1407,7 +1407,7 @@ fi test -n "$ac_init_help" && exit $ac_status if $ac_init_version; then cat <<\_ACEOF -xgboost configure 3.0.0 +xgboost configure 3.1.0 generated by GNU Autoconf 2.71 Copyright (C) 2021 Free Software Foundation, Inc. @@ -1668,7 +1668,7 @@ cat >config.log <<_ACEOF This file contains any messages produced by compilers while running configure, to aid debugging if configure makes a mistake. -It was created by xgboost $as_me 3.0.0, which was +It was created by xgboost $as_me 3.1.0, which was generated by GNU Autoconf 2.71. Invocation command line was $ $0$ac_configure_args_raw @@ -2796,11 +2796,11 @@ if test x$ac_prog_cxx_stdcxx = xno then : { printf "%s\n" "$as_me:${as_lineno-$LINENO}: checking for $CXX option to enable C++11 features" >&5 printf %s "checking for $CXX option to enable C++11 features... " >&6; } -if test ${ac_cv_prog_cxx_cxx11+y} +if test ${ac_cv_prog_cxx_11+y} then : printf %s "(cached) " >&6 else $as_nop - ac_cv_prog_cxx_cxx11=no + ac_cv_prog_cxx_11=no ac_save_CXX=$CXX cat confdefs.h - <<_ACEOF >conftest.$ac_ext /* end confdefs.h. */ @@ -2842,11 +2842,11 @@ if test x$ac_prog_cxx_stdcxx = xno then : { printf "%s\n" "$as_me:${as_lineno-$LINENO}: checking for $CXX option to enable C++98 features" >&5 printf %s "checking for $CXX option to enable C++98 features... " >&6; } -if test ${ac_cv_prog_cxx_cxx98+y} +if test ${ac_cv_prog_cxx_98+y} then : printf %s "(cached) " >&6 else $as_nop - ac_cv_prog_cxx_cxx98=no + ac_cv_prog_cxx_98=no ac_save_CXX=$CXX cat confdefs.h - <<_ACEOF >conftest.$ac_ext /* end confdefs.h. */ @@ -3855,7 +3855,7 @@ cat >>$CONFIG_STATUS <<\_ACEOF || ac_write_fail=1 # report actual input values of CONFIG_FILES etc. instead of their # values after options handling. ac_log=" -This file was extended by xgboost $as_me 3.0.0, which was +This file was extended by xgboost $as_me 3.1.0, which was generated by GNU Autoconf 2.71. Invocation command line was CONFIG_FILES = $CONFIG_FILES @@ -3919,7 +3919,7 @@ ac_cs_config_escaped=`printf "%s\n" "$ac_cs_config" | sed "s/^ //; s/'/'\\\\\\\\ cat >>$CONFIG_STATUS <<_ACEOF || ac_write_fail=1 ac_cs_config='$ac_cs_config_escaped' ac_cs_version="\\ -xgboost config.status 3.0.0 +xgboost config.status 3.1.0 configured by $0, generated by GNU Autoconf 2.71, with options \\"\$ac_cs_config\\" diff --git a/R-package/configure.ac b/R-package/configure.ac index eb2728c17645..fb5a28b5a95f 100644 --- a/R-package/configure.ac +++ b/R-package/configure.ac @@ -2,7 +2,7 @@ AC_PREREQ(2.69) -AC_INIT([xgboost],[3.0.0],[],[xgboost],[]) +AC_INIT([xgboost],[3.1.0],[],[xgboost],[]) : ${R_HOME=`R RHOME`} if test -z "${R_HOME}"; then diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 7f7ffcbddad2..954ffc586006 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -51,7 +51,7 @@ struct CatContainer; class MetaInfo { public: /*! \brief number of data fields in MetaInfo */ - static constexpr uint64_t kNumField = 12; + static constexpr uint64_t kNumField = 13; /*! \brief number of rows in the data */ bst_idx_t num_row_{0}; // NOLINT diff --git a/include/xgboost/version_config.h b/include/xgboost/version_config.h index 1638c65ecb1e..785984174b2e 100644 --- a/include/xgboost/version_config.h +++ b/include/xgboost/version_config.h @@ -5,7 +5,7 @@ #define XGBOOST_VERSION_CONFIG_H_ #define XGBOOST_VER_MAJOR 3 /* NOLINT */ -#define XGBOOST_VER_MINOR 0 /* NOLINT */ +#define XGBOOST_VER_MINOR 1 /* NOLINT */ #define XGBOOST_VER_PATCH 0 /* NOLINT */ #endif // XGBOOST_VERSION_CONFIG_H_ diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index 3854cfa4cadf..f4cadbf9a787 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -6,7 +6,7 @@ ml.dmlc xgboost-jvm_2.12 - 3.0.0-SNAPSHOT + 3.1.0-SNAPSHOT pom XGBoost JVM Package JVM Package for XGBoost diff --git a/jvm-packages/xgboost4j-example/pom.xml b/jvm-packages/xgboost4j-example/pom.xml index e27a03dd5c9b..9a8408124c63 100644 --- a/jvm-packages/xgboost4j-example/pom.xml +++ b/jvm-packages/xgboost4j-example/pom.xml @@ -6,11 +6,11 @@ ml.dmlc xgboost-jvm_2.12 - 3.0.0-SNAPSHOT + 3.1.0-SNAPSHOT xgboost4j-example xgboost4j-example_2.12 - 3.0.0-SNAPSHOT + 3.1.0-SNAPSHOT jar diff --git a/jvm-packages/xgboost4j-flink/pom.xml b/jvm-packages/xgboost4j-flink/pom.xml index 19271b3dde76..96fe0563d499 100644 --- a/jvm-packages/xgboost4j-flink/pom.xml +++ b/jvm-packages/xgboost4j-flink/pom.xml @@ -6,12 +6,12 @@ ml.dmlc xgboost-jvm_2.12 - 3.0.0-SNAPSHOT + 3.1.0-SNAPSHOT xgboost4j-flink xgboost4j-flink_2.12 - 3.0.0-SNAPSHOT + 3.1.0-SNAPSHOT 2.2.0 diff --git a/jvm-packages/xgboost4j-spark-gpu/pom.xml b/jvm-packages/xgboost4j-spark-gpu/pom.xml index c37c583becf7..a4768878f879 100644 --- a/jvm-packages/xgboost4j-spark-gpu/pom.xml +++ b/jvm-packages/xgboost4j-spark-gpu/pom.xml @@ -6,7 +6,7 @@ ml.dmlc xgboost-jvm_2.12 - 3.0.0-SNAPSHOT + 3.1.0-SNAPSHOT xgboost4j-spark-gpu xgboost4j-spark-gpu_2.12 diff --git a/jvm-packages/xgboost4j-spark/pom.xml b/jvm-packages/xgboost4j-spark/pom.xml index f2e6692c71c1..904c97a08bcd 100644 --- a/jvm-packages/xgboost4j-spark/pom.xml +++ b/jvm-packages/xgboost4j-spark/pom.xml @@ -6,7 +6,7 @@ ml.dmlc xgboost-jvm_2.12 - 3.0.0-SNAPSHOT + 3.1.0-SNAPSHOT xgboost4j-spark xgboost4j-spark_2.12 diff --git a/jvm-packages/xgboost4j/pom.xml b/jvm-packages/xgboost4j/pom.xml index 77664b9dfd9b..b9c144dd044f 100644 --- a/jvm-packages/xgboost4j/pom.xml +++ b/jvm-packages/xgboost4j/pom.xml @@ -6,11 +6,11 @@ ml.dmlc xgboost-jvm_2.12 - 3.0.0-SNAPSHOT + 3.1.0-SNAPSHOT xgboost4j xgboost4j_2.12 - 3.0.0-SNAPSHOT + 3.1.0-SNAPSHOT jar diff --git a/python-package/pyproject.toml b/python-package/pyproject.toml index df272350fcce..d188319dd4a0 100644 --- a/python-package/pyproject.toml +++ b/python-package/pyproject.toml @@ -15,7 +15,7 @@ authors = [ { name = "Hyunsu Cho", email = "chohyu01@cs.washington.edu" }, { name = "Jiaming Yuan", email = "jm.yuan@outlook.com" } ] -version = "3.0.0-dev" +version = "3.1.0-dev" requires-python = ">=3.10" license = { text = "Apache-2.0" } classifiers = [ diff --git a/python-package/pyproject.toml.in b/python-package/pyproject.toml.in index 50c1b470d19e..035e13a68227 100644 --- a/python-package/pyproject.toml.in +++ b/python-package/pyproject.toml.in @@ -14,7 +14,7 @@ authors = [ { name = "Hyunsu Cho", email = "chohyu01@cs.washington.edu" }, { name = "Jiaming Yuan", email = "jm.yuan@outlook.com" } ] -version = "3.0.0-dev" +version = "3.1.0-dev" requires-python = ">=3.10" license = { text = "Apache-2.0" } classifiers = [ diff --git a/python-package/xgboost/VERSION b/python-package/xgboost/VERSION index 2468aa9eae58..0f9d6b15dc04 100644 --- a/python-package/xgboost/VERSION +++ b/python-package/xgboost/VERSION @@ -1 +1 @@ -3.0.0-dev +3.1.0-dev diff --git a/src/data/data.cc b/src/data/data.cc index 8f240d415655..043fc14bd2b5 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -228,6 +228,7 @@ void MetaInfo::Clear() { * | feature_names | kStr | False | ${size} | 1 | ${feature_names} | * | feature_types | kStr | False | ${size} | 1 | ${feature_types} | * | feature_weights | kFloat32 | False | ${size} | 1 | ${feature_weights} | + * | cats | kStr | False | ${size} | 1 | ${cats} | * * Note that the scalar fields (is_scalar=True) will have num_row and num_col missing. * Also notice the difference between the saved name and the name used in `SetInfo': @@ -235,9 +236,6 @@ void MetaInfo::Clear() { */ void MetaInfo::SaveBinary(dmlc::Stream *fo) const { - if (!this->Cats()->Empty()) { - LOG(FATAL) << "Cannot save binary when there are category indices."; - } Version::Save(fo); fo->Write(kNumField); int field_cnt = 0; // make sure we are actually writing kNumField fields @@ -256,14 +254,22 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const { SaveVectorField(fo, u8"labels_upper_bound", DataType::kFloat32, {labels_upper_bound_.Size(), 1}, labels_upper_bound_); ++field_cnt; - SaveVectorField(fo, u8"feature_names", DataType::kStr, - {feature_names.size(), 1}, feature_names); ++field_cnt; - SaveVectorField(fo, u8"feature_types", DataType::kStr, - {feature_type_names.size(), 1}, feature_type_names); ++field_cnt; + SaveVectorField(fo, u8"feature_names", DataType::kStr, {feature_names.size(), 1}, feature_names); + ++field_cnt; + SaveVectorField(fo, u8"feature_types", DataType::kStr, {feature_type_names.size(), 1}, + feature_type_names); + ++field_cnt; SaveVectorField(fo, u8"feature_weights", DataType::kFloat32, {feature_weights.Size(), 1}, feature_weights); ++field_cnt; + Json jcats{Object{}}; + this->cats_->Save(&jcats); + std::vector values; + Json::Dump(jcats, &values, std::ios::binary); + SaveVectorField(fo, u8"cats", DataType::kStr, {values.size(), 1}, values); + ++field_cnt; + CHECK_EQ(field_cnt, kNumField) << "Wrong number of fields"; } @@ -309,6 +315,7 @@ const std::vector& MetaInfo::LabelAbsSort(Context const* ctx) const { void MetaInfo::LoadBinary(dmlc::Stream *fi) { auto version = Version::Load(fi); auto major = std::get<0>(version); + auto minor = std::get<1>(version); // MetaInfo is saved in `SparsePageSource'. So the version in MetaInfo represents the // version of DMatrix. std::stringstream msg; @@ -316,11 +323,8 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) { << " is no longer supported. " << "Please process and save your data in current version: " << Version::String(Version::Self()) << " again."; - CHECK_GE(major, 1) << msg.str(); - if (major == 1) { - auto minor = std::get<1>(version); - CHECK_GE(minor, 6) << msg.str(); - } + CHECK_GE(major, 3) << msg.str(); + CHECK_GE(minor, 1) << msg.str(); const uint64_t expected_num_field = kNumField; uint64_t num_field { 0 }; @@ -356,6 +360,11 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) { LoadVectorField(fi, u8"feature_weights", DataType::kFloat32, &feature_weights); this->has_categorical_ = LoadFeatureType(feature_type_names, &feature_types.HostVector()); + + std::vector values; + LoadVectorField(fi, u8"cats", DataType::kStr, &values); + auto jcats = Json::Load(StringView{values.data(), values.size()}, std::ios::binary); + this->cats_->Load(jcats); } namespace { From 885b977ea39e5b9df228b05cc670fcccf3e2eede Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 3 Mar 2025 22:21:58 +0800 Subject: [PATCH 10/11] Binary. --- python-package/xgboost/testing/ordinal.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/python-package/xgboost/testing/ordinal.py b/python-package/xgboost/testing/ordinal.py index 0278de5c8f68..dafd06abec13 100644 --- a/python-package/xgboost/testing/ordinal.py +++ b/python-package/xgboost/testing/ordinal.py @@ -1,6 +1,8 @@ # pylint: disable=invalid-name """Tests for the ordinal re-coder.""" +import os +import tempfile from typing import Any, Tuple, Type import numpy as np @@ -92,6 +94,21 @@ def check(Xy: DMatrix, X: pd.DataFrame) -> None: else: assert cats[fname] is None + with tempfile.TemporaryDirectory() as tmpdir: + fname = os.path.join(tmpdir, "DMatrix.binary") + Xy.save_binary(fname) + + Xy_1 = DMatrix(fname) + cats_1 = Xy_1.get_categories() + assert cats_1 is not None + + for k, v_0 in cats.items(): + v_1 = cats_1[k] + if v_0 is None: + assert v_1 is None + else: + assert v_0.to_pylist() == v_1.to_pylist() + # full str type X, y = make_categorical(256, 16, 7, onehot=False, cat_dtype=np.str_) Xy = DMatrix(X, y, enable_categorical=True) From 775d9ff0e25b7a96218d52965a414b33ae80f003 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 3 Mar 2025 22:32:16 +0800 Subject: [PATCH 11/11] empty --- src/data/cat_container.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/data/cat_container.cc b/src/data/cat_container.cc index ea12b00b5ab8..28c5ad375385 100644 --- a/src/data/cat_container.cc +++ b/src/data/cat_container.cc @@ -209,8 +209,11 @@ void CatContainer::Load(Json const& in) { auto& hf_segments = this->feature_segments_.HostVector(); LoadJson(in["feature_segments"], &hf_segments); - CHECK(!hf_segments.empty()); - this->n_total_cats_ = hf_segments.back(); + if (hf_segments.empty()) { + this->n_total_cats_ = 0; + } else { + this->n_total_cats_ = hf_segments.back(); + } auto& h_sorted_idx = this->sorted_idx_.HostVector(); LoadJson(in["sorted_idx"], &h_sorted_idx);