Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for work_group_memory extension #1984

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ per-file-ignores =
dpctl/utils/_compute_follows_data.pyx: E999, E225, E227
dpctl/utils/_onetrace_context.py: E501, W505
dpctl/tensor/_array_api.py: E501, W505
dpctl/experimental/_work_group_memory.pyx: E999
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

File no longer exists I believe

examples/cython/sycl_buffer/syclbuffer/_syclbuffer.pyx: E999, E225, E402
examples/cython/usm_memory/blackscholes/_blackscholes_usm.pyx: E999, E225, E226, E402
examples/cython/use_dpctl_sycl/use_dpctl_sycl/_cython_api.pyx: E999, E225, E226, E402
2 changes: 2 additions & 0 deletions dpctl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
SyclKernelSubmitError,
SyclQueue,
SyclQueueCreationError,
WorkGroupMemory,
)
from ._sycl_queue_manager import get_device_cached_queue
from ._sycl_timer import SyclTimer
Expand Down Expand Up @@ -100,6 +101,7 @@
"SyclKernelInvalidRangeError",
"SyclKernelSubmitError",
"SyclQueueCreationError",
"WorkGroupMemory",
]
__all__ += [
"get_device_cached_queue",
Expand Down
17 changes: 16 additions & 1 deletion dpctl/_backend.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ cdef extern from "syclinterface/dpctl_sycl_enum_types.h":
_FLOAT 'DPCTL_FLOAT32_T',
_DOUBLE 'DPCTL_FLOAT64_T',
_VOID_PTR 'DPCTL_VOID_PTR',
_LOCAL_ACCESSOR 'DPCTL_LOCAL_ACCESSOR'
_LOCAL_ACCESSOR 'DPCTL_LOCAL_ACCESSOR',
_WORK_GROUP_MEMORY 'DPCTL_WORK_GROUP_MEMORY'

ctypedef enum _queue_property_type 'DPCTLQueuePropertyType':
_DEFAULT_PROPERTY 'DPCTL_DEFAULT_PROPERTY'
Expand Down Expand Up @@ -468,3 +469,17 @@ cdef extern from "syclinterface/dpctl_sycl_usm_interface.h":
cdef DPCTLSyclDeviceRef DPCTLUSM_GetPointerDevice(
DPCTLSyclUSMRef MRef,
DPCTLSyclContextRef CRef)

cdef extern from "syclinterface/dpctl_sycl_extension_interface.h":
cdef struct RawWorkGroupMemoryTy
ctypedef RawWorkGroupMemoryTy RawWorkGroupMemory

cdef struct DPCTLOpaqueWorkGroupMemory
ctypedef DPCTLOpaqueWorkGroupMemory *DPCTLSyclWorkGroupMemoryRef;

cdef DPCTLSyclWorkGroupMemoryRef DPCTLWorkGroupMemory_Create(size_t nbytes);

cdef void DPCTLWorkGroupMemory_Delete(
DPCTLSyclWorkGroupMemoryRef Ref);

cdef bint DPCTLWorkGroupMemory_Available();
17 changes: 16 additions & 1 deletion dpctl/_sycl_queue.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@

from libcpp cimport bool as cpp_bool

from ._backend cimport DPCTLSyclDeviceRef, DPCTLSyclQueueRef, _arg_data_type
from ._backend cimport (
DPCTLSyclDeviceRef,
DPCTLSyclQueueRef,
DPCTLSyclWorkGroupMemoryRef,
_arg_data_type,
)
from ._sycl_context cimport SyclContext
from ._sycl_device cimport SyclDevice
from ._sycl_event cimport SyclEvent
Expand Down Expand Up @@ -98,3 +103,13 @@ cdef public api class SyclQueue (_SyclQueue) [
cpdef prefetch(self, ptr, size_t count=*)
cpdef mem_advise(self, ptr, size_t count, int mem)
cpdef SyclEvent submit_barrier(self, dependent_events=*)

cdef public api class _WorkGroupMemory [
object Py_WorkGroupMemoryObject, type Py_WorkGroupMemoryType
]:
cdef DPCTLSyclWorkGroupMemoryRef _mem_ref

cdef public api class WorkGroupMemory(_WorkGroupMemory) [
object PyWorkGroupMemoryObject, type PyWorkGroupMemoryType
]:
pass
53 changes: 53 additions & 0 deletions dpctl/_sycl_queue.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ from ._backend cimport ( # noqa: E211
DPCTLSyclContextRef,
DPCTLSyclDeviceSelectorRef,
DPCTLSyclEventRef,
DPCTLWorkGroupMemory_Available,
DPCTLWorkGroupMemory_Create,
DPCTLWorkGroupMemory_Delete,
_arg_data_type,
_backend_type,
_queue_property_type,
Expand Down Expand Up @@ -250,6 +253,15 @@ cdef class _kernel_arg_type:
_arg_data_type._LOCAL_ACCESSOR
)

@property
def dpctl_work_group_memory(self):
cdef str p_name = "dpctl_work_group_memory"
return kernel_arg_type_attribute(
self._name,
p_name,
_arg_data_type._WORK_GROUP_MEMORY
)


kernel_arg_type = _kernel_arg_type()

Expand Down Expand Up @@ -849,6 +861,9 @@ cdef class SyclQueue(_SyclQueue):
elif isinstance(arg, _Memory):
kargs[idx]= <void*>(<size_t>arg._pointer)
kargty[idx] = _arg_data_type._VOID_PTR
elif isinstance(arg, WorkGroupMemory):
kargs[idx] = <void*>(<size_t>arg._ref)
kargty[idx] = _arg_data_type._WORK_GROUP_MEMORY
else:
ret = -1
return ret
Expand Down Expand Up @@ -1524,3 +1539,41 @@ cdef api SyclQueue SyclQueue_Make(DPCTLSyclQueueRef QRef):
"""
cdef DPCTLSyclQueueRef copied_QRef = DPCTLQueue_Copy(QRef)
return SyclQueue._create(copied_QRef)

cdef class _WorkGroupMemory:
def __dealloc__(self):
if(self._mem_ref):
DPCTLWorkGroupMemory_Delete(self._mem_ref)

cdef class WorkGroupMemory:
"""
WorkGroupMemory(nbytes)
Python class representing the ``work_group_memory`` class from the
Workgroup Memory oneAPI SYCL extension for low-overhead allocation of local
memory shared by the workitems in a workgroup.

This is based on a DPC++ SYCL extension and only available in newer
versions. Use ``is_available()`` to check availability in your build.

Args:
nbytes (int)
number of bytes to allocate in local memory.
Expected to be positive.
"""
def __cinit__(self, Py_ssize_t nbytes):
Copy link
Collaborator

@ndgrigorian ndgrigorian Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you see a reason to provide another constructor that takes a given type and a given number of elements? The extension includes work_group_memory(size_t num, handler& cgh)

could be as simple as a a call which wraps conversion from num and dtype to nbytes, and could be handled similarly to SyclQueue (*args and check the length, when it's two, use num and dtype overload, when it's one, nbytes)

only question is if there's much of a need or use for it

if not DPCTLWorkGroupMemory_Available():
raise RuntimeError("Workgroup memory extension not available")

self._mem_ref = DPCTLWorkGroupMemory_Create(nbytes)

"""Check whether the work_group_memory extension is available"""
@staticmethod
def is_available():
return DPCTLWorkGroupMemory_Available()

property _ref:
"""Returns the address of the C API ``DPCTLWorkGroupMemoryRef``
pointer as a ``size_t``.
"""
def __get__(self):
return <size_t>self._mem_ref
6 changes: 4 additions & 2 deletions dpctl/apis/include/dpctl_capi.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@
#pragma once

// clang-format off
// Ordering of includes is important here. dpctl_sycl_types defines types
// used by dpctl's Python C-API headers.
// Ordering of includes is important here. dpctl_sycl_types and
// dpctl_sycl_extension_interface define types used by dpctl's Python
// C-API headers.
#include "syclinterface/dpctl_sycl_types.h"
#include "syclinterface/dpctl_sycl_extension_interface.h"
#ifdef __cplusplus
#define CYTHON_EXTERN_C extern "C"
#else
Expand Down
13 changes: 13 additions & 0 deletions dpctl/sycl.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ cdef extern from "sycl/sycl.hpp" namespace "sycl":
"sycl::kernel_bundle<sycl::bundle_state::executable>":
pass

cdef extern from "syclinterface/dpctl_sycl_extension_interface.h":
cdef struct RawWorkGroupMemoryTy
ctypedef RawWorkGroupMemoryTy RawWorkGroupMemory

cdef extern from "syclinterface/dpctl_sycl_type_casters.hpp" \
namespace "dpctl::syclinterface":
# queue
Expand All @@ -67,3 +71,12 @@ cdef extern from "syclinterface/dpctl_sycl_type_casters.hpp" \
"dpctl::syclinterface::wrap<sycl::event>" (const event *)
cdef event * unwrap_event "dpctl::syclinterface::unwrap<sycl::event>" (
dpctl_backend.DPCTLSyclEventRef)

# work group memory extension
cdef dpctl_backend.DPCTLSyclWorkGroupMemoryRef wrap_work_group_memory \
"dpctl::syclinterface::wrap<RawWorkGroupMemory>" \
(const RawWorkGroupMemory *)

cdef RawWorkGroupMemory * unwrap_work_group_memory \
"dpctl::syclinterface::unwrap<RawWorkGroupMemory>" (
dpctl_backend.DPCTLSyclWorkGroupMemoryRef)
Binary file not shown.
1 change: 1 addition & 0 deletions dpctl/tests/test_sycl_kernel_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,4 @@ def test_kernel_arg_type():
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_float64)
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_void_ptr)
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_local_accessor)
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_work_group_memory)
90 changes: 90 additions & 0 deletions dpctl/tests/test_work_group_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Data Parallel Control (dpctl)
#
# Copyright 2020-2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines unit test cases for the work_group_memory in a SYCL kernel"""

import os

import pytest

import dpctl
import dpctl.tensor


def get_spirv_abspath(fn):
curr_dir = os.path.dirname(os.path.abspath(__file__))
spirv_file = os.path.join(curr_dir, "input_files", fn)
return spirv_file


# The kernel in the SPIR-V file used in this test was generated from the
# following SYCL source code:
# #include <sycl/sycl.hpp>
# using namespace sycl;
# namespace syclexp = sycl::ext::oneapi::experimental;
# namespace syclext = sycl::ext::oneapi;
# using data_t = int32_t;
#
# extern "C" SYCL_EXTERNAL
# SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>))
# void local_mem_kernel(data_t* in, data_t* out,
# syclexp::work_group_memory<data_t> mem){
# auto* local_mem = &mem;
# auto item = syclext::this_work_item::get_nd_item<1>();
# size_t global_id = item.get_global_linear_id();
# size_t local_id = item.get_local_linear_id();
# local_mem[local_id] = in[global_id];
# out[global_id] = local_mem[local_id];
# }


def test_submit_work_group_memory():
if not dpctl.WorkGroupMemory.is_available():
pytest.skip("Work group memory extension not supported")

try:
q = dpctl.SyclQueue("level_zero")
except dpctl.SyclQueueCreationError:
pytest.skip("LevelZero queue could not be created")
spirv_file = get_spirv_abspath("work-group-memory-kernel.spv")
with open(spirv_file, "br") as spv:
spv_bytes = spv.read()
prog = dpctl.program.create_program_from_spirv(q, spv_bytes)
kernel = prog.get_sycl_kernel("__sycl_kernel_local_mem_kernel")
local_size = 16
global_size = local_size * 8

x = dpctl.tensor.ones(global_size, dtype="int32")
y = dpctl.tensor.zeros(global_size, dtype="int32")
x.sycl_queue.wait()
y.sycl_queue.wait()

try:
q.submit(
kernel,
[
x.usm_data,
y.usm_data,
dpctl.WorkGroupMemory(local_size * x.itemsize),
],
[global_size],
[local_size],
)
q.wait()
except dpctl._sycl_queue.SyclKernelSubmitError:
pytest.skip(f"Kernel submission to {q.sycl_device} failed")

assert dpctl.tensor.all(x == y)
80 changes: 80 additions & 0 deletions dpctl/tests/test_work_group_memory_opencl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Data Parallel Control (dpctl)
#
# Copyright 2020-2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines unit test cases for the work_group_memory in an OpenCL kernel"""

import numpy as np
import pytest

import dpctl
import dpctl.tensor

ocl_kernel_src = """
__kernel void local_mem_kernel(__global float *input, __global float *output,
__local float *local_data) {
int gid = get_global_id(0);
int lid = get_local_id(0);

// Load input data into local memory
local_data[lid] = input[gid];

// Store the data in the output array
output[gid] = local_data[lid];
}
"""


def test_submit_work_group_memory_opencl():
if not dpctl.WorkGroupMemory.is_available():
pytest.skip("Work group memory extension not supported")

try:
q = dpctl.SyclQueue("opencl")
except dpctl.SyclQueueCreationError:
pytest.skip("OpenCL queue could not be created")

prog = dpctl.program.create_program_from_source(q, ocl_kernel_src)
kernel = prog.get_sycl_kernel("local_mem_kernel")
local_size = 16
global_size = local_size * 8

x_dev = dpctl.memory.MemoryUSMDevice(global_size * 4, queue=q)
y_dev = dpctl.memory.MemoryUSMDevice(global_size * 4, queue=q)

x = np.ones(global_size, dtype="float32")
y = np.zeros(global_size, dtype="float32")
q.memcpy(x_dev, x, x_dev.nbytes)
q.memcpy(y_dev, y, y_dev.nbytes)

try:
q.submit(
kernel,
[
x_dev,
y_dev,
dpctl.WorkGroupMemory(local_size * x.itemsize),
],
[global_size],
[local_size],
)
q.wait()
except dpctl._sycl_queue.SyclKernelSubmitError:
pytest.fail("Foo")
pytest.skip(f"Kernel submission to {q.sycl_device} failed")

q.memcpy(y, y_dev, y_dev.nbytes)

assert np.all(x == y)
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ typedef enum
DPCTL_FLOAT64_T,
DPCTL_VOID_PTR,
DPCTL_LOCAL_ACCESSOR,
DPCTL_WORK_GROUP_MEMORY,
DPCTL_UNSUPPORTED_KERNEL_ARG
} DPCTLKernelArgType;

Expand Down
Loading
Loading