-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for work_group_memory extension #1984
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,6 +54,9 @@ from ._backend cimport ( # noqa: E211 | |
DPCTLSyclContextRef, | ||
DPCTLSyclDeviceSelectorRef, | ||
DPCTLSyclEventRef, | ||
DPCTLWorkGroupMemory_Available, | ||
DPCTLWorkGroupMemory_Create, | ||
DPCTLWorkGroupMemory_Delete, | ||
_arg_data_type, | ||
_backend_type, | ||
_queue_property_type, | ||
|
@@ -250,6 +253,15 @@ cdef class _kernel_arg_type: | |
_arg_data_type._LOCAL_ACCESSOR | ||
) | ||
|
||
@property | ||
def dpctl_work_group_memory(self): | ||
cdef str p_name = "dpctl_work_group_memory" | ||
return kernel_arg_type_attribute( | ||
self._name, | ||
p_name, | ||
_arg_data_type._WORK_GROUP_MEMORY | ||
) | ||
|
||
|
||
kernel_arg_type = _kernel_arg_type() | ||
|
||
|
@@ -849,6 +861,9 @@ cdef class SyclQueue(_SyclQueue): | |
elif isinstance(arg, _Memory): | ||
kargs[idx]= <void*>(<size_t>arg._pointer) | ||
kargty[idx] = _arg_data_type._VOID_PTR | ||
elif isinstance(arg, WorkGroupMemory): | ||
kargs[idx] = <void*>(<size_t>arg._ref) | ||
kargty[idx] = _arg_data_type._WORK_GROUP_MEMORY | ||
else: | ||
ret = -1 | ||
return ret | ||
|
@@ -1524,3 +1539,41 @@ cdef api SyclQueue SyclQueue_Make(DPCTLSyclQueueRef QRef): | |
""" | ||
cdef DPCTLSyclQueueRef copied_QRef = DPCTLQueue_Copy(QRef) | ||
return SyclQueue._create(copied_QRef) | ||
|
||
cdef class _WorkGroupMemory: | ||
def __dealloc__(self): | ||
if(self._mem_ref): | ||
DPCTLWorkGroupMemory_Delete(self._mem_ref) | ||
|
||
cdef class WorkGroupMemory: | ||
""" | ||
WorkGroupMemory(nbytes) | ||
Python class representing the ``work_group_memory`` class from the | ||
Workgroup Memory oneAPI SYCL extension for low-overhead allocation of local | ||
memory shared by the workitems in a workgroup. | ||
|
||
This is based on a DPC++ SYCL extension and only available in newer | ||
versions. Use ``is_available()`` to check availability in your build. | ||
|
||
Args: | ||
nbytes (int) | ||
number of bytes to allocate in local memory. | ||
Expected to be positive. | ||
""" | ||
def __cinit__(self, Py_ssize_t nbytes): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you see a reason to provide another constructor that takes a given type and a given number of elements? The extension includes could be as simple as a a call which wraps conversion from only question is if there's much of a need or use for it |
||
if not DPCTLWorkGroupMemory_Available(): | ||
raise RuntimeError("Workgroup memory extension not available") | ||
|
||
self._mem_ref = DPCTLWorkGroupMemory_Create(nbytes) | ||
|
||
"""Check whether the work_group_memory extension is available""" | ||
@staticmethod | ||
def is_available(): | ||
return DPCTLWorkGroupMemory_Available() | ||
|
||
property _ref: | ||
"""Returns the address of the C API ``DPCTLWorkGroupMemoryRef`` | ||
pointer as a ``size_t``. | ||
""" | ||
def __get__(self): | ||
return <size_t>self._mem_ref |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Data Parallel Control (dpctl) | ||
# | ||
# Copyright 2020-2025 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Defines unit test cases for the work_group_memory in a SYCL kernel""" | ||
|
||
import os | ||
|
||
import pytest | ||
|
||
import dpctl | ||
import dpctl.tensor | ||
|
||
|
||
def get_spirv_abspath(fn): | ||
curr_dir = os.path.dirname(os.path.abspath(__file__)) | ||
spirv_file = os.path.join(curr_dir, "input_files", fn) | ||
return spirv_file | ||
|
||
|
||
# The kernel in the SPIR-V file used in this test was generated from the | ||
# following SYCL source code: | ||
# #include <sycl/sycl.hpp> | ||
# using namespace sycl; | ||
# namespace syclexp = sycl::ext::oneapi::experimental; | ||
# namespace syclext = sycl::ext::oneapi; | ||
# using data_t = int32_t; | ||
# | ||
# extern "C" SYCL_EXTERNAL | ||
# SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>)) | ||
# void local_mem_kernel(data_t* in, data_t* out, | ||
# syclexp::work_group_memory<data_t> mem){ | ||
# auto* local_mem = &mem; | ||
# auto item = syclext::this_work_item::get_nd_item<1>(); | ||
# size_t global_id = item.get_global_linear_id(); | ||
# size_t local_id = item.get_local_linear_id(); | ||
# local_mem[local_id] = in[global_id]; | ||
# out[global_id] = local_mem[local_id]; | ||
# } | ||
|
||
|
||
def test_submit_work_group_memory(): | ||
if not dpctl.WorkGroupMemory.is_available(): | ||
pytest.skip("Work group memory extension not supported") | ||
|
||
try: | ||
q = dpctl.SyclQueue("level_zero") | ||
except dpctl.SyclQueueCreationError: | ||
pytest.skip("LevelZero queue could not be created") | ||
spirv_file = get_spirv_abspath("work-group-memory-kernel.spv") | ||
with open(spirv_file, "br") as spv: | ||
spv_bytes = spv.read() | ||
prog = dpctl.program.create_program_from_spirv(q, spv_bytes) | ||
kernel = prog.get_sycl_kernel("__sycl_kernel_local_mem_kernel") | ||
local_size = 16 | ||
global_size = local_size * 8 | ||
|
||
x = dpctl.tensor.ones(global_size, dtype="int32") | ||
y = dpctl.tensor.zeros(global_size, dtype="int32") | ||
x.sycl_queue.wait() | ||
y.sycl_queue.wait() | ||
|
||
try: | ||
q.submit( | ||
kernel, | ||
[ | ||
x.usm_data, | ||
y.usm_data, | ||
dpctl.WorkGroupMemory(local_size * x.itemsize), | ||
], | ||
[global_size], | ||
[local_size], | ||
) | ||
q.wait() | ||
except dpctl._sycl_queue.SyclKernelSubmitError: | ||
pytest.skip(f"Kernel submission to {q.sycl_device} failed") | ||
|
||
assert dpctl.tensor.all(x == y) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Data Parallel Control (dpctl) | ||
# | ||
# Copyright 2020-2025 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Defines unit test cases for the work_group_memory in an OpenCL kernel""" | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
import dpctl | ||
import dpctl.tensor | ||
|
||
ocl_kernel_src = """ | ||
__kernel void local_mem_kernel(__global float *input, __global float *output, | ||
__local float *local_data) { | ||
int gid = get_global_id(0); | ||
int lid = get_local_id(0); | ||
|
||
// Load input data into local memory | ||
local_data[lid] = input[gid]; | ||
|
||
// Store the data in the output array | ||
output[gid] = local_data[lid]; | ||
} | ||
""" | ||
|
||
|
||
def test_submit_work_group_memory_opencl(): | ||
if not dpctl.WorkGroupMemory.is_available(): | ||
pytest.skip("Work group memory extension not supported") | ||
|
||
try: | ||
q = dpctl.SyclQueue("opencl") | ||
except dpctl.SyclQueueCreationError: | ||
pytest.skip("OpenCL queue could not be created") | ||
|
||
prog = dpctl.program.create_program_from_source(q, ocl_kernel_src) | ||
kernel = prog.get_sycl_kernel("local_mem_kernel") | ||
local_size = 16 | ||
global_size = local_size * 8 | ||
|
||
x_dev = dpctl.memory.MemoryUSMDevice(global_size * 4, queue=q) | ||
y_dev = dpctl.memory.MemoryUSMDevice(global_size * 4, queue=q) | ||
|
||
x = np.ones(global_size, dtype="float32") | ||
y = np.zeros(global_size, dtype="float32") | ||
q.memcpy(x_dev, x, x_dev.nbytes) | ||
q.memcpy(y_dev, y, y_dev.nbytes) | ||
|
||
try: | ||
q.submit( | ||
kernel, | ||
[ | ||
x_dev, | ||
y_dev, | ||
dpctl.WorkGroupMemory(local_size * x.itemsize), | ||
], | ||
[global_size], | ||
[local_size], | ||
) | ||
q.wait() | ||
except dpctl._sycl_queue.SyclKernelSubmitError: | ||
pytest.fail("Foo") | ||
pytest.skip(f"Kernel submission to {q.sycl_device} failed") | ||
|
||
q.memcpy(y, y_dev, y_dev.nbytes) | ||
|
||
assert np.all(x == y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
File no longer exists I believe