Skip to content

Commit

Permalink
Use smart pointer utilities from dpctl
Browse files Browse the repository at this point in the history
  • Loading branch information
ndgrigorian committed Jan 8, 2025
1 parent 763fc25 commit a54f420
Showing 1 changed file with 12 additions and 48 deletions.
60 changes: 12 additions & 48 deletions dpnp/backend/extensions/indexing/choose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,6 @@ static kernels::choose_fn_ptr_t choose_wrap_dispatch_table[td_ns::num_types]

namespace py = pybind11;

/*
Returns an std::unique_ptr wrapping a USM allocation and deleter.
Must still be manually freed by host_task when allocation is needed
for duration of asynchronous kernel execution.
*/
template <typename T>
auto usm_unique_ptr(std::size_t sz, sycl::queue &q)
{
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
auto deleter = [&q](T *usm) { sycl_free_noexcept(usm, q); };

return std::unique_ptr<T, decltype(deleter)>(sycl::malloc_device<T>(sz, q),
deleter);
}

std::vector<sycl::event>
_populate_choose_kernel_params(sycl::queue &exec_q,
std::vector<sycl::event> &host_task_events,
Expand Down Expand Up @@ -305,11 +289,8 @@ std::pair<sycl::event, sycl::event>
std::to_string(src_type_id));
}

auto packed_chc_ptrs = usm_unique_ptr<char *>(n_chcs, exec_q);
if (packed_chc_ptrs.get() == nullptr) {
throw std::runtime_error(
"Unable to allocate packed_chc_ptrs device memory");
}
auto packed_chc_ptrs =
dpctl::tensor::alloc_utils::smart_malloc_device<char *>(n_chcs, exec_q);

// packed_shapes_strides = [common shape,
// src.strides,
Expand All @@ -318,17 +299,12 @@ std::pair<sycl::event, sycl::event>
// ...,
// chcs[n_chcs].strides]
auto packed_shapes_strides =
usm_unique_ptr<py::ssize_t>((3 + n_chcs) * sh_nelems, exec_q);
if (packed_shapes_strides.get() == nullptr) {
throw std::runtime_error(
"Unable to allocate packed_shapes_strides device memory");
}
dpctl::tensor::alloc_utils::smart_malloc_device<py::ssize_t>(
(3 + n_chcs) * sh_nelems, exec_q);

auto packed_chc_offsets = usm_unique_ptr<py::ssize_t>(n_chcs, exec_q);
if (packed_chc_offsets.get() == nullptr) {
throw std::runtime_error(
"Unable to allocate packed_chc_offsets device memory");
}
auto packed_chc_offsets =
dpctl::tensor::alloc_utils::smart_malloc_device<py::ssize_t>(n_chcs,
exec_q);

std::vector<sycl::event> host_task_events;
host_task_events.reserve(2);
Expand Down Expand Up @@ -370,23 +346,11 @@ std::pair<sycl::event, sycl::event>
src_data, dst_data, packed_chc_ptrs.get(), src_offset, dst_offset,
packed_chc_offsets.get(), all_deps);

// release usm_unique_ptrs
auto chc_ptrs_ = packed_chc_ptrs.release();
auto shapes_strides_ = packed_shapes_strides.release();
auto chc_offsets_ = packed_chc_offsets.release();

// free packed temporaries
sycl::event temporaries_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(choose_generic_ev);
const auto &ctx = exec_q.get_context();

using dpctl::tensor::alloc_utils::sycl_free_noexcept;
cgh.host_task([chc_ptrs_, shapes_strides_, chc_offsets_, ctx]() {
sycl_free_noexcept(chc_ptrs_, ctx);
sycl_free_noexcept(shapes_strides_, ctx);
sycl_free_noexcept(chc_offsets_, ctx);
});
});
// async_smart_free releases owners
sycl::event temporaries_cleanup_ev =
dpctl::tensor::alloc_utils::async_smart_free(
exec_q, {choose_generic_ev}, packed_chc_ptrs, packed_shapes_strides,
packed_chc_offsets);

host_task_events.push_back(temporaries_cleanup_ev);

Expand Down

0 comments on commit a54f420

Please sign in to comment.