From ee50e189e8171e80e1db340f8a4e16445046ff82 Mon Sep 17 00:00:00 2001 From: Kai Schmidt Date: Sun, 16 Jun 2024 15:04:25 -0700 Subject: [PATCH] give an explicit implementation for un shape --- src/algorithm/invert.rs | 22 ++-------------------- src/algorithm/monadic.rs | 16 ++++++++++++++++ src/primitive/defs.rs | 1 + src/primitive/mod.rs | 2 ++ 4 files changed, 21 insertions(+), 20 deletions(-) diff --git a/src/algorithm/invert.rs b/src/algorithm/invert.rs index d2c72192..938bcdaf 100644 --- a/src/algorithm/invert.rs +++ b/src/algorithm/invert.rs @@ -95,6 +95,7 @@ fn prim_inverse(prim: Primitive, span: usize) -> Option { Utf => Instr::ImplPrim(UnUtf, span), Parse => Instr::ImplPrim(UnParse, span), Fix => Instr::ImplPrim(UnFix, span), + Shape => Instr::ImplPrim(UnShape, span), Map => Instr::ImplPrim(UnMap, span), Trace => Instr::ImplPrim(TraceN(1, true), span), Stack => Instr::ImplPrim(UnStack, span), @@ -130,6 +131,7 @@ fn impl_prim_inverse(prim: ImplPrimitive, span: usize) -> Option { UnCouple => Instr::Prim(Couple, span), UnParse => Instr::Prim(Parse, span), UnFix => Instr::Prim(Fix, span), + UnShape => Instr::Prim(Shape, span), UnMap => Instr::Prim(Map, span), UnStack => Instr::Prim(Stack, span), UnJoin => Instr::Prim(Join, span), @@ -190,7 +192,6 @@ static INVERT_PATTERNS: &[&dyn InvertPattern] = { &InvertPatternFn(invert_dup_pattern, "dup"), &InvertPatternFn(invert_stack_swizzle_pattern, "stack swizzle"), &InvertPatternFn(invert_select_pattern, "select"), - &InvertPatternFn(invert_shape_pattern, "shape"), &pat!(Sqrt, (Dup, Mul)), &pat!((Dup, Add), (2, Div)), &([Dup, Mul], [Sqrt]), @@ -918,25 +919,6 @@ fn invert_stack_swizzle_pattern<'a>( Some((input, instrs)) } -fn invert_shape_pattern<'a>( - input: &'a [Instr], - comp: &mut Compiler, -) -> Option<(&'a [Instr], EcoVec)> { - let [Instr::Prim(Primitive::Shape, span), input @ ..] = input else { - return None; - }; - let mul = make_fn(eco_vec![Instr::Prim(Primitive::Mul, *span)], *span, comp)?; - let instrs = eco_vec![ - Instr::Prim(Primitive::Dup, *span), - Instr::PushFunc(mul), - Instr::Prim(Primitive::Reduce, *span), - Instr::Prim(Primitive::Range, *span), - Instr::Prim(Primitive::Flip, *span), - Instr::Prim(Primitive::Reshape, *span), - ]; - Some((input, instrs)) -} - fn invert_select_pattern<'a>( input: &'a [Instr], _: &mut Compiler, diff --git a/src/algorithm/monadic.rs b/src/algorithm/monadic.rs index 4cce0d83..0615dece 100644 --- a/src/algorithm/monadic.rs +++ b/src/algorithm/monadic.rs @@ -213,6 +213,22 @@ impl Value { Err(data) => Array::new(shape, data).into(), }) } + pub(crate) fn unshape(&self, env: &Uiua) -> UiuaResult { + let ishape = self.as_ints( + env, + "Shape should be a single integer or a list of integers", + )?; + let shape = Shape::from_iter(ishape.iter().map(|n| n.unsigned_abs())); + let elems: usize = validate_size::(shape.iter().copied(), env)?; + let data = EcoVec::from_iter((0..elems).map(|i| i as f64)); + let mut arr = Array::new(shape, data); + for (i, s) in ishape.into_iter().enumerate() { + if s < 0 { + arr.reverse_depth(i); + } + } + Ok(arr.into()) + } } fn range(shape: &[isize], env: &Uiua) -> UiuaResult, CowSlice>> { diff --git a/src/primitive/defs.rs b/src/primitive/defs.rs index 906487a0..de4fb2a7 100644 --- a/src/primitive/defs.rs +++ b/src/primitive/defs.rs @@ -2701,6 +2701,7 @@ impl_primitive!( (1(2), UnComplex), (1, UnParse), (1, UnFix), + (1, UnShape), (1[1], UnScan), (1(2), UnMap), (0(0), UnStack, Impure), diff --git a/src/primitive/mod.rs b/src/primitive/mod.rs index 30bb2605..ac8b055e 100644 --- a/src/primitive/mod.rs +++ b/src/primitive/mod.rs @@ -168,6 +168,7 @@ impl fmt::Display for ImplPrimitive { UnUtf => write!(f, "{Un}{Utf}"), UnParse => write!(f, "{Un}{Parse}"), UnFix => write!(f, "{Un}{Fix}"), + UnShape => write!(f, "{Un}{Shape}"), UnJoin | UnJoinPattern => write!(f, "{Un}{Join}"), UnKeep => write!(f, "{Un}{Keep}"), UnScan => write!(f, "{Un}{Scan}"), @@ -989,6 +990,7 @@ impl ImplPrimitive { } ImplPrimitive::UnParse => env.monadic_ref_env(Value::unparse)?, ImplPrimitive::UnFix => env.monadic_mut_env(Value::unfix)?, + ImplPrimitive::UnShape => env.monadic_ref_env(Value::unshape)?, ImplPrimitive::UndoFix => env.monadic_mut(Value::undo_fix)?, ImplPrimitive::UnScan => reduce::unscan(env)?, ImplPrimitive::TraceN(n, inverse) => trace_n(env, *n, *inverse)?,