Skip to content

Commit 1bd17e4

Browse files
committed
Start blas source to target precomputaion, distributed
1 parent 53b163c commit 1bd17e4

File tree

2 files changed

+157
-5
lines changed

2 files changed

+157
-5
lines changed

kifmm/src/fmm/field_translation/metadata/multi_node.rs

+157-3
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,13 @@ where
852852
..((local_load_displacement + local_load_count) as usize)];
853853
let check_surface_order = &self.check_surface_order[(local_load_displacement as usize)
854854
..((local_load_displacement + local_load_count) as usize)];
855-
// let mut local_shared_dim = Vec::new();
855+
856+
let mut u_r = Vec::new();
857+
let mut st_r = Vec::new();
858+
let mut c_u_r = Vec::new();
859+
let mut c_vt_r = Vec::new();
860+
let mut cutoff_rank_r = Vec::new();
861+
let mut directional_cutoff_ranks_r = Vec::new();
856862

857863
for (&equivalent_surface_order, &check_surface_order, level) in izip!(equivalent_surface_order, check_surface_order, 0..=total_depth)
858864
{
@@ -970,10 +976,158 @@ where
970976
let mut sigma_mat = rlst_dynamic_array2!(Scalar, [cutoff_rank, cutoff_rank]);
971977
let mut vt = rlst_dynamic_array2!(Scalar, [cutoff_rank, nvt]);
972978

973-
979+
// Store compressed M2L operators
980+
let thin_nrows = se2tc_thin.shape()[0];
981+
let nst = se2tc_thin.shape()[1];
982+
let k = std::cmp::min(thin_nrows, nst);
983+
let mut st;
984+
let mut _gamma;
985+
let mut _r;
986+
987+
if self.source_to_target.surface_diff() == 0 {
988+
st = rlst_dynamic_array2!(Scalar, u_big.r().transpose().shape());
989+
st.fill_from(u_big.r().transpose())
990+
} else {
991+
match &self.source_to_target.svd_mode {
992+
&FmmSvdMode::Random {
993+
n_components,
994+
normaliser,
995+
n_oversamples,
996+
random_state,
997+
} => {
998+
let target_rank;
999+
if let Some(n_components) = n_components {
1000+
target_rank = n_components
1001+
} else {
1002+
// Estimate target rank
1003+
let max_equivalent_surface_ncoeffs =
1004+
self.n_coeffs_equivalent_surface.iter().max().unwrap();
1005+
let max_check_surface_ncoeffs =
1006+
self.n_coeffs_check_surface.iter().max().unwrap();
1007+
target_rank =
1008+
max_equivalent_surface_ncoeffs.max(max_check_surface_ncoeffs) / 2;
1009+
}
1010+
1011+
(_gamma, _r, st) = Scalar::rsvd_fixed_rank(
1012+
&se2tc_thin,
1013+
target_rank,
1014+
n_oversamples,
1015+
normaliser,
1016+
random_state,
1017+
)
1018+
.unwrap();
1019+
}
1020+
FmmSvdMode::Deterministic => {
1021+
_r = rlst_dynamic_array2!(Scalar, [thin_nrows, k]);
1022+
_gamma = vec![Scalar::zero().re(); k];
1023+
st = rlst_dynamic_array2!(Scalar, [k, nst]);
1024+
se2tc_thin
1025+
.into_svd_alloc(
1026+
_r.r_mut(),
1027+
st.r_mut(),
1028+
&mut _gamma[..],
1029+
SvdMode::Reduced,
1030+
)
1031+
.unwrap();
1032+
}
1033+
}
1034+
}
1035+
1036+
u.fill_from(u_big.into_subview([0, 0], [mu, cutoff_rank]));
1037+
vt.fill_from(vt_big.into_subview([0, 0], [cutoff_rank, nvt]));
1038+
for (j, s) in sigma.iter().enumerate().take(cutoff_rank) {
1039+
unsafe {
1040+
*sigma_mat.get_unchecked_mut([j, j]) = Scalar::from(*s).unwrap();
1041+
}
1042+
}
1043+
1044+
let mut s_trunc = rlst_dynamic_array2!(Scalar, [nst, cutoff_rank]);
1045+
for j in 0..cutoff_rank {
1046+
for i in 0..nst {
1047+
unsafe { *s_trunc.get_unchecked_mut([i, j]) = *st.get_unchecked([j, i]) }
1048+
}
1049+
}
1050+
1051+
let c_u = Mutex::new(Vec::new());
1052+
let c_vt = Mutex::new(Vec::new());
1053+
let directional_cutoff_ranks =
1054+
Mutex::new(vec![0usize; self.source_to_target.transfer_vectors.len()]);
1055+
1056+
for _ in 0..NTRANSFER_VECTORS_KIFMM {
1057+
c_u.lock()
1058+
.unwrap()
1059+
.push(rlst_dynamic_array2!(Scalar, [1, 1]));
1060+
c_vt.lock()
1061+
.unwrap()
1062+
.push(rlst_dynamic_array2!(Scalar, [1, 1]));
1063+
}
1064+
1065+
(0..NTRANSFER_VECTORS_KIFMM).into_par_iter().for_each(|i| {
1066+
let vt_block = vt.r().into_subview([0, i * n_cols], [cutoff_rank, n_cols]);
1067+
1068+
let tmp = empty_array::<Scalar, 2>().simple_mult_into_resize(
1069+
sigma_mat.r(),
1070+
empty_array::<Scalar, 2>().simple_mult_into_resize(vt_block.r(), s_trunc.r()),
1071+
);
1072+
1073+
let mut u_i = rlst_dynamic_array2!(Scalar, [cutoff_rank, cutoff_rank]);
1074+
let mut sigma_i = vec![Scalar::zero().re(); cutoff_rank];
1075+
let mut vt_i = rlst_dynamic_array2!(Scalar, [cutoff_rank, cutoff_rank]);
1076+
1077+
tmp.into_svd_alloc(u_i.r_mut(), vt_i.r_mut(), &mut sigma_i, SvdMode::Full)
1078+
.unwrap();
1079+
1080+
let directional_cutoff_rank =
1081+
find_cutoff_rank(&sigma_i, self.source_to_target.threshold, cutoff_rank);
1082+
1083+
let mut u_i_compressed =
1084+
rlst_dynamic_array2!(Scalar, [cutoff_rank, directional_cutoff_rank]);
1085+
let mut vt_i_compressed_ =
1086+
rlst_dynamic_array2!(Scalar, [directional_cutoff_rank, cutoff_rank]);
1087+
1088+
let mut sigma_mat_i_compressed = rlst_dynamic_array2!(
1089+
Scalar,
1090+
[directional_cutoff_rank, directional_cutoff_rank]
1091+
);
1092+
1093+
u_i_compressed
1094+
.fill_from(u_i.into_subview([0, 0], [cutoff_rank, directional_cutoff_rank]));
1095+
vt_i_compressed_
1096+
.fill_from(vt_i.into_subview([0, 0], [directional_cutoff_rank, cutoff_rank]));
1097+
1098+
for (j, s) in sigma_i.iter().enumerate().take(directional_cutoff_rank) {
1099+
unsafe {
1100+
*sigma_mat_i_compressed.get_unchecked_mut([j, j]) =
1101+
Scalar::from(*s).unwrap();
1102+
}
1103+
}
1104+
1105+
let vt_i_compressed = empty_array::<Scalar, 2>()
1106+
.simple_mult_into_resize(sigma_mat_i_compressed.r(), vt_i_compressed_.r());
1107+
1108+
directional_cutoff_ranks.lock().unwrap()[i] = directional_cutoff_rank;
1109+
c_u.lock().unwrap()[i] = u_i_compressed;
1110+
c_vt.lock().unwrap()[i] = vt_i_compressed;
1111+
});
1112+
1113+
let mut st_trunc = rlst_dynamic_array2!(Scalar, [cutoff_rank, nst]);
1114+
st_trunc.fill_from(s_trunc.transpose());
1115+
1116+
let c_vt = std::mem::take(&mut *c_vt.lock().unwrap());
1117+
let c_u = std::mem::take(&mut *c_u.lock().unwrap());
1118+
let directional_cutoff_ranks =
1119+
std::mem::take(&mut *directional_cutoff_ranks.lock().unwrap());
1120+
1121+
u_r.push(u);
1122+
st_r.push(st_trunc);
1123+
c_u_r.push(c_u);
1124+
c_vt_r.push(c_vt);
1125+
cutoff_rank_r.push(cutoff_rank);
1126+
directional_cutoff_ranks_r.push(directional_cutoff_ranks);
9741127
}
9751128

976-
// TODO
1129+
// TODO Communicate local results back
1130+
9771131

9781132
// // Compute unique M2L interactions at level 3, shallowest level which contains them all
9791133
// // Compute interaction matrices between source and unique targets, defined by unique transfer vectors

kifmm/src/fmm/field_translation/metadata/single_node.rs

-2
Original file line numberDiff line numberDiff line change
@@ -1228,8 +1228,6 @@ where
12281228
.directional_cutoff_ranks
12291229
.push(directional_cutoff_ranks);
12301230
}
1231-
1232-
// self.source_to_target = result;
12331231
}
12341232
}
12351233

0 commit comments

Comments
 (0)