Skip to content
Permalink

Comparing changes

This is a direct comparison between two commits made in this repository or its related repositories. View the default comparison for this range or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: robertknight/rten
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: f3926369611ff9f342da62e9e4e7b161bf9538f6
Choose a base ref
..
head repository: robertknight/rten
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: c4aebf6c09d6dfd5f6c33211d5121220082a7c53
Choose a head ref
Showing with 32 additions and 34 deletions.
  1. +30 −32 rten-simd/src/iter.rs
  2. +2 −2 rten-simd/src/writer.rs
62 changes: 30 additions & 32 deletions rten-simd/src/iter.rs
Original file line number Diff line number Diff line change
@@ -12,50 +12,44 @@ pub trait SimdIterable {
/// If the input length is not divisble by the SIMD vector width, the
/// iterator yields only the full chunks. The tail is accessible via the
/// iterator's [`tail`](Iter::tail) method.
fn simd_iter<S: Simd<Elem = Self::Elem>, O: NumOps<Self::Elem, Simd = S>>(
&self,
ops: O,
) -> Iter<S, O>;
fn simd_iter<O: NumOps<Self::Elem>>(&self, ops: O) -> Iter<Self::Elem, O>;

/// Iterate over SIMD-sized chunks of the input.
///
/// If the input length is not divisble by the SIMD vector width, the final
/// chunk will be padded with zeros.
fn simd_iter_pad<S: Simd<Elem = Self::Elem>, O: NumOps<Self::Elem, Simd = S>>(
fn simd_iter_pad<O: NumOps<Self::Elem>>(
&self,
ops: O,
) -> impl ExactSizeIterator<Item = S>;
) -> impl ExactSizeIterator<Item = O::Simd>;
}

