Skip to content

Commit

Permalink
Eliminate dependency on xoshiro
Browse files Browse the repository at this point in the history
Change sorting traits to use Order enum
  • Loading branch information
YuhanLiin committed Jun 4, 2022
1 parent 5715734 commit 82837ff
Show file tree
Hide file tree
Showing 14 changed files with 69 additions and 104 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ ndarray = { version = "0.15", features = ["approx"] }
num-traits = "0.2.0"
thiserror = "1"
rand = { version = "0.8", optional=true }
rand_xoshiro = { version = "0.6", optional=true }

[dev-dependencies]
approx = "0.4"
Expand Down
27 changes: 13 additions & 14 deletions src/eigh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
use ndarray::{s, Array1, Array2, ArrayBase, Data, DataMut, Ix2, NdFloat};

use crate::{
check_square, givens::GivensRotation, index::*, tridiagonal::SymmetricTridiagonal, Result,
check_square, givens::GivensRotation, index::*, tridiagonal::SymmetricTridiagonal, Order,
Result,
};

fn symmetric_eig<A: NdFloat, S: DataMut<Elem = A>>(
Expand Down Expand Up @@ -272,44 +273,42 @@ impl<A: NdFloat, S: Data<Elem = A>> EigValsh for ArrayBase<S, Ix2> {
///
/// Will panic if shape or layout of inputs differ from eigen output, or if input contains NaN.
pub trait EigSort: Sized {
fn sort_eig(self, descending: bool) -> Self;
fn sort_eig(self, order: Order) -> Self;

/// Sort eigendecomposition by the eigenvalues in ascending order
fn sort_eig_asc(self) -> Self {
self.sort_eig(false)
self.sort_eig(Order::Smallest)
}

/// Sort eigendecomposition by the eigenvalues in descending order
fn sort_eig_desc(self) -> Self {
self.sort_eig(true)
self.sort_eig(Order::Largest)
}
}

/// Implementation on output of `EigValsh` traits
impl<A: NdFloat> EigSort for Array1<A> {
fn sort_eig(mut self, descending: bool) -> Self {
fn sort_eig(mut self, order: Order) -> Self {
// Panics on non-standard layouts, which is fine because our eigenvals have standard layout
let slice = self.as_slice_mut().unwrap();
// Panic only happens with NaN values
if descending {
slice.sort_by(|a, b| b.partial_cmp(a).unwrap());
} else {
slice.sort_by(|a, b| a.partial_cmp(b).unwrap());
match order {
Order::Largest => slice.sort_by(|a, b| b.partial_cmp(a).unwrap()),
Order::Smallest => slice.sort_by(|a, b| a.partial_cmp(b).unwrap()),
}
self
}
}

/// Implementation on output of `Eigh` traits
impl<A: NdFloat> EigSort for (Array1<A>, Array2<A>) {
fn sort_eig(self, descending: bool) -> Self {
fn sort_eig(self, order: Order) -> Self {
let (mut vals, vecs) = self;
let mut value_idx: Vec<_> = vals.iter().copied().enumerate().collect();
// Panic only happens with NaN values
if descending {
value_idx.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
} else {
value_idx.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
match order {
Order::Largest => value_idx.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()),
Order::Smallest => value_idx.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()),
}

let mut out = Array2::zeros(vecs.dim());
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ pub(crate) fn check_square<S: RawData>(arr: &ArrayBase<S, Ix2>) -> Result<usize>
}

/// Find largest or smallest eigenvalues
///
/// Corresponds to descending and ascending order
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum Order {
Largest,
Expand Down
13 changes: 2 additions & 11 deletions src/lobpcg/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,16 @@ fn sorted_eig<A: NdFloat>(
size: usize,
order: Order,
) -> Result<(Array1<A>, Array2<A>)> {
let n = a.len_of(Axis(0));

let res = match b {
Some(b) => generalized_eig(a, b)?,
_ => a.eigh_into()?,
};

// sort and ensure that signs are deterministic
let (vals, vecs) = res.sort_eig(false);
let (vals, vecs) = res.sort_eig(order);
let s = vecs.row(0).mapv(|x| x.signum());
let vecs = vecs * s;

Ok(match order {
Order::Largest => (
vals.slice_move(s![n-size..; -1]),
vecs.slice_move(s![.., n-size..; -1]),
),
Order::Smallest => (vals.slice_move(s![..size]), vecs.slice_move(s![.., ..size])),
})
Ok((vals.slice_move(s![..size]), vecs.slice_move(s![.., ..size])))
}

/// Masks a matrix with the given `matrix`
Expand Down
44 changes: 16 additions & 28 deletions src/lobpcg/eig.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@ use num_traits::NumCast;
use rand::Rng;
use std::iter::Sum;

#[cfg(feature = "rand_xoshiro")]
use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256Plus};
//#[cfg(feature="rand_xoshiro")]

#[derive(Debug, Clone)]
/// Truncated eigenproblem solver
///
Expand All @@ -28,12 +24,14 @@ use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256Plus};
///
/// ```rust
/// use ndarray::{arr1, Array2};
/// use ndarray_linalg_rs::{Order, lobpcg::TruncatedEig};
/// use linfa_linalg::{Order, lobpcg::TruncatedEig};
/// use rand::SeedableRng;
/// use rand_xoshiro::Xoshiro256Plus;
///
/// let diag = arr1(&[1., 2., 3., 4., 5.]);
/// let a = Array2::from_diag(&diag);
///
/// let mut eig = TruncatedEig::new_from_seed(a, Order::Largest, 42)
/// let mut eig = TruncatedEig::new_with_rng(a, Order::Largest, Xoshiro256Plus::seed_from_u64(42))
/// .precision(1e-5)
/// .maxiter(500);
///
Expand All @@ -49,22 +47,6 @@ pub struct TruncatedEig<A: NdFloat, R: Rng> {
rng: R,
}

#[cfg(feature = "rand_xoshiro")]
impl<A: NdFloat + Sum> TruncatedEig<A, Xoshiro256Plus> {
/// Create a new truncated eigenproblem solver
///
/// # Properties
/// * `problem`: problem matrix
/// * `order`: ordering of the eigenvalues with [Order](crate::Order)
pub fn new_from_seed(
problem: Array2<A>,
order: Order,
seed: u64,
) -> TruncatedEig<A, Xoshiro256Plus> {
Self::new_with_rng(problem, order, Xoshiro256Plus::seed_from_u64(seed))
}
}

impl<A: NdFloat + Sum, R: Rng> TruncatedEig<A, R> {
/// Create a new truncated eigenproblem solver
///
Expand Down Expand Up @@ -140,12 +122,14 @@ impl<A: NdFloat + Sum, R: Rng> TruncatedEig<A, R> {
///
/// ```rust
/// use ndarray::{arr1, Array2};
/// use ndarray_linalg_rs::{Order, lobpcg::TruncatedEig};
/// use linfa_linalg::{Order, lobpcg::TruncatedEig};
/// use rand::SeedableRng;
/// use rand_xoshiro::Xoshiro256Plus;
///
/// let diag = arr1(&[1., 2., 3., 4., 5.]);
/// let a = Array2::from_diag(&diag);
///
/// let mut eig = TruncatedEig::new_from_seed(a, Order::Largest, 42)
/// let mut eig = TruncatedEig::new_with_rng(a, Order::Largest, Xoshiro256Plus::seed_from_u64(42))
/// .precision(1e-5)
/// .maxiter(500);
///
Expand Down Expand Up @@ -216,12 +200,14 @@ impl<A: NdFloat + Sum, R: Rng> TruncatedEig<A, R> {
///
/// ```rust
/// use ndarray::{arr1, Array2};
/// use ndarray_linalg_rs::{Order, lobpcg::TruncatedEig};
/// use linfa_linalg::{Order, lobpcg::TruncatedEig};
/// use rand::SeedableRng;
/// use rand_xoshiro::Xoshiro256Plus;
///
/// let diag = arr1(&[1., 2., 3., 4., 5.]);
/// let a = Array2::from_diag(&diag);
///
/// let teig = TruncatedEig::new_from_seed(a, Order::Largest, 42)
/// let teig = TruncatedEig::new_with_rng(a, Order::Largest, Xoshiro256Plus::seed_from_u64(42))
/// .precision(1e-5)
/// .maxiter(500);
///
Expand Down Expand Up @@ -306,11 +292,13 @@ impl<A: NdFloat + Sum, R: Rng> Iterator for TruncatedEigIterator<A, R> {
}
}

#[cfg(all(test, feature = "rand_xoshiro"))]
#[cfg(test)]
mod tests {
use super::Order;
use super::TruncatedEig;
use ndarray::{arr1, Array2};
use rand::SeedableRng;
use rand_xoshiro::Xoshiro256Plus;

#[test]
fn test_truncated_eig() {
Expand All @@ -320,7 +308,7 @@ mod tests {
]);
let a = Array2::from_diag(&diag);

let teig = TruncatedEig::new_from_seed(a, Order::Largest, 42)
let teig = TruncatedEig::new_with_rng(a, Order::Largest, Xoshiro256Plus::seed_from_u64(42))
.precision(1e-5)
.maxiter(500);

Expand Down
53 changes: 20 additions & 33 deletions src/lobpcg/svd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@ use std::iter::Sum;

use rand::Rng;

#[cfg(feature = "rand_xoshiro")]
use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256Plus};

/// The result of a eigenvalue decomposition, not yet transformed into singular values/vectors
///
/// Provides methods for either calculating just the singular values with reduced cost or the
Expand Down Expand Up @@ -115,23 +112,6 @@ pub struct TruncatedSvd<A: NdFloat, R: Rng> {
rng: R,
}

#[cfg(feature = "rand_xoshiro")]
impl<A: NdFloat + Sum> TruncatedSvd<A, Xoshiro256Plus> {
/// Create a new truncated SVD problem
///
/// # Parameters
/// * `problem`: rectangular matrix which is decomposed
/// * `order`: whether to return large or small (close to zero) singular values
/// * `seed`: seed of the random number generator
pub fn new_from_seed(
problem: Array2<A>,
order: Order,
seed: u64,
) -> TruncatedSvd<A, Xoshiro256Plus> {
Self::new_with_rng(problem, order, Xoshiro256Plus::seed_from_u64(seed))
}
}

impl<A: NdFloat + Sum, R: Rng> TruncatedSvd<A, R> {
/// Create a new truncated SVD problem
///
Expand Down Expand Up @@ -182,12 +162,14 @@ impl<A: NdFloat + Sum, R: Rng> TruncatedSvd<A, R> {
///
/// ```rust
/// use ndarray::{arr1, Array2};
/// use ndarray_linalg_rs::{Order, lobpcg::TruncatedSvd};
/// use linfa_linalg::{Order, lobpcg::TruncatedSvd};
/// use rand::SeedableRng;
/// use rand_xoshiro::Xoshiro256Plus;
///
/// let diag = arr1(&[1., 2., 3., 4., 5.]);
/// let a = Array2::from_diag(&diag);
///
/// let eig = TruncatedSvd::new_from_seed(a, Order::Largest, 42)
/// let eig = TruncatedSvd::new_with_rng(a, Order::Largest, Xoshiro256Plus::seed_from_u64(42))
/// .precision(1e-5)
/// .maxiter(500);
///
Expand Down Expand Up @@ -280,7 +262,7 @@ impl MagnitudeCorrection for f64 {
}
}

#[cfg(all(test, feature = "rand_xoshiro"))]
#[cfg(test)]
mod tests {
use super::Order;
use super::TruncatedSvd;
Expand All @@ -306,7 +288,7 @@ mod tests {
fn test_truncated_svd() {
let a = arr2(&[[3., 2., 2.], [2., 3., -2.]]);

let res = TruncatedSvd::new_from_seed(a, Order::Largest, 42)
let res = TruncatedSvd::new_with_rng(a, Order::Largest, Xoshiro256Plus::seed_from_u64(42))
.precision(1e-5)
.maxiter(10)
.decompose(2)
Expand All @@ -321,11 +303,15 @@ mod tests {
fn test_truncated_svd_random() {
let a: Array2<f64> = random((50, 10));

let res = TruncatedSvd::new_from_seed(a.clone(), Order::Largest, 42)
.precision(1e-5)
.maxiter(10)
.decompose(10)
.unwrap();
let res = TruncatedSvd::new_with_rng(
a.clone(),
Order::Largest,
Xoshiro256Plus::seed_from_u64(42),
)
.precision(1e-5)
.maxiter(10)
.decompose(10)
.unwrap();

let (u, sigma, v_t) = res.values_vectors();
let reconstructed = u.dot(&Array2::from_diag(&sigma).dot(&v_t));
Expand All @@ -349,10 +335,11 @@ mod tests {
// generate normal distribution random data with N >> p
let data = Array2::random_using((1000, 500), StandardNormal, &mut rng) / 1000f64.sqrt();

let res = TruncatedSvd::new_from_seed(data, Order::Largest, 42)
.precision(1e-3)
.decompose(500)
.unwrap();
let res =
TruncatedSvd::new_with_rng(data, Order::Largest, Xoshiro256Plus::seed_from_u64(42))
.precision(1e-3)
.decompose(500)
.unwrap();

let sv = res.values().mapv(|x: f64| x * x);

Expand Down
17 changes: 8 additions & 9 deletions src/svd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use ndarray::{s, Array1, Array2, ArrayBase, Axis, Data, DataMut, Ix2, NdFloat};

use crate::{
bidiagonal::Bidiagonal, eigh::wilkinson_shift, givens::GivensRotation, index::*, LinalgError,
Result,
Order, Result,
};

fn svd<A: NdFloat, S: DataMut<Elem = A>>(
Expand Down Expand Up @@ -482,29 +482,28 @@ impl<A: NdFloat, S: Data<Elem = A>> SVD for ArrayBase<S, Ix2> {
///
/// Will panic if shape of inputs differs from shape of SVD output, or if input contains NaN.
pub trait SvdSort: Sized {
fn sort_svd(self, descending: bool) -> Self;
fn sort_svd(self, order: Order) -> Self;

/// Sort SVD decomposition by the singular values in ascending order
fn sort_svd_asc(self) -> Self {
self.sort_svd(false)
self.sort_svd(Order::Smallest)
}

/// Sort SVD decomposition by the singular values in descending order
fn sort_svd_desc(self) -> Self {
self.sort_svd(true)
self.sort_svd(Order::Largest)
}
}

/// Implemented on the output of the `SVD` traits
impl<A: NdFloat> SvdSort for (Option<Array2<A>>, Array1<A>, Option<Array2<A>>) {
fn sort_svd(self, descending: bool) -> Self {
fn sort_svd(self, order: Order) -> Self {
let (u, mut s, vt) = self;
let mut value_idx: Vec<_> = s.iter().copied().enumerate().collect();
// Panic only happens with NaN values
if descending {
value_idx.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
} else {
value_idx.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
match order {
Order::Largest => value_idx.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()),
Order::Smallest => value_idx.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()),
}

let apply_ordering = |arr: &Array2<A>, ax, values_idx: &Vec<_>| {
Expand Down
2 changes: 1 addition & 1 deletion tests/bidiagonal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use approx::assert_abs_diff_eq;
use ndarray::prelude::*;
use proptest::prelude::*;

use ndarray_linalg_rs::bidiagonal::*;
use linfa_linalg::bidiagonal::*;

mod common;

Expand Down
2 changes: 1 addition & 1 deletion tests/cholesky.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use approx::assert_abs_diff_eq;
use ndarray::prelude::*;
use proptest::prelude::*;

use ndarray_linalg_rs::{cholesky::*, triangular::*};
use linfa_linalg::{cholesky::*, triangular::*};

mod common;

Expand Down
2 changes: 1 addition & 1 deletion tests/eigh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use approx::assert_abs_diff_eq;
use ndarray::prelude::*;
use proptest::prelude::*;

use ndarray_linalg_rs::eigh::*;
use linfa_linalg::eigh::*;

mod common;

Expand Down
Loading

0 comments on commit 82837ff

Please sign in to comment.