Skip to content

Commit

Permalink
Enable std::byte as vec's element type
Browse files Browse the repository at this point in the history
Spec changes are at KhronosGroup/SYCL-Docs#674.
  • Loading branch information
aelovikov-intel committed Feb 26, 2025
1 parent ef8b127 commit 80e6378
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 61 deletions.
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ endfunction()
function(get_no_vec_alias_type OUT_LIST)
set(NO_VEC_ALIAS_LIST "")
list(APPEND NO_VEC_ALIAS_LIST sycl::byte)
list(APPEND NO_VEC_ALIAS_LIST std::byte)

set(${OUT_LIST} ${${OUT_LIST}} ${NO_VEC_ALIAS_LIST} PARENT_SCOPE)
endfunction()
Expand Down
48 changes: 26 additions & 22 deletions tests/common/common_python_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Data:
standard_sizes = [1, 2, 3, 4, 8, 16]
standard_types = [
'bool', 'char', 'short', 'int', 'long', 'long long', 'float',
'double', 'sycl::half'
'double', 'sycl::half', 'std::byte'
]
standard_type_dict = {
(True, 'bool'): 'bool',
Expand All @@ -46,7 +46,8 @@ class Data:
(True, 'long long'): 'long long',
(True, 'float'): 'float',
(True, 'double'): 'double',
(True, 'sycl::half'): 'sycl::half'
(True, 'sycl::half'): 'sycl::half',
(False, 'std::byte'): 'std::byte'
}

