From 30586b076be031b0703870ef6e858acb63e8fa0b Mon Sep 17 00:00:00 2001 From: Kai Schmidt Date: Mon, 17 Jun 2024 13:48:42 -0700 Subject: [PATCH] add orient inverses --- src/algorithm/dyadic/mod.rs | 43 +++++++++++++++++++++++++++++++++++-- src/algorithm/invert.rs | 7 ++++++ src/primitive/defs.rs | 1 + src/primitive/mod.rs | 7 ++++++ 4 files changed, 56 insertions(+), 2 deletions(-) diff --git a/src/algorithm/dyadic/mod.rs b/src/algorithm/dyadic/mod.rs index b4cd7d70..32a30ce1 100644 --- a/src/algorithm/dyadic/mod.rs +++ b/src/algorithm/dyadic/mod.rs @@ -1728,7 +1728,11 @@ impl Value { return Err(env.error("Orient indices must be unique")); } } - let mut orientation: Vec = (0..target.rank()).collect(); + target.orient_impl(&undices); + Ok(()) + } + fn orient_impl(&mut self, undices: &[usize]) { + let mut orientation: Vec = (0..self.rank()).collect(); let mut depth_rotations: Vec<(usize, i32)> = Vec::new(); for (i, &u) in undices.iter().enumerate() { let j = orientation.iter().position(|&o| o == u).unwrap(); @@ -1743,8 +1747,43 @@ impl Value { depth_rotations.push((i, -1)); } for (depth, amnt) in depth_rotations { - target.transpose_depth(depth, amnt); + self.transpose_depth(depth, amnt); + } + } + pub(crate) fn undo_orient(&self, target: &mut Self, env: &Uiua) -> UiuaResult { + let indices = self.as_ints(env, "Unorient indices must be integers")?; + let mut undices = Vec::with_capacity(indices.len()); + for i in indices { + let u = i.unsigned_abs(); + if u >= target.rank() { + return Err(env.error(format!( + "Cannot unorient axis {i} in array of rank {}", + target.rank() + ))); + } + if i >= 0 { + undices.push(u); + } else { + undices.push(target.rank() - u); + } + } + if undices.len() > target.rank() { + return Err(env.error(format!( + "Cannot unorient array of rank {} with {} indices", + target.rank(), + undices.len() + ))); + } + for i in 0..target.rank() { + if !undices.contains(&i) { + undices.push(i); + } + } + let mut inverted_undices = undices.clone(); + for (i, u) in undices.into_iter().enumerate() { + inverted_undices[u] = i; } + target.orient_impl(&inverted_undices); Ok(()) } } diff --git a/src/algorithm/invert.rs b/src/algorithm/invert.rs index 938bcdaf..b5828e8f 100644 --- a/src/algorithm/invert.rs +++ b/src/algorithm/invert.rs @@ -225,6 +225,7 @@ static ON_INVERT_PATTERNS: &[&dyn InvertPattern] = { &pat!((Flip, Log), (Flip, 1, Flip, Div, Pow)), &([Min], [Min]), &([Max], [Max]), + &([Orient], [UndoOrient]), &pat!( Join, ( @@ -535,6 +536,12 @@ pub(crate) fn under_instrs( (Dup, Shape, PushToUnder(1), Where), (PopUnder(1), UndoWhere) ), + // Orient + &maybe_val!(pat!( + Orient, + (CopyToUnder(1), Orient), + (PopUnder(1), UndoOrient) + )), // System stuff &pat!(Now, (Now, PushToUnder(1)), (PopUnder(1), Now, Flip, Sub)), &maybe_val!(store1copy!(Sys(SysOp::FOpen), Sys(SysOp::Close))), diff --git a/src/primitive/defs.rs b/src/primitive/defs.rs index 4264d583..11d085ba 100644 --- a/src/primitive/defs.rs +++ b/src/primitive/defs.rs @@ -2747,6 +2747,7 @@ impl_primitive!( (3, UndoRerank), (2, UndoReshape), (2, UndoWhere), + (2, UndoOrient), (3(2), UndoJoin), (1[1], UndoPartition1), (3, UndpPartition2), diff --git a/src/primitive/mod.rs b/src/primitive/mod.rs index 1dc411cc..346e8b0c 100644 --- a/src/primitive/mod.rs +++ b/src/primitive/mod.rs @@ -184,6 +184,7 @@ impl fmt::Display for ImplPrimitive { UndoSelect => write!(f, "{Under}{Select}"), UndoPick => write!(f, "{Under}{Pick}"), UndoWhere => write!(f, "{Under}{Where}"), + UndoOrient => write!(f, "{Under}{Orient}"), UndoInsert => write!(f, "{Under}{Insert}"), UndoRemove => write!(f, "{Under}{Remove}"), UndoPartition1 | UndpPartition2 => write!(f, "{Under}{Partition}"), @@ -920,6 +921,12 @@ impl ImplPrimitive { let mask = indices.undo_where(&shape, env)?; env.push(mask); } + ImplPrimitive::UndoOrient => { + let indices = env.pop(1)?; + let mut array = env.pop(2)?; + indices.undo_orient(&mut array, env)?; + env.push(array); + } ImplPrimitive::UndoRerank => { let rank = env.pop(1)?; let shape = env.pop(2)?;