Skip to content

Commit 9f3232a

Browse files
authored
Merge pull request #624 from robertknight/extend-init
Add `ExtendInit` utility for updating length of destination buffers for vectorized ops
2 parents 69ecbd8 + f4b203f commit 9f3232a

File tree

4 files changed

+146
-53
lines changed

4 files changed

+146
-53
lines changed

rten-vecmath/src/extend_init.rs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
use std::mem::MaybeUninit;
2+
3+
/// Extend a buffer by incrementally initializing spare capacity.
4+
///
5+
/// This is implemented for [`Vec<T>`], where it provides a safe API to
6+
/// initialize the spare capacity returned by
7+
/// [`spare_capacity_mut`](Vec::spare_capacity_mut).
8+
pub trait ExtendInit {
9+
/// Element type in the buffer.
10+
type Elem;
11+
12+
/// Extend the buffer by initializing a portion of the buffer's spare
13+
/// capacity.
14+
///
15+
/// The function `f` is passed the uninitialized portion of the buffer and
16+
/// should return the portion that it has initialized. `extend_init` can
17+
/// be called many times, until the entire buffer has been initialized.
18+
///
19+
/// # Panics
20+
///
21+
/// Panics if `f` returns a slice that is not a prefix of the slice that
22+
/// was passed to it.
23+
fn extend_init<F: Fn(&mut [MaybeUninit<Self::Elem>]) -> &[Self::Elem]>(&mut self, f: F);
24+
}
25+
26+
impl<T> ExtendInit for Vec<T> {
27+
type Elem = T;
28+
29+
fn extend_init<F: Fn(&mut [MaybeUninit<Self::Elem>]) -> &[Self::Elem]>(&mut self, f: F) {
30+
let cap = self.spare_capacity_mut();
31+
let cap_ptr = cap.as_ptr();
32+
let cap_len = cap.len();
33+
34+
let initialized = f(cap);
35+
assert_eq!(
36+
initialized.as_ptr(),
37+
cap_ptr as *const T,
38+
"returned slice must be a prefix of the input"
39+
);
40+
assert!(
41+
initialized.len() <= cap_len,
42+
"initialized slice length {} is longer than input {}",
43+
initialized.len(),
44+
cap_len
45+
);
46+
let n_init = initialized.len();
47+
48+
// Safety: `n_init` elements from the spare capacity have been initialized.
49+
unsafe { self.set_len(self.len() + n_init) }
50+
}
51+
}
52+
53+
#[cfg(test)]
54+
mod tests {
55+
use std::mem::MaybeUninit;
56+
57+
use super::ExtendInit;
58+
59+
// Implementation of `MaybeUninit::fill` from nightly Rust.
60+
fn fill<T: Copy>(xs: &mut [MaybeUninit<T>], value: T) -> &mut [T] {
61+
for x in xs.iter_mut() {
62+
x.write(value);
63+
}
64+
unsafe { std::mem::transmute::<&mut [MaybeUninit<T>], &mut [T]>(xs) }
65+
}
66+
67+
#[test]
68+
fn test_extend_init() {
69+
let mut vec = Vec::with_capacity(7);
70+
71+
vec.extend_init(|uninit| {
72+
assert_eq!(uninit.len(), 7);
73+
fill(&mut uninit[..3], 1.)
74+
});
75+
assert_eq!(vec.len(), 3);
76+
assert_eq!(vec, &[1., 1., 1.]);
77+
78+
vec.extend_init(|uninit| {
79+
assert_eq!(uninit.len(), 4);
80+
fill(uninit, 2.)
81+
});
82+
assert_eq!(vec.len(), 7);
83+
assert_eq!(vec, &[1., 1., 1., 2., 2., 2., 2.]);
84+
}
85+
}