fixed_width_types = [
Expand Down Expand Up @@ -92,7 +93,8 @@ class Data:
'bool': 'false',
'float': '0.0f',
'double': '0.0',
'sycl::half': '0.0f'
'sycl::half': '0.0f',
'std::byte': 'std::byte{0}'
})
vec_name_dict = {
1: 'One',
Expand Down Expand Up @@ -245,7 +247,8 @@ def remove_namespaces_whitespaces(type_str):
Clear type name from namespaces and whitespaces
"""
return type_str.replace('sycl::', '').replace(
' ', '_').replace('std::', '')
' ', '_').replace('std::byte', 'std_byte').replace(
'std::', '')

def wrap_with_kernel(type_str, kernel_name, test_name, test_string):
"""
Expand Down Expand Up @@ -338,7 +341,7 @@ def wrap_with_extension_checks(type_str, test_string):
return test_string


def append_fp_postfix(type_str, input_val_list):
def make_fp_or_byte_explicit(type_str, input_val_list):
"""Generates and returns a new list from the input, with .0f or .0 appended
to each value in the list if type_str is 'float', 'double' or 'sycl::half'"""
result_val_list = []
Expand All @@ -348,6 +351,8 @@ def append_fp_postfix(type_str, input_val_list):
result_val_list.append(val + '.0f')
elif type_str == 'double':
result_val_list.append(val + '.0')
elif type_str == 'std::byte':
result_val_list.append('std::byte{{{}}}'.format(val))
else:
result_val_list.append(val)
return result_val_list
Expand All @@ -359,7 +364,7 @@ def generate_value_list(type_str, size):
vec_val_list = []
for val in Data.vals_list_dict[size]:
vec_val_list.append(val)
vec_val_list = append_fp_postfix(type_str, vec_val_list)
vec_val_list = make_fp_or_byte_explicit(type_str, vec_val_list)
vec_val_string = ', '.join(vec_val_list)
return str(vec_val_string)

Expand Down Expand Up @@ -444,6 +449,8 @@ def get_types():
if (base_type == 'float' or base_type == 'double' or base_type == 'bool'
or base_type == 'sycl::half') and sign is False:
continue
if base_type == 'std::byte' and sign is True:
continue
types.append(Data.standard_type_dict[(sign, base_type)])

for base_type in Data.fixed_width_types:
Expand Down Expand Up @@ -581,7 +588,6 @@ def substitute_swizzles_templates(type_str, size, index_subset, value_subset, co
for index, value in zip(index_subset, value_subset):
index_list.append(index)
val_list.append(value)
val_list = append_fp_postfix(type_str, val_list)
index_string = ''.join(index_list)
test_string = SwizzleData.swizzle_template.substitute(
name=Data.vec_name_dict[size],
Expand Down Expand Up @@ -624,6 +630,7 @@ def substitute_swizzles_templates(type_str, size, index_subset, value_subset, co

def gen_swizzle_test(type_str, convert_type_str, as_type_str, size, num_batches, batch_index):
string = ''
val_list = make_fp_or_byte_explicit(type_str, Data.vals_list_dict[size])
if size > 4:
test_string = SwizzleData.swizzle_full_test_template.substitute(
name=Data.vec_name_dict[size],
Expand All @@ -639,17 +646,15 @@ def gen_swizzle_test(type_str, convert_type_str, as_type_str, size, num_batches,
swap_pairs(Data.swizzle_elem_list_dict[size])),
reverse_order_reversed_pair_swiz_indexes=', '.join(
swap_pairs(Data.swizzle_elem_list_dict[size][::-1])),
in_order_vals=', '.join(Data.vals_list_dict[size]),
reversed_vals=', '.join(Data.vals_list_dict[size][::-1]),
in_order_pair_vals=', '.join(
swap_pairs(Data.vals_list_dict[size])),
reverse_order_pair_vals=', '.join(
swap_pairs(Data.vals_list_dict[size][::-1])))
in_order_vals=', '.join(val_list),
reversed_vals=', '.join(val_list[::-1]),
in_order_pair_vals=', '.join(swap_pairs(val_list)),
reverse_order_pair_vals=', '.join(swap_pairs(val_list[::-1])))
string += wrap_with_swizzle_kernel(
type_str, str(size), ', '.join(Data.vals_list_dict[size]),
', '.join(Data.vals_list_dict[size][::-1]),
', '.join(swap_pairs(Data.vals_list_dict[size])),
', '.join(swap_pairs(Data.vals_list_dict[size][::-1])),
type_str, str(size), ', '.join(val_list),
', '.join(val_list[::-1]),
', '.join(swap_pairs(val_list)),
', '.join(swap_pairs(val_list[::-1])),
'ELEM_KERNEL_' + type_str + str(size) +
''.join(Data.swizzle_elem_list_dict[size][:size]).replace(
'sycl::elem::', ''),
Expand All @@ -672,7 +677,7 @@ def gen_swizzle_test(type_str, convert_type_str, as_type_str, size, num_batches,
product(
Data.swizzle_xyzw_list_dict[size][:size],
repeat=length),
product(Data.vals_list_dict[size][:size], repeat=length)):
product(val_list[:size], repeat=length)):
total_tests += 1
batch_size = ceil(total_tests / num_batches)
cur_index = 0
Expand All @@ -682,7 +687,7 @@ def gen_swizzle_test(type_str, convert_type_str, as_type_str, size, num_batches,
product(
Data.swizzle_xyzw_list_dict[size][:size],
repeat=length),
product(Data.vals_list_dict[size][:size], repeat=length)):
product(val_list[:size], repeat=length)):
cur_batch = floor(cur_index / batch_size)
if cur_batch > batch_index:
break
Expand All @@ -700,7 +705,7 @@ def gen_swizzle_test(type_str, convert_type_str, as_type_str, size, num_batches,
Data.swizzle_rgba_list_dict[size][:size],
repeat=length),
product(
Data.vals_list_dict[size][:size], repeat=length)):
val_list[:size], repeat=length)):
total_tests += 1
batch_size = ceil(total_tests / num_batches)
cur_index = 0
Expand All @@ -710,8 +715,7 @@ def gen_swizzle_test(type_str, convert_type_str, as_type_str, size, num_batches,
product(
Data.swizzle_rgba_list_dict[size][:size],
repeat=length),
product(
Data.vals_list_dict[size][:size], repeat=length)):
product(val_list[:size], repeat=length)):
cur_batch = floor(cur_index / batch_size)
if cur_batch > batch_index:
break
Expand Down
16 changes: 8 additions & 8 deletions tests/common/common_vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ bool check_lo_hi_odd_even(sycl::vec<vecType, N> inputVec, vecType* vals) {
// lo()
{
sycl::vec<vecType, mid> loVec{inputVec.lo()};
vecType loVals[mid] = {0};
vecType loVals[mid] = {vecType{0}};
for (size_t i = 0; i < mid; i++) {
loVals[i] = vals[i];
}
Expand All @@ -514,7 +514,7 @@ bool check_lo_hi_odd_even(sycl::vec<vecType, N> inputVec, vecType* vals) {
{
sycl::vec<vecType, mid> loVec;
DO_OPERATION_ON_SWIZZLE(N, inputVec, loVec, lo());
vecType loVals[mid] = {0};
vecType loVals[mid] = {vecType{0}};
for (size_t i = 0; i < mid; i++) {
loVals[i] = vals[i];
}
Expand All @@ -528,7 +528,7 @@ bool check_lo_hi_odd_even(sycl::vec<vecType, N> inputVec, vecType* vals) {
{
// hi()
sycl::vec<vecType, mid> hiVec{inputVec.hi()};
vecType hiVals[mid] = {0};
vecType hiVals[mid] = {vecType{0}};
for (size_t i = 0; i < mid; i++) {
hiVals[i] = vals[i + mid];
}
Expand All @@ -540,7 +540,7 @@ bool check_lo_hi_odd_even(sycl::vec<vecType, N> inputVec, vecType* vals) {
// hi()
sycl::vec<vecType, mid> hiVec;
DO_OPERATION_ON_SWIZZLE(N, inputVec, hiVec, hi());
vecType hiVals[mid] = {0};
vecType hiVals[mid] ={vecType{0}};
for (size_t i = 0; i < mid; i++) {
hiVals[i] = vals[i + mid];
}
Expand All @@ -555,7 +555,7 @@ bool check_lo_hi_odd_even(sycl::vec<vecType, N> inputVec, vecType* vals) {
{
// odd()
sycl::vec<vecType, mid> oddVec{inputVec.odd()};
vecType oddVals[mid] = {0};
vecType oddVals[mid] ={vecType{0}};
for (size_t i = 0; i < mid; ++i) {
oddVals[i] = vals[i * 2 + 1];
}
Expand All @@ -567,7 +567,7 @@ bool check_lo_hi_odd_even(sycl::vec<vecType, N> inputVec, vecType* vals) {
// odd()
sycl::vec<vecType, mid> oddVec;
DO_OPERATION_ON_SWIZZLE(N, inputVec, oddVec, odd());
vecType oddVals[mid] = {0};
vecType oddVals[mid] ={vecType{0}};
for (size_t i = 0; i < mid; ++i) {
oddVals[i] = vals[i * 2 + 1];
}
Expand All @@ -579,7 +579,7 @@ bool check_lo_hi_odd_even(sycl::vec<vecType, N> inputVec, vecType* vals) {
// even()
{
sycl::vec<vecType, mid> evenVec{inputVec.even()};
vecType evenVals[mid] = {0};
vecType evenVals[mid] ={vecType{0}};
for (size_t i = 0; i < mid; ++i) {
evenVals[i] = vals[i * 2];
}
Expand All @@ -590,7 +590,7 @@ bool check_lo_hi_odd_even(sycl::vec<vecType, N> inputVec, vecType* vals) {
{
sycl::vec<vecType, mid> evenVec;
DO_OPERATION_ON_SWIZZLE(N, inputVec, evenVec, even());
vecType evenVals[mid] = {0};
vecType evenVals[mid] ={vecType{0}};
for (size_t i = 0; i < mid; ++i) {
evenVals[i] = vals[i * 2];
}
Expand Down
7 changes: 4 additions & 3 deletions tests/vector_api/generate_vector_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import argparse
from string import Template
sys.path.append('../common/')
from common_python_vec import (Data, ReverseData, append_fp_postfix, wrap_with_kernel,
from common_python_vec import (Data, ReverseData, make_fp_or_byte_explicit, wrap_with_kernel,
wrap_with_test_func, make_func_call,
write_source_file, get_types, cast_to_bool)

Expand Down Expand Up @@ -77,8 +77,9 @@


def gen_checks(type_str, size):
vals_list = append_fp_postfix(type_str, Data.vals_list_dict[size])
if 'double' in type_str or 'half' in type_str or 'float' in type_str:
vals_list = make_fp_or_byte_explicit(type_str, Data.vals_list_dict[size])
if type_str in ('double', 'float', 'half'):
# Want fractions/sign...
vals_list = Data.vals_list_dict_float[size]
reverse_vals_list = vals_list[::-1]
kernel_name = 'KERNEL_API_' + type_str + str(size)
Expand Down
8 changes: 4 additions & 4 deletions tests/vector_constructors/generate_vector_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
sys.path.append('../common/')
from common_python_vec import (Data, ReverseData, wrap_with_kernel, wrap_with_test_func,
make_func_call, write_source_file, get_types, cast_to_bool,
append_fp_postfix)
make_fp_or_byte_explicit)

TEST_NAME = 'CONSTRUCTORS'

Expand All @@ -43,7 +43,7 @@
""")

explicit_constructor_vec_template = Template(
""" const ${type} val = ${val};
""" const ${type} val{${val}};
${type} vals[] = {${vals}};
auto test = sycl::vec<${type}, ${size}>(val);
if (!check_equal_type_bool<sycl::vec<${type}, ${size}>>(test)) {
Expand Down Expand Up @@ -151,8 +151,8 @@ def generate_mixed(type_str, size):
return ''
input_vec_size = size//2
values_size = size - input_vec_size
vals_list1 = append_fp_postfix(type_str, Data.vals_list_dict[size][:input_vec_size])
vals_list2 = append_fp_postfix(type_str, Data.vals_list_dict[size][-values_size:])
vals_list1 = make_fp_or_byte_explicit(type_str, Data.vals_list_dict[size][:input_vec_size])
vals_list2 = make_fp_or_byte_explicit(type_str, Data.vals_list_dict[size][-values_size:])
test_string = mixed_constructor_vec_template.substitute(
type=type_str,
size=size,
Expand Down
26 changes: 13 additions & 13 deletions tests/vector_load_store/generate_vector_load_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import argparse
from string import Template
sys.path.append('../common/')
from common_python_vec import (Data, append_fp_postfix, make_func_call,
from common_python_vec import (Data, make_fp_or_byte_explicit, make_func_call,
wrap_with_test_func, write_source_file,
wrap_with_extension_checks, get_types,
remove_namespaces_whitespaces, cast_to_bool)
Expand Down Expand Up @@ -411,9 +411,9 @@ def gen_load_store_test(type_str, size):
size=size,
val=Data.value_default_dict[type_str],
in_order_vals=', '.join(
append_fp_postfix(type_str, Data.vals_list_dict[size])),
make_fp_or_byte_explicit(type_str, Data.vals_list_dict[size])),
reverse_order_vals=', '.join(
append_fp_postfix(type_str, Data.vals_list_dict[size][::-1])),
make_fp_or_byte_explicit(type_str, Data.vals_list_dict[size][::-1])),
kernelName=gen_kernel_name(type_str, size, decoration + '_global'),
swizVals=', '.join(Data.swizzle_elem_list_dict[size]),
decorated=('sycl::access::decorated::' + decoration))
Expand All @@ -423,9 +423,9 @@ def gen_load_store_test(type_str, size):
size=size,
val=Data.value_default_dict[type_str],
in_order_vals=', '.join(
append_fp_postfix(type_str, Data.vals_list_dict[size])),
make_fp_or_byte_explicit(type_str, Data.vals_list_dict[size])),
reverse_order_vals=', '.join(
append_fp_postfix(type_str, Data.vals_list_dict[size][::-1])),
make_fp_or_byte_explicit(type_str, Data.vals_list_dict[size][::-1])),
kernelName=gen_kernel_name(type_str, size, decoration + '_local'),
swizVals=', '.join(Data.swizzle_elem_list_dict[size]),
decorated=('sycl::access::decorated::' + decoration))
Expand All @@ -435,9 +435,9 @@ def gen_load_store_test(type_str, size):
size=size,
val=Data.value_default_dict[type_str],
in_order_vals=', '.join(
append_fp_postfix(type_str, Data.vals_list_dict[size])),
make_fp_or_byte_explicit(type_str, Data.vals_list_dict[size])),
reverse_order_vals=', '.join(
append_fp_postfix(type_str, Data.vals_list_dict[size][::-1])),
make_fp_or_byte_explicit(type_str, Data.vals_list_dict[size][::-1])),
kernelName=gen_kernel_name(type_str, size, decoration + '_private'),
swizVals=', '.join(Data.swizzle_elem_list_dict[size]),
decorated=('sycl::access::decorated::' + decoration))
Expand All @@ -450,9 +450,9 @@ def gen_load_store_test(type_str, size):
size=size,
val=Data.value_default_dict[type_str],
in_order_vals=', '.join(
append_fp_postfix(type_str, Data.vals_list_dict[size])),
make_fp_or_byte_explicit(type_str, Data.vals_list_dict[size])),
reverse_order_vals=', '.join(
append_fp_postfix(type_str, Data.vals_list_dict[size][::-1])),
make_fp_or_byte_explicit(type_str, Data.vals_list_dict[size][::-1])),
kernelName=gen_kernel_name(type_str, size, 'raw_global'),
swizVals=', '.join(Data.swizzle_elem_list_dict[size]))
test_string += local_raw_ptr_load_store_test_template.substitute(
Expand All @@ -461,9 +461,9 @@ def gen_load_store_test(type_str, size):
size=size,
val=Data.value_default_dict[type_str],
in_order_vals=', '.join(
append_fp_postfix(type_str, Data.vals_list_dict[size])),
make_fp_or_byte_explicit(type_str, Data.vals_list_dict[size])),
reverse_order_vals=', '.join(
append_fp_postfix(type_str, Data.vals_list_dict[size][::-1])),
make_fp_or_byte_explicit(type_str, Data.vals_list_dict[size][::-1])),
kernelName=gen_kernel_name(type_str, size, 'raw_local'),
swizVals=', '.join(Data.swizzle_elem_list_dict[size]))
test_string += private_raw_ptr_load_store_test_template.substitute(
Expand All @@ -472,9 +472,9 @@ def gen_load_store_test(type_str, size):
size=size,
val=Data.value_default_dict[type_str],
in_order_vals=', '.join(
append_fp_postfix(type_str, Data.vals_list_dict[size])),
make_fp_or_byte_explicit(type_str, Data.vals_list_dict[size])),
reverse_order_vals=', '.join(
append_fp_postfix(type_str, Data.vals_list_dict[size][::-1])),
make_fp_or_byte_explicit(type_str, Data.vals_list_dict[size][::-1])),
kernelName=gen_kernel_name(type_str, size, 'raw_private'),
swizVals=', '.join(Data.swizzle_elem_list_dict[size]))
return wrap_with_test_func(TEST_NAME, type_str,
Expand Down
Loading

0 comments on commit 80e6378

Please sign in to comment.