Skip to content

Commit

Permalink
WIP: Matmul unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tbetcke committed Jul 27, 2022
1 parent 3ea2dd0 commit 530217f
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 24 deletions.
35 changes: 15 additions & 20 deletions src/matrix/base_methods.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,22 @@
//! Methods on Matrix types that require an underlying base matrix.
use crate::matrix::Matrix;
use crate::data_container::DataContainerMut;
use crate::traits::*;
use crate::types::*;
use crate::base_matrix::*;
use crate::data_container::{DataContainerMut};
use rand::prelude::*;
use rand_distr::StandardNormal;

use super::GenericBaseMatrixMut;

impl<Item: Scalar,
L: LayoutType,
RS: SizeIdentifier,
CS: SizeIdentifier,
Data: DataContainerMut<Item=Item>>
GenericBaseMatrixMut<Item, L, Data, RS, CS>{

pub fn for_each<F: FnMut(&mut Item)>(&mut self, mut f: F) {
for index in 0..self.layout().number_of_elements() {
unsafe {f(self.get1d_unchecked_mut(index))}
}
}

}

impl<
Item: Scalar,
L: LayoutType,
RS: SizeIdentifier,
CS: SizeIdentifier,
Data: DataContainerMut<Item = Item>,
> GenericBaseMatrixMut<Item, L, Data, RS, CS>
{
pub fn for_each<F: FnMut(&mut Item)>(&mut self, mut f: F) {
for index in 0..self.layout().number_of_elements() {
unsafe { f(self.get1d_unchecked_mut(index)) }
}
}
}
5 changes: 3 additions & 2 deletions src/matrix/random.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
//! Methods for the creation of random matrices.
use crate::base_matrix::*;
use crate::data_container::DataContainerMut;
use crate::matrix::Matrix;
use crate::traits::*;
use crate::types::*;
use rand::prelude::*;
Expand All @@ -27,7 +25,10 @@ macro_rules! rand_impl {
};
}

rand_impl!(f32);
rand_impl!(f64);
rand_impl!(c32);
rand_impl!(c64);

// Random number implementations for the scalar types
trait RandScalar: Scalar {
Expand Down
71 changes: 69 additions & 2 deletions src/matrix_multiply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ macro_rules! matmul_real {
let rsb = mat_b.layout().stride().0 as isize;
let csb = mat_b.layout().stride().1 as isize;
let rsc = mat_c.layout().stride().0 as isize;
let csc = mat_c.layout().stride().0 as isize;
let csc = mat_c.layout().stride().1 as isize;

unsafe {
$Blas(
Expand Down Expand Up @@ -223,7 +223,7 @@ macro_rules! matmul_complex {
let rsb = mat_b.layout().stride().0 as isize;
let csb = mat_b.layout().stride().1 as isize;
let rsc = mat_c.layout().stride().0 as isize;
let csc = mat_c.layout().stride().0 as isize;
let csc = mat_c.layout().stride().1 as isize;

let alpha = [alpha.re(), alpha.im()];
let beta = [beta.re(), beta.im()];
Expand Down Expand Up @@ -266,6 +266,73 @@ dot_impl!(f32);
dot_impl!(c32);
dot_impl!(c64);

#[cfg(test)]
mod test {

use super::*;
use approx::{self, assert_relative_eq};

use rand::prelude::*;

fn matmul_expect<
Item: Scalar,
L1: LayoutType,
L2: LayoutType,
L3: LayoutType,
Data1: DataContainer<Item = Item>,
Data2: DataContainer<Item = Item>,
Data3: DataContainerMut<Item = Item>,
>(
alpha: Item,
mat_a: &GenericBaseMatrix<Item, L1, Data1, Dynamic, Dynamic>,
mat_b: &GenericBaseMatrix<Item, L2, Data2, Dynamic, Dynamic>,
beta: Item,
mat_c: &mut GenericBaseMatrix<Item, L3, Data3, Dynamic, Dynamic>,
) {
let m = mat_a.layout().dim().0;
let k = mat_a.layout().dim().1;
let n = mat_b.layout().dim().1;

for m_index in 0..m {
for n_index in 0..n {
*mat_c.get_mut(m_index, n_index) *= beta;
for k_index in 0..k {
*mat_c.get_mut(m_index, n_index) +=
alpha * mat_a.get(m_index, k_index) * mat_b.get(k_index, n_index);
}
}
}
}

#[test]
fn test_matmul_f64() {
let mut mat_a = MatrixD::<f64, RowMajor>::zeros_from_dim(4, 6);
let mut mat_b = MatrixD::<f64, RowMajor>::zeros_from_dim(6, 5);
let mut mat_c_actual = MatrixD::<f64, RowMajor>::zeros_from_dim(4, 5);
let mut mat_c_expect = MatrixD::<f64, RowMajor>::zeros_from_dim(4, 5);

let mut rng = rand::rngs::StdRng::seed_from_u64(0);

mat_a.fill_from_rand_standard_normal(&mut rng);
mat_b.fill_from_rand_standard_normal(&mut rng);
mat_c_actual.fill_from_rand_standard_normal(&mut rng);

for index in 0..mat_c_actual.layout().number_of_elements() {
*mat_c_expect.get1d_mut(index) = mat_c_actual.get1d(index);
}

let alpha = 1.0;
let beta = 0.0;

matmul_expect(alpha, &mat_a, &mat_b, beta, &mut mat_c_expect);
f64::matmul(alpha, &mat_a, &mat_b, beta, &mut mat_c_actual);

for index in 0..mat_c_expect.layout().number_of_elements() {
assert_relative_eq!(mat_c_actual.get1d(index), mat_c_expect.get1d(index));
}
}
}

// impl<'a, 'b, Item: Scalar> Dot<ColVectorD<'a, Item>> for CMatrixD<'b, Item> {
// type Output = ColVectorD<'static, Item>;

Expand Down

0 comments on commit 530217f

Please sign in to comment.