Skip to content

Commit b373abd

Browse files
authored
Merge pull request #626 from robertknight/interleave-i8-i16
Implement i8 and i16 interleave in new SIMD API
2 parents 9c1b36f + e043a3e commit b373abd

File tree

7 files changed

+283
-37
lines changed

7 files changed

+283
-37
lines changed

rten-simd/src/safe.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ pub mod isa {
156156
pub use dispatch::{SimdOp, SimdUnaryOp};
157157
pub use iter::{Iter, SimdIterable};
158158
pub use vec::{
159-
Elem, Extend, FloatOps, Isa, Mask, MaskOps, NarrowSaturate, NumOps, SignedIntOps, Simd,
159+
Elem, Extend, FloatOps, Interleave, Isa, Mask, MaskOps, NarrowSaturate, NumOps, SignedIntOps,
160+
Simd,
160161
};
161162
pub use writer::SliceWriter;
162163

rten-simd/src/safe/arch/aarch64.rs

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@ use std::arch::aarch64::{
1212
vmulq_s8, vmulq_u16, vmulq_u8, vnegq_f32, vnegq_s16, vnegq_s32, vnegq_s8, vqmovn_s32,
1313
vqmovun_s16, vshlq_n_s16, vshlq_n_s32, vshlq_n_s8, vst1q_f32, vst1q_s16, vst1q_s32, vst1q_s8,
1414
vst1q_u16, vst1q_u8, vsubq_f32, vsubq_s16, vsubq_s32, vsubq_s8, vsubq_u16, vsubq_u8,
15+
vzip1q_s16, vzip1q_s8, vzip2q_s16, vzip2q_s8,
1516
};
1617
use std::mem::transmute;
1718

1819
use crate::safe::{
19-
Extend, FloatOps, Isa, Mask, MaskOps, NarrowSaturate, NumOps, SignedIntOps, Simd,
20+
Extend, FloatOps, Interleave, Isa, Mask, MaskOps, NarrowSaturate, NumOps, SignedIntOps, Simd,
2021
};
2122

2223
#[derive(Copy, Clone)]
@@ -52,11 +53,15 @@ unsafe impl Isa for ArmNeonIsa {
5253
self,
5354
) -> impl SignedIntOps<Self::I16>
5455
+ NarrowSaturate<Self::I16, Self::U8>
55-
+ Extend<Self::I16, Output = Self::I32> {
56+
+ Extend<Self::I16, Output = Self::I32>
57+
+ Interleave<Self::I16> {
5658
self
5759
}
5860

59-
fn i8(self) -> impl SignedIntOps<Self::I8> + Extend<Self::I8, Output = Self::I16> {
61+
fn i8(
62+
self,
63+
) -> impl SignedIntOps<Self::I8> + Extend<Self::I8, Output = Self::I16> + Interleave<Self::I8>
64+
{
6065
self
6166
}
6267

@@ -424,6 +429,18 @@ impl Extend<int16x8_t> for ArmNeonIsa {
424429
}
425430
}
426431

432+
impl Interleave<int16x8_t> for ArmNeonIsa {
433+
#[inline]
434+
fn interleave_low(self, a: int16x8_t, b: int16x8_t) -> int16x8_t {
435+
unsafe { vzip1q_s16(a, b) }
436+
}
437+
438+
#[inline]
439+
fn interleave_high(self, a: int16x8_t, b: int16x8_t) -> int16x8_t {
440+
unsafe { vzip2q_s16(a, b) }
441+
}
442+
}
443+
427444
unsafe impl NumOps<int8x16_t> for ArmNeonIsa {
428445
simd_ops_common!(int8x16_t, uint8x16_t);
429446

@@ -519,6 +536,18 @@ impl Extend<int8x16_t> for ArmNeonIsa {
519536
}
520537
}
521538

539+
impl Interleave<int8x16_t> for ArmNeonIsa {
540+
#[inline]
541+
fn interleave_low(self, a: int8x16_t, b: int8x16_t) -> int8x16_t {
542+
unsafe { vzip1q_s8(a, b) }
543+
}
544+
545+
#[inline]
546+
fn interleave_high(self, a: int8x16_t, b: int8x16_t) -> int8x16_t {
547+
unsafe { vzip2q_s8(a, b) }
548+
}
549+
}
550+
522551
unsafe impl NumOps<uint8x16_t> for ArmNeonIsa {
523552
simd_ops_common!(uint8x16_t, uint8x16_t);
524553

rten-simd/src/safe/arch/generic.rs

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::array;
22
use std::mem::transmute;
33

44
use crate::safe::{
5-
Extend, FloatOps, Isa, Mask, MaskOps, NarrowSaturate, NumOps, SignedIntOps, Simd,
5+
Extend, FloatOps, Interleave, Isa, Mask, MaskOps, NarrowSaturate, NumOps, SignedIntOps, Simd,
66
};
77

88
// Size of SIMD vector in 32-bit lanes.
@@ -74,11 +74,15 @@ unsafe impl Isa for GenericIsa {
7474
self,
7575
) -> impl SignedIntOps<Self::I16>
7676
+ NarrowSaturate<Self::I16, Self::U8>
77-
+ Extend<Self::I16, Output = Self::I32> {
77+
+ Extend<Self::I16, Output = Self::I32>
78+
+ Interleave<Self::I16> {
7879
self
7980
}
8081

81-
fn i8(self) -> impl SignedIntOps<Self::I8> + Extend<Self::I8, Output = Self::I16> {
82+
fn i8(
83+
self,
84+
) -> impl SignedIntOps<Self::I8> + Extend<Self::I8, Output = Self::I16> + Interleave<Self::I8>
85+
{
8286
self
8387
}
8488

@@ -302,6 +306,30 @@ macro_rules! impl_extend {
302306
impl_extend!(I8x16, I16x8);
303307
impl_extend!(I16x8, I32x4);
304308

309+
macro_rules! impl_interleave {
310+
($simd:ty) => {
311+
impl Interleave<$simd> for GenericIsa {
312+
fn interleave_low(self, a: $simd, b: $simd) -> $simd {
313+
array::from_fn(|i| if i % 2 == 0 { a.0[i / 2] } else { b.0[i / 2] }).into()
314+
}
315+
316+
fn interleave_high(self, a: $simd, b: $simd) -> $simd {
317+
let start = a.0.len() / 2;
318+
array::from_fn(|i| {
319+
if i % 2 == 0 {
320+
a.0[start + i / 2]
321+
} else {
322+
b.0[start + i / 2]
323+
}
324+
})
325+
.into()
326+
}
327+
}
328+
};
329+
}
330+
impl_interleave!(I8x16);
331+
impl_interleave!(I16x8);
332+
305333
macro_rules! impl_simd_unsigned_int_ops {
306334
($simd:ident, $elem:ty, $len:expr, $mask:ident) => {
307335
unsafe impl NumOps<$simd> for GenericIsa {

rten-simd/src/safe/arch/wasm32.rs

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::arch::wasm32::{
33
f32x4_lt, f32x4_max, f32x4_min, f32x4_mul, f32x4_nearest, f32x4_neg, f32x4_splat, f32x4_sub,
44
i16x8_add, i16x8_eq, i16x8_extend_high_i8x16, i16x8_extend_low_i8x16, i16x8_extmul_high_i8x16,
55
i16x8_extmul_low_i8x16, i16x8_ge, i16x8_gt, i16x8_mul, i16x8_narrow_i32x4, i16x8_neg,
6-
i16x8_shl, i16x8_splat, i16x8_sub, i32x4_add, i32x4_eq, i32x4_extend_high_i16x8,
6+
i16x8_shl, i16x8_shuffle, i16x8_splat, i16x8_sub, i32x4_add, i32x4_eq, i32x4_extend_high_i16x8,
77
i32x4_extend_low_i16x8, i32x4_ge, i32x4_gt, i32x4_mul, i32x4_neg, i32x4_shl, i32x4_shuffle,
88
i32x4_splat, i32x4_sub, i32x4_trunc_sat_f32x4, i8x16_add, i8x16_eq, i8x16_ge, i8x16_gt,
99
i8x16_neg, i8x16_shl, i8x16_shuffle, i8x16_splat, i8x16_sub, u16x8_add, u16x8_eq,
@@ -15,7 +15,7 @@ use std::mem::transmute;
1515

1616
use super::{lanes, simd_type};
1717
use crate::safe::{
18-
Extend, FloatOps, Isa, Mask, MaskOps, NarrowSaturate, NumOps, SignedIntOps, Simd,
18+
Extend, FloatOps, Interleave, Isa, Mask, MaskOps, NarrowSaturate, NumOps, SignedIntOps, Simd,
1919
};
2020

2121
simd_type!(F32x4, v128, f32, M32, Wasm32Isa);
@@ -59,11 +59,15 @@ unsafe impl Isa for Wasm32Isa {
5959
self,
6060
) -> impl SignedIntOps<Self::I16>
6161
+ NarrowSaturate<Self::I16, Self::U8>
62-
+ Extend<Self::I16, Output = Self::I32> {
62+
+ Extend<Self::I16, Output = Self::I32>
63+
+ Interleave<Self::I16> {
6364
self
6465
}
6566

66-
fn i8(self) -> impl SignedIntOps<Self::I8> + Extend<Self::I8, Output = Self::I16> {
67+
fn i8(
68+
self,
69+
) -> impl SignedIntOps<Self::I8> + Extend<Self::I8, Output = Self::I16> + Interleave<Self::I8>
70+
{
6771
self
6872
}
6973

@@ -372,6 +376,18 @@ impl Extend<I16x8> for Wasm32Isa {
372376
}
373377
}
374378

379+
impl Interleave<I16x8> for Wasm32Isa {
380+
#[inline]
381+
fn interleave_low(self, a: I16x8, b: I16x8) -> I16x8 {
382+
i16x8_shuffle::<0, 8, 1, 9, 2, 10, 3, 11>(a.0, b.0).into()
383+
}
384+
385+
#[inline]
386+
fn interleave_high(self, a: I16x8, b: I16x8) -> I16x8 {
387+
i16x8_shuffle::<4, 12, 5, 13, 6, 14, 7, 15>(a.0, b.0).into()
388+
}
389+
}
390+
375391
impl NarrowSaturate<I16x8, U8x16> for Wasm32Isa {
376392
#[inline]
377393
fn narrow_saturate(self, low: I16x8, high: I16x8) -> U8x16 {
@@ -450,6 +466,19 @@ impl Extend<I8x16> for Wasm32Isa {
450466
}
451467
}
452468

469+
impl Interleave<I8x16> for Wasm32Isa {
470+
#[inline]
471+
fn interleave_low(self, a: I8x16, b: I8x16) -> I8x16 {
472+
i8x16_shuffle::<0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23>(a.0, b.0).into()
473+
}
474+
475+
#[inline]
476+
fn interleave_high(self, a: I8x16, b: I8x16) -> I8x16 {
477+
i8x16_shuffle::<8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31>(a.0, b.0)
478+
.into()
479+
}
480+
}
481+
453482
unsafe impl NumOps<U8x16> for Wasm32Isa {
454483
simd_ops_common!(U8x16, M8, i8);
455484

rten-simd/src/safe/arch/x86_64/avx2.rs

Lines changed: 69 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,27 @@ use std::arch::x86_64::{
55
_mm256_cmpeq_epi32, _mm256_cmpeq_epi8, _mm256_cmpgt_epi16, _mm256_cmpgt_epi32,
66
_mm256_cmpgt_epi8, _mm256_cvtepi16_epi32, _mm256_cvtepi8_epi16, _mm256_cvtepu8_epi16,
77
_mm256_cvtps_epi32, _mm256_cvttps_epi32, _mm256_div_ps, _mm256_extractf128_ps,
8-
_mm256_extracti128_si256, _mm256_fmadd_ps, _mm256_loadu_ps, _mm256_loadu_si256,
9-
_mm256_maskload_epi32, _mm256_maskload_ps, _mm256_maskstore_epi32, _mm256_maskstore_ps,
10-
_mm256_max_ps, _mm256_min_ps, _mm256_movemask_epi8, _mm256_mul_ps, _mm256_mullo_epi16,
11-
_mm256_mullo_epi32, _mm256_or_si256, _mm256_packs_epi32, _mm256_packus_epi16,
12-
_mm256_permute4x64_epi64, _mm256_set1_epi16, _mm256_set1_epi32, _mm256_set1_epi8,
13-
_mm256_set1_ps, _mm256_setr_m128i, _mm256_setzero_si256, _mm256_slli_epi16, _mm256_slli_epi32,
14-
_mm256_storeu_ps, _mm256_storeu_si256, _mm256_sub_epi16, _mm256_sub_epi32, _mm256_sub_epi8,
15-
_mm256_sub_ps, _mm256_xor_ps, _mm256_xor_si256, _mm_add_ps, _mm_cvtss_f32, _mm_movehl_ps,
16-
_mm_prefetch, _mm_setr_epi8, _mm_shuffle_epi8, _mm_shuffle_ps, _mm_unpacklo_epi64, _CMP_EQ_OQ,
17-
_CMP_GE_OQ, _CMP_GT_OQ, _CMP_LE_OQ, _CMP_LT_OQ, _MM_HINT_ET0, _MM_HINT_T0,
8+
_mm256_extracti128_si256, _mm256_fmadd_ps, _mm256_insertf128_si256, _mm256_loadu_ps,
9+
_mm256_loadu_si256, _mm256_maskload_epi32, _mm256_maskload_ps, _mm256_maskstore_epi32,
10+
_mm256_maskstore_ps, _mm256_max_ps, _mm256_min_ps, _mm256_movemask_epi8, _mm256_mul_ps,
11+
_mm256_mullo_epi16, _mm256_mullo_epi32, _mm256_or_si256, _mm256_packs_epi32,
12+
_mm256_packus_epi16, _mm256_permute2x128_si256, _mm256_permute4x64_epi64, _mm256_set1_epi16,
13+
_mm256_set1_epi32, _mm256_set1_epi8, _mm256_set1_ps, _mm256_setr_m128i, _mm256_setzero_si256,
14+
_mm256_slli_epi16, _mm256_slli_epi32, _mm256_storeu_ps, _mm256_storeu_si256, _mm256_sub_epi16,
15+
_mm256_sub_epi32, _mm256_sub_epi8, _mm256_sub_ps, _mm256_unpackhi_epi16, _mm256_unpackhi_epi8,
16+
_mm256_unpacklo_epi16, _mm256_unpacklo_epi8, _mm256_xor_ps, _mm256_xor_si256, _mm_add_ps,
17+
_mm_cvtss_f32, _mm_movehl_ps, _mm_prefetch, _mm_setr_epi8, _mm_shuffle_epi8, _mm_shuffle_ps,
18+
_mm_unpacklo_epi64, _CMP_EQ_OQ, _CMP_GE_OQ, _CMP_GT_OQ, _CMP_LE_OQ, _CMP_LT_OQ, _MM_HINT_ET0,
19+
_MM_HINT_T0,
1820
};
1921
use std::is_x86_feature_detected;
2022
use std::mem::transmute;
2123

2224
use super::super::{lanes, simd_type};
2325
use crate::safe::vec::{Extend, Narrow};
24-
use crate::safe::{FloatOps, Isa, Mask, MaskOps, NarrowSaturate, NumOps, SignedIntOps, Simd};
26+
use crate::safe::{
27+
FloatOps, Interleave, Isa, Mask, MaskOps, NarrowSaturate, NumOps, SignedIntOps, Simd,
28+
};
2529

2630
simd_type!(F32x8, __m256, f32, F32x8, Avx2Isa);
2731
simd_type!(I32x8, __m256i, i32, I32x8, Avx2Isa);
@@ -67,11 +71,15 @@ unsafe impl Isa for Avx2Isa {
6771
self,
6872
) -> impl SignedIntOps<Self::I16>
6973
+ NarrowSaturate<Self::I16, Self::U8>
70-
+ Extend<Self::I16, Output = Self::I32> {
74+
+ Extend<Self::I16, Output = Self::I32>
75+
+ Interleave<Self::I16> {
7176
self
7277
}
7378

74-
fn i8(self) -> impl SignedIntOps<Self::I8> + Extend<Self::I8, Output = Self::I16> {
79+
fn i8(
80+
self,
81+
) -> impl SignedIntOps<Self::I8> + Extend<Self::I8, Output = Self::I16> + Interleave<Self::I8>
82+
{
7583
self
7684
}
7785

@@ -472,6 +480,30 @@ impl NarrowSaturate<I16x16, U8x32> for Avx2Isa {
472480
}
473481
}
474482

483+
impl Interleave<I16x16> for Avx2Isa {
484+
#[inline]
485+
fn interleave_low(self, a: I16x16, b: I16x16) -> I16x16 {
486+
unsafe {
487+
// AB{N} = Interleaved Nth 64-bit block.
488+
let lo = _mm256_unpacklo_epi16(a.0, b.0); // AB0 AB2
489+
let hi = _mm256_unpackhi_epi16(a.0, b.0); // AB1 AB3
490+
_mm256_insertf128_si256(lo, _mm256_castsi256_si128(hi), 1) // AB0 AB1
491+
}
492+
.into()
493+
}
494+
495+
#[inline]
496+
fn interleave_high(self, a: I16x16, b: I16x16) -> I16x16 {
497+
unsafe {
498+
// AB{N} = Interleaved Nth 64-bit block.
499+
let lo = _mm256_unpacklo_epi16(a.0, b.0); // AB0 AB2
500+
let hi = _mm256_unpackhi_epi16(a.0, b.0); // AB1 AB3
501+
_mm256_permute2x128_si256(lo, hi, 0x31) // AB2 AB3
502+
}
503+
.into()
504+
}
505+
}
506+
475507
unsafe impl NumOps<I8x32> for Avx2Isa {
476508
simd_ops_common!(I8x32, I8x32);
477509

@@ -589,6 +621,30 @@ impl SignedIntOps<I8x32> for Avx2Isa {
589621
}
590622
}
591623

624+
impl Interleave<I8x32> for Avx2Isa {
625+
#[inline]
626+
fn interleave_low(self, a: I8x32, b: I8x32) -> I8x32 {
627+
unsafe {
628+
// AB{N} = Interleaved Nth 64-bit block.
629+
let lo = _mm256_unpacklo_epi8(a.0, b.0); // AB0 AB2
630+
let hi = _mm256_unpackhi_epi8(a.0, b.0); // AB1 AB3
631+
_mm256_insertf128_si256(lo, _mm256_castsi256_si128(hi), 1) // AB0 AB1
632+
}
633+
.into()
634+
}
635+
636+
#[inline]
637+
fn interleave_high(self, a: I8x32, b: I8x32) -> I8x32 {
638+
unsafe {
639+
// AB{N} = Interleaved Nth 64-bit block.
640+
let lo = _mm256_unpacklo_epi8(a.0, b.0); // AB0 AB2
641+
let hi = _mm256_unpackhi_epi8(a.0, b.0); // AB1 AB3
642+
_mm256_permute2x128_si256(lo, hi, 0x31) // AB2 AB3
643+
}
644+
.into()
645+
}
646+
}
647+
592648
unsafe impl NumOps<U8x32> for Avx2Isa {
593649
simd_ops_common!(U8x32, I8x32);
594650

0 commit comments

Comments
 (0)