Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose LocalAccessor as kernel argument type #1991

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
6 changes: 6 additions & 0 deletions dpctl/_backend.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,12 @@ cdef extern from "syclinterface/dpctl_sycl_kernel_bundle_interface.h":


cdef extern from "syclinterface/dpctl_sycl_queue_interface.h":
ctypedef struct _md_local_accessor 'MDLocalAccessor':
size_t ndim
_arg_data_type dpctl_type_id
size_t dim0
size_t dim1
size_t dim2
cdef bool DPCTLQueue_AreEq(const DPCTLSyclQueueRef QRef1,
const DPCTLSyclQueueRef QRef2)
cdef DPCTLSyclQueueRef DPCTLQueue_Create(
Expand Down
84 changes: 84 additions & 0 deletions dpctl/_sycl_queue.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ from ._backend cimport ( # noqa: E211
DPCTLSyclEventRef,
_arg_data_type,
_backend_type,
_md_local_accessor,
_queue_property_type,
)
from .memory._memory cimport _Memory
Expand Down Expand Up @@ -121,6 +122,86 @@ cdef class kernel_arg_type_attribute:
return self.attr_value


cdef class LocalAccessor:
"""
LocalAccessor(ndim, dtype, dim0, dim1, dim2)

Python class for specifying the dimensionality and type of a
``sycl::local_accessor``, to be used as a kernel argument type.

Args:
ndim (size_t):
number of dimensions.
Can be between one and three.
dtype (str):
the data type of the local memory.
The permitted values are

`'i1'`, `'i2'`, `'i4'`, `'i8'`:
signed integral types int8_t, int16_t, int32_t, int64_t
`'u1'`, `'u2'`, `'u4'`, `'u8'`
unsigned integral types uint8_t, uint16_t, uint32_t,
uint64_t
`'f4'`, `'f8'`,
single- and double-precision floating-point types float and
double
dim0 (size_t):
Size of the first dimension.
dim1 (size_t):
Size of the second dimension.
dim2 (size_t):
Size of the third dimension.

Raises:
ValueError:
If the given dimension is not between one and three.
ValueError:
If the dtype string is unrecognized.
"""
cdef _md_local_accessor lacc

def __cinit__(self, size_t ndim, str dtype, size_t dim0, size_t dim1, size_t dim2):
self.lacc.ndim = ndim
self.lacc.dim0 = dim0
self.lacc.dim1 = dim1
self.lacc.dim2 = dim2

if ndim < 1 or ndim > 3:
raise ValueError("LocalAccessor must have dimension between one and three")
if dtype == 'i1':
self.lacc.dpctl_type_id = _arg_data_type._INT8_T
elif dtype == 'u1':
self.lacc.dpctl_type_id = _arg_data_type._UINT8_T
elif dtype == 'i2':
self.lacc.dpctl_type_id = _arg_data_type._INT16_T
elif dtype == 'u2':
self.lacc.dpctl_type_id = _arg_data_type._UINT16_T
elif dtype == 'i4':
self.lacc.dpctl_type_id = _arg_data_type._INT32_T
elif dtype == 'u4':
self.lacc.dpctl_type_id = _arg_data_type._UINT32_T
elif dtype == 'i8':
self.lacc.dpctl_type_id = _arg_data_type._INT64_T
elif dtype == 'u8':
self.lacc.dpctl_type_id = _arg_data_type._UINT64_T
elif dtype == 'f4':
self.lacc.dpctl_type_id = _arg_data_type._FLOAT
elif dtype == 'f8':
self.lacc.dpctl_type_id = _arg_data_type._DOUBLE
else:
raise ValueError(f"Unrecognized type value: '{dtype}'")

def __repr__(self):
return "LocalAccessor(" + self.ndim + ")"

cdef size_t addressof(self):
"""
Returns the address of the _md_local_accessor for this LocalAccessor
cast to ``size_t``.
"""
return <size_t>&self.lacc


cdef class _kernel_arg_type:
"""
An enumeration of supported kernel argument types in
Expand Down Expand Up @@ -849,6 +930,9 @@ cdef class SyclQueue(_SyclQueue):
elif isinstance(arg, _Memory):
kargs[idx]= <void*>(<size_t>arg._pointer)
kargty[idx] = _arg_data_type._VOID_PTR
elif isinstance(arg, LocalAccessor):
kargs[idx] = <void*>((<LocalAccessor>arg).addressof())
kargty[idx] = _arg_data_type._LOCAL_ACCESSOR
else:
ret = -1
return ret
Expand Down
Binary file not shown.
Binary file not shown.
44 changes: 44 additions & 0 deletions dpctl/tests/test_sycl_kernel_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""

import ctypes
import os

import numpy as np
import pytest
Expand Down Expand Up @@ -278,3 +279,46 @@ def test_kernel_arg_type():
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_float64)
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_void_ptr)
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_local_accessor)


def get_spirv_abspath(fn):
curr_dir = os.path.dirname(os.path.abspath(__file__))
spirv_file = os.path.join(curr_dir, "input_files", fn)
return spirv_file


# the process for generating the .spv files in this test is documented in
# libsyclinterface/tests/test_sycl_queue_submit_local_accessor_arg.cpp
# in a comment starting on line 123
def test_submit_local_accessor_arg():
try:
q = dpctl.SyclQueue("level_zero")
except dpctl.SyclQueueCreationError:
pytest.skip("OpenCL queue could not be created")
fn = get_spirv_abspath("local_accessor_kernel_inttys_fp32.spv")
with open(fn, "br") as f:
spirv_bytes = f.read()
prog = dpctl_prog.create_program_from_spirv(q, spirv_bytes)
krn = prog.get_sycl_kernel("_ZTS14SyclKernel_SLMIlE")
lws = 32
gws = lws * 10
x = dpt.ones(gws, dtype="i8")
x.sycl_queue.wait()
try:
e = q.submit(
krn,
[x.usm_data, dpctl._sycl_queue.LocalAccessor(1, "i8", lws, 1, 1)],
[
gws,
],
[
lws,
],
)
e.wait()
except dpctl._sycl_queue.SyclKernelSubmitError:
pytest.skip(f"Kernel submission failed for device {q.sycl_device}")
expected = dpt.arange(1, x.size + 1, dtype=x.dtype, device=x.device) * (
2 * lws
)
assert dpt.all(x == expected)
Loading