rten-vecmath/src/lib.rs

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,26 @@
88
//! The operations are implemented by structs which implement the SIMD operation
99
//! traits from [rten-simd](rten_simd). To apply an operation to data, first
1010
//! construct the operation using the struct from this crate, then use a
11-
//! dispatch method from the [`SimdOp`](rten_simd::dispatch::SimdOp) or
12-
//! [`SimdUnaryOp`](rten_simd::dispatch::SimdUnaryOp) traits to execute the
13-
//! operation using the preferred SIMD instruction set.
11+
//! dispatch method from the [`SimdOp`](rten_simd::safe::SimdOp) or
12+
//! [`SimdUnaryOp`](rten_simd::safe::SimdUnaryOp) traits to execute
13+
//! the operation.
1414
//!
15-
//! ## In-place versus mutating operations
15+
//! ## In-place and non in-place operations
1616
//!
1717
//! Some operations support both updating data in place or reading input from
1818
//! one slice and writing to another. For unary operations this is controlled by
19-
//! dispatching with either [`map`](rten_simd::dispatch::SimdUnaryOp::map) or
20-
//! [`map_mut`](rten_simd::dispatch::SimdUnaryOp::map_mut). For other operations
19+
//! dispatching with either [`map`](rten_simd::safe::SimdUnaryOp::map) or
20+
//! [`map_mut`](rten_simd::safe::SimdUnaryOp::map_mut). For other operations
2121
//! this is handled by exposing different constructors for the in-place and
2222
//! mutating cases, such as [`Softmax::new`] and [`Softmax::new_mut`].
2323
//!
24+
//! For operations which use a separate source and destination, the destination
25+
//! is expected to be an uninitialized slice (`[MaybeUninit<T>]`). This allows
26+
//! the caller to control allocation of the buffer and avoid the overhead of
27+
//! initializing elements which the operation will overwrite. The [`ExtendInit`]
28+
//! trait provides a safe API for the common task of filling a new `Vec` with
29+
//! the result of the operation.
30+
//!
2431
//! ## Examples
2532
//!
2633
//! ### Applying a vectorized unary function
@@ -44,6 +51,8 @@
4451
//!
4552
//! ### Applying softmax in place
4653
//!
54+
//! This example applies the softmax function in-place to a mutable slice.
55+
//!
4756
//! ```
4857
//! use rten_simd::safe::SimdOp;
4958
//! use rten_vecmath::Softmax;
@@ -55,20 +64,24 @@
5564
//! ### Applying softmax with separate input and output buffers
5665
//!
5766
//! This example reads data from an input and writes to an uninitialized output
58-
//! buffer. The softmax operation returns the initialized slice.
67+
//! buffer (`&mut [MaybeUninit<f32>]`), obtained from the uninitialized portion
68+
//! of a `Vec<f32>`. To update the length of the `Vec<f32>` after it is
69+
//! initialized, the helper `ExtendInit` trait is used.
5970
//!
6071
//! ```
6172
//! use rten_simd::safe::SimdOp;
62-
//! use rten_vecmath::Softmax;
73+
//! use rten_vecmath::{Softmax, ExtendInit};
6374
//!
6475
//! let data = [1., 0.5, 2.0];
6576
//! let mut output = Vec::with_capacity(data.len());
66-
//! let output_uninit = &mut output.spare_capacity_mut()[..data.len()];
67-
//! let output_init = Softmax::new(&data, output_uninit).dispatch();
68-
//!
69-
//! // Safety: The softmax operation initialized all output elements.
70-
//! let init_len = output_init.len();
71-
//! unsafe { output.set_len(init_len) };
77+
//! output.extend_init(|output_uninit| {
78+
//! // `output_uninit` is the uninitialized part of `output`, as returned by
79+
//! // `output.spare_capacity_mut()`.
80+
//! //
81+
//! // The `dispatch` call initializes it and returns the initialized slice.
82+
//! Softmax::new(&data, output_uninit).dispatch()
83+
//! });
84+
//! assert_eq!(output.len(), 3);
7285
//! ```
7386
//!
7487
//! ### Computing the sum of a list of floats
@@ -96,6 +109,8 @@ mod ulp;
96109
#[cfg(test)]
97110
mod testing;
98111

112+
mod extend_init;
113+
99114
// Unary functions.
100115
pub use erf::{Erf, Gelu};
101116
pub use exp::{Exp, Sigmoid, Silu, Swish};
@@ -107,3 +122,6 @@ pub use min_max::MinMax;
107122
pub use normalize::{Normalize, NormalizeOptions};
108123
pub use softmax::Softmax;
109124
pub use sum::{Sum, SumSquare, SumSquareSub};
125+
126+
// Utilities
127+
pub use extend_init::ExtendInit;

src/ops/norm.rs

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use rten_simd::safe::SimdOp;
55
use rten_tensor::prelude::*;
66
use rten_tensor::{NdTensorView, Tensor, TensorView};
77
use rten_vecmath as vecmath;
8+
use rten_vecmath::ExtendInit;
89

910
use crate::ops::static_dims;
1011
use crate::ops::{resolve_axis, InputList, IntoOpResult, OpError, Operator, Output, OutputList};
@@ -70,23 +71,23 @@ impl Default for NormalizeOptions<'_> {
7071
}
7172
}
7273

