From 530217f2e73c82a32cce9545a83886ec7b3c78c6 Mon Sep 17 00:00:00 2001 From: Timo Betcke Date: Wed, 27 Jul 2022 19:04:13 +0100 Subject: [PATCH] WIP: Matmul unit tests --- src/matrix/base_methods.rs | 35 ++++++++----------- src/matrix/random.rs | 5 +-- src/matrix_multiply.rs | 71 ++++++++++++++++++++++++++++++++++++-- 3 files changed, 87 insertions(+), 24 deletions(-) diff --git a/src/matrix/base_methods.rs b/src/matrix/base_methods.rs index 601a908..c2a3634 100644 --- a/src/matrix/base_methods.rs +++ b/src/matrix/base_methods.rs @@ -1,27 +1,22 @@ //! Methods on Matrix types that require an underlying base matrix. -use crate::matrix::Matrix; +use crate::data_container::DataContainerMut; use crate::traits::*; use crate::types::*; -use crate::base_matrix::*; -use crate::data_container::{DataContainerMut}; -use rand::prelude::*; -use rand_distr::StandardNormal; use super::GenericBaseMatrixMut; -impl> - GenericBaseMatrixMut{ - - pub fn for_each(&mut self, mut f: F) { - for index in 0..self.layout().number_of_elements() { - unsafe {f(self.get1d_unchecked_mut(index))} - } - } - - } - +impl< + Item: Scalar, + L: LayoutType, + RS: SizeIdentifier, + CS: SizeIdentifier, + Data: DataContainerMut, + > GenericBaseMatrixMut +{ + pub fn for_each(&mut self, mut f: F) { + for index in 0..self.layout().number_of_elements() { + unsafe { f(self.get1d_unchecked_mut(index)) } + } + } +} diff --git a/src/matrix/random.rs b/src/matrix/random.rs index 6133d1e..a978b7e 100644 --- a/src/matrix/random.rs +++ b/src/matrix/random.rs @@ -1,8 +1,6 @@ //! Methods for the creation of random matrices. -use crate::base_matrix::*; use crate::data_container::DataContainerMut; -use crate::matrix::Matrix; use crate::traits::*; use crate::types::*; use rand::prelude::*; @@ -27,7 +25,10 @@ macro_rules! rand_impl { }; } +rand_impl!(f32); rand_impl!(f64); +rand_impl!(c32); +rand_impl!(c64); // Random number implementations for the scalar types trait RandScalar: Scalar { diff --git a/src/matrix_multiply.rs b/src/matrix_multiply.rs index 117ed06..fb705f9 100644 --- a/src/matrix_multiply.rs +++ b/src/matrix_multiply.rs @@ -153,7 +153,7 @@ macro_rules! matmul_real { let rsb = mat_b.layout().stride().0 as isize; let csb = mat_b.layout().stride().1 as isize; let rsc = mat_c.layout().stride().0 as isize; - let csc = mat_c.layout().stride().0 as isize; + let csc = mat_c.layout().stride().1 as isize; unsafe { $Blas( @@ -223,7 +223,7 @@ macro_rules! matmul_complex { let rsb = mat_b.layout().stride().0 as isize; let csb = mat_b.layout().stride().1 as isize; let rsc = mat_c.layout().stride().0 as isize; - let csc = mat_c.layout().stride().0 as isize; + let csc = mat_c.layout().stride().1 as isize; let alpha = [alpha.re(), alpha.im()]; let beta = [beta.re(), beta.im()]; @@ -266,6 +266,73 @@ dot_impl!(f32); dot_impl!(c32); dot_impl!(c64); +#[cfg(test)] +mod test { + + use super::*; + use approx::{self, assert_relative_eq}; + + use rand::prelude::*; + + fn matmul_expect< + Item: Scalar, + L1: LayoutType, + L2: LayoutType, + L3: LayoutType, + Data1: DataContainer, + Data2: DataContainer, + Data3: DataContainerMut, + >( + alpha: Item, + mat_a: &GenericBaseMatrix, + mat_b: &GenericBaseMatrix, + beta: Item, + mat_c: &mut GenericBaseMatrix, + ) { + let m = mat_a.layout().dim().0; + let k = mat_a.layout().dim().1; + let n = mat_b.layout().dim().1; + + for m_index in 0..m { + for n_index in 0..n { + *mat_c.get_mut(m_index, n_index) *= beta; + for k_index in 0..k { + *mat_c.get_mut(m_index, n_index) += + alpha * mat_a.get(m_index, k_index) * mat_b.get(k_index, n_index); + } + } + } + } + + #[test] + fn test_matmul_f64() { + let mut mat_a = MatrixD::::zeros_from_dim(4, 6); + let mut mat_b = MatrixD::::zeros_from_dim(6, 5); + let mut mat_c_actual = MatrixD::::zeros_from_dim(4, 5); + let mut mat_c_expect = MatrixD::::zeros_from_dim(4, 5); + + let mut rng = rand::rngs::StdRng::seed_from_u64(0); + + mat_a.fill_from_rand_standard_normal(&mut rng); + mat_b.fill_from_rand_standard_normal(&mut rng); + mat_c_actual.fill_from_rand_standard_normal(&mut rng); + + for index in 0..mat_c_actual.layout().number_of_elements() { + *mat_c_expect.get1d_mut(index) = mat_c_actual.get1d(index); + } + + let alpha = 1.0; + let beta = 0.0; + + matmul_expect(alpha, &mat_a, &mat_b, beta, &mut mat_c_expect); + f64::matmul(alpha, &mat_a, &mat_b, beta, &mut mat_c_actual); + + for index in 0..mat_c_expect.layout().number_of_elements() { + assert_relative_eq!(mat_c_actual.get1d(index), mat_c_expect.get1d(index)); + } + } +} + // impl<'a, 'b, Item: Scalar> Dot> for CMatrixD<'b, Item> { // type Output = ColVectorD<'static, Item>;