Skip to content

Commit

Permalink
change how new shape is used
Browse files Browse the repository at this point in the history
  • Loading branch information
kaikalii committed Dec 21, 2023
1 parent 28b4e1b commit 6c8b0d5
Show file tree
Hide file tree
Showing 16 changed files with 87 additions and 67 deletions.
2 changes: 1 addition & 1 deletion site/src/editor/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ fn run_code_single(code: &str) -> Vec<OutputItem> {
}
// Try to convert the value to a gif
if let Ok(bytes) = value_to_gif_bytes(&value, 16.0) {
match value.shape() {
match value.shape().dims() {
&[f, h, w] | &[f, h, w, _]
if h >= MIN_AUTO_IMAGE_DIM && w >= MIN_AUTO_IMAGE_DIM && f >= 5 =>
{
Expand Down
18 changes: 9 additions & 9 deletions src/algorithm/dyadic/combine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,11 @@ impl<T: ArrayValue> Array<T> {
other.rank()
))));
}
if self.shape() != &other.shape()[1..] {
if self.shape() != other.shape()[1..] {
return Err(C::fill_error(ctx.error(format!(
"Cannot join arrays of shapes {} and {}{e}",
self.format_shape(),
other.format_shape()
self.shape(),
other.shape()
))));
}
other.shape
Expand All @@ -249,7 +249,7 @@ impl<T: ArrayValue> Array<T> {
self
}
Ordering::Greater => {
if other.shape() == [0] {
if other.shape() == 0 {
return Ok(self);
}
self.append(other, ctx)?;
Expand All @@ -274,8 +274,8 @@ impl<T: ArrayValue> Array<T> {
Err(e) if self.shape[1..] != other.shape[1..] => {
return Err(C::fill_error(ctx.error(format!(
"Cannot join arrays of shapes {} and {}. {e}",
self.format_shape(),
other.format_shape()
self.shape(),
other.shape()
))));
}
_ => (),
Expand Down Expand Up @@ -313,7 +313,7 @@ impl<T: ArrayValue> Array<T> {
if &self.shape()[1..] != other.shape() {
return Err(C::fill_error(ctx.error(format!(
"Cannot add shape {} row to array with shape {} rows{e}",
other.format_shape(),
other.shape(),
FormatShape(&self.shape()[1..]),
))));
}
Expand Down Expand Up @@ -445,8 +445,8 @@ impl<T: ArrayValue> Array<T> {
Err(e) => {
return Err(C::fill_error(ctx.error(format!(
"Cannot couple arrays with shapes {} and {}{e}",
self.format_shape(),
other.format_shape()
self.shape(),
other.shape()
))));
}
}
Expand Down
18 changes: 9 additions & 9 deletions src/algorithm/dyadic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ impl<T: Clone + std::fmt::Debug> Array<T> {
return Err(ctx.error(format!(
"Cannot combine arrays with shapes {} and {} \
because shape prefixes {} and {} are not compatible",
a.format_shape(),
b.format_shape(),
a.shape(),
b.shape(),
FormatShape(a_prefix),
FormatShape(b_prefix)
)));
Expand Down Expand Up @@ -195,7 +195,7 @@ impl Value {
"Cannot unreshape array because its old shape was {}, \
but its new shape is {}, which has a different number of elements",
FormatShape(&orig_shape),
self.format_shape()
self.shape()
)))
}
}
Expand Down Expand Up @@ -520,7 +520,7 @@ impl<T: ArrayValue> Array<T> {
Err(e) => {
return Err(env.error(format!(
"Cannot keep array with shape {} with array of shape {}{e}",
self.format_shape(),
self.shape(),
FormatShape(&[amount.len()])
)));
}
Expand All @@ -532,14 +532,14 @@ impl<T: ArrayValue> Array<T> {
"Cannot keep array with shape {} with array of shape {}.\
A fill value is available, but keep can only been filled\
if there are fewer counts than rows.",
self.format_shape(),
self.shape(),
FormatShape(amount.as_ref())
)
}
Err(e) => {
format!(
"Cannot keep array with shape {} with array of shape {}{e}",
self.format_shape(),
self.shape(),
FormatShape(amount.as_ref())
)
}
Expand Down Expand Up @@ -621,8 +621,8 @@ impl<T: ArrayValue> Array<T> {
return Err(env.error(format!(
"Kept array's shape was changed from {} to {}, \
so the keep cannot be inverted",
into_row.format_shape(),
new_row.format_shape()
into_row.shape(),
new_row.shape()
)));
}
new_rows.push(new_row);
Expand Down Expand Up @@ -795,7 +795,7 @@ impl<T: ArrayValue> Array<T> {
if isize_spec.len() > self.shape.len() {
return Err(env.error(format!(
"Window size {isize_spec:?} has too many axes for shape {}",
self.format_shape()
self.shape()
)));
}
let mut size_spec = Vec::with_capacity(isize_spec.len());
Expand Down
6 changes: 3 additions & 3 deletions src/algorithm/dyadic/structure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ impl<T: ArrayValue> Array<T> {
return Err(env
.error(format!(
"Index {i} is out of bounds of length {s} (dimension {d}) in shape {}{e}",
self.format_shape()
self.shape()
))
.fill());
}
Expand Down Expand Up @@ -194,7 +194,7 @@ impl<T: ArrayValue> Array<T> {
"Attempted to undo pick, but the shape of the selected \
array changed from {} to {}",
FormatShape(&expected_shape),
self.format_shape()
self.shape()
)));
}
let index_row_len: usize = index_shape[1..].iter().product();
Expand All @@ -216,7 +216,7 @@ impl<T: ArrayValue> Array<T> {
"Attempted to undo pick, but the shape of the selected \
array changed from {} to {}",
FormatShape(expected_shape),
self.format_shape()
self.shape()
)));
}
let mut start = 0;
Expand Down
4 changes: 2 additions & 2 deletions src/algorithm/loops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ impl<T: ArrayValue> Array<T> {
if markers.len() != self.row_count() {
return Err(env.error(format!(
"Cannot partition array of shape {} with markers of length {}",
self.format_shape(),
self.shape(),
markers.len()
)));
}
Expand Down Expand Up @@ -291,7 +291,7 @@ impl<T: ArrayValue> Array<T> {
if indices.len() != self.row_count() {
return Err(env.error(format!(
"Cannot group array of shape {} with indices of length {}",
self.format_shape(),
self.shape(),
indices.len()
)));
}
Expand Down
12 changes: 6 additions & 6 deletions src/algorithm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ where
},
Ordering::Equal => {
let target_shape = max_shape(a.shape(), b.shape());
if a.shape() != &*target_shape {
if a.shape() != *target_shape {
match ctx.fill() {
Ok(fill) => {
a.fill_to_shape(&target_shape, fill);
Expand All @@ -236,7 +236,7 @@ where
Err(e) => fill_error = Some(e),
}
}
if b.shape() != &*target_shape {
if b.shape() != *target_shape {
match ctx.fill() {
Ok(fill) => {
b.fill_to_shape(&target_shape, fill);
Expand All @@ -253,8 +253,8 @@ where
if let Some(e) = fill_error {
return Err(C::fill_error(ctx.error(format!(
"Shapes {} and {} do not match{e}",
a.format_shape(),
b.format_shape(),
a.shape(),
b.shape(),
))));
}

Expand Down Expand Up @@ -340,9 +340,9 @@ pub fn switch(count: usize, sig: Signature, env: &mut Uiua) -> UiuaResult {
return Err(env.error(format!(
"The function's select's shape {} is not compatible \
with the argument {}'s shape {}",
selector.format_shape(),
selector.shape(),
i + 1,
arg.format_shape(),
arg.shape(),
)));
}
let row_shape = Shape::from(&arg.shape()[selector.rank()..]);
Expand Down
4 changes: 2 additions & 2 deletions src/algorithm/monadic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ impl Value {
}
/// Attempt to parse the value into a number
pub fn parse_num(&self, env: &Uiua) -> UiuaResult<Self> {
Ok(match (self, self.shape()) {
Ok(match (self, self.shape().dims()) {
(Value::Char(arr), [] | [_]) => {
let mut s: String = arr.data.iter().copied().collect();
if s.contains('¯') {
Expand Down Expand Up @@ -867,7 +867,7 @@ impl Value {
}
/// `invert` `where`
pub fn inverse_where(&self, env: &Uiua) -> UiuaResult<Self> {
Ok(match self.shape() {
Ok(match self.shape().dims() {
[] | [_] => {
let indices =
self.as_nats(env, "Argument to inverse where must be a list of naturals")?;
Expand Down
2 changes: 1 addition & 1 deletion src/algorithm/pervade.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ where
// Fill
fill_array_shapes(&mut a, &mut b, env)?;
// Pervade
let shape = Shape::from(a.shape().max(b.shape()));
let shape = a.shape().max(b.shape()).clone();
let mut data = CowSlice::with_capacity(a.element_count().max(b.element_count()));
bin_pervade_recursive(&a, &b, &mut data, env, f).map_err(Into::into)?;
Ok(Array::new(shape, data))
Expand Down
2 changes: 1 addition & 1 deletion src/algorithm/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ fn generic_table(f: Function, xs: Value, ys: Value, env: &mut Uiua) -> UiuaResul
Primitive::Table.format()
)));
}
let mut new_shape = Shape::from(xs.shape());
let mut new_shape = xs.shape().clone();
new_shape.extend_from_slice(ys.shape());
let outputs = sig.outputs;
let mut items = multi_output(
Expand Down
14 changes: 7 additions & 7 deletions src/algorithm/zip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::slice;

use crate::{
algorithm::pervade::bin_pervade_generic, function::Function, value::Value, FormatShape,
ImplPrimitive, Instr, Primitive, Shape, Uiua, UiuaResult,
ImplPrimitive, Instr, Primitive, Uiua, UiuaResult,
};

use super::{multi_output, MultiOutput};
Expand Down Expand Up @@ -174,7 +174,7 @@ fn each1(f: Function, xs: Value, env: &mut Uiua) -> UiuaResult {
} else {
let outputs = f.signature().outputs;
let mut new_values = multi_output(outputs, Vec::with_capacity(xs.element_count()));
let new_shape = Shape::from(xs.shape());
let new_shape = xs.shape().clone();
let is_empty = outputs > 0 && xs.row_count() == 0;
if is_empty {
env.push(xs.proxy_scalar(env));
Expand Down Expand Up @@ -212,8 +212,8 @@ fn each2(f: Function, xs: Value, ys: Value, env: &mut Uiua) -> UiuaResult {
"Cannot {} arrays with shapes {} and {} because their \
shape prefixes {} and {} are different",
Primitive::Each.format(),
xs.format_shape(),
ys.format_shape(),
xs.shape(),
ys.shape(),
FormatShape(&xs.shape()[..min_rank]),
FormatShape(&ys.shape()[..min_rank])
)));
Expand Down Expand Up @@ -302,8 +302,8 @@ fn eachn(f: Function, args: Vec<Value>, env: &mut Uiua) -> UiuaResult {
"The shapes in each of 3 or more arrays must all match, \
but shapes {} and {} cannot be {}ed together. \
If you want more flexibility, use rows.",
win[0].format_shape(),
win[1].format_shape(),
win[0].shape(),
win[1].shape(),
Primitive::Each.format()
)));
}
Expand All @@ -312,7 +312,7 @@ fn eachn(f: Function, args: Vec<Value>, env: &mut Uiua) -> UiuaResult {
let is_empty = outputs > 0 && args.iter().any(|v| v.row_count() == 0);
let elem_count = args[0].element_count() + is_empty as usize;
let mut new_values = multi_output(outputs, Vec::with_capacity(elem_count));
let new_shape = Shape::from(args[0].shape());
let new_shape = args[0].shape().clone();
if is_empty {
for arg in args.into_iter().rev() {
env.push(arg.proxy_scalar(env));
Expand Down
6 changes: 1 addition & 5 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ impl<T> Array<T> {
self.shape.len()
}
/// Get the shape of the array
pub fn shape(&self) -> &[usize] {
pub fn shape(&self) -> &Shape {
&self.shape
}
/// Get the metadata of the array
Expand All @@ -176,10 +176,6 @@ impl<T> Array<T> {
self.meta_mut().flags.reset();
}
}
/// Get a formattable shape of the array
pub fn format_shape(&self) -> FormatShape<'_> {
FormatShape(self.shape())
}
/// Get an iterator over the row slices of the array
pub fn row_slices(&self) -> impl ExactSizeIterator<Item = &[T]> + DoubleEndedIterator {
(0..self.row_count()).map(move |row| self.row_slice(row))
Expand Down
2 changes: 1 addition & 1 deletion src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ impl fmt::Debug for Instr {
if val.element_count() < 50 && val.shape().len() <= 1 {
write!(f, "push {val:?}")
} else {
write!(f, "push {} array", val.format_shape())
write!(f, "push {} array", val.shape())
}
}
_ => write!(f, "{self}"),
Expand Down
2 changes: 1 addition & 1 deletion src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1130,7 +1130,7 @@ code:
if !fill.shape().is_empty() {
return Err(self.error(format!(
"Fill values must be scalar or an empty list, but its shape is {}",
fill.format_shape()
fill.shape()
)));
}
self.rt.fill_stack.push(match fill {
Expand Down
32 changes: 31 additions & 1 deletion src/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,13 @@ impl Extend<usize> for Shape {

impl PartialEq<usize> for Shape {
fn eq(&self, other: &usize) -> bool {
self == &[*other]
self == [*other]
}
}

impl PartialEq<usize> for &Shape {
fn eq(&self, other: &usize) -> bool {
*self == [*other]
}
}

Expand All @@ -159,14 +165,38 @@ impl<const N: usize> PartialEq<[usize; N]> for Shape {
}
}

impl<const N: usize> PartialEq<[usize; N]> for &Shape {
fn eq(&self, other: &[usize; N]) -> bool {
*self == other.as_slice()
}
}

impl PartialEq<[usize]> for Shape {
fn eq(&self, other: &[usize]) -> bool {
self.dims == other
}
}

impl PartialEq<[usize]> for &Shape {
fn eq(&self, other: &[usize]) -> bool {
*self == other
}
}

impl PartialEq<&[usize]> for Shape {
fn eq(&self, other: &&[usize]) -> bool {
self.dims == *other
}
}

impl PartialEq<Shape> for &[usize] {
fn eq(&self, other: &Shape) -> bool {
other == self
}
}

impl PartialEq<Shape> for [usize] {
fn eq(&self, other: &Shape) -> bool {
other == self
}
}
Loading

0 comments on commit 6c8b0d5

Please sign in to comment.