Skip to content

Commit

Permalink
[GPU-Plugin] Multi-GPU gpu_id bug fixes for grow_gpu_hist and grow_gp…
Browse files Browse the repository at this point in the history
…u methods, and additional documentation for the gpu plugin. (dmlc#2463)
  • Loading branch information
pseudotensor authored and RAMitchell committed Jun 30, 2017
1 parent 91dae84 commit 6b28717
Show file tree
Hide file tree
Showing 21 changed files with 574 additions and 445 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*.Rcheck
*.rds
*.tar.gz
*txt*
#*txt*
*conf
*buffer
*model
Expand Down
76 changes: 45 additions & 31 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@ project (xgboost)
find_package(OpenMP)

option(PLUGIN_UPDATER_GPU "Build GPU accelerated tree construction plugin")
set(GPU_COMPUTE_VER 35;50;52;60;61 CACHE STRING
"Space separated list of compute versions to be built against")
if(PLUGIN_UPDATER_GPU)
cmake_minimum_required (VERSION 3.5)
find_package(CUDA REQUIRED)
endif()

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
Expand Down Expand Up @@ -83,6 +80,14 @@ set(RABIT_SOURCES
rabit/src/c_api.cc
)

set(NCCL_SOURCES
nccl/src/*.cu
)
set(UPDATER_GPU_SOURCES
plugin/updater_gpu/src/*.cu
plugin/updater_gpu/src/exact/*.cu
)

add_subdirectory(dmlc-core)

add_library(rabit STATIC ${RABIT_SOURCES})
Expand All @@ -102,35 +107,44 @@ endif()
set(LINK_LIBRARIES dmlccore rabit)

if(PLUGIN_UPDATER_GPU)
# nccl
set(LINK_LIBRARIES ${LINK_LIBRARIES} nccl)
add_subdirectory(nccl)
set(NCCL_DIRECTORY ${PROJECT_SOURCE_DIR}/nccl)
include_directories(${NCCL_DIRECTORY}/src)
set(LINK_LIBRARIES ${LINK_LIBRARIES} ${CUDA_LIBRARIES})
#Find cub
set(CUB_DIRECTORY ${PROJECT_SOURCE_DIR}/cub/)
include_directories(${CUB_DIRECTORY})
#Find googletest
set(GTEST_DIRECTORY "${CACHE_PREFIX}" CACHE PATH "Googletest directory")
include_directories(${GTEST_DIRECTORY}/include)
#gencode flags
set(GENCODE_FLAGS "")
foreach(ver ${GPU_COMPUTE_VER})
set(GENCODE_FLAGS "${GENCODE_FLAGS}-gencode arch=compute_${ver},code=sm_${ver};")
endforeach()
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};--expt-extended-lambda;${GENCODE_FLAGS};-lineinfo;")
if(NOT MSVC)
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};-Xcompiler -fPIC")
endif()
set(CUDA_SOURCES
plugin/updater_gpu/src/updater_gpu.cu
plugin/updater_gpu/src/gpu_hist_builder.cu
)
# use below for forcing specific arch
cuda_compile(CUDA_OBJS ${CUDA_SOURCES} ${CUDA_NVCC_FLAGS})
find_package(CUDA REQUIRED)

# nccl
set(LINK_LIBRARIES ${LINK_LIBRARIES} nccl)
add_subdirectory(nccl)
set(NCCL_DIRECTORY ${PROJECT_SOURCE_DIR}/nccl)
include_directories(${NCCL_DIRECTORY}/src)

#Find cub
set(CUB_DIRECTORY ${PROJECT_SOURCE_DIR}/cub/)
include_directories(${CUB_DIRECTORY})

#Find googletest
set(GTEST_DIRECTORY "${CACHE_PREFIX}" CACHE PATH "Googletest directory")
include_directories(${GTEST_DIRECTORY}/include)

#gencode flags
set(GPU_COMPUTE_VER 35;50;52;60;61 CACHE STRING
"Space separated list of compute versions to be built against")

set(GENCODE_FLAGS "")
foreach(ver ${GPU_COMPUTE_VER})
set(GENCODE_FLAGS "${GENCODE_FLAGS}-gencode arch=compute_${ver},code=sm_${ver};")
endforeach()
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};--expt-extended-lambda;${GENCODE_FLAGS};-lineinfo;")
if(NOT MSVC)
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS};-Xcompiler -fPIC")
endif()
set(CUDA_SOURCES
plugin/updater_gpu/src/updater_gpu.cu
plugin/updater_gpu/src/gpu_hist_builder.cu
)
# use below for forcing specific arch
cuda_compile(CUDA_OBJS ${CUDA_SOURCES} ${CUDA_NVCC_FLAGS})


else()
set(CUDA_OBJS "")
set(CUDA_OBJS "")
endif()

add_library(objxgboost OBJECT ${SOURCES})
Expand Down
15 changes: 10 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ endif
CFLAGS += $(OPENMP_FLAGS)

# for using GPUs
GPU_COMPUTE_VER ?= 50 52 60 61
GPU_COMPUTE_VER ?= 35 50 52 60 61
NVCC = nvcc
INCLUDES = -Iinclude -I$(DMLC_CORE)/include -I$(RABIT)/include
INCLUDES += -I$(CUB_PATH)
Expand All @@ -106,14 +106,13 @@ NVCC_FLAGS = --std=c++11 $(CODE) $(INCLUDES) -lineinfo --expt-extended-lambda
NVCC_FLAGS += -Xcompiler=$(OPENMP_FLAGS) -Xcompiler=-fPIC
ifeq ($(PLUGIN_UPDATER_GPU),ON)
CUDA_ROOT = $(shell dirname $(shell dirname $(shell which $(NVCC))))
INCLUDES += -I$(CUDA_ROOT)/include
LDFLAGS += -L$(CUDA_ROOT)/lib64 -lcudart
INCLUDES += -I$(CUDA_ROOT)/include -Inccl/src/
LDFLAGS += -L$(CUDA_ROOT)/lib64 -lcudart -lcudadevrt -Lnccl/build/lib/ -lnccl_static -lm -ldl -lrt
endif

# specify tensor path
.PHONY: clean all lint clean_all doxygen rcpplint pypack Rpack Rbuild Rcheck java pylint


all: lib/libxgboost.a $(XGBOOST_DYLIB) xgboost

$(DMLC_CORE)/libdmlc.a: $(wildcard $(DMLC_CORE)/src/*.cc $(DMLC_CORE)/src/*/*.cc)
Expand Down Expand Up @@ -143,7 +142,7 @@ build/%.o: src/%.cc
$(CXX) -c $(CFLAGS) $< -o $@

# order of this rule matters wrt %.cc rule below!
build_plugin/%.o: plugin/%.cu
build_plugin/%.o: plugin/%.cu build_nccl
@mkdir -p $(@D)
$(NVCC) -c $(NVCC_FLAGS) $< -o $@

Expand All @@ -152,6 +151,11 @@ build_plugin/%.o: plugin/%.cc
$(CXX) $(CFLAGS) -MM -MT build_plugin/$*.o $< >build_plugin/$*.d
$(CXX) -c $(CFLAGS) $< -o $@

build_nccl:
@mkdir -p build/include
cd build/include ; ln -sf ../../nccl/src/nccl.h .
cd nccl ; make -j ; cd ..

# The should be equivalent to $(ALL_OBJ) except for build/cli_main.o
amalgamation/xgboost-all0.o: amalgamation/xgboost-all0.cc
$(CXX) -c $(CFLAGS) $< -o $@
Expand All @@ -173,6 +177,7 @@ jvm-packages/lib/libxgboost4j.so: jvm-packages/xgboost4j/src/native/xgboost4j.cp
@mkdir -p $(@D)
$(CXX) $(CFLAGS) $(JAVAINCFLAGS) -shared -o $@ $(filter %.cpp %.o %.a, $^) $(LDFLAGS)


xgboost: $(CLI_OBJ) $(ALL_DEP)
$(CXX) $(CFLAGS) -o $@ $(filter %.o %.a, $^) $(LDFLAGS)

Expand Down
2 changes: 1 addition & 1 deletion cub
Submodule cub updated 54 files
+1 −0 .settings/.gitignore
+1 −1 .settings/language.settings.xml
+12 −0 CHANGE_LOG.TXT
+5 −0 common.mk
+4 −4 cub/agent/agent_histogram.cuh
+20 −18 cub/agent/agent_radix_sort_downsweep.cuh
+6 −6 cub/agent/agent_radix_sort_upsweep.cuh
+3 −3 cub/agent/agent_reduce.cuh
+13 −8 cub/agent/agent_reduce_by_key.cuh
+6 −9 cub/agent/agent_rle.cuh
+4 −4 cub/agent/agent_scan.cuh
+1 −1 cub/agent/agent_segment_fixup.cuh
+9 −9 cub/agent/agent_select_if.cuh
+10 −10 cub/agent/agent_spmv_csrt.cuh
+17 −17 cub/agent/agent_spmv_orig.cuh
+5 −5 cub/agent/agent_spmv_row_based.cuh
+5 −5 cub/agent/single_pass_scan_operators.cuh
+8 −8 cub/block/block_adjacent_difference.cuh
+8 −8 cub/block/block_discontinuity.cuh
+23 −23 cub/block/block_exchange.cuh
+1 −1 cub/block/block_histogram.cuh
+18 −6 cub/block/block_load.cuh
+2 −2 cub/block/block_radix_rank.cuh
+6 −6 cub/block/block_radix_sort.cuh
+16 −16 cub/block/block_scan.cuh
+4 −4 cub/block/block_shuffle.cuh
+9 −3 cub/block/block_store.cuh
+3 −3 cub/block/specializations/block_histogram_sort.cuh
+1 −1 cub/block/specializations/block_reduce_raking.cuh
+2 −2 cub/block/specializations/block_reduce_raking_commutative_only.cuh
+1 −1 cub/block/specializations/block_reduce_warp_reductions.cuh
+18 −17 cub/block/specializations/block_scan_raking.cuh
+3 −3 cub/block/specializations/block_scan_warp_scans.cuh
+5 −5 cub/block/specializations/block_scan_warp_scans2.cuh
+10 −10 cub/block/specializations/block_scan_warp_scans3.cuh
+0 −2 cub/device/device_reduce.cuh
+855 −855 cub/device/device_segmented_radix_sort.cuh
+8 −8 cub/device/dispatch/dispatch_radix_sort.cuh
+4 −4 cub/grid/grid_barrier.cuh
+0 −16 cub/thread/thread_load.cuh
+4 −1 cub/thread/thread_operators.cuh
+0 −10 cub/thread/thread_store.cuh
+3 −0 cub/util_arch.cuh
+25 −25 cub/util_debug.cuh
+347 −347 cub/util_device.cuh
+159 −221 cub/util_ptx.cuh
+115 −39 cub/warp/specializations/warp_reduce_shfl.cuh
+19 −3 cub/warp/specializations/warp_reduce_smem.cuh
+103 −15 cub/warp/specializations/warp_scan_shfl.cuh
+41 −2 cub/warp/specializations/warp_scan_smem.cuh
+1 −0 test/Makefile
+9 −1 test/test_device_histogram.cu
+1 −0 test/test_device_scan.cu
+0 −4 test/test_warp_reduce.cu
2 changes: 1 addition & 1 deletion nccl
Submodule nccl updated 0 files
28 changes: 25 additions & 3 deletions plugin/updater_gpu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,22 @@ submodule: The plugin also depends on CUB 1.6.4 - https://nvlabs.github.io/cub/

submodule: NVIDIA NCCL from https://github.com/NVIDIA/nccl with windows port allowed by [email protected]:h2oai/nccl.git

## Download full repo + full submodules for your choice (or empty) path <mypath>

git clone --recursive https://github.com/dmlc/xgboost.git <mypath>

## Download with shallow submodules for much quicker download:

git 2.9.0+ (assumes only HEAD used for all submodules, but not true currently for dmlc-core and rabbit)

git clone --recursive --shallow-submodules https://github.com/dmlc/xgboost.git <mypath>

git 2.9.0-: (only cub is shallow, as largest repo)

git clone https://github.com/dmlc/xgboost.git <mypath>
cd <mypath>
bash plugin/updater/gpu/gitshallow_submodules.sh

## Build

From the command line on Linux starting from the xgboost directory:
Expand All @@ -84,12 +100,18 @@ $ mkdir build
$ cd build
$ cmake .. -G"Visual Studio 14 2015 Win64" -DPLUGIN_UPDATER_GPU=ON
```
Cmake will generate an xgboost.sln solution file in the build directory. Build this solution in release mode as a x64 build.
Cmake will create an xgboost.sln solution file in the build directory. Build this solution in release mode as a x64 build.

Visual studio community 2015, supported by cuda toolkit (http://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/#axzz4isREr2nS), can be downloaded from: https://my.visualstudio.com/Downloads?q=Visual%20Studio%20Community%202015 . You may also be able to use a later version of visual studio depending on whether the CUDA toolkit supports it. Note that Mingw cannot be used with cuda.

### For other nccl libraries

On some systems, nccl libraries are specific to a particular system (IBM Power or nvidia-docker) and can enable use of nvlink (between GPUs or even between GPUs and system memory). In that case, one wants to avoid the static nccl library by changing "STATIC" to "SHARED" in nccl/CMakeLists.txt and deleting the shared nccl library created (so that the system one is used).

### For Developers!



In case you want to build only for a specific GPU(s), for eg. GP100 and GP102,
whose compute capability are 60 and 61 respectively:
```bash
Expand All @@ -101,12 +123,12 @@ By default, the versions will include support for all GPUs in Maxwell and Pascal
Now, it also supports the usual 'make' flow to build gpu-enabled tree construction plugins. It's currently only tested on Linux. From the xgboost directory
```bash
# make sure CUDA SDK bin directory is in the 'PATH' env variable
$ make PLUGIN_UPDATER_GPU=ON
$ make -j PLUGIN_UPDATER_GPU=ON
```

Similar to cmake, if you want to build only for a specific GPU(s):
```bash
$ make PLUGIN_UPDATER_GPU=ON GPU_COMPUTE_VER="60 61"
$ make -j PLUGIN_UPDATER_GPU=ON GPU_COMPUTE_VER="60 61"
```

### For Developers!
Expand Down
6 changes: 4 additions & 2 deletions plugin/updater_gpu/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def run_benchmark(args, gpu_algorithm, cpu_algorithm):
param = {'objective': 'binary:logistic',
'max_depth': 6,
'silent': 1,
'n_gpus': 1,
'gpu_id': 0,
'eval_metric': 'auc'}

param['tree_method'] = gpu_algorithm
Expand All @@ -41,9 +43,9 @@ def run_benchmark(args, gpu_algorithm, cpu_algorithm):

if 'gpu_hist' in args.algorithm:
run_benchmark(args, args.algorithm, 'hist')
if 'gpu_exact' in args.algorithm:
elif 'gpu_exact' in args.algorithm:
run_benchmark(args, args.algorithm, 'exact')
if 'all' in args.algorithm:
elif 'all' in args.algorithm:
run_benchmark(args, 'gpu_exact', 'exact')
run_benchmark(args, 'gpu_hist', 'hist')

12 changes: 12 additions & 0 deletions plugin/updater_gpu/gitshallow_submodules.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/bash
git submodule init
for i in $(git submodule | awk '{print $2}'); do
spath=$(git config -f .gitmodules --get submodule.$i.path)
surl=$(git config -f .gitmodules --get submodule.$i.url)
if [ $spath == "cub" ]
then
git submodule update --depth 3 $spath
else
git submodule update $spath
fi
done
74 changes: 38 additions & 36 deletions plugin/updater_gpu/src/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
* Copyright 2017 XGBoost contributors
*/
#pragma once
#include <cstdio>
#include <stdexcept>
#include <string>
#include <vector>
#include "../../../src/common/random.h"
#include "../../../src/tree/param.h"
#include "device_helpers.cuh"
#include "types.cuh"
#include <string>
#include <stdexcept>
#include <cstdio>
#include "cub/cub.cuh"
#include "device_helpers.cuh"
#include "device_helpers.cuh"
#include "types.cuh"

namespace xgboost {
namespace tree {
Expand Down Expand Up @@ -172,8 +172,8 @@ inline void subsample_gpair(dh::dvec<gpu_gpair>* p_gpair, float subsample) {
}

inline std::vector<int> col_sample(std::vector<int> features, float colsample) {
int n = colsample * features.size();
CHECK_GT(n, 0);
CHECK_GT(features.size(), 0);
int n = std::max(1,static_cast<int>(colsample * features.size()));

std::shuffle(features.begin(), features.end(), common::GlobalRandom());
features.resize(n);
Expand Down Expand Up @@ -202,17 +202,18 @@ struct GpairCallbackOp {
* @param offsets the segments
*/
template <typename T1, typename T2>
void segmentedSort(dh::CubMemory &tmp_mem, dh::dvec2<T1> &keys, dh::dvec2<T2> &vals,
int nVals, int nSegs, dh::dvec<int> &offsets, int start=0,
int end=sizeof(T1)*8) {
void segmentedSort(dh::CubMemory& tmp_mem, dh::dvec2<T1>& keys,
dh::dvec2<T2>& vals, int nVals, int nSegs,
dh::dvec<int>& offsets, int start = 0,
int end = sizeof(T1) * 8) {
size_t tmpSize;
dh::safe_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
NULL, tmpSize, keys.buff(), vals.buff(), nVals, nSegs,
offsets.data(), offsets.data()+1, start, end));
NULL, tmpSize, keys.buff(), vals.buff(), nVals, nSegs, offsets.data(),
offsets.data() + 1, start, end));
tmp_mem.LazyAllocate(tmpSize);
dh::safe_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
tmp_mem.d_temp_storage, tmpSize, keys.buff(), vals.buff(),
nVals, nSegs, offsets.data(), offsets.data()+1, start, end));
tmp_mem.d_temp_storage, tmpSize, keys.buff(), vals.buff(), nVals, nSegs,
offsets.data(), offsets.data() + 1, start, end));
}

