Skip to content

Commit

Permalink
add fill defaults for reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
kaikalii committed Jan 11, 2024
1 parent d10706f commit 97c55c8
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 58 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ This version is not yet released. If you are reading this on the website, then t
- [`insert`](https://uiua.org/docs/insert)
- [`remove`](https://uiua.org/docs/remove)
- Add experimental [`bind`](https://uiua.org/docs/bind) modifier, which binds local values within a function
- [`fill` ``](https://uiua.org/docs/fill) can now be used to specify default accumulators for [`reduce` `/`](https://uiua.org/docs/reduce)
### Interpreter
- The internal byte array type is now used in more places, which should improve performance a bit
- Lots of bug and crash fixes
Expand Down
109 changes: 70 additions & 39 deletions src/algorithm/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ pub fn reduce(env: &mut Uiua) -> UiuaResult {
let xs = env.pop(1)?;

match (f.as_flipped_primitive(env), xs) {
(Some((Primitive::Join, false)), mut xs) if !env.unpack_boxes() => {
(Some((Primitive::Join, false)), mut xs)
if !env.unpack_boxes() && env.box_fill().is_err() =>
{
if xs.rank() < 2 {
env.push(xs);
return Ok(());
Expand All @@ -38,28 +40,41 @@ pub fn reduce(env: &mut Uiua) -> UiuaResult {
}
}
#[cfg(feature = "bytes")]
(Some((prim, flipped)), Value::Byte(bytes)) => env.push(match prim {
Primitive::Add => fast_reduce(bytes.convert(), 0.0, add::num_num),
Primitive::Sub if flipped => fast_reduce(bytes.convert(), 0.0, flip(sub::num_num)),
Primitive::Sub => fast_reduce(bytes.convert(), 0.0, sub::num_num),
Primitive::Mul => fast_reduce(bytes.convert(), 1.0, mul::num_num),
Primitive::Div if flipped => fast_reduce(bytes.convert(), 1.0, flip(div::num_num)),
Primitive::Div => fast_reduce(bytes.convert(), 1.0, div::num_num),
Primitive::Mod if flipped => fast_reduce(bytes.convert(), 1.0, flip(modulus::num_num)),
Primitive::Mod => fast_reduce(bytes.convert(), 1.0, modulus::num_num),
Primitive::Atan if flipped => fast_reduce(bytes.convert(), 0.0, flip(atan2::num_num)),
Primitive::Atan => fast_reduce(bytes.convert(), 0.0, atan2::num_num),
Primitive::Max => fast_reduce(bytes.convert(), f64::NEG_INFINITY, max::num_num),
Primitive::Min => fast_reduce(bytes.convert(), f64::INFINITY, min::num_num),
_ => return generic_reduce(f, Value::Byte(bytes), env),
}),
(Some((prim, flipped)), Value::Byte(bytes)) => {
let fill = env.num_fill().ok();
env.push(match prim {
Primitive::Add => fast_reduce(bytes.convert(), 0.0, fill, add::num_num),
Primitive::Sub if flipped => {
fast_reduce(bytes.convert(), 0.0, fill, flip(sub::num_num))
}
Primitive::Sub => fast_reduce(bytes.convert(), 0.0, fill, sub::num_num),
Primitive::Mul => fast_reduce(bytes.convert(), 1.0, fill, mul::num_num),
Primitive::Div if flipped => {
fast_reduce(bytes.convert(), 1.0, fill, flip(div::num_num))
}
Primitive::Div => fast_reduce(bytes.convert(), 1.0, fill, div::num_num),
Primitive::Mod if flipped => {
fast_reduce(bytes.convert(), 1.0, fill, flip(modulus::num_num))
}
Primitive::Mod => fast_reduce(bytes.convert(), 1.0, fill, modulus::num_num),
Primitive::Atan if flipped => {
fast_reduce(bytes.convert(), 0.0, fill, flip(atan2::num_num))
}
Primitive::Atan => fast_reduce(bytes.convert(), 0.0, fill, atan2::num_num),
Primitive::Max => {
fast_reduce(bytes.convert(), f64::NEG_INFINITY, fill, max::num_num)
}
Primitive::Min => fast_reduce(bytes.convert(), f64::INFINITY, fill, min::num_num),
_ => return generic_reduce(f, Value::Byte(bytes), env),
})
}
(_, xs) => generic_reduce(f, xs, env)?,
}
Ok(())
}

macro_rules! reduce_math {
($fname:ident, $ty:ty, $f:ident) => {
($fname:ident, $ty:ty, $f:ident, $fill:ident) => {
#[allow(clippy::result_large_err)]
fn $fname(
prim: Primitive,
Expand All @@ -70,42 +85,51 @@ macro_rules! reduce_math {
where
$ty: From<f64>,
{
let fill = env.$fill().ok();
env.push(match prim {
Primitive::Add => fast_reduce(xs, 0.0.into(), add::$f),
Primitive::Sub if flipped => fast_reduce(xs, 0.0.into(), flip(sub::$f)),
Primitive::Sub => fast_reduce(xs, 0.0.into(), sub::$f),
Primitive::Mul => fast_reduce(xs, 1.0.into(), mul::$f),
Primitive::Div if flipped => fast_reduce(xs, 1.0.into(), flip(div::$f)),
Primitive::Div => fast_reduce(xs, 1.0.into(), div::$f),
Primitive::Mod if flipped => fast_reduce(xs, 1.0.into(), flip(modulus::$f)),
Primitive::Mod => fast_reduce(xs, 1.0.into(), modulus::$f),
Primitive::Atan if flipped => fast_reduce(xs, 0.0.into(), flip(atan2::$f)),
Primitive::Atan => fast_reduce(xs, 0.0.into(), atan2::$f),
Primitive::Max => fast_reduce(xs, f64::NEG_INFINITY.into(), max::$f),
Primitive::Min => fast_reduce(xs, f64::INFINITY.into(), min::$f),
Primitive::Add => fast_reduce(xs, 0.0.into(), fill, add::$f),
Primitive::Sub if flipped => fast_reduce(xs, 0.0.into(), fill, flip(sub::$f)),
Primitive::Sub => fast_reduce(xs, 0.0.into(), fill, sub::$f),
Primitive::Mul => fast_reduce(xs, 1.0.into(), fill, mul::$f),
Primitive::Div if flipped => fast_reduce(xs, 1.0.into(), fill, flip(div::$f)),
Primitive::Div => fast_reduce(xs, 1.0.into(), fill, div::$f),
Primitive::Mod if flipped => fast_reduce(xs, 1.0.into(), fill, flip(modulus::$f)),
Primitive::Mod => fast_reduce(xs, 1.0.into(), fill, modulus::$f),
Primitive::Atan if flipped => fast_reduce(xs, 0.0.into(), fill, flip(atan2::$f)),
Primitive::Atan => fast_reduce(xs, 0.0.into(), fill, atan2::$f),
Primitive::Max => fast_reduce(xs, f64::NEG_INFINITY.into(), fill, max::$f),
Primitive::Min => fast_reduce(xs, f64::INFINITY.into(), fill, min::$f),
_ => return Err(xs),
});
Ok(())
}
};
}

reduce_math!(reduce_nums, f64, num_num);
reduce_math!(reduce_nums, f64, num_num, num_fill);
reduce_math!(reduce_coms, crate::Complex, com_x, complex_fill);

reduce_math!(reduce_coms, crate::Complex, com_x);

pub fn fast_reduce<T>(mut arr: Array<T>, identity: T, f: impl Fn(T, T) -> T) -> Array<T>
pub fn fast_reduce<T>(
mut arr: Array<T>,
identity: T,
default: Option<T>,
f: impl Fn(T, T) -> T,
) -> Array<T>
where
T: ArrayValue + Copy,
{
match arr.shape.len() {
0 => arr,
1 => {
let data = arr.data.as_mut_slice();
let reduced = data.iter().copied().reduce(f);
let reduced = default.into_iter().chain(data.iter().copied()).reduce(f);
if let Some(reduced) = reduced {
data[0] = reduced;
arr.data.truncate(1);
if data.is_empty() {
arr.data.extend(Some(reduced));
} else {
data[0] = reduced;
arr.data.truncate(1);
}
} else {
arr.data.extend(Some(identity));
}
Expand All @@ -126,6 +150,11 @@ where
}
let sliced = arr.data.as_mut_slice();
let (acc, rest) = sliced.split_at_mut(row_len);
if let Some(default) = default {
for acc in &mut *acc {
*acc = f(default, *acc);
}
}
rest.chunks_exact(row_len).fold(acc, |acc, row| {
for (a, b) in acc.iter_mut().zip(row) {
*a = f(*a, *b);
Expand All @@ -150,9 +179,11 @@ fn generic_reduce(f: Function, xs: Value, env: &mut Uiua) -> UiuaResult {
}
(2, 1) => {
let mut rows = xs.into_rows();
let mut acc = rows.next().ok_or_else(|| {
env.error(format!("Cannot {} empty array", Primitive::Reduce.format()))
})?;
let mut acc = (env.box_fill().ok().map(|b| b.0))
.or_else(|| rows.next())
.ok_or_else(|| {
env.error(format!("Cannot {} empty array", Primitive::Reduce.format()))
})?;
if env.unpack_boxes() {
acc.unpack();
}
Expand Down
6 changes: 6 additions & 0 deletions src/primitive/defs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,12 @@ primitive!(
/// ex: /↧ []
/// ex: /∠ []
/// ex! /⊡ []
///
/// A default value can be set with [fill].
/// ex: /↥ []
/// ex: ⬚5/↥ []
/// ex: /↥ [1 2 3]
/// ex: ⬚5/↥ [1 2 3]
(1[1], Reduce, AggregatingModifier, ("reduce", '/')),
/// Apply a function to aggregate arrays
///
Expand Down
30 changes: 11 additions & 19 deletions src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ enum Fill {
Complex(Complex),
Char(char),
Box(Boxed),
None,
}

#[derive(Clone)]
Expand Down Expand Up @@ -1157,27 +1156,20 @@ code:
/// Do something with the fill context set
pub(crate) fn with_fill(
&mut self,
fill: Value,
mut fill: Value,
in_ctx: impl FnOnce(&mut Self) -> UiuaResult,
) -> UiuaResult {
if fill.shape() == [0] {
self.rt.fill_stack.push(Fill::None)
} else {
if !fill.shape().is_empty() {
return Err(self.error(format!(
"Fill values must be scalar or an empty list, but its shape is {}",
fill.shape()
)));
}
self.rt.fill_stack.push(match fill {
Value::Num(n) => Fill::Num(n.data.into_iter().next().unwrap()),
#[cfg(feature = "bytes")]
Value::Byte(b) => Fill::Num(b.data.into_iter().next().unwrap() as f64),
Value::Char(c) => Fill::Char(c.data.into_iter().next().unwrap()),
Value::Box(b) => Fill::Box(b.data.into_iter().next().unwrap()),
Value::Complex(c) => Fill::Complex(c.data.into_iter().next().unwrap()),
});
if !fill.shape().is_empty() {
fill = Array::from(Boxed(fill)).into();
}
self.rt.fill_stack.push(match fill {
Value::Num(n) => Fill::Num(n.data.into_iter().next().unwrap()),
#[cfg(feature = "bytes")]
Value::Byte(b) => Fill::Num(b.data.into_iter().next().unwrap() as f64),
Value::Char(c) => Fill::Char(c.data.into_iter().next().unwrap()),
Value::Box(b) => Fill::Box(b.data.into_iter().next().unwrap()),
Value::Complex(c) => Fill::Complex(c.data.into_iter().next().unwrap()),
});
let res = in_ctx(self);
self.rt.fill_stack.pop();
res
Expand Down
4 changes: 4 additions & 0 deletions tests/loops.ua
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
⍤⊃⋅∘≍ [1 1 1] /× ↯0_3 0
⍤⊃⋅∘≍ ⊃/-(⊢⇌\-) [1 2 3 4 5]
⍤⊃⋅∘≍ ⊃/(-:)(⊢⇌\(-:)) [1 2 3 4 5]
⍤⊃⋅∘≍ ¯∞ /↥ []
⍤⊃⋅∘≍ 5 ⬚5/↥ []
⍤⊃⋅∘≍ 1 /↥ [1]
⍤⊃⋅∘≍ 5 ⬚5/↥ [1]

⍤⊃⋅∘≍ [1 3 6 10] \+[1 2 3 4]
⍤⊃⋅∘≍ [1_0_0 1_2_0 1_2_3] ⬚0\⊂ [1 2 3]
Expand Down

0 comments on commit 97c55c8

Please sign in to comment.