Skip to content

Commit 110b8d9

Browse files
authored
Merge pull request #627 from robertknight/simd-xor
Implement bitwise NOT and XOR in new SIMD API
2 parents b373abd + e51d1cc commit 110b8d9

File tree

7 files changed

+198
-31
lines changed

7 files changed

+198
-31
lines changed

rten-simd/src/safe.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,17 @@ macro_rules! assert_simd_eq {
173173
};
174174
}
175175

176+
/// Test that two [`Simd`] vectors are not equal according to a [`PartialEq`]
177+
/// comparison of their array representations.
178+
#[cfg(test)]
179+
macro_rules! assert_simd_ne {
180+
($x:expr, $y:expr) => {
181+
assert_ne!($x.to_array(), $y.to_array());
182+
};
183+
}
184+
176185
#[cfg(test)]
177-
pub(crate) use assert_simd_eq;
186+
pub(crate) use {assert_simd_eq, assert_simd_ne};
178187

179188
#[cfg(test)]
180189
mod tests {

rten-simd/src/safe/arch/aarch64.rs

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ use std::arch::aarch64::{
66
vcgeq_u16, vcgeq_u8, vcgtq_f32, vcgtq_s16, vcgtq_s32, vcgtq_s8, vcgtq_u16, vcgtq_u8, vcleq_f32,
77
vcleq_s16, vcleq_s8, vcleq_u16, vcleq_u8, vcltq_f32, vcltq_s16, vcltq_s8, vcltq_u16, vcltq_u8,
88
vcombine_s16, vcombine_u8, vcvtnq_s32_f32, vcvtq_s32_f32, vdivq_f32, vdupq_n_f32, vdupq_n_s16,
9-
vdupq_n_s32, vdupq_n_s8, vdupq_n_u16, vdupq_n_u8, vfmaq_f32, vget_low_s16, vget_low_s8,
10-
vld1q_f32, vld1q_s16, vld1q_s32, vld1q_s8, vld1q_u16, vld1q_u32, vld1q_u8, vmaxq_f32,
11-
vminq_f32, vmovl_high_s16, vmovl_high_s8, vmovl_s16, vmovl_s8, vmulq_f32, vmulq_s16, vmulq_s32,
12-
vmulq_s8, vmulq_u16, vmulq_u8, vnegq_f32, vnegq_s16, vnegq_s32, vnegq_s8, vqmovn_s32,
13-
vqmovun_s16, vshlq_n_s16, vshlq_n_s32, vshlq_n_s8, vst1q_f32, vst1q_s16, vst1q_s32, vst1q_s8,
14-
vst1q_u16, vst1q_u8, vsubq_f32, vsubq_s16, vsubq_s32, vsubq_s8, vsubq_u16, vsubq_u8,
9+
vdupq_n_s32, vdupq_n_s8, vdupq_n_u16, vdupq_n_u8, veorq_u32, vfmaq_f32, vget_low_s16,
10+
vget_low_s8, vld1q_f32, vld1q_s16, vld1q_s32, vld1q_s8, vld1q_u16, vld1q_u32, vld1q_u8,
11+
vmaxq_f32, vminq_f32, vmovl_high_s16, vmovl_high_s8, vmovl_s16, vmovl_s8, vmulq_f32, vmulq_s16,
12+
vmulq_s32, vmulq_s8, vmulq_u16, vmulq_u8, vmvnq_u32, vnegq_f32, vnegq_s16, vnegq_s32, vnegq_s8,
13+
vqmovn_s32, vqmovun_s16, vshlq_n_s16, vshlq_n_s32, vshlq_n_s8, vst1q_f32, vst1q_s16, vst1q_s32,
14+
vst1q_s8, vst1q_u16, vst1q_u8, vsubq_f32, vsubq_s16, vsubq_s32, vsubq_s8, vsubq_u16, vsubq_u8,
1515
vzip1q_s16, vzip1q_s8, vzip2q_s16, vzip2q_s8,
1616
};
1717
use std::mem::transmute;
@@ -113,6 +113,28 @@ macro_rules! simd_ops_common {
113113
fn mask_ops(self) -> impl MaskOps<$mask> {
114114
self
115115
}
116+
117+
// Since bitwise ops work on individual bits, we can use the same
118+
// implementation regardless of numeric type.
119+
120+
#[inline]
121+
fn xor(self, x: $simd, y: $simd) -> $simd {
122+
unsafe {
123+
let x = transmute::<$simd, uint32x4_t>(x);
124+
let y = transmute::<$simd, uint32x4_t>(y);
125+
let tmp = veorq_u32(x, y);
126+
transmute::<uint32x4_t, $simd>(tmp)
127+
}
128+
}
129+
130+
#[inline]
131+
fn not(self, x: $simd) -> $simd {
132+
unsafe {
133+
let x = transmute::<$simd, uint32x4_t>(x);
134+
let tmp = vmvnq_u32(x);
135+
transmute::<uint32x4_t, $simd>(tmp)
136+
}
137+
}
116138
};
117139
}
118140

