Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
Adding mdspan calls to ivf raft
Browse files Browse the repository at this point in the history
Signed-off-by: Mickael Ide <[email protected]>
  • Loading branch information
lowener committed Apr 14, 2023
1 parent c05c876 commit 7a67021
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 40 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ list(APPEND KNOWHERE_LINKER_LIBS prometheus-cpp::core prometheus-cpp::pull prome
add_library(knowhere SHARED ${KNOWHERE_SRCS})
add_dependencies(knowhere ${KNOWHERE_LINKER_LIBS})
if(WITH_RAFT)
list(APPEND KNOWHERE_LINKER_LIBS raft::raft)
list(APPEND KNOWHERE_LINKER_LIBS raft::raft raft::compiled)
endif()
target_link_libraries(knowhere PUBLIC ${KNOWHERE_LINKER_LIBS})
target_include_directories(
Expand Down
26 changes: 15 additions & 11 deletions cmake/libs/libraft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,14 @@ set(RAFT_FORK "rapidsai")
set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}")

function(find_and_configure_raft)
set(oneValueArgs VERSION FORK PINNED_TAG)
set(oneValueArgs VERSION FORK PINNED_TAG COMPILE_LIBRARY)
cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}"
${ARGN})

set(RAFT_COMPONENTS "")
if(PKG_COMPILE_LIBRARY)
string(APPEND RAFT_COMPONENTS " compiled")
endif()
# -----------------------------------------------------
# Invoke CPM find_package()
# -----------------------------------------------------
Expand All @@ -44,12 +48,8 @@ function(find_and_configure_raft)
${PKG_VERSION}
GLOBAL_TARGETS
raft::raft
BUILD_EXPORT_SET
faiss-exports
INSTALL_EXPORT_SET
faiss-exports
COMPONENTS
"distance nn"
${RAFT_COMPONENTS}
CPM_ARGS
GIT_REPOSITORY
https://github.com/${PKG_FORK}/raft.git
Expand All @@ -60,13 +60,17 @@ function(find_and_configure_raft)
OPTIONS
"BUILD_TESTS OFF"
"BUILD_BENCH OFF"
"RAFT_COMPILE_LIBRARIES OFF"
"RAFT_COMPILE_NN_LIBRARY OFF"
"RAFT_USE_FAISS_STATIC OFF" # Turn this on to build FAISS into your binary
"RAFT_ENABLE_NN_DEPENDENCIES OFF")
"RAFT_COMPILE_LIBRARY ${PKG_COMPILE_LIBRARY}"
"RAFT_USE_FAISS_STATIC OFF") # Turn this on to build FAISS into your binary

if(raft_ADDED)
message(VERBOSE "KNOWHERE: Using RAFT located in ${raft_SOURCE_DIR}")
else()
message(VERBOSE "KNOWHERE: Using RAFT located in ${raft_DIR}")
endif()
endfunction()

# Change pinned tag here to test a commit in CI To use a different RAFT locally,
# set the CMake variable CPM_raft_SOURCE=/path/to/local/raft
find_and_configure_raft(VERSION ${RAFT_VERSION}.00 FORK ${RAFT_FORK} PINNED_TAG
${RAFT_PINNED_TAG})
${RAFT_PINNED_TAG} COMPILE_LIBRARY OFF)
8 changes: 4 additions & 4 deletions cmake/utils/fetch_rapids.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# License for the specific language governing permissions and limitations under
# the License.

set(RAPIDS_VERSION "23.02")
set(RAPIDS_VERSION "23.04")