73-
enum NormalizeData<'a> {
74+
enum NormalizeData<'src, 'dst> {
7475
/// Read from a source slice and write normalized data to an output slice
7576
/// of the same length.
76-
SrcDest((&'a [f32], &'a mut [MaybeUninit<f32>])),
77+
SrcDest((&'src [f32], &'dst mut [MaybeUninit<f32>])),
7778

7879
/// Normalize elements of a slice in place.
79-
InPlace(&'a mut [f32]),
80+
InPlace(&'dst mut [f32]),
8081
}
8182

82-
impl<'a> From<&'a mut [f32]> for NormalizeData<'a> {
83-
fn from(val: &'a mut [f32]) -> Self {
83+
impl<'dst> From<&'dst mut [f32]> for NormalizeData<'dst, 'dst> {
84+
fn from(val: &'dst mut [f32]) -> Self {
8485
NormalizeData::InPlace(val)
8586
}
8687
}
8788

88-
impl<'a> From<(&'a [f32], &'a mut [MaybeUninit<f32>])> for NormalizeData<'a> {
89-
fn from(val: (&'a [f32], &'a mut [MaybeUninit<f32>])) -> Self {
89+
impl<'src, 'dst> From<(&'src [f32], &'dst mut [MaybeUninit<f32>])> for NormalizeData<'src, 'dst> {
90+
fn from(val: (&'src [f32], &'dst mut [MaybeUninit<f32>])) -> Self {
9091
NormalizeData::SrcDest(val)
9192
}
9293
}
@@ -95,7 +96,10 @@ impl<'a> From<(&'a [f32], &'a mut [MaybeUninit<f32>])> for NormalizeData<'a> {
9596
/// and bias to the result.
9697
///
9798
/// Returns the normalized elements.
98-
fn normalize_slice<'a>(data: NormalizeData<'a>, opts: NormalizeOptions<'a>) -> &'a mut [f32] {
99+
fn normalize_slice<'src, 'dst>(
100+
data: NormalizeData<'src, 'dst>,
101+
opts: NormalizeOptions<'src>,
102+
) -> &'dst mut [f32] {
99103
let NormalizeOptions {
100104
mean_normalize,
101105
epsilon,
@@ -469,29 +473,19 @@ fn layer_normalization_impl(
469473
let scale = scale.to_contiguous_in(pool);
470474
let scale_data = scale.data().unwrap();
471475

472-
let mut n_init = 0;
473-
for (in_chunk, out_chunk) in input
474-
.data()
475-
.unwrap()
476-
.chunks(chunk_size)
477-
.zip(output.spare_capacity_mut().chunks_mut(chunk_size))
478-
{
479-
normalize_slice(
480-
(in_chunk, out_chunk).into(),
481-
NormalizeOptions {
482-
mean_normalize,
483-
epsilon,
484-
element_scale: Some(scale_data),
485-
element_bias: bias_data,
486-
..Default::default()
487-
},
488-
);
489-
n_init += in_chunk.len();
490-
}
491-
492-
// Safety: We initialized `n_init` elements.
493-
unsafe {
494-
output.set_len(n_init);
476+
for in_chunk in input.data().unwrap().chunks(chunk_size) {
477+
output.extend_init(|out_chunk| {
478+
normalize_slice(
479+
(in_chunk, &mut out_chunk[..chunk_size]).into(),
480+
NormalizeOptions {
481+
mean_normalize,
482+
epsilon,
483+
element_scale: Some(scale_data),
484+
element_bias: bias_data,
485+
..Default::default()
486+
},
487+
)
488+
})
495489
}
496490

497491
Ok(Tensor::from_data(input.shape(), output))

src/ops/quantize.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use rten_simd::safe::SimdOp;
44
use rten_tensor::prelude::*;
55
use rten_tensor::{AssumeInit, NdTensor, NdTensorView, Scalar, Tensor, TensorView};
66
use rten_vecmath as vecmath;
7+
use rten_vecmath::ExtendInit;
78

89
use crate::ops::{
910
resolve_axis, DataType, Input, InputList, IntoOpResult, OpError, Operator, Output, OutputList,
@@ -200,14 +201,9 @@ where
200201

201202
if let Some(data) = input.data() {
202203
let mut buf = pool.alloc(data.len());
203-
let buf_data = &mut buf.spare_capacity_mut()[..data.len()];
204-
205-
Quantize::quantize_slice(data, buf_data, inv_scale, zero_point);
206-
207-
// Safety: `quantize_slice` initialized `data.len()` elements
208-
unsafe {
209-
buf.set_len(data.len());
210-
}
204+
buf.extend_init(|buf_data| {
205+
Quantize::quantize_slice(data, buf_data, inv_scale, zero_point)
206+
});
211207
Ok(Tensor::from_data(input.shape(), buf))
212208
} else {
213209
Ok(input.map_in(pool, |x| x.quantize(inv_scale, zero_point)))

0 commit comments

Comments
 (0)