impl<T: Elem> SimdIterable for [T] {
type Elem = T;

#[inline]
fn simd_iter<S: Simd<Elem = T>, O: NumOps<T, Simd = S>>(&self, ops: O) -> Iter<S, O> {
fn simd_iter<O: NumOps<T>>(&self, ops: O) -> Iter<T, O> {
Iter::new(ops, self)
}

#[inline]
fn simd_iter_pad<S: Simd<Elem = T>, O: NumOps<T, Simd = S>>(
&self,
ops: O,
) -> impl ExactSizeIterator<Item = S> {
fn simd_iter_pad<O: NumOps<T>>(&self, ops: O) -> impl ExactSizeIterator<Item = O::Simd> {
IterPad::new(ops, self)
}
}

/// Iterator which yields chunks of a slice as a SIMD vector.
///
/// This type is created by [`SimdIterable::simd_iter`].
pub struct Iter<'a, S: Simd, O: NumOps<S::Elem, Simd = S>> {
pub struct Iter<'a, T: Elem, O: NumOps<T>> {
ops: O,
xs: &'a [S::Elem],
xs: &'a [T],
n_full_chunks: usize,
}

impl<'a, S: Simd, O: NumOps<S::Elem, Simd = S>> Iter<'a, S, O> {
impl<'a, T: Elem, O: NumOps<T>> Iter<'a, T, O> {
#[inline]
fn new(ops: O, xs: &'a [S::Elem]) -> Self {
fn new(ops: O, xs: &'a [T]) -> Self {
let n_full_chunks = xs.len() / ops.len();
Iter {
ops,
@@ -71,7 +65,11 @@ impl<'a, S: Simd, O: NumOps<S::Elem, Simd = S>> Iter<'a, S, O> {
/// multiple of the SIMD vector length, the final vector will be padded with
/// zeros.
#[inline]
pub fn fold<F: FnMut(S, S) -> S>(mut self, mut accum: S, mut fold: F) -> S {
pub fn fold<F: FnMut(O::Simd, O::Simd) -> O::Simd>(
mut self,
mut accum: O::Simd,
mut fold: F,
) -> O::Simd {
for chunk in &mut self {
accum = fold(accum, chunk);
}
@@ -89,9 +87,9 @@ impl<'a, S: Simd, O: NumOps<S::Elem, Simd = S>> Iter<'a, S, O> {
#[inline]
pub fn fold_n<const N: usize>(
mut self,
mut accum: [S; N],
mut fold: impl FnMut([S; N], S) -> [S; N],
) -> [S; N] {
mut accum: [O::Simd; N],
mut fold: impl FnMut([O::Simd; N], O::Simd) -> [O::Simd; N],
) -> [O::Simd; N] {
for chunk in &mut self {
accum = fold(accum, chunk);
}
@@ -112,7 +110,7 @@ impl<'a, S: Simd, O: NumOps<S::Elem, Simd = S>> Iter<'a, S, O> {
/// Elements of the SIMD vector that correspond to positions where the mask
/// is false will be set to zero.
#[inline]
pub fn tail(&self) -> Option<(S, S::Mask)> {
pub fn tail(&self) -> Option<(O::Simd, <O::Simd as Simd>::Mask)> {
let n = self.xs.len();
if n > 0 {
Some(self.ops.load_pad(self.xs))
@@ -122,8 +120,8 @@ impl<'a, S: Simd, O: NumOps<S::Elem, Simd = S>> Iter<'a, S, O> {
}
}

impl<S: Simd, O: NumOps<S::Elem, Simd = S>> Iterator for Iter<'_, S, O> {
type Item = S;
impl<T: Elem, O: NumOps<T>> Iterator for Iter<'_, T, O> {
type Item = O::Simd;

#[inline]
fn next(&mut self) -> Option<Self::Item> {
@@ -146,29 +144,29 @@ impl<S: Simd, O: NumOps<S::Elem, Simd = S>> Iterator for Iter<'_, S, O> {
}
}

impl<S: Simd, O: NumOps<S::Elem, Simd = S>> ExactSizeIterator for Iter<'_, S, O> {}
impl<T: Elem, O: NumOps<T>> ExactSizeIterator for Iter<'_, T, O> {}

impl<S: Simd, O: NumOps<S::Elem, Simd = S>> std::iter::FusedIterator for Iter<'_, S, O> {}
impl<T: Elem, O: NumOps<T>> std::iter::FusedIterator for Iter<'_, T, O> {}

/// Iterator which yields chunks of a slice as a SIMD vector.
///
/// This type is created by [`SimdIterable::simd_iter_pad`].
pub struct IterPad<'a, S: Simd, O: NumOps<S::Elem, Simd = S>> {
iter: Iter<'a, S, O>,
pub struct IterPad<'a, T: Elem, O: NumOps<T>> {
iter: Iter<'a, T, O>,
has_tail: bool,
}

impl<'a, S: Simd, O: NumOps<S::Elem, Simd = S>> IterPad<'a, S, O> {
impl<'a, T: Elem, O: NumOps<T>> IterPad<'a, T, O> {
#[inline]
fn new(ops: O, xs: &'a [S::Elem]) -> Self {
fn new(ops: O, xs: &'a [T]) -> Self {
let iter = Iter::new(ops, xs);
let has_tail = xs.len() % ops.len() != 0;
Self { iter, has_tail }
}
}

impl<S: Simd, O: NumOps<S::Elem, Simd = S>> Iterator for IterPad<'_, S, O> {
type Item = S;
impl<T: Elem, O: NumOps<T>> Iterator for IterPad<'_, T, O> {
type Item = O::Simd;

#[inline]
fn next(&mut self) -> Option<Self::Item> {
@@ -191,9 +189,9 @@ impl<S: Simd, O: NumOps<S::Elem, Simd = S>> Iterator for IterPad<'_, S, O> {
}
}

impl<S: Simd, O: NumOps<S::Elem, Simd = S>> ExactSizeIterator for IterPad<'_, S, O> {}
impl<T: Elem, O: NumOps<T>> ExactSizeIterator for IterPad<'_, T, O> {}

impl<S: Simd, O: NumOps<S::Elem, Simd = S>> std::iter::FusedIterator for IterPad<'_, S, O> {}
impl<T: Elem, O: NumOps<T>> std::iter::FusedIterator for IterPad<'_, T, O> {}

#[cfg(test)]
mod tests {
4 changes: 2 additions & 2 deletions rten-simd/src/writer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::mem::{transmute, MaybeUninit};

use super::{Elem, NumOps, Simd};
use super::{Elem, NumOps};

/// Utility for incrementally filling an uninitialized slice, one SIMD vector
/// at a time.
@@ -19,7 +19,7 @@ impl<'a, T: Elem> SliceWriter<'a, T> {
/// of SIMD vector `xs`.
///
/// Panics if the slice does not have space for `ops.len()` elements.
pub fn write_vec<S: Simd<Elem = T>>(&mut self, ops: impl NumOps<T, Simd = S>, xs: S) {
pub fn write_vec<O: NumOps<T>>(&mut self, ops: O, xs: O::Simd) {
let written = ops.store_uninit(xs, &mut self.buf[self.n_init..]);
self.n_init += written.len();
}