Skip to content

Commit 7767735

Browse files
committed
Use rmm prefetc API so we don't have to handle CUDA 12/13 API differences
1 parent 97773bd commit 7767735

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

cpp/libcugraph_etl/include/hash/concurrent_unordered_map.cuh

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2017-2024, NVIDIA CORPORATION. All rights reserved.
2+
* Copyright (c) 2017-2025, NVIDIA CORPORATION. All rights reserved.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -26,6 +26,7 @@
2626

2727
#include <rmm/cuda_stream_view.hpp>
2828
#include <rmm/mr/device/polymorphic_allocator.hpp>
29+
#include <rmm/prefetch.hpp>
2930

3031
#include <cuda/atomic>
3132
#include <thrust/pair.h>
@@ -83,6 +84,11 @@ constexpr bool is_packable()
8384
std::has_unique_object_representations_v<pair_type>;
8485
}
8586

87+
#if defined(CUDART_VERSION) && CUDART_VERSION >= 13000
88+
cudaMemLocation location{
89+
(device.value() == cudaCpuDeviceId) ? cudaMemLocationTypeHost : cudaMemLocationTypeDevice,
90+
device.value()};
91+
8692
/**
8793
* @brief Allows viewing a pair in a packed representation
8894
*
@@ -472,10 +478,10 @@ class concurrent_unordered_map {
472478
cudaError_t status = cudaPointerGetAttributes(&hashtbl_values_ptr_attributes, m_hashtbl_values);
473479

474480
if (cudaSuccess == status && isPtrManaged(hashtbl_values_ptr_attributes)) {
475-
RAFT_CUDA_TRY(cudaMemPrefetchAsync(
476-
m_hashtbl_values, m_capacity * sizeof(value_type), dev_id, stream.value()));
481+
rmm::prefetch(
482+
m_hashtbl_values, m_capacity * sizeof(value_type), rmm::cuda_device_id{dev_id}, stream);
477483
}
478-
RAFT_CUDA_TRY(cudaMemPrefetchAsync(this, sizeof(*this), dev_id, stream.value()));
484+
rmm::prefetch(this, sizeof(*this), rmm::cuda_device_id{dev_id}, stream);
479485
}
480486

481487
/**
@@ -545,8 +551,8 @@ class concurrent_unordered_map {
545551
if (cudaSuccess == status && isPtrManaged(hashtbl_values_ptr_attributes)) {
546552
int dev_id = 0;
547553
RAFT_CUDA_TRY(cudaGetDevice(&dev_id));
548-
RAFT_CUDA_TRY(cudaMemPrefetchAsync(
549-
m_hashtbl_values, m_capacity * sizeof(value_type), dev_id, stream.value()));
554+
rmm::prefetch(
555+
m_hashtbl_values, m_capacity * sizeof(value_type), rmm::cuda_device_id{dev_id}, stream);
550556
}
551557
}
552558

0 commit comments

Comments
 (0)