HugeCTR is a GPU-accelerated recommender framework designed for training and inference of large deep learning models.
Design Goals:
- Fast: HugeCTR performs outstandingly in recommendation benchmarks including MLPerf.
- Easy: Regardless of whether you are a data scientist or machine learning practitioner, we've made it easy for anybody to use HugeCTR with plenty of documents, notebooks and samples.
- Domain Specific: HugeCTR provides the essentials, so that you can efficiently deploy your recommender models with very large embedding.
NOTE: If you have any questions in using HugeCTR, please file an issue or join our Slack channel to have more interactive discussions.
- Core Features
- Getting Started
- HugeCTR SDK
- Support and Feedback
- Contributing to HugeCTR
- Additional Resources
HugeCTR supports a variety of features, including the following:
- High-Level abstracted Python interface
- Model parallel training
- Optimized GPU workflow
- Multi-node training
- Mixed precision training
- HugeCTR to ONNX Converter
- Sparse Operation Kit
To learn about our latest enhancements, refer to our release notes.
If you'd like to quickly train a model using the Python interface, do the following:
-
Build the HugeCTR Docker image: From version 25.03, HugeCTR only provides the Dockerfile source, and users need to build the image by themselves. To build the hugectr image, use the Dockerfile located at
tools/dockerfiles/Dockerfile.base
with the following command:docker build --build-arg RELEASE=true -t hugectr:release -f tools/dockerfiles/Dockerfile.base .
-
Start the container with your local host directory (/your/host/dir mounted) by running the following command:
docker run --gpus=all --rm -it --cap-add SYS_NICE -v /your/host/dir:/your/container/dir -w /your/container/dir -it -u $(id -u):$(id -g) hugectr:release
NOTE: The /your/host/dir directory is just as visible as the /your/container/dir directory. The /your/host/dir directory is also your starting directory.
NOTE: HugeCTR uses NCCL to share data between ranks, and NCCL may requires shared memory for IPC and pinned (page-locked) system memory resources. It is recommended that you increase these resources by issuing the following options in the
docker run
command.-shm-size=1g -ulimit memlock=-1
-
Write a simple Python script to generate a synthetic dataset:
# dcn_parquet_generate.py import hugectr from hugectr.tools import DataGeneratorParams, DataGenerator data_generator_params = DataGeneratorParams( format = hugectr.DataReaderType_t.Parquet, label_dim = 1, dense_dim = 13, num_slot = 26, i64_input_key = False, source = "./dcn_parquet/file_list.txt", eval_source = "./dcn_parquet/file_list_test.txt", slot_size_array = [39884, 39043, 17289, 7420, 20263, 3, 7120, 1543, 39884, 39043, 17289, 7420, 20263, 3, 7120, 1543, 63, 63, 39884, 39043, 17289, 7420, 20263, 3, 7120, 1543 ], dist_type = hugectr.Distribution_t.PowerLaw, power_law_type = hugectr.PowerLaw_t.Short) data_generator = DataGenerator(data_generator_params) data_generator.generate()
-
Generate the Parquet dataset for your DCN model by running the following command:
python dcn_parquet_generate.py
NOTE: The generated dataset will reside in the folder
./dcn_parquet
, which contains training and evaluation data. -
Write a simple Python script for training:
# dcn_parquet_train.py import hugectr from mpi4py import MPI solver = hugectr.CreateSolver(max_eval_batches = 1280, batchsize_eval = 1024, batchsize = 1024, lr = 0.001, vvgpu = [[0]], repeat_dataset = True) reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Parquet, source = ["./dcn_parquet/file_list.txt"], eval_source = "./dcn_parquet/file_list_test.txt", slot_size_array = [39884, 39043, 17289, 7420, 20263, 3, 7120, 1543, 39884, 39043, 17289, 7420, 20263, 3, 7120, 1543, 63, 63, 39884, 39043, 17289, 7420, 20263, 3, 7120, 1543 ]) optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam, update_type = hugectr.Update_t.Global) model = hugectr.Model(solver, reader, optimizer) model.add(hugectr.Input(label_dim = 1, label_name = "label", dense_dim = 13, dense_name = "dense", data_reader_sparse_param_array = [hugectr.DataReaderSparseParam("data1", 1, True, 26)])) model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, workspace_size_per_gpu_in_mb = 75, embedding_vec_size = 16, combiner = "sum", sparse_embedding_name = "sparse_embedding1", bottom_name = "data1", optimizer = optimizer)) model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape, bottom_names = ["sparse_embedding1"], top_names = ["reshape1"], leading_dim=416)) model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat, bottom_names = ["reshape1", "dense"], top_names = ["concat1"])) model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.MultiCross, bottom_names = ["concat1"], top_names = ["multicross1"], num_layers=6)) model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct, bottom_names = ["concat1"], top_names = ["fc1"], num_output=1024)) model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU, bottom_names = ["fc1"], top_names = ["relu1"])) model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout, bottom_names = ["relu1"], top_names = ["dropout1"], dropout_rate=0.5)) model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat, bottom_names = ["dropout1", "multicross1"], top_names = ["concat2"])) model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct, bottom_names = ["concat2"], top_names = ["fc2"], num_output=1)) model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.BinaryCrossEntropyLoss, bottom_names = ["fc2", "label"], top_names = ["loss"])) model.compile() model.summary() model.graph_to_json(graph_config_file = "dcn.json") model.fit(max_iter = 5120, display = 200, eval_interval = 1000, snapshot = 5000, snapshot_prefix = "dcn")
NOTE: Ensure that the paths to the synthetic datasets are correct with respect to this Python script.
data_reader_type
,check_type
,label_dim
,dense_dim
, anddata_reader_sparse_param_array
should be consistent with the generated dataset. -
Train the model by running the following command:
python dcn_parquet_train.py
NOTE: It is presumed that the evaluation AUC value is incorrect since randomly generated datasets are being used. When the training is done, files that contain the dumped graph JSON, saved model weights, and optimizer states will be generated.
For more information, refer to the HugeCTR User Guide.
We're able to support external developers who can't use HugeCTR directly by exporting important HugeCTR components using:
- Sparse Operation Kit directory | documentation: a python package wrapped with GPU accelerated operations dedicated for sparse training/inference cases.
If you encounter any issues or have questions, go to https://github.com/NVIDIA/HugeCTR/issues and submit an issue so that we can provide you with the necessary resolutions and answers. To further advance the HugeCTR Roadmap, we encourage you to share all the details regarding your recommender system pipeline using this survey.
With HugeCTR being an open source project, we welcome contributions from the general public. With your contributions, we can continue to improve HugeCTR's quality and performance. To learn how to contribute, refer to our HugeCTR Contributor Guide.
Webpages |
---|
NVIDIA Merlin |
NVIDIA HugeCTR |
Shijie Liu, Nan Zheng, Hui Kang, Xavier Simmons, Junjie Zhang, Matthias Langer, Wenjing Zhu, Minseok Lee, and Zehuan Wang. "Embedding Optimization for Training Large-scale Deep Learning Recommendation Systems with EMBark." In Proceedings of the 18th ACM Conference on Recommender Systems, pp. 622-632. 2024.
Yingcan Wei, Matthias Langer, Fan Yu, Minseok Lee, Jie Liu, Ji Shi and Zehuan Wang, "A GPU-specialized Inference Parameter Server for Large-Scale Deep Recommendation Models," Proceedings of the 16th ACM Conference on Recommender Systems, pp. 408-419, 2022.
Zehuan Wang, Yingcan Wei, Minseok Lee, Matthias Langer, Fan Yu, Jie Liu, Shijie Liu, Daniel G. Abel, Xu Guo, Jianbing Dong, Ji Shi and Kunlun Li, "Merlin HugeCTR: GPU-accelerated Recommender System Training and Inference," Proceedings of the 16th ACM Conference on Recommender Systems, pp. 534-537, 2022.
Conference / Website | Title | Date | Speaker | Language |
---|---|---|---|---|
ACM RecSys 2022 | A GPU-specialized Inference Parameter Server for Large-Scale Deep Recommendation Models | September 2022 | Matthias Langer | English |
Short Videos Episode 1 | Merlin HugeCTR:GPU 加速的推荐系统框架 | May 2022 | Joey Wang | 中文 |
Short Videos Episode 2 | HugeCTR 分级参数服务器如何加速推理 | May 2022 | Joey Wang | 中文 |
Short Videos Episode 3 | 使用 HugeCTR SOK 加速 TensorFlow 训练 | May 2022 | Gems Guo | 中文 |
GTC Sping 2022 | Merlin HugeCTR: Distributed Hierarchical Inference Parameter Server Using GPU Embedding Cache | March 2022 | Matthias Langer, Yingcan Wei, Yu Fan | English |
APSARA 2021 | GPU 推荐系统 Merlin | Oct 2021 | Joey Wang | 中文 |
GTC Spring 2021 | Learn how Tencent Deployed an Advertising System on the Merlin GPU Recommender Framework | April 2021 | Xiangting Kong, Joey Wang | English |
GTC Spring 2021 | Merlin HugeCTR: Deep Dive Into Performance Optimization | April 2021 | Minseok Lee | English |
GTC Spring 2021 | Integrate HugeCTR Embedding with TensorFlow | April 2021 | Jianbing Dong | English |
GTC China 2020 | MERLIN HUGECTR :深入研究性能优化 | Oct 2020 | Minseok Lee | English |
GTC China 2020 | 性能提升 7 倍 + 的高性能 GPU 广告推荐加速系统的落地实现 | Oct 2020 | Xiangting Kong | 中文 |
GTC China 2020 | 使用 GPU EMBEDDING CACHE 加速 CTR 推理过程 | Oct 2020 | Fan Yu | 中文 |
GTC China 2020 | 将 HUGECTR EMBEDDING 集成于 TENSORFLOW | Oct 2020 | Jianbing Dong | 中文 |
GTC Spring 2020 | HugeCTR: High-Performance Click-Through Rate Estimation Training | March 2020 | Minseok Lee, Joey Wang | English |
GTC China 2019 | HUGECTR: GPU 加速的推荐系统训练 | Oct 2019 | Joey Wang | 中文 |
- HugeCTR Hierarchical Parameter Server (HPS)
- Embedding Cache
Above components have been deprecated since v25.03. Please refer to prior version if you need such features.