/**
Expand All @@ -223,11 +224,11 @@ void segmentedSort(dh::CubMemory &tmp_mem, dh::dvec2<T1> &keys, dh::dvec2<T2> &v
* @param nVals number of elements in the input array
*/
template <typename T>
void sumReduction(dh::CubMemory &tmp_mem, dh::dvec<T> &in, dh::dvec<T> &out,
void sumReduction(dh::CubMemory& tmp_mem, dh::dvec<T>& in, dh::dvec<T>& out,
int nVals) {
size_t tmpSize;
dh::safe_cuda(cub::DeviceReduce::Sum(NULL, tmpSize, in.data(), out.data(),
nVals));
dh::safe_cuda(
cub::DeviceReduce::Sum(NULL, tmpSize, in.data(), out.data(), nVals));
tmp_mem.LazyAllocate(tmpSize);
dh::safe_cuda(cub::DeviceReduce::Sum(tmp_mem.d_temp_storage, tmpSize,
in.data(), out.data(), nVals));
Expand All @@ -239,9 +240,10 @@ void sumReduction(dh::CubMemory &tmp_mem, dh::dvec<T> &in, dh::dvec<T> &out,
* @param len number of elements i the buffer
* @param def default value to be filled
*/
template <typename T, int BlkDim=256, int ItemsPerThread=4>
template <typename T, int BlkDim = 256, int ItemsPerThread = 4>
void fillConst(int device_idx, T* out, int len, T def) {
dh::launch_n<ItemsPerThread,BlkDim>(device_idx, len, [=] __device__(int i) { out[i] = def; });
dh::launch_n<ItemsPerThread, BlkDim>(device_idx, len,
[=] __device__(int i) { out[i] = def; });
}

/**
Expand All @@ -253,17 +255,17 @@ void fillConst(int device_idx, T* out, int len, T def) {
* @param instId gather indices
* @param nVals length of the buffers
*/
template <typename T1, typename T2, int BlkDim=256, int ItemsPerThread=4>
void gather(int device_idx, T1* out1, const T1* in1, T2* out2, const T2* in2, const int* instId,
int nVals) {
dh::launch_n<ItemsPerThread,BlkDim>
(device_idx, nVals, [=] __device__(int i) {
int iid = instId[i];
T1 v1 = in1[iid];
T2 v2 = in2[iid];
out1[i] = v1;
out2[i] = v2;
});
template <typename T1, typename T2, int BlkDim = 256, int ItemsPerThread = 4>
void gather(int device_idx, T1* out1, const T1* in1, T2* out2, const T2* in2,
const int* instId, int nVals) {
dh::launch_n<ItemsPerThread, BlkDim>(device_idx, nVals,
[=] __device__(int i) {
int iid = instId[i];
T1 v1 = in1[iid];
T2 v2 = in2[iid];
out1[i] = v1;
out2[i] = v2;
});
}

/**
Expand All @@ -273,13 +275,13 @@ void gather(int device_idx, T1* out1, const T1* in1, T2* out2, const T2* in2, co
* @param instId gather indices
* @param nVals length of the buffers
*/
template <typename T, int BlkDim=256, int ItemsPerThread=4>
template <typename T, int BlkDim = 256, int ItemsPerThread = 4>
void gather(int device_idx, T* out, const T* in, const int* instId, int nVals) {
dh::launch_n<ItemsPerThread,BlkDim>
(device_idx, nVals, [=] __device__(int i) {
int iid = instId[i];
out[i] = in[iid];
});
dh::launch_n<ItemsPerThread, BlkDim>(device_idx, nVals,
[=] __device__(int i) {
int iid = instId[i];
out[i] = in[iid];
});
}

} // namespace tree
Expand Down
Loading

0 comments on commit 6b28717

Please sign in to comment.