rten-simd/src/safe/arch/generic.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,32 @@ macro_rules! simd_ops_common {
225225
};
226226
}
227227

228+
macro_rules! simd_int_ops_common {
229+
($simd:ty) => {
230+
#[inline]
231+
fn xor(self, x: $simd, y: $simd) -> $simd {
232+
array::from_fn(|i| x.0[i] ^ y.0[i]).into()
233+
}
234+
235+
#[inline]
236+
fn not(self, x: $simd) -> $simd {
237+
array::from_fn(|i| !x.0[i]).into()
238+
}
239+
};
240+
}
241+
228242
unsafe impl NumOps<F32x4> for GenericIsa {
229243
simd_ops_common!(F32x4, f32, 4, M32);
244+
245+
#[inline]
246+
fn xor(self, x: F32x4, y: F32x4) -> F32x4 {
247+
array::from_fn(|i| f32::from_bits(x.0[i].to_bits() ^ y.0[i].to_bits())).into()
248+
}
249+
250+
#[inline]
251+
fn not(self, x: F32x4) -> F32x4 {
252+
array::from_fn(|i| f32::from_bits(!x.0[i].to_bits())).into()
253+
}
230254
}
231255

232256
impl FloatOps<F32x4> for GenericIsa {
@@ -267,6 +291,7 @@ macro_rules! impl_simd_signed_int_ops {
267291
($simd:ident, $elem:ty, $len:expr, $mask:ident) => {
268292
unsafe impl NumOps<$simd> for GenericIsa {
269293
simd_ops_common!($simd, $elem, $len, $mask);
294+
simd_int_ops_common!($simd);
270295
}
271296

272297
impl SignedIntOps<$simd> for GenericIsa {
@@ -334,6 +359,7 @@ macro_rules! impl_simd_unsigned_int_ops {
334359
($simd:ident, $elem:ty, $len:expr, $mask:ident) => {
335360
unsafe impl NumOps<$simd> for GenericIsa {
336361
simd_ops_common!($simd, $elem, $len, $mask);
362+
simd_int_ops_common!($simd);
337363
}
338364
};
339365
}

rten-simd/src/safe/arch/wasm32.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ use std::arch::wasm32::{
99
i8x16_neg, i8x16_shl, i8x16_shuffle, i8x16_splat, i8x16_sub, u16x8_add, u16x8_eq,
1010
u16x8_extmul_high_u8x16, u16x8_extmul_low_u8x16, u16x8_ge, u16x8_gt, u16x8_mul, u16x8_splat,
1111
u16x8_sub, u8x16_add, u8x16_eq, u8x16_ge, u8x16_gt, u8x16_narrow_i16x8, u8x16_shuffle,
12-
u8x16_splat, u8x16_sub, v128, v128_and, v128_bitselect, v128_load, v128_store,
12+
u8x16_splat, u8x16_sub, v128, v128_and, v128_bitselect, v128_load, v128_not, v128_store,
13+
v128_xor,
1314
};
1415
use std::mem::transmute;
1516

@@ -141,6 +142,16 @@ macro_rules! simd_ops_common {
141142
fn select(self, x: $simd, y: $simd, mask: <$simd as Simd>::Mask) -> $simd {
142143
$simd(v128_bitselect(x.0, y.0, mask.0))
143144
}
145+
146+
#[inline]
147+
fn xor(self, x: $simd, y: $simd) -> $simd {
148+
v128_xor(x.0, y.0).into()
149+
}
150+
151+
#[inline]
152+
fn not(self, x: $simd) -> $simd {
153+
v128_not(x.0).into()
154+
}
144155
};
145156
}
146157

rten-simd/src/safe/arch/x86_64/avx2.rs

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
use std::arch::x86_64::{
22
__m128i, __m256, __m256i, _mm256_add_epi16, _mm256_add_epi32, _mm256_add_epi8, _mm256_add_ps,
3-
_mm256_and_ps, _mm256_and_si256, _mm256_andnot_ps, _mm256_blendv_epi8, _mm256_blendv_ps,
4-
_mm256_castps256_ps128, _mm256_castsi256_si128, _mm256_cmp_ps, _mm256_cmpeq_epi16,
5-
_mm256_cmpeq_epi32, _mm256_cmpeq_epi8, _mm256_cmpgt_epi16, _mm256_cmpgt_epi32,
6-
_mm256_cmpgt_epi8, _mm256_cvtepi16_epi32, _mm256_cvtepi8_epi16, _mm256_cvtepu8_epi16,
7-
_mm256_cvtps_epi32, _mm256_cvttps_epi32, _mm256_div_ps, _mm256_extractf128_ps,
8-
_mm256_extracti128_si256, _mm256_fmadd_ps, _mm256_insertf128_si256, _mm256_loadu_ps,
9-
_mm256_loadu_si256, _mm256_maskload_epi32, _mm256_maskload_ps, _mm256_maskstore_epi32,
10-
_mm256_maskstore_ps, _mm256_max_ps, _mm256_min_ps, _mm256_movemask_epi8, _mm256_mul_ps,
11-
_mm256_mullo_epi16, _mm256_mullo_epi32, _mm256_or_si256, _mm256_packs_epi32,
12-
_mm256_packus_epi16, _mm256_permute2x128_si256, _mm256_permute4x64_epi64, _mm256_set1_epi16,
13-
_mm256_set1_epi32, _mm256_set1_epi8, _mm256_set1_ps, _mm256_setr_m128i, _mm256_setzero_si256,
14-
_mm256_slli_epi16, _mm256_slli_epi32, _mm256_storeu_ps, _mm256_storeu_si256, _mm256_sub_epi16,
15-
_mm256_sub_epi32, _mm256_sub_epi8, _mm256_sub_ps, _mm256_unpackhi_epi16, _mm256_unpackhi_epi8,
16-
_mm256_unpacklo_epi16, _mm256_unpacklo_epi8, _mm256_xor_ps, _mm256_xor_si256, _mm_add_ps,
17-
_mm_cvtss_f32, _mm_movehl_ps, _mm_prefetch, _mm_setr_epi8, _mm_shuffle_epi8, _mm_shuffle_ps,
18-
_mm_unpacklo_epi64, _CMP_EQ_OQ, _CMP_GE_OQ, _CMP_GT_OQ, _CMP_LE_OQ, _CMP_LT_OQ, _MM_HINT_ET0,
19-
_MM_HINT_T0,
3+
_mm256_and_ps, _mm256_and_si256, _mm256_andnot_ps, _mm256_andnot_si256, _mm256_blendv_epi8,
4+
_mm256_blendv_ps, _mm256_castps256_ps128, _mm256_castsi256_si128, _mm256_cmp_ps,
5+
_mm256_cmpeq_epi16, _mm256_cmpeq_epi32, _mm256_cmpeq_epi8, _mm256_cmpgt_epi16,
6+
_mm256_cmpgt_epi32, _mm256_cmpgt_epi8, _mm256_cvtepi16_epi32, _mm256_cvtepi8_epi16,
7+
_mm256_cvtepu8_epi16, _mm256_cvtps_epi32, _mm256_cvttps_epi32, _mm256_div_ps,
8+
_mm256_extractf128_ps, _mm256_extracti128_si256, _mm256_fmadd_ps, _mm256_insertf128_si256,
9+
_mm256_loadu_ps, _mm256_loadu_si256, _mm256_maskload_epi32, _mm256_maskload_ps,
10+
_mm256_maskstore_epi32, _mm256_maskstore_ps, _mm256_max_ps, _mm256_min_ps,
11+
_mm256_movemask_epi8, _mm256_mul_ps, _mm256_mullo_epi16, _mm256_mullo_epi32, _mm256_or_si256,
12+
_mm256_packs_epi32, _mm256_packus_epi16, _mm256_permute2x128_si256, _mm256_permute4x64_epi64,
13+
_mm256_set1_epi16, _mm256_set1_epi32, _mm256_set1_epi8, _mm256_set1_ps, _mm256_setr_m128i,
14+
_mm256_setzero_si256, _mm256_slli_epi16, _mm256_slli_epi32, _mm256_storeu_ps,
15+
_mm256_storeu_si256, _mm256_sub_epi16, _mm256_sub_epi32, _mm256_sub_epi8, _mm256_sub_ps,
16+
_mm256_unpackhi_epi16, _mm256_unpackhi_epi8, _mm256_unpacklo_epi16, _mm256_unpacklo_epi8,
17+
_mm256_xor_ps, _mm256_xor_si256, _mm_add_ps, _mm_cvtss_f32, _mm_movehl_ps, _mm_prefetch,
18+
_mm_setr_epi8, _mm_shuffle_epi8, _mm_shuffle_ps, _mm_unpacklo_epi64, _CMP_EQ_OQ, _CMP_GE_OQ,
19+
_CMP_GT_OQ, _CMP_LE_OQ, _CMP_LT_OQ, _MM_HINT_ET0, _MM_HINT_T0,
2020
};
2121
use std::is_x86_feature_detected;
2222
use std::mem::transmute;
@@ -116,6 +116,20 @@ macro_rules! simd_ops_common {
116116
};
117117
}
118118

119+
macro_rules! simd_int_ops_common {
120+
($simd:ty) => {
121+
#[inline]
122+
fn xor(self, x: $simd, y: $simd) -> $simd {
123+
unsafe { _mm256_xor_si256(x.0, y.0) }.into()
124+
}
125+
126+
#[inline]
127+
fn not(self, x: $simd) -> $simd {
128+
unsafe { _mm256_andnot_si256(x.0, _mm256_set1_epi8(-1)) }.into()
129+
}
130+
};
131+
}
132+
119133
unsafe impl NumOps<F32x8> for Avx2Isa {
120134
simd_ops_common!(F32x8, F32x8);
121135

@@ -180,6 +194,17 @@ unsafe impl NumOps<F32x8> for Avx2Isa {
180194
unsafe { _mm256_max_ps(x.0, y.0) }.into()
181195
}
182196

197+
#[inline]
198+
fn xor(self, x: F32x8, y: F32x8) -> F32x8 {
199+
unsafe { _mm256_xor_ps(x.0, y.0) }.into()
200+
}
201+
202+
#[inline]
203+
fn not(self, x: F32x8) -> F32x8 {
204+
let all_ones: F32x8 = self.splat(f32::from_bits(0xFFFFFFFF));
205+
unsafe { _mm256_andnot_ps(x.0, all_ones.0) }.into()
206+
}
207+
183208
#[inline]
184209
fn splat(self, x: f32) -> F32x8 {
185210
unsafe { _mm256_set1_ps(x) }.into()
@@ -259,6 +284,7 @@ impl FloatOps<F32x8> for Avx2Isa {
259284

260285
unsafe impl NumOps<I32x8> for Avx2Isa {
261286
simd_ops_common!(I32x8, I32x8);
287+
simd_int_ops_common!(I32x8);
262288

263289
#[inline]
264290
fn first_n_mask(self, n: usize) -> I32x8 {
@@ -362,6 +388,7 @@ impl NarrowSaturate<I32x8, I16x16> for Avx2Isa {
362388

363389
unsafe impl NumOps<I16x16> for Avx2Isa {
364390
simd_ops_common!(I16x16, I16x16);
391+
simd_int_ops_common!(I16x16);
365392

366393
#[inline]
367394
fn first_n_mask(self, n: usize) -> I16x16 {
@@ -506,6 +533,7 @@ impl Interleave<I16x16> for Avx2Isa {
506533

507534
unsafe impl NumOps<I8x32> for Avx2Isa {
508535
simd_ops_common!(I8x32, I8x32);
536+
simd_int_ops_common!(I8x32);
509537

510538
#[inline]
511539
fn first_n_mask(self, n: usize) -> I8x32 {
@@ -647,6 +675,7 @@ impl Interleave<I8x32> for Avx2Isa {
647675

648676
unsafe impl NumOps<U8x32> for Avx2Isa {
649677
simd_ops_common!(U8x32, I8x32);
678+
simd_int_ops_common!(U8x32);
650679

651680
#[inline]
652681
fn first_n_mask(self, n: usize) -> I8x32 {
@@ -755,6 +784,7 @@ unsafe impl NumOps<U8x32> for Avx2Isa {
755784

756785
unsafe impl NumOps<U16x16> for Avx2Isa {
757786
simd_ops_common!(U16x16, I16x16);
787+
simd_int_ops_common!(U16x16);
758788

759789
#[inline]
760790
fn first_n_mask(self, n: usize) -> I16x16 {

rten-simd/src/safe/arch/x86_64/avx512.rs

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::arch::x86_64::{
22
__m512, __m512i, __mmask16, __mmask32, __mmask64, _mm512_add_epi16, _mm512_add_epi32,
3-
_mm512_add_epi8, _mm512_add_ps, _mm512_andnot_ps, _mm512_castsi256_si512,
3+
_mm512_add_epi8, _mm512_add_ps, _mm512_andnot_ps, _mm512_andnot_si512, _mm512_castsi256_si512,
44
_mm512_cmp_epi16_mask, _mm512_cmp_epi32_mask, _mm512_cmp_epu16_mask, _mm512_cmp_ps_mask,
55
_mm512_cmpeq_epi8_mask, _mm512_cmpeq_epu8_mask, _mm512_cmpge_epi8_mask, _mm512_cmpge_epu8_mask,
66
_mm512_cmpgt_epi8_mask, _mm512_cmpgt_epu8_mask, _mm512_cvtepi16_epi32, _mm512_cvtepi16_epi8,
@@ -16,8 +16,8 @@ use std::arch::x86_64::{
1616
_mm512_setzero_si512, _mm512_sllv_epi16, _mm512_sllv_epi32, _mm512_storeu_ps,
1717
_mm512_storeu_si512, _mm512_sub_epi16, _mm512_sub_epi32, _mm512_sub_epi8, _mm512_sub_ps,
1818
_mm512_unpackhi_epi16, _mm512_unpackhi_epi8, _mm512_unpacklo_epi16, _mm512_unpacklo_epi8,
19-
_mm512_xor_ps, _mm_prefetch, _CMP_EQ_OQ, _CMP_GE_OQ, _CMP_GT_OQ, _CMP_LE_OQ, _CMP_LT_OQ,
20-
_MM_CMPINT_EQ, _MM_CMPINT_NLE, _MM_CMPINT_NLT, _MM_HINT_ET0, _MM_HINT_T0,
19+
_mm512_xor_ps, _mm512_xor_si512, _mm_prefetch, _CMP_EQ_OQ, _CMP_GE_OQ, _CMP_GT_OQ, _CMP_LE_OQ,
20+
_CMP_LT_OQ, _MM_CMPINT_EQ, _MM_CMPINT_NLE, _MM_CMPINT_NLT, _MM_HINT_ET0, _MM_HINT_T0,
2121
};
2222
use std::mem::transmute;
2323

@@ -125,6 +125,20 @@ macro_rules! simd_ops_common {
125125
};
126126
}
127127

128+
macro_rules! simd_int_ops_common {
129+
($simd:ty) => {
130+
#[inline]
131+
fn xor(self, x: $simd, y: $simd) -> $simd {
132+
unsafe { _mm512_xor_si512(x.0, y.0) }.into()
133+
}
134+
135+
#[inline]
136+
fn not(self, x: $simd) -> $simd {
137+
unsafe { _mm512_andnot_si512(x.0, _mm512_set1_epi8(-1)) }.into()
138+
}
139+
};
140+
}
141+
128142
unsafe impl NumOps<F32x16> for Avx512Isa {
129143
simd_ops_common!(F32x16, __mmask16);
130144

@@ -183,6 +197,17 @@ unsafe impl NumOps<F32x16> for Avx512Isa {
183197
unsafe { _mm512_max_ps(x.0, y.0) }.into()
184198
}
185199

200+
#[inline]
201+
fn xor(self, x: F32x16, y: F32x16) -> F32x16 {
202+
unsafe { _mm512_xor_ps(x.0, y.0) }.into()
203+
}
204+
205+
#[inline]
206+
fn not(self, x: F32x16) -> F32x16 {
207+
let all_ones: F32x16 = self.splat(f32::from_bits(0xFFFFFFFF));
208+
unsafe { _mm512_andnot_ps(x.0, all_ones.0) }.into()
209+
}
210+
186211
#[inline]
187212
fn splat(self, x: f32) -> F32x16 {
188213
unsafe { _mm512_set1_ps(x) }.into()
@@ -250,6 +275,7 @@ impl FloatOps<F32x16> for Avx512Isa {
250275

251276
unsafe impl NumOps<I32x16> for Avx512Isa {
252277
simd_ops_common!(I32x16, __mmask16);
278+
simd_int_ops_common!(I32x16);
253279

254280
#[inline]
255281
fn add(self, x: I32x16, y: I32x16) -> I32x16 {
@@ -343,6 +369,7 @@ impl NarrowSaturate<I32x16, I16x32> for Avx512Isa {
343369

344370
unsafe impl NumOps<I16x32> for Avx512Isa {
345371
simd_ops_common!(I16x32, __mmask32);
372+
simd_int_ops_common!(I16x32);
346373

347374
#[inline]
348375
fn add(self, x: I16x32, y: I16x32) -> I16x32 {
@@ -463,6 +490,7 @@ impl Interleave<I16x32> for Avx512Isa {
463490

464491
unsafe impl NumOps<I8x64> for Avx512Isa {
465492
simd_ops_common!(I8x64, __mmask64);
493+
simd_int_ops_common!(I8x64);
466494

467495
#[inline]
468496
fn add(self, x: I8x64, y: I8x64) -> I8x64 {
@@ -581,6 +609,7 @@ impl Interleave<I8x64> for Avx512Isa {
581609

582610
unsafe impl NumOps<U8x64> for Avx512Isa {
583611
simd_ops_common!(U8x64, __mmask64);
612+
simd_int_ops_common!(U8x64);
584613

585614
#[inline]
586615
fn add(self, x: U8x64, y: U8x64) -> U8x64 {
@@ -728,6 +757,7 @@ impl Narrow<U16x32> for Avx512Isa {
728757

729758
unsafe impl NumOps<U16x32> for Avx512Isa {
730759
simd_ops_common!(U16x32, __mmask32);
760+
simd_int_ops_common!(U16x32);
731761

732762
#[inline]
733763
fn add(self, x: U16x32, y: U16x32) -> U16x32 {

0 commit comments

Comments
 (0)