diff --git a/src/index/ivf_raft/ivf_raft.cu b/src/index/ivf_raft/ivf_raft.cu index 024cc660c..1d1c02c2f 100644 --- a/src/index/ivf_raft/ivf_raft.cu +++ b/src/index/ivf_raft/ivf_raft.cu @@ -21,27 +21,29 @@ constexpr uint32_t cuda_concurrent_size = 16; namespace knowhere { + +static std::shared_ptr +GlobalThreadPoolRaft() { + static std::shared_ptr pool = std::make_shared(cuda_concurrent_size); + return pool; +} KNOWHERE_REGISTER_GLOBAL(GPU_RAFT_IVF_FLAT, [](const Object& object) { return Index::Create( - std::make_unique>(object), - std::make_shared(cuda_concurrent_size)); + std::make_unique>(object), GlobalThreadPoolRaft()); }); KNOWHERE_REGISTER_GLOBAL(GPU_RAFT_IVF_PQ, [](const Object& object) { return Index::Create( - std::make_unique>(object), - std::make_shared(cuda_concurrent_size)); + std::make_unique>(object), GlobalThreadPoolRaft()); }); KNOWHERE_REGISTER_GLOBAL(GPU_IVF_FLAT, [](const Object& object) { return Index::Create( - std::make_unique>(object), - std::make_shared(cuda_concurrent_size)); + std::make_unique>(object), GlobalThreadPoolRaft()); }); KNOWHERE_REGISTER_GLOBAL(GPU_IVF_PQ, [](const Object& object) { return Index::Create( - std::make_unique>(object), - std::make_shared(cuda_concurrent_size)); + std::make_unique>(object), GlobalThreadPoolRaft()); }); } // namespace knowhere