From 6c8b0d5e3d0a9968a8a0b747f792abe332823fee Mon Sep 17 00:00:00 2001 From: Kai Schmidt Date: Thu, 21 Dec 2023 15:30:44 -0800 Subject: [PATCH] change how new shape is used --- site/src/editor/utils.rs | 2 +- src/algorithm/dyadic/combine.rs | 18 ++++++++--------- src/algorithm/dyadic/mod.rs | 18 ++++++++--------- src/algorithm/dyadic/structure.rs | 6 +++--- src/algorithm/loops.rs | 4 ++-- src/algorithm/mod.rs | 12 ++++++------ src/algorithm/monadic.rs | 4 ++-- src/algorithm/pervade.rs | 2 +- src/algorithm/table.rs | 2 +- src/algorithm/zip.rs | 14 +++++++------- src/array.rs | 6 +----- src/function.rs | 2 +- src/run.rs | 2 +- src/shape.rs | 32 ++++++++++++++++++++++++++++++- src/sys.rs | 6 +++--- src/value.rs | 24 +++++++++-------------- 16 files changed, 87 insertions(+), 67 deletions(-) diff --git a/site/src/editor/utils.rs b/site/src/editor/utils.rs index 8439269cb..0e280472d 100644 --- a/site/src/editor/utils.rs +++ b/site/src/editor/utils.rs @@ -810,7 +810,7 @@ fn run_code_single(code: &str) -> Vec { } // Try to convert the value to a gif if let Ok(bytes) = value_to_gif_bytes(&value, 16.0) { - match value.shape() { + match value.shape().dims() { &[f, h, w] | &[f, h, w, _] if h >= MIN_AUTO_IMAGE_DIM && w >= MIN_AUTO_IMAGE_DIM && f >= 5 => { diff --git a/src/algorithm/dyadic/combine.rs b/src/algorithm/dyadic/combine.rs index d137d16f7..53cac17fe 100644 --- a/src/algorithm/dyadic/combine.rs +++ b/src/algorithm/dyadic/combine.rs @@ -233,11 +233,11 @@ impl Array { other.rank() )))); } - if self.shape() != &other.shape()[1..] { + if self.shape() != other.shape()[1..] { return Err(C::fill_error(ctx.error(format!( "Cannot join arrays of shapes {} and {}{e}", - self.format_shape(), - other.format_shape() + self.shape(), + other.shape() )))); } other.shape @@ -249,7 +249,7 @@ impl Array { self } Ordering::Greater => { - if other.shape() == [0] { + if other.shape() == 0 { return Ok(self); } self.append(other, ctx)?; @@ -274,8 +274,8 @@ impl Array { Err(e) if self.shape[1..] != other.shape[1..] => { return Err(C::fill_error(ctx.error(format!( "Cannot join arrays of shapes {} and {}. {e}", - self.format_shape(), - other.format_shape() + self.shape(), + other.shape() )))); } _ => (), @@ -313,7 +313,7 @@ impl Array { if &self.shape()[1..] != other.shape() { return Err(C::fill_error(ctx.error(format!( "Cannot add shape {} row to array with shape {} rows{e}", - other.format_shape(), + other.shape(), FormatShape(&self.shape()[1..]), )))); } @@ -445,8 +445,8 @@ impl Array { Err(e) => { return Err(C::fill_error(ctx.error(format!( "Cannot couple arrays with shapes {} and {}{e}", - self.format_shape(), - other.format_shape() + self.shape(), + other.shape() )))); } } diff --git a/src/algorithm/dyadic/mod.rs b/src/algorithm/dyadic/mod.rs index bad91eca7..d2b5a82f2 100644 --- a/src/algorithm/dyadic/mod.rs +++ b/src/algorithm/dyadic/mod.rs @@ -103,8 +103,8 @@ impl Array { return Err(ctx.error(format!( "Cannot combine arrays with shapes {} and {} \ because shape prefixes {} and {} are not compatible", - a.format_shape(), - b.format_shape(), + a.shape(), + b.shape(), FormatShape(a_prefix), FormatShape(b_prefix) ))); @@ -195,7 +195,7 @@ impl Value { "Cannot unreshape array because its old shape was {}, \ but its new shape is {}, which has a different number of elements", FormatShape(&orig_shape), - self.format_shape() + self.shape() ))) } } @@ -520,7 +520,7 @@ impl Array { Err(e) => { return Err(env.error(format!( "Cannot keep array with shape {} with array of shape {}{e}", - self.format_shape(), + self.shape(), FormatShape(&[amount.len()]) ))); } @@ -532,14 +532,14 @@ impl Array { "Cannot keep array with shape {} with array of shape {}.\ A fill value is available, but keep can only been filled\ if there are fewer counts than rows.", - self.format_shape(), + self.shape(), FormatShape(amount.as_ref()) ) } Err(e) => { format!( "Cannot keep array with shape {} with array of shape {}{e}", - self.format_shape(), + self.shape(), FormatShape(amount.as_ref()) ) } @@ -621,8 +621,8 @@ impl Array { return Err(env.error(format!( "Kept array's shape was changed from {} to {}, \ so the keep cannot be inverted", - into_row.format_shape(), - new_row.format_shape() + into_row.shape(), + new_row.shape() ))); } new_rows.push(new_row); @@ -795,7 +795,7 @@ impl Array { if isize_spec.len() > self.shape.len() { return Err(env.error(format!( "Window size {isize_spec:?} has too many axes for shape {}", - self.format_shape() + self.shape() ))); } let mut size_spec = Vec::with_capacity(isize_spec.len()); diff --git a/src/algorithm/dyadic/structure.rs b/src/algorithm/dyadic/structure.rs index 2c55442a3..9afdb31ad 100644 --- a/src/algorithm/dyadic/structure.rs +++ b/src/algorithm/dyadic/structure.rs @@ -150,7 +150,7 @@ impl Array { return Err(env .error(format!( "Index {i} is out of bounds of length {s} (dimension {d}) in shape {}{e}", - self.format_shape() + self.shape() )) .fill()); } @@ -194,7 +194,7 @@ impl Array { "Attempted to undo pick, but the shape of the selected \ array changed from {} to {}", FormatShape(&expected_shape), - self.format_shape() + self.shape() ))); } let index_row_len: usize = index_shape[1..].iter().product(); @@ -216,7 +216,7 @@ impl Array { "Attempted to undo pick, but the shape of the selected \ array changed from {} to {}", FormatShape(expected_shape), - self.format_shape() + self.shape() ))); } let mut start = 0; diff --git a/src/algorithm/loops.rs b/src/algorithm/loops.rs index 4d9caedeb..44bb4d1b5 100644 --- a/src/algorithm/loops.rs +++ b/src/algorithm/loops.rs @@ -242,7 +242,7 @@ impl Array { if markers.len() != self.row_count() { return Err(env.error(format!( "Cannot partition array of shape {} with markers of length {}", - self.format_shape(), + self.shape(), markers.len() ))); } @@ -291,7 +291,7 @@ impl Array { if indices.len() != self.row_count() { return Err(env.error(format!( "Cannot group array of shape {} with indices of length {}", - self.format_shape(), + self.shape(), indices.len() ))); } diff --git a/src/algorithm/mod.rs b/src/algorithm/mod.rs index e19593100..d956e24db 100644 --- a/src/algorithm/mod.rs +++ b/src/algorithm/mod.rs @@ -227,7 +227,7 @@ where }, Ordering::Equal => { let target_shape = max_shape(a.shape(), b.shape()); - if a.shape() != &*target_shape { + if a.shape() != *target_shape { match ctx.fill() { Ok(fill) => { a.fill_to_shape(&target_shape, fill); @@ -236,7 +236,7 @@ where Err(e) => fill_error = Some(e), } } - if b.shape() != &*target_shape { + if b.shape() != *target_shape { match ctx.fill() { Ok(fill) => { b.fill_to_shape(&target_shape, fill); @@ -253,8 +253,8 @@ where if let Some(e) = fill_error { return Err(C::fill_error(ctx.error(format!( "Shapes {} and {} do not match{e}", - a.format_shape(), - b.format_shape(), + a.shape(), + b.shape(), )))); } @@ -340,9 +340,9 @@ pub fn switch(count: usize, sig: Signature, env: &mut Uiua) -> UiuaResult { return Err(env.error(format!( "The function's select's shape {} is not compatible \ with the argument {}'s shape {}", - selector.format_shape(), + selector.shape(), i + 1, - arg.format_shape(), + arg.shape(), ))); } let row_shape = Shape::from(&arg.shape()[selector.rank()..]); diff --git a/src/algorithm/monadic.rs b/src/algorithm/monadic.rs index bda42a096..5fd70f75d 100644 --- a/src/algorithm/monadic.rs +++ b/src/algorithm/monadic.rs @@ -63,7 +63,7 @@ impl Value { } /// Attempt to parse the value into a number pub fn parse_num(&self, env: &Uiua) -> UiuaResult { - Ok(match (self, self.shape()) { + Ok(match (self, self.shape().dims()) { (Value::Char(arr), [] | [_]) => { let mut s: String = arr.data.iter().copied().collect(); if s.contains('¯') { @@ -867,7 +867,7 @@ impl Value { } /// `invert` `where` pub fn inverse_where(&self, env: &Uiua) -> UiuaResult { - Ok(match self.shape() { + Ok(match self.shape().dims() { [] | [_] => { let indices = self.as_nats(env, "Argument to inverse where must be a list of naturals")?; diff --git a/src/algorithm/pervade.rs b/src/algorithm/pervade.rs index 79817f6b6..f54fedecc 100644 --- a/src/algorithm/pervade.rs +++ b/src/algorithm/pervade.rs @@ -163,7 +163,7 @@ where // Fill fill_array_shapes(&mut a, &mut b, env)?; // Pervade - let shape = Shape::from(a.shape().max(b.shape())); + let shape = a.shape().max(b.shape()).clone(); let mut data = CowSlice::with_capacity(a.element_count().max(b.element_count())); bin_pervade_recursive(&a, &b, &mut data, env, f).map_err(Into::into)?; Ok(Array::new(shape, data)) diff --git a/src/algorithm/table.rs b/src/algorithm/table.rs index 0efa74b77..63c88fe75 100644 --- a/src/algorithm/table.rs +++ b/src/algorithm/table.rs @@ -209,7 +209,7 @@ fn generic_table(f: Function, xs: Value, ys: Value, env: &mut Uiua) -> UiuaResul Primitive::Table.format() ))); } - let mut new_shape = Shape::from(xs.shape()); + let mut new_shape = xs.shape().clone(); new_shape.extend_from_slice(ys.shape()); let outputs = sig.outputs; let mut items = multi_output( diff --git a/src/algorithm/zip.rs b/src/algorithm/zip.rs index f6fb59022..b55f99bc0 100644 --- a/src/algorithm/zip.rs +++ b/src/algorithm/zip.rs @@ -4,7 +4,7 @@ use std::slice; use crate::{ algorithm::pervade::bin_pervade_generic, function::Function, value::Value, FormatShape, - ImplPrimitive, Instr, Primitive, Shape, Uiua, UiuaResult, + ImplPrimitive, Instr, Primitive, Uiua, UiuaResult, }; use super::{multi_output, MultiOutput}; @@ -174,7 +174,7 @@ fn each1(f: Function, xs: Value, env: &mut Uiua) -> UiuaResult { } else { let outputs = f.signature().outputs; let mut new_values = multi_output(outputs, Vec::with_capacity(xs.element_count())); - let new_shape = Shape::from(xs.shape()); + let new_shape = xs.shape().clone(); let is_empty = outputs > 0 && xs.row_count() == 0; if is_empty { env.push(xs.proxy_scalar(env)); @@ -212,8 +212,8 @@ fn each2(f: Function, xs: Value, ys: Value, env: &mut Uiua) -> UiuaResult { "Cannot {} arrays with shapes {} and {} because their \ shape prefixes {} and {} are different", Primitive::Each.format(), - xs.format_shape(), - ys.format_shape(), + xs.shape(), + ys.shape(), FormatShape(&xs.shape()[..min_rank]), FormatShape(&ys.shape()[..min_rank]) ))); @@ -302,8 +302,8 @@ fn eachn(f: Function, args: Vec, env: &mut Uiua) -> UiuaResult { "The shapes in each of 3 or more arrays must all match, \ but shapes {} and {} cannot be {}ed together. \ If you want more flexibility, use rows.", - win[0].format_shape(), - win[1].format_shape(), + win[0].shape(), + win[1].shape(), Primitive::Each.format() ))); } @@ -312,7 +312,7 @@ fn eachn(f: Function, args: Vec, env: &mut Uiua) -> UiuaResult { let is_empty = outputs > 0 && args.iter().any(|v| v.row_count() == 0); let elem_count = args[0].element_count() + is_empty as usize; let mut new_values = multi_output(outputs, Vec::with_capacity(elem_count)); - let new_shape = Shape::from(args[0].shape()); + let new_shape = args[0].shape().clone(); if is_empty { for arg in args.into_iter().rev() { env.push(arg.proxy_scalar(env)); diff --git a/src/array.rs b/src/array.rs index 13585e3ad..9c96fdfe6 100644 --- a/src/array.rs +++ b/src/array.rs @@ -154,7 +154,7 @@ impl Array { self.shape.len() } /// Get the shape of the array - pub fn shape(&self) -> &[usize] { + pub fn shape(&self) -> &Shape { &self.shape } /// Get the metadata of the array @@ -176,10 +176,6 @@ impl Array { self.meta_mut().flags.reset(); } } - /// Get a formattable shape of the array - pub fn format_shape(&self) -> FormatShape<'_> { - FormatShape(self.shape()) - } /// Get an iterator over the row slices of the array pub fn row_slices(&self) -> impl ExactSizeIterator + DoubleEndedIterator { (0..self.row_count()).map(move |row| self.row_slice(row)) diff --git a/src/function.rs b/src/function.rs index 5fed35580..d370d6681 100644 --- a/src/function.rs +++ b/src/function.rs @@ -191,7 +191,7 @@ impl fmt::Debug for Instr { if val.element_count() < 50 && val.shape().len() <= 1 { write!(f, "push {val:?}") } else { - write!(f, "push {} array", val.format_shape()) + write!(f, "push {} array", val.shape()) } } _ => write!(f, "{self}"), diff --git a/src/run.rs b/src/run.rs index 3cd2ed52a..97a965956 100644 --- a/src/run.rs +++ b/src/run.rs @@ -1130,7 +1130,7 @@ code: if !fill.shape().is_empty() { return Err(self.error(format!( "Fill values must be scalar or an empty list, but its shape is {}", - fill.format_shape() + fill.shape() ))); } self.rt.fill_stack.push(match fill { diff --git a/src/shape.rs b/src/shape.rs index 1321632df..29f0c94da 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -149,7 +149,13 @@ impl Extend for Shape { impl PartialEq for Shape { fn eq(&self, other: &usize) -> bool { - self == &[*other] + self == [*other] + } +} + +impl PartialEq for &Shape { + fn eq(&self, other: &usize) -> bool { + *self == [*other] } } @@ -159,14 +165,38 @@ impl PartialEq<[usize; N]> for Shape { } } +impl PartialEq<[usize; N]> for &Shape { + fn eq(&self, other: &[usize; N]) -> bool { + *self == other.as_slice() + } +} + impl PartialEq<[usize]> for Shape { fn eq(&self, other: &[usize]) -> bool { self.dims == other } } +impl PartialEq<[usize]> for &Shape { + fn eq(&self, other: &[usize]) -> bool { + *self == other + } +} + impl PartialEq<&[usize]> for Shape { fn eq(&self, other: &&[usize]) -> bool { self.dims == *other } } + +impl PartialEq for &[usize] { + fn eq(&self, other: &Shape) -> bool { + other == self + } +} + +impl PartialEq for [usize] { + fn eq(&self, other: &Shape) -> bool { + other == self + } +} diff --git a/src/sys.rs b/src/sys.rs index 076ea6d37..f0d7129f1 100644 --- a/src/sys.rs +++ b/src/sys.rs @@ -1242,7 +1242,7 @@ impl SysOp { let samples = samples.as_num_array().ok_or_else(|| { stream_env.error("Audio stream function must return a numeric array") })?; - match samples.shape() { + match samples.shape().dims() { [_] => Ok(samples.data.iter().map(|&x| [x, x]).collect()), [_, 2] => Ok(samples .data @@ -1493,7 +1493,7 @@ pub fn value_to_image(value: &Value) -> Result { _ => return Err("Image must be a numeric array".into()), }; #[allow(clippy::match_ref_pats)] - let [height, width, px_size] = match value.shape() { + let [height, width, px_size] = match value.shape().dims() { &[a, b] => [a, b, 1], &[a, b, c] => [a, b, c], _ => unreachable!("Shape checked above"), @@ -1586,7 +1586,7 @@ pub fn value_to_audio_channels(audio: &Value) -> Result>, String> { if channels.len() > 5 { return Err(format!( "Audio can have at most 5 channels, but its shape is {}", - audio.format_shape() + audio.shape() )); } diff --git a/src/value.rs b/src/value.rs index 028b4ab4a..58d442a09 100644 --- a/src/value.rs +++ b/src/value.rs @@ -268,10 +268,6 @@ impl Value { Self::Box(array) => array.first_dim_zero().into(), } } - /// Get a formattable representation of the shape - pub fn format_shape(&self) -> FormatShape { - FormatShape(self.shape()) - } /// Get the rank pub fn rank(&self) -> usize { self.shape().len() @@ -316,7 +312,7 @@ impl Value { &mut *(self as *mut Self as *mut Repr) } /// Get the shape of the value - pub fn shape(&self) -> &[usize] { + pub fn shape(&self) -> &Shape { &unsafe { self.repr() }.arr.shape } /// Get a mutable reference to the shape @@ -837,10 +833,9 @@ impl Value { Ok(match self { Value::Num(nums) => { if !test_shape(self.shape()) { - return Err(env.error(format!( - "{requirement}, but its shape is {}", - nums.format_shape() - ))); + return Err( + env.error(format!("{requirement}, but its shape is {}", nums.shape())) + ); } let mut result = EcoVec::with_capacity(nums.element_count()); for &num in nums.data() { @@ -849,15 +844,14 @@ impl Value { } result.push(convert_num(num)); } - Array::new(self.shape(), result) + Array::new(self.shape().clone(), result) } #[cfg(feature = "bytes")] Value::Byte(bytes) => { if !test_shape(self.shape()) { - return Err(env.error(format!( - "{requirement}, but its shape is {}", - bytes.format_shape() - ))); + return Err( + env.error(format!("{requirement}, but its shape is {}", bytes.shape())) + ); } let mut result = EcoVec::with_capacity(bytes.element_count()); for &byte in bytes.data() { @@ -867,7 +861,7 @@ impl Value { } result.push(convert_num(num)); } - Array::new(self.shape(), result) + Array::new(self.shape().clone(), result) } value => { return Err(env.error(format!(