Skip to content

Commit

Permalink
Merge pull request #1985 from sommerlukas/nd-memory-copy
Browse files Browse the repository at this point in the history
Copy from/to multidimensional buffers
  • Loading branch information
oleksandr-pavlyk authored Feb 7, 2025
2 parents a332dd5 + a8bc75e commit 5529d66
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 7 deletions.
39 changes: 32 additions & 7 deletions dpctl/_sycl_queue.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,15 @@ import ctypes
from .enum_types import backend_type

from cpython cimport pycapsule
from cpython.buffer cimport PyObject_CheckBuffer
from cpython.buffer cimport (
Py_buffer,
PyBUF_ANY_CONTIGUOUS,
PyBUF_SIMPLE,
PyBUF_WRITABLE,
PyBuffer_Release,
PyObject_CheckBuffer,
PyObject_GetBuffer,
)
from cpython.ref cimport Py_DECREF, Py_INCREF, PyObject
from libc.stdlib cimport free, malloc

Expand Down Expand Up @@ -338,14 +346,20 @@ cdef DPCTLSyclEventRef _memcpy_impl(
cdef void *c_dst_ptr = NULL
cdef void *c_src_ptr = NULL
cdef DPCTLSyclEventRef ERef = NULL
cdef const unsigned char[::1] src_host_buf = None
cdef unsigned char[::1] dst_host_buf = None
cdef Py_buffer src_buf_view
cdef Py_buffer dst_buf_view
cdef bint src_is_buf = False
cdef bint dst_is_buf = False
cdef int ret_code = 0

if isinstance(src, _Memory):
c_src_ptr = <void*>(<_Memory>src).get_data_ptr()
elif _is_buffer(src):
src_host_buf = src
c_src_ptr = <void *>&src_host_buf[0]
ret_code = PyObject_GetBuffer(src, &src_buf_view, PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS)
if ret_code != 0: # pragma: no cover
raise RuntimeError("Could not access buffer")
c_src_ptr = src_buf_view.buf
src_is_buf = True
else:
raise TypeError(
"Parameter `src` should have either type "
Expand All @@ -356,8 +370,13 @@ cdef DPCTLSyclEventRef _memcpy_impl(
if isinstance(dst, _Memory):
c_dst_ptr = <void*>(<_Memory>dst).get_data_ptr()
elif _is_buffer(dst):
dst_host_buf = dst
c_dst_ptr = <void *>&dst_host_buf[0]
ret_code = PyObject_GetBuffer(dst, &dst_buf_view, PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS | PyBUF_WRITABLE)
if ret_code != 0: # pragma: no cover
if src_is_buf:
PyBuffer_Release(&src_buf_view)
raise RuntimeError("Could not access buffer")
c_dst_ptr = dst_buf_view.buf
dst_is_buf = True
else:
raise TypeError(
"Parameter `dst` should have either type "
Expand All @@ -376,6 +395,12 @@ cdef DPCTLSyclEventRef _memcpy_impl(
dep_events,
dep_events_count
)

if src_is_buf:
PyBuffer_Release(&src_buf_view)
if dst_is_buf:
PyBuffer_Release(&dst_buf_view)

return ERef


Expand Down
39 changes: 39 additions & 0 deletions dpctl/tests/test_sycl_queue_memcpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""Defines unit test cases for the SyclQueue.memcpy.
"""

import numpy as np
import pytest

import dpctl
Expand Down Expand Up @@ -97,6 +98,44 @@ def test_memcpy_copy_host_to_host():
assert dst_buf == src_buf


def test_2D_memcpy_copy_host_to_usm():
try:
q = dpctl.SyclQueue()
except dpctl.SyclQueueCreationError:
pytest.skip("Default constructor for SyclQueue failed")
usm_obj = _create_memory(q)

n = 12
canary = bytearray([i for i in range(n)])
host_obj = np.frombuffer(canary, dtype=np.uint8).reshape(3, 4)

q.memcpy(usm_obj, host_obj, len(canary))

mv2 = memoryview(usm_obj)

assert mv2[: len(canary)] == canary


def test_2D_memcpy_copy_usm_to_host():
try:
q = dpctl.SyclQueue()
except dpctl.SyclQueueCreationError:
pytest.skip("Default constructor for SyclQueue failed")
usm_obj = _create_memory(q)
mv2 = memoryview(usm_obj)

n = 12
shape = (3, 4)
for id in range(n):
mv2[id] = id

host_obj = np.ones(shape, dtype=np.uint8)

q.memcpy(host_obj, usm_obj, n)

assert np.array_equal(host_obj, np.arange(n, dtype=np.uint8).reshape(shape))


def test_memcpy_async():
try:
q = dpctl.SyclQueue()
Expand Down

0 comments on commit 5529d66

Please sign in to comment.