diff --git a/Cargo.toml b/Cargo.toml index d2124d1..b1ef902 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ itertools = "0.10" rand_distr = "0.4" thiserror = "1.0" matrixmultiply = "0.3" -approx = "0.5" +approx = { version = "0.5", features=["num-complex"] } [dev-dependencies] criterion = { version = "0.3", features = ["html_reports"] } diff --git a/src/lib.rs b/src/lib.rs index 8673548..49dd926 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,7 @@ pub mod matrix; pub mod matrix_ref; pub mod scalar_mult; pub mod addition; +pub mod tools; //pub mod traits; //pub mod slice_matrix; diff --git a/src/matrix/random.rs b/src/matrix/random.rs index a978b7e..2cd5b72 100644 --- a/src/matrix/random.rs +++ b/src/matrix/random.rs @@ -4,7 +4,8 @@ use crate::data_container::DataContainerMut; use crate::traits::*; use crate::types::*; use rand::prelude::*; -use rand_distr::{Distribution, StandardNormal}; +use rand_distr::StandardNormal; +use crate::tools::*; use super::GenericBaseMatrixMut; @@ -30,31 +31,3 @@ rand_impl!(f64); rand_impl!(c32); rand_impl!(c64); -// Random number implementations for the scalar types -trait RandScalar: Scalar { - fn random_scalar>(rng: &mut R, dist: &D) -> Self; -} - -impl RandScalar for f32 { - fn random_scalar>(rng: &mut R, dist: &D) -> Self { - dist.sample(rng) - } -} - -impl RandScalar for f64 { - fn random_scalar>(rng: &mut R, dist: &D) -> Self { - dist.sample(rng) - } -} - -impl RandScalar for c32 { - fn random_scalar>(rng: &mut R, dist: &D) -> Self { - c32::new(dist.sample(rng), dist.sample(rng)) - } -} - -impl RandScalar for c64 { - fn random_scalar>(rng: &mut R, dist: &D) -> Self { - c64::new(dist.sample(rng), dist.sample(rng)) - } -} diff --git a/src/matrix_multiply.rs b/src/matrix_multiply.rs index fb705f9..63cb506 100644 --- a/src/matrix_multiply.rs +++ b/src/matrix_multiply.rs @@ -15,39 +15,28 @@ pub trait Dot { fn dot(&self, rhs: &Rhs) -> Self::Output; } -pub trait MatMul { - type Item: Scalar; - fn matmul< - L1: LayoutType, - L2: LayoutType, - L3: LayoutType, - Data1: DataContainer, - Data2: DataContainer, - Data3: DataContainerMut, - >( - alpha: Self::Item, - mat_a: &Matrix< - Self::Item, - BaseMatrix, - L1, - Dynamic, - Dynamic, - >, - mat_b: &Matrix< - Self::Item, - BaseMatrix, - L2, - Dynamic, - Dynamic, - >, - beta: Self::Item, - mat_c: &mut Matrix< - Self::Item, - BaseMatrix, - L3, - Dynamic, - Dynamic, - >, +pub trait MatMul< + Item: Scalar, + L1: LayoutType, + L2: LayoutType, + L3: LayoutType, + Data1: DataContainer, + Data2: DataContainer, + Data3: DataContainerMut, + RS1: SizeIdentifier, + RS2: SizeIdentifier, + RS3: SizeIdentifier, + CS1: SizeIdentifier, + CS2: SizeIdentifier, + CS3: SizeIdentifier, +> +{ + fn matmul( + alpha: Item, + mat_a: &Matrix, L1, RS1, CS1>, + mat_b: &Matrix, L2, RS2, CS2>, + beta: Item, + mat_c: &mut Matrix, L3, RS3, CS3>, ); } @@ -115,18 +104,35 @@ macro_rules! matmul_real { ($Scalar:ty, $Blas:ident) => { - impl MatMul for $Scalar { + impl< + L1: LayoutType, + L2: LayoutType, + L3: LayoutType, + Data1: DataContainer, + Data2: DataContainer, + Data3: DataContainerMut +> - type Item = $Scalar; - fn matmul< - L1: LayoutType, - L2: LayoutType, - L3: LayoutType, - Data1: DataContainer, - Data2: DataContainer, - Data3: DataContainerMut, - >( + MatMul< + $Scalar, + L1, + L2, + L3, + Data1, + Data2, + Data3, + Dynamic, + Dynamic, + Dynamic, + Dynamic, + Dynamic, + Dynamic> + + + for $Scalar { + + fn matmul( alpha: $Scalar, mat_a: &Matrix<$Scalar, BaseMatrix<$Scalar, Data1, L1, Dynamic, Dynamic>, L1, Dynamic, Dynamic>, mat_b: &Matrix<$Scalar, BaseMatrix<$Scalar, Data2, L2, Dynamic, Dynamic>, L2, Dynamic, Dynamic>, @@ -185,18 +191,35 @@ macro_rules! matmul_complex { ($Scalar:ty, $Real:ty, $Blas:ident) => { - impl MatMul for $Scalar { + impl< + L1: LayoutType, + L2: LayoutType, + L3: LayoutType, + Data1: DataContainer, + Data2: DataContainer, + Data3: DataContainerMut +> - type Item = $Scalar; - fn matmul< - L1: LayoutType, - L2: LayoutType, - L3: LayoutType, - Data1: DataContainer, - Data2: DataContainer, - Data3: DataContainerMut, - >( + MatMul< + $Scalar, + L1, + L2, + L3, + Data1, + Data2, + Data3, + Dynamic, + Dynamic, + Dynamic, + Dynamic, + Dynamic, + Dynamic> + + + for $Scalar { + + fn matmul( alpha: $Scalar, mat_a: &Matrix<$Scalar, BaseMatrix<$Scalar, Data1, L1, Dynamic, Dynamic>, L1, Dynamic, Dynamic>, mat_b: &Matrix<$Scalar, BaseMatrix<$Scalar, Data2, L2, Dynamic, Dynamic>, L2, Dynamic, Dynamic>, @@ -270,7 +293,9 @@ dot_impl!(c64); mod test { use super::*; - use approx::{self, assert_relative_eq}; + use crate::tools::RandScalar; + use approx::assert_ulps_eq; + use rand_distr::StandardNormal; use rand::prelude::*; @@ -304,527 +329,43 @@ mod test { } } - #[test] - fn test_matmul_f64() { - let mut mat_a = MatrixD::::zeros_from_dim(4, 6); - let mut mat_b = MatrixD::::zeros_from_dim(6, 5); - let mut mat_c_actual = MatrixD::::zeros_from_dim(4, 5); - let mut mat_c_expect = MatrixD::::zeros_from_dim(4, 5); + macro_rules! matmul_test { + ($Scalar:ty, $fname:ident) => { + #[test] + fn $fname() { + let mut mat_a = MatrixD::<$Scalar, RowMajor>::zeros_from_dim(4, 6); + let mut mat_b = MatrixD::<$Scalar, RowMajor>::zeros_from_dim(6, 5); + let mut mat_c_actual = MatrixD::<$Scalar, RowMajor>::zeros_from_dim(4, 5); + let mut mat_c_expect = MatrixD::<$Scalar, RowMajor>::zeros_from_dim(4, 5); - let mut rng = rand::rngs::StdRng::seed_from_u64(0); + let dist = StandardNormal; - mat_a.fill_from_rand_standard_normal(&mut rng); - mat_b.fill_from_rand_standard_normal(&mut rng); - mat_c_actual.fill_from_rand_standard_normal(&mut rng); + let mut rng = rand::rngs::StdRng::seed_from_u64(0); - for index in 0..mat_c_actual.layout().number_of_elements() { - *mat_c_expect.get1d_mut(index) = mat_c_actual.get1d(index); - } + mat_a.fill_from_rand_standard_normal(&mut rng); + mat_b.fill_from_rand_standard_normal(&mut rng); + mat_c_actual.fill_from_rand_standard_normal(&mut rng); - let alpha = 1.0; - let beta = 0.0; + for index in 0..mat_c_actual.layout().number_of_elements() { + *mat_c_expect.get1d_mut(index) = mat_c_actual.get1d(index); + } - matmul_expect(alpha, &mat_a, &mat_b, beta, &mut mat_c_expect); - f64::matmul(alpha, &mat_a, &mat_b, beta, &mut mat_c_actual); + let alpha = <$Scalar>::random_scalar(&mut rng, &dist); + let beta = <$Scalar>::random_scalar(&mut rng, &dist); - for index in 0..mat_c_expect.layout().number_of_elements() { - assert_relative_eq!(mat_c_actual.get1d(index), mat_c_expect.get1d(index)); - } + matmul_expect(alpha, &mat_a, &mat_b, beta, &mut mat_c_expect); + <$Scalar>::matmul(alpha, &mat_a, &mat_b, beta, &mut mat_c_actual); + + for index in 0..mat_c_expect.layout().number_of_elements() { + let val1 = mat_c_actual.get1d(index); + let val2 = mat_c_expect.get1d(index); + assert_ulps_eq!(&val1, &val2, max_ulps = 100); + } + } + }; } + matmul_test!(f32, test_matmul_f32); + matmul_test!(f64, test_matmul_f64); + matmul_test!(c32, test_matmul_c32); + matmul_test!(c64, test_matmul_c64); } - -// impl<'a, 'b, Item: Scalar> Dot> for CMatrixD<'b, Item> { -// type Output = ColVectorD<'static, Item>; - -// fn dot(&self, rhs: &ColVectorD<'a, Item>) -> Self::Output { - -// let dim = self.dim(); - -// let mut res = ColVectorD<'static, Item>; - -// for row_index in 0..dim.0 { -// for col_index in 0..dim.1 { - -// } -// } - -// } - -// } - -// macro_rules! mat_mat_dot_impl_real { -// ($Scalar:ty, $Blas:ident, $MatType:ident) => { -// impl<'a, 'b> Dot<$MatType<'a, $Scalar>> for $MatType<'b, $Scalar> { -// type Output = $MatType<'static, $Scalar>; -// /// Return the product of this matrix with another matrix. -// fn dot(&self, rhs: &$MatType<'a, $Scalar>) -> Self::Output { -// let dim1 = self.dim(); -// let dim2 = rhs.dim(); - -// assert_eq!( -// dim1.1, dim2.0, -// "Matrix multiply incompatible dimensions: A = {:#?}, B = {:#?}", -// dim1, dim2 -// ); - -// let m = dim1.0; -// let k = dim1.1; -// let n = dim2.1; - -// let mut res = $MatType::<$Scalar>::from_dimension(m, n); - -// let rsa = self.row_stride() as isize; -// let csa = self.column_stride() as isize; -// let rsb = rhs.row_stride() as isize; -// let csb = rhs.column_stride() as isize; -// let rsc = res.row_stride() as isize; -// let csc = res.column_stride() as isize; - -// unsafe { -// $Blas( -// m, -// k, -// n, -// num::cast::(1.0).unwrap(), -// self.as_ptr(), -// rsa, -// csa, -// rhs.as_ptr(), -// rsb, -// csb, -// num::cast::(0.0).unwrap(), -// res.as_mut_ptr(), -// rsc, -// csc, -// ); -// } - -// res -// } -// } -// }; -// } - -// macro_rules! mat_mat_dot_impl_complex { -// ($Scalar:ty, $Real:ty, $Blas:ident, $MatType:ident) => { -// impl<'a, 'b> Dot<$MatType<'a, $Scalar>> for $MatType<'b, $Scalar> { -// type Output = $MatType<'static, $Scalar>; -// /// Return the product of this matrix with another matrix. -// fn dot(&self, rhs: &$MatType<'a, $Scalar>) -> Self::Output { -// let dim1 = self.dim(); -// let dim2 = rhs.dim(); - -// assert_eq!( -// dim1.1, dim2.0, -// "Matrix multiply incompatible dimensions: A = {:#?}, B = {:#?}", -// dim1, dim2 -// ); - -// let m = dim1.0; -// let k = dim1.1; -// let n = dim2.1; - -// let mut res = $MatType::<$Scalar>::from_dimension(m, n); - -// let rsa = self.row_stride() as isize; -// let csa = self.column_stride() as isize; -// let rsb = rhs.row_stride() as isize; -// let csb = rhs.column_stride() as isize; -// let rsc = res.row_stride() as isize; -// let csc = res.column_stride() as isize; - -// let one: [$Real; 2] = [1.0, 0.0]; -// let zero: [$Real; 2] = [0.0, 0.0]; - -// unsafe { -// $Blas( -// CGemmOption::Standard, -// CGemmOption::Standard, -// m, -// k, -// n, -// one, -// self.as_ptr() as *const [$Real; 2], -// rsa, -// csa, -// rhs.as_ptr() as *const [$Real; 2], -// rsb, -// csb, -// zero, -// res.as_mut_ptr() as *mut [$Real; 2], -// rsc, -// csc, -// ); -// } - -// res -// } -// } - -// }; -// } - -// mat_mat_dot_impl_real!(f32, sgemm, FMatrixD); -// mat_mat_dot_impl_real!(f32, sgemm, CMatrixD); -// mat_mat_dot_impl_real!(f64, dgemm, CMatrixD); -// mat_mat_dot_impl_real!(f64, dgemm, FMatrixD); -// mat_mat_dot_impl_complex!(c32, f32, cgemm, CMatrixD); -// mat_mat_dot_impl_complex!(c32, f32, cgemm, FMatrixD); -// mat_mat_dot_impl_complex!(c64, f64, zgemm, CMatrixD); -// mat_mat_dot_impl_complex!(c64, f64, zgemm, FMatrixD); - -// #[cfg(test)] -// mod test { -// use approx::assert_relative_eq; - -// use super::*; -// use crate::mat; - -// #[test] -// fn dot_product_real_double_c() { -// let dim1 = (2, 3); -// let dim2 = (3, 4); - -// let mut mat1 = mat![f64, dim1, CLayout]; -// let mut mat2 = mat![f64, dim2, CLayout]; - -// let mut count = 0; -// for row in 0..dim1.0 { -// for col in 0..dim1.1 { -// *mat1.get_mut(row, col) = count as f64; -// count += 1; -// } -// } - -// for row in 0..dim2.0 { -// for col in 0..dim2.1 { -// *mat2.get_mut(row, col) = count as f64; -// count += 1; -// } -// } - -// let mut expected = mat![f64, (dim1.0, dim2.1), CLayout]; - -// *expected.get_mut(0, 0) = 38.0; -// *expected.get_mut(0, 1) = 41.0; -// *expected.get_mut(0, 2) = 44.0; -// *expected.get_mut(0, 3) = 47.0; -// *expected.get_mut(1, 0) = 128.0; -// *expected.get_mut(1, 1) = 140.0; -// *expected.get_mut(1, 2) = 152.0; -// *expected.get_mut(1, 3) = 164.0; - -// let actual = mat1.dot(&mat2); - -// for row in 0..dim1.0 { -// for col in 0..dim2.1 { -// assert_relative_eq!(actual.get(row, col), expected.get(row, col)); -// } -// } -// } - -// #[test] -// fn dot_product_real_double_f() { -// let dim1 = (2, 3); -// let dim2 = (3, 4); - -// let mut mat1 = mat![f64, dim1, FLayout]; -// let mut mat2 = mat![f64, dim2, FLayout]; - -// let mut count = 0; -// for row in 0..dim1.0 { -// for col in 0..dim1.1 { -// *mat1.get_mut(row, col) = count as f64; -// count += 1; -// } -// } - -// for row in 0..dim2.0 { -// for col in 0..dim2.1 { -// *mat2.get_mut(row, col) = count as f64; -// count += 1; -// } -// } - -// let mut expected = mat![f64, (dim1.0, dim2.1), CLayout]; - -// *expected.get_mut(0, 0) = 38.0; -// *expected.get_mut(0, 1) = 41.0; -// *expected.get_mut(0, 2) = 44.0; -// *expected.get_mut(0, 3) = 47.0; -// *expected.get_mut(1, 0) = 128.0; -// *expected.get_mut(1, 1) = 140.0; -// *expected.get_mut(1, 2) = 152.0; -// *expected.get_mut(1, 3) = 164.0; - -// let actual = mat1.dot(&mat2); - -// for row in 0..dim1.0 { -// for col in 0..dim2.1 { -// assert_relative_eq!(actual.get(row, col), expected.get(row, col)); -// } -// } -// } - -// #[test] -// fn dot_product_real_single_c() { -// let dim1 = (2, 3); -// let dim2 = (3, 4); - -// let mut mat1 = mat![f32, dim1, CLayout]; -// let mut mat2 = mat![f32, dim2, CLayout]; - -// let mut count = 0; -// for row in 0..dim1.0 { -// for col in 0..dim1.1 { -// *mat1.get_mut(row, col) = count as f32; -// count += 1; -// } -// } - -// for row in 0..dim2.0 { -// for col in 0..dim2.1 { -// *mat2.get_mut(row, col) = count as f32; -// count += 1; -// } -// } - -// let mut expected = mat![f32, (dim1.0, dim2.1), CLayout]; - -// *expected.get_mut(0, 0) = 38.0; -// *expected.get_mut(0, 1) = 41.0; -// *expected.get_mut(0, 2) = 44.0; -// *expected.get_mut(0, 3) = 47.0; -// *expected.get_mut(1, 0) = 128.0; -// *expected.get_mut(1, 1) = 140.0; -// *expected.get_mut(1, 2) = 152.0; -// *expected.get_mut(1, 3) = 164.0; - -// let actual = mat1.dot(&mat2); - -// for row in 0..dim1.0 { -// for col in 0..dim2.1 { -// assert_relative_eq!(actual.get(row, col), expected.get(row, col)); -// } -// } -// } - -// #[test] -// fn dot_product_real_single_f() { -// let dim1 = (2, 3); -// let dim2 = (3, 4); - -// let mut mat1 = mat![f32, dim1, FLayout]; -// let mut mat2 = mat![f32, dim2, FLayout]; - -// let mut count = 0; -// for row in 0..dim1.0 { -// for col in 0..dim1.1 { -// *mat1.get_mut(row, col) = count as f32; -// count += 1; -// } -// } - -// for row in 0..dim2.0 { -// for col in 0..dim2.1 { -// *mat2.get_mut(row, col) = count as f32; -// count += 1; -// } -// } - -// let mut expected = mat![f32, (dim1.0, dim2.1), CLayout]; - -// *expected.get_mut(0, 0) = 38.0; -// *expected.get_mut(0, 1) = 41.0; -// *expected.get_mut(0, 2) = 44.0; -// *expected.get_mut(0, 3) = 47.0; -// *expected.get_mut(1, 0) = 128.0; -// *expected.get_mut(1, 1) = 140.0; -// *expected.get_mut(1, 2) = 152.0; -// *expected.get_mut(1, 3) = 164.0; - -// let actual = mat1.dot(&mat2); - -// for row in 0..dim1.0 { -// for col in 0..dim2.1 { -// assert_relative_eq!(actual.get(row, col), expected.get(row, col)); -// } -// } -// } - -// #[test] -// fn dot_product_complex_double_c() { -// let dim1 = (2, 3); -// let dim2 = (3, 4); - -// let mut mat1 = mat![c64, dim1, CLayout]; -// let mut mat2 = mat![c64, dim2, CLayout]; - -// let mut count = 0; -// for row in 0..dim1.0 { -// for col in 0..dim1.1 { -// *mat1.get_mut(row, col) = c64::new(1.0, 1.0) * c64::new(count as f64, 0.0); -// count += 1; -// } -// } - -// for row in 0..dim2.0 { -// for col in 0..dim2.1 { -// *mat2.get_mut(row, col) = c64::new(1.0, 2.0) * c64::new(count as f64, 0.0); -// count += 1; -// } -// } - -// let mut expected = mat![c64, (dim1.0, dim2.1), CLayout]; - -// *expected.get_mut(0, 0) = c64::new(-38.0, 114.0); -// *expected.get_mut(0, 1) = c64::new(-41.0, 123.0); -// *expected.get_mut(0, 2) = c64::new(-44.0, 132.0); -// *expected.get_mut(0, 3) = c64::new(-47.0, 141.0); -// *expected.get_mut(1, 0) = c64::new(-128.0, 384.0); -// *expected.get_mut(1, 1) = c64::new(-140.0, 420.0); -// *expected.get_mut(1, 2) = c64::new(-152.0, 456.0); -// *expected.get_mut(1, 3) = c64::new(-164.0, 492.0); - -// let actual = mat1.dot(&mat2); - -// for row in 0..dim1.0 { -// for col in 0..dim2.1 { -// assert_relative_eq!(actual.get(row, col).re, expected.get(row, col).re); -// assert_relative_eq!(actual.get(row, col).re, expected.get(row, col).re); -// } -// } -// } - -// #[test] -// fn dot_product_complex_double_f() { -// let dim1 = (2, 3); -// let dim2 = (3, 4); - -// let mut mat1 = mat![c64, dim1, FLayout]; -// let mut mat2 = mat![c64, dim2, FLayout]; - -// let mut count = 0; -// for row in 0..dim1.0 { -// for col in 0..dim1.1 { -// *mat1.get_mut(row, col) = c64::new(1.0, 1.0) * c64::new(count as f64, 0.0); -// count += 1; -// } -// } - -// for row in 0..dim2.0 { -// for col in 0..dim2.1 { -// *mat2.get_mut(row, col) = c64::new(1.0, 2.0) * c64::new(count as f64, 0.0); -// count += 1; -// } -// } - -// let mut expected = mat![c64, (dim1.0, dim2.1), FLayout]; - -// *expected.get_mut(0, 0) = c64::new(-38.0, 114.0); -// *expected.get_mut(0, 1) = c64::new(-41.0, 123.0); -// *expected.get_mut(0, 2) = c64::new(-44.0, 132.0); -// *expected.get_mut(0, 3) = c64::new(-47.0, 141.0); -// *expected.get_mut(1, 0) = c64::new(-128.0, 384.0); -// *expected.get_mut(1, 1) = c64::new(-140.0, 420.0); -// *expected.get_mut(1, 2) = c64::new(-152.0, 456.0); -// *expected.get_mut(1, 3) = c64::new(-164.0, 492.0); - -// let actual = mat1.dot(&mat2); - -// for row in 0..dim1.0 { -// for col in 0..dim2.1 { -// assert_relative_eq!(actual.get(row, col).re, expected.get(row, col).re); -// assert_relative_eq!(actual.get(row, col).re, expected.get(row, col).re); -// } -// } -// } - -// #[test] -// fn dot_product_complex_single_c() { -// let dim1 = (2, 3); -// let dim2 = (3, 4); - -// let mut mat1 = mat![c32, dim1, CLayout]; -// let mut mat2 = mat![c32, dim2, CLayout]; - -// let mut count = 0; -// for row in 0..dim1.0 { -// for col in 0..dim1.1 { -// *mat1.get_mut(row, col) = c32::new(1.0, 1.0) * c32::new(count as f32, 0.0); -// count += 1; -// } -// } - -// for row in 0..dim2.0 { -// for col in 0..dim2.1 { -// *mat2.get_mut(row, col) = c32::new(1.0, 2.0) * c32::new(count as f32, 0.0); -// count += 1; -// } -// } - -// let mut expected = mat![c32, (dim1.0, dim2.1), CLayout]; - -// *expected.get_mut(0, 0) = c32::new(-38.0, 114.0); -// *expected.get_mut(0, 1) = c32::new(-41.0, 123.0); -// *expected.get_mut(0, 2) = c32::new(-44.0, 132.0); -// *expected.get_mut(0, 3) = c32::new(-47.0, 141.0); -// *expected.get_mut(1, 0) = c32::new(-128.0, 384.0); -// *expected.get_mut(1, 1) = c32::new(-140.0, 420.0); -// *expected.get_mut(1, 2) = c32::new(-152.0, 456.0); -// *expected.get_mut(1, 3) = c32::new(-164.0, 492.0); - -// let actual = mat1.dot(&mat2); - -// for row in 0..dim1.0 { -// for col in 0..dim2.1 { -// assert_relative_eq!(actual.get(row, col).re, expected.get(row, col).re); -// assert_relative_eq!(actual.get(row, col).re, expected.get(row, col).re); -// } -// } -// } - -// #[test] -// fn dot_product_complex_single_f() { -// let dim1 = (2, 3); -// let dim2 = (3, 4); - -// let mut mat1 = mat![c32, dim1, FLayout]; -// let mut mat2 = mat![c32, dim2, FLayout]; - -// let mut count = 0; -// for row in 0..dim1.0 { -// for col in 0..dim1.1 { -// *mat1.get_mut(row, col) = c32::new(1.0, 1.0) * c32::new(count as f32, 0.0); -// count += 1; -// } -// } - -// for row in 0..dim2.0 { -// for col in 0..dim2.1 { -// *mat2.get_mut(row, col) = c32::new(1.0, 2.0) * c32::new(count as f32, 0.0); -// count += 1; -// } -// } - -// let mut expected = mat![c32, (dim1.0, dim2.1), FLayout]; - -// *expected.get_mut(0, 1) = c32::new(-41.0, 123.0); -// *expected.get_mut(0, 0) = c32::new(-38.0, 114.0); -// *expected.get_mut(0, 2) = c32::new(-44.0, 132.0); -// *expected.get_mut(0, 3) = c32::new(-47.0, 141.0); -// *expected.get_mut(1, 0) = c32::new(-128.0, 384.0); -// *expected.get_mut(1, 1) = c32::new(-140.0, 420.0); -// *expected.get_mut(1, 2) = c32::new(-152.0, 456.0); -// *expected.get_mut(1, 3) = c32::new(-164.0, 492.0); - -// let actual = mat1.dot(&mat2); - -// for row in 0..dim1.0 { -// for col in 0..dim2.1 { -// assert_relative_eq!(actual.get(row, col).re, expected.get(row, col).re); -// assert_relative_eq!(actual.get(row, col).re, expected.get(row, col).re); -// } -// } -// } -// } diff --git a/src/tools.rs b/src/tools.rs new file mode 100644 index 0000000..1a5272f --- /dev/null +++ b/src/tools.rs @@ -0,0 +1,34 @@ +//! Various useful tools + +use crate::types::*; +use rand::prelude::*; +use rand_distr::Distribution; + +// Random number implementations for the scalar types +pub trait RandScalar: Scalar { + fn random_scalar>(rng: &mut R, dist: &D) -> Self; +} + +impl RandScalar for f32 { + fn random_scalar>(rng: &mut R, dist: &D) -> Self { + dist.sample(rng) + } +} + +impl RandScalar for f64 { + fn random_scalar>(rng: &mut R, dist: &D) -> Self { + dist.sample(rng) + } +} + +impl RandScalar for c32 { + fn random_scalar>(rng: &mut R, dist: &D) -> Self { + c32::new(dist.sample(rng), dist.sample(rng)) + } +} + +impl RandScalar for c64 { + fn random_scalar>(rng: &mut R, dist: &D) -> Self { + c64::new(dist.sample(rng), dist.sample(rng)) + } +} diff --git a/src/types.rs b/src/types.rs index 8480ef8..0062d43 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,4 +1,6 @@ //! Scalar types used by the library pub use cauchy::{Scalar, c32, c64}; -pub type IndexType = usize; \ No newline at end of file +pub type IndexType = usize; + +