|
1 | 1 | /*
|
2 |
| - * Copyright (c) 2017-2024, NVIDIA CORPORATION. All rights reserved. |
| 2 | + * Copyright (c) 2017-2025, NVIDIA CORPORATION. All rights reserved. |
3 | 3 | *
|
4 | 4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | 5 | * you may not use this file except in compliance with the License.
|
|
26 | 26 |
|
27 | 27 | #include <rmm/cuda_stream_view.hpp>
|
28 | 28 | #include <rmm/mr/device/polymorphic_allocator.hpp>
|
| 29 | +#include <rmm/prefetch.hpp> |
29 | 30 |
|
30 | 31 | #include <cuda/atomic>
|
31 | 32 | #include <thrust/pair.h>
|
@@ -83,6 +84,11 @@ constexpr bool is_packable()
|
83 | 84 | std::has_unique_object_representations_v<pair_type>;
|
84 | 85 | }
|
85 | 86 |
|
| 87 | +#if defined(CUDART_VERSION) && CUDART_VERSION >= 13000 |
| 88 | +cudaMemLocation location{ |
| 89 | + (device.value() == cudaCpuDeviceId) ? cudaMemLocationTypeHost : cudaMemLocationTypeDevice, |
| 90 | + device.value()}; |
| 91 | + |
86 | 92 | /**
|
87 | 93 | * @brief Allows viewing a pair in a packed representation
|
88 | 94 | *
|
@@ -472,10 +478,10 @@ class concurrent_unordered_map {
|
472 | 478 | cudaError_t status = cudaPointerGetAttributes(&hashtbl_values_ptr_attributes, m_hashtbl_values);
|
473 | 479 |
|
474 | 480 | if (cudaSuccess == status && isPtrManaged(hashtbl_values_ptr_attributes)) {
|
475 |
| - RAFT_CUDA_TRY(cudaMemPrefetchAsync( |
476 |
| - m_hashtbl_values, m_capacity * sizeof(value_type), dev_id, stream.value())); |
| 481 | + rmm::prefetch( |
| 482 | + m_hashtbl_values, m_capacity * sizeof(value_type), rmm::cuda_device_id{dev_id}, stream); |
477 | 483 | }
|
478 |
| - RAFT_CUDA_TRY(cudaMemPrefetchAsync(this, sizeof(*this), dev_id, stream.value())); |
| 484 | + rmm::prefetch(this, sizeof(*this), rmm::cuda_device_id{dev_id}, stream); |
479 | 485 | }
|
480 | 486 |
|
481 | 487 | /**
|
@@ -545,8 +551,8 @@ class concurrent_unordered_map {
|
545 | 551 | if (cudaSuccess == status && isPtrManaged(hashtbl_values_ptr_attributes)) {
|
546 | 552 | int dev_id = 0;
|
547 | 553 | RAFT_CUDA_TRY(cudaGetDevice(&dev_id));
|
548 |
| - RAFT_CUDA_TRY(cudaMemPrefetchAsync( |
549 |
| - m_hashtbl_values, m_capacity * sizeof(value_type), dev_id, stream.value())); |
| 554 | + rmm::prefetch( |
| 555 | + m_hashtbl_values, m_capacity * sizeof(value_type), rmm::cuda_device_id{dev_id}, stream); |
550 | 556 | }
|
551 | 557 | }
|
552 | 558 |
|
|
0 commit comments