if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/FAISS_RAPIDS.cmake)
if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake)
file(
DOWNLOAD
https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-${RAPIDS_VERSION}/RAPIDS.cmake
${CMAKE_CURRENT_BINARY_DIR}/FAISS_RAPIDS.cmake)
${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake)
endif()
include(${CMAKE_CURRENT_BINARY_DIR}/FAISS_RAPIDS.cmake)
include(${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake)
54 changes: 30 additions & 24 deletions src/index/ivf_raft/ivf_raft.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
#include "thrust/execution_policy.h"
#include "thrust/sequence.h"

#ifdef RAFT_COMPILED
#include <raft/neighbors/specializations.cuh>
#endif

namespace knowhere {

namespace raft_res_pool {
Expand Down Expand Up @@ -248,9 +252,9 @@ class RaftIvfIndexNode : public IndexNode {
auto* data = reinterpret_cast<float const*>(dataset.GetTensor());

auto stream = res_->get_stream();
auto data_gpu = rmm::device_uvector<float>(rows * dim, stream);
RAFT_CUDA_TRY(cudaMemcpyAsync(data_gpu.data(), data, data_gpu.size() * sizeof(float), cudaMemcpyDefault,
stream.value()));
auto data_gpu = raft::make_device_matrix<float, std::int64_t>(*res_, rows, dim);
RAFT_CUDA_TRY(cudaMemcpyAsync(data_gpu.data_handle(), data, data_gpu.size() * sizeof(float),
cudaMemcpyDefault, stream.value()));
if constexpr (std::is_same_v<detail::raft_ivf_flat_index, T>) {
auto build_params = raft::neighbors::ivf_flat::index_params{};
build_params.metric = metric.value();
Expand All @@ -259,7 +263,7 @@ class RaftIvfIndexNode : public IndexNode {
build_params.kmeans_trainset_fraction = ivf_raft_cfg.kmeans_trainset_fraction;
build_params.adaptive_centers = ivf_raft_cfg.adaptive_centers;
gpu_index_ = raft::neighbors::ivf_flat::build<float, std::int64_t>(*res_, build_params,
data_gpu.data(), rows, dim);
data_gpu.view());
} else if constexpr (std::is_same_v<detail::raft_ivf_pq_index, T>) {
auto build_params = raft::neighbors::ivf_pq::index_params{};
build_params.metric = metric.value();
Expand All @@ -276,7 +280,7 @@ class RaftIvfIndexNode : public IndexNode {
build_params.codebook_kind = codebook_kind.value();
build_params.force_random_rotation = ivf_raft_cfg.force_random_rotation;
gpu_index_ = raft::neighbors::ivf_pq::build<float, std::int64_t>(*res_, build_params,
data_gpu.data(), rows, dim);
data_gpu.view());
} else {
static_assert(std::is_same_v<detail::raft_ivf_flat_index, T>);
}
Expand Down Expand Up @@ -312,19 +316,21 @@ class RaftIvfIndexNode : public IndexNode {
auto stream = res_->get_stream();
// TODO(wphicks): Clean up transfer with raft
// buffer objects when available
auto data_gpu = rmm::device_uvector<float>(rows * dim, stream);
RAFT_CUDA_TRY(cudaMemcpyAsync(data_gpu.data(), data, data_gpu.size() * sizeof(float), cudaMemcpyDefault,
stream.value()));
auto data_gpu = raft::make_device_matrix<float, std::int64_t>(*res_, rows, dim);
RAFT_CUDA_TRY(cudaMemcpyAsync(data_gpu.data_handle(), data, data_gpu.size() * sizeof(float),
cudaMemcpyDefault, stream.value()));

auto indices = rmm::device_uvector<std::int64_t>(rows, stream);
thrust::sequence(thrust::device, indices.begin(), indices.end(), gpu_index_->size());

if constexpr (std::is_same_v<detail::raft_ivf_flat_index, T>) {
raft::neighbors::ivf_flat::extend<float, std::int64_t>(*res_, *gpu_index_, data_gpu.data(),
indices.data(), rows);
raft::neighbors::ivf_flat::extend<float, std::int64_t>(*res_, raft::make_const_mdspan(data_gpu.view()),
std::make_optional(raft::make_device_vector_view<const std::int64_t, std::int64_t>(indices.data(), rows)),
gpu_index_.value());
} else if constexpr (std::is_same_v<detail::raft_ivf_pq_index, T>) {
raft::neighbors::ivf_pq::extend<float, std::int64_t>(*res_, *gpu_index_, data_gpu.data(),
indices.data(), rows);
raft::neighbors::ivf_pq::extend<float, std::int64_t>(*res_, raft::make_const_mdspan(data_gpu.view()),
std::make_optional(raft::make_device_matrix_view<const std::int64_t, std::int64_t>(indices.data(), rows, 1)),
gpu_index_.value());
} else {
static_assert(std::is_same_v<detail::raft_ivf_flat_index, T>);
}
Expand Down Expand Up @@ -356,19 +362,19 @@ class RaftIvfIndexNode : public IndexNode {
auto stream = res_->get_stream();
// TODO(wphicks): Clean up transfer with raft
// buffer objects when available
auto data_gpu = rmm::device_uvector<float>(rows * dim, stream);
RAFT_CUDA_TRY(cudaMemcpyAsync(data_gpu.data(), data, data_gpu.size() * sizeof(float), cudaMemcpyDefault,
stream.value()));
auto data_gpu = raft::make_device_matrix<float, std::int64_t>(*res_, rows, dim);
RAFT_CUDA_TRY(cudaMemcpyAsync(data_gpu.data_handle(), data, data_gpu.size() * sizeof(float),
cudaMemcpyDefault, stream.value()));

auto ids_gpu = rmm::device_uvector<std::int64_t>(output_size, stream);
auto dis_gpu = rmm::device_uvector<float>(output_size, stream);
auto ids_gpu = raft::make_device_matrix<std::int64_t, std::int64_t>(*res_, rows, ivf_raft_cfg.k);
auto dis_gpu = raft::make_device_matrix<float, std::int64_t>(*res_, rows, ivf_raft_cfg.k);

if constexpr (std::is_same_v<detail::raft_ivf_flat_index, T>) {
auto search_params = raft::neighbors::ivf_flat::search_params{};
search_params.n_probes = ivf_raft_cfg.nprobe;
raft::neighbors::ivf_flat::search<float, std::int64_t>(*res_, search_params, *gpu_index_,
data_gpu.data(), rows, ivf_raft_cfg.k,
ids_gpu.data(), dis_gpu.data());
raft::make_const_mdspan(data_gpu.view()),
ids_gpu.view(), dis_gpu.view());
} else if constexpr (std::is_same_v<detail::raft_ivf_pq_index, T>) {
auto search_params = raft::neighbors::ivf_pq::search_params{};
search_params.n_probes = ivf_raft_cfg.nprobe;
Expand Down Expand Up @@ -396,15 +402,15 @@ class RaftIvfIndexNode : public IndexNode {
}
search_params.internal_distance_dtype = internal_distance_dtype.value();
search_params.preferred_shmem_carveout = search_params.preferred_shmem_carveout;
raft::neighbors::ivf_pq::search<float, std::int64_t>(*res_, search_params, *gpu_index_, data_gpu.data(),
rows, ivf_raft_cfg.k, ids_gpu.data(),
dis_gpu.data());
raft::neighbors::ivf_pq::search<float, std::int64_t>(*res_, search_params, *gpu_index_,
raft::make_const_mdspan(data_gpu.view()),
ids_gpu.view(), dis_gpu.view());
} else {
static_assert(std::is_same_v<detail::raft_ivf_flat_index, T>);
}
RAFT_CUDA_TRY(cudaMemcpyAsync(ids.get(), ids_gpu.data(), ids_gpu.size() * sizeof(std::int64_t),
RAFT_CUDA_TRY(cudaMemcpyAsync(ids.get(), ids_gpu.data_handle(), ids_gpu.size() * sizeof(std::int64_t),
cudaMemcpyDefault, stream.value()));
RAFT_CUDA_TRY(cudaMemcpyAsync(dis.get(), dis_gpu.data(), dis_gpu.size() * sizeof(float), cudaMemcpyDefault,
RAFT_CUDA_TRY(cudaMemcpyAsync(dis.get(), dis_gpu.data_handle(), dis_gpu.size() * sizeof(float), cudaMemcpyDefault,
stream.value()));
stream.synchronize();
} catch (std::exception& e) {
Expand Down

0 comments on commit 7a67021

Please sign in to comment.