diff --git a/src/discretise/discrete_model.rs b/src/discretise/discrete_model.rs index 5fd1395..0b1985b 100644 --- a/src/discretise/discrete_model.rs +++ b/src/discretise/discrete_model.rs @@ -101,6 +101,27 @@ impl<'s> DiscreteModel<'s> { fn build_array(array: &ast::Tensor<'s>, env: &mut Env) -> Option> { let rank = array.indices().len(); + let reserved_names = [ + "u0", + "t", + "data", + "root", + "thread_id", + "thread_dim", + "rr", + "states", + "inputs", + "outputs", + "hass_mass", + ]; + if reserved_names.contains(&array.name()) { + let span = env.current_span().to_owned(); + env.errs_mut().push(ValidationError::new( + format!("{} is a reserved name", array.name()), + span, + )); + return None; + } let mut elmts = Vec::new(); let mut start = Index::zeros(rank); let nerrs = env.errs().len(); @@ -148,6 +169,16 @@ impl<'s> DiscreteModel<'s> { i64::try_from(elmt_layout.shape()[0]).unwrap() }; + if reserved_names + .contains(&name.as_ref().unwrap_or(&"".to_string()).as_str()) + { + let span = env.current_span().to_owned(); + env.errs_mut().push(ValidationError::new( + format!("{} is a reserved name", name.as_ref().unwrap()), + span, + )); + } + elmts.push(TensorBlock::new( name, start.clone(), @@ -446,7 +477,7 @@ impl<'s> DiscreteModel<'s> { None, )], ); - ret.out = Self::build_array(&out_tensor, &mut env).unwrap(); + ret.out = Self::build_array(&out_tensor, &mut env).unwrap_or(Tensor::new_empty("out")); } if let Some(span) = span_f { Self::check_match(&ret.rhs, &ret.state, span, &mut env); diff --git a/src/execution/compiler.rs b/src/execution/compiler.rs index 2f792a2..5bd953a 100644 --- a/src/execution/compiler.rs +++ b/src/execution/compiler.rs @@ -9,8 +9,10 @@ use crate::{ use super::{ interface::{ - BarrierInitFunc, CalcOutGradientFunc, CalcOutReverseGradientFunc, RhsGradientFunc, - SetInputsGradientFunc, U0GradientFunc, + BarrierInitFunc, CalcOutGradFunc, CalcOutRevGradFunc, CalcOutSensGradFunc, + CalcOutSensRevGradFunc, GetInputsFunc, MassRevGradFunc, RhsGradFunc, RhsRevGradFunc, + RhsSensGradFunc, RhsSensRevGradFunc, SetInputsGradFunc, SetInputsRevGradFunc, U0GradFunc, + U0RevGradFunc, }, module::CodegenModule, }; @@ -45,22 +47,34 @@ struct JitFunctions { set_id: SetIdFunc, get_dims: GetDimsFunc, set_inputs: SetInputsFunc, + get_inputs: GetInputsFunc, get_out: GetOutFunc, barrier_init: Option, } struct JitGradFunctions { - set_u0_grad: U0GradientFunc, - rhs_grad: RhsGradientFunc, - calc_out_grad: CalcOutGradientFunc, - set_inputs_grad: SetInputsGradientFunc, + set_u0_grad: U0GradFunc, + rhs_grad: RhsGradFunc, + calc_out_grad: CalcOutGradFunc, + set_inputs_grad: SetInputsGradFunc, } struct JitGradRFunctions { - set_u0_rgrad: U0GradientFunc, - rhs_rgrad: RhsGradientFunc, - calc_out_rgrad: CalcOutReverseGradientFunc, - set_inputs_rgrad: SetInputsGradientFunc, + set_u0_rgrad: U0RevGradFunc, + rhs_rgrad: RhsRevGradFunc, + mass_rgrad: MassRevGradFunc, + calc_out_rgrad: CalcOutRevGradFunc, + set_inputs_rgrad: SetInputsRevGradFunc, +} + +struct JitSensGradFunctions { + rhs_sgrad: RhsSensGradFunc, + calc_out_sgrad: CalcOutSensGradFunc, +} + +struct JitSensRevGradFunctions { + rhs_rgrad: RhsSensRevGradFunc, + calc_out_rgrad: CalcOutSensRevGradFunc, } pub struct Compiler { @@ -68,6 +82,8 @@ pub struct Compiler { jit_functions: JitFunctions, jit_grad_functions: JitGradFunctions, jit_grad_r_functions: Option, + jit_sens_grad_functions: Option, + jit_sens_rev_grad_functions: Option, number_of_states: usize, number_of_parameters: usize, @@ -158,6 +174,7 @@ impl Compiler { let set_id = module.compile_set_id(model)?; let get_dims = module.compile_get_dims(model)?; let set_inputs = module.compile_set_inputs(model)?; + let get_inputs = module.compile_get_inputs(model)?; let get_output = module.compile_get_tensor(model, "out")?; module.pre_autodiff_optimisation()?; @@ -171,11 +188,24 @@ impl Compiler { let mut rhs_rgrad = None; let mut calc_out_rgrad = None; let mut set_inputs_rgrad = None; + let mut rhs_sgrad = None; + let mut rhs_srgrad = None; + let mut calc_out_sgrad = None; + let mut calc_out_srgrad = None; + let mut mass_rgrad = None; if module.supports_reverse_autodiff() { set_u0_rgrad = Some(module.compile_set_u0_rgrad(&set_u0, model)?); rhs_rgrad = Some(module.compile_rhs_rgrad(&rhs, model)?); calc_out_rgrad = Some(module.compile_calc_out_rgrad(&calc_out, model)?); set_inputs_rgrad = Some(module.compile_set_inputs_rgrad(&set_inputs, model)?); + mass_rgrad = Some(module.compile_mass_rgrad(&mass, model)?); + + let rhs_full = module.compile_rhs_full(model)?; + rhs_sgrad = Some(module.compile_rhs_sgrad(&rhs_full, model)?); + rhs_srgrad = Some(module.compile_rhs_srgrad(&rhs_full, model)?); + let calc_out_full = module.compile_calc_out_full(model)?; + calc_out_sgrad = Some(module.compile_calc_out_sgrad(&calc_out_full, model)?); + calc_out_srgrad = Some(module.compile_calc_out_srgrad(&calc_out_full, model)?); } module.post_autodiff_optimisation()?; @@ -199,42 +229,83 @@ impl Compiler { unsafe { std::mem::transmute::<*const u8, GetDimsFunc>(module.jit(get_dims)?) }; let set_inputs = unsafe { std::mem::transmute::<*const u8, SetInputsFunc>(module.jit(set_inputs)?) }; + let get_inputs = + unsafe { std::mem::transmute::<*const u8, GetInputsFunc>(module.jit(get_inputs)?) }; let get_out = unsafe { std::mem::transmute::<*const u8, GetOutFunc>(module.jit(get_output)?) }; let set_u0_grad = - unsafe { std::mem::transmute::<*const u8, U0GradientFunc>(module.jit(set_u0_grad)?) }; + unsafe { std::mem::transmute::<*const u8, U0GradFunc>(module.jit(set_u0_grad)?) }; let rhs_grad = - unsafe { std::mem::transmute::<*const u8, RhsGradientFunc>(module.jit(rhs_grad)?) }; + unsafe { std::mem::transmute::<*const u8, RhsGradFunc>(module.jit(rhs_grad)?) }; let calc_out_grad = unsafe { - std::mem::transmute::<*const u8, CalcOutGradientFunc>(module.jit(calc_out_grad)?) + std::mem::transmute::<*const u8, CalcOutGradFunc>(module.jit(calc_out_grad)?) }; let set_inputs_grad = unsafe { - std::mem::transmute::<*const u8, SetInputsGradientFunc>(module.jit(set_inputs_grad)?) + std::mem::transmute::<*const u8, SetInputsGradFunc>(module.jit(set_inputs_grad)?) }; let jit_grad_r_functions = if module.supports_reverse_autodiff() { Some(JitGradRFunctions { set_u0_rgrad: unsafe { - std::mem::transmute::<*const u8, U0GradientFunc>( + std::mem::transmute::<*const u8, U0RevGradFunc>( module.jit(set_u0_rgrad.unwrap())?, ) }, rhs_rgrad: unsafe { - std::mem::transmute::<*const u8, RhsGradientFunc>( + std::mem::transmute::<*const u8, RhsRevGradFunc>( module.jit(rhs_rgrad.unwrap())?, ) }, calc_out_rgrad: unsafe { - std::mem::transmute::<*const u8, CalcOutReverseGradientFunc>( + std::mem::transmute::<*const u8, CalcOutRevGradFunc>( module.jit(calc_out_rgrad.unwrap())?, ) }, set_inputs_rgrad: unsafe { - std::mem::transmute::<*const u8, SetInputsGradientFunc>( + std::mem::transmute::<*const u8, SetInputsRevGradFunc>( module.jit(set_inputs_rgrad.unwrap())?, ) }, + mass_rgrad: unsafe { + std::mem::transmute::<*const u8, MassRevGradFunc>( + module.jit(mass_rgrad.unwrap())?, + ) + }, + }) + } else { + None + }; + + let jit_sens_grad_functions = if module.supports_reverse_autodiff() { + Some(JitSensGradFunctions { + rhs_sgrad: unsafe { + std::mem::transmute::<*const u8, RhsSensGradFunc>( + module.jit(rhs_sgrad.unwrap())?, + ) + }, + calc_out_sgrad: unsafe { + std::mem::transmute::<*const u8, CalcOutSensGradFunc>( + module.jit(calc_out_sgrad.unwrap())?, + ) + }, + }) + } else { + None + }; + + let jit_sens_rev_grad_functions = if module.supports_reverse_autodiff() { + Some(JitSensRevGradFunctions { + rhs_rgrad: unsafe { + std::mem::transmute::<*const u8, RhsSensRevGradFunc>( + module.jit(rhs_srgrad.unwrap())?, + ) + }, + calc_out_rgrad: unsafe { + std::mem::transmute::<*const u8, CalcOutSensRevGradFunc>( + module.jit(calc_out_srgrad.unwrap())?, + ) + }, }) } else { None @@ -247,6 +318,7 @@ impl Compiler { rhs, mass, calc_out, + get_inputs, calc_stop, set_id, get_dims, @@ -261,6 +333,8 @@ impl Compiler { set_inputs_grad, }, jit_grad_r_functions, + jit_sens_grad_functions, + jit_sens_rev_grad_functions, number_of_states, number_of_parameters, number_of_outputs, @@ -350,7 +424,7 @@ impl Compiler { }); } - pub fn set_u0_rgrad(&self, yy: &mut [f64], dyy: &mut [f64], data: &[f64], ddata: &mut [f64]) { + pub fn set_u0_rgrad(&self, yy: &[f64], dyy: &mut [f64], data: &[f64], ddata: &mut [f64]) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); self.check_data_len(data, "data"); @@ -361,9 +435,9 @@ impl Compiler { .as_ref() .expect("module does not support reverse autograd") .set_u0_rgrad)( - yy.as_ptr() as *mut f64, + yy.as_ptr(), dyy.as_ptr() as *mut f64, - data.as_ptr() as *mut f64, + data.as_ptr(), ddata.as_ptr() as *mut f64, i, dim, @@ -371,13 +445,7 @@ impl Compiler { }); } - pub fn set_u0_grad( - &self, - yy: &mut [f64], - dyy: &mut [f64], - data: &mut [f64], - ddata: &mut [f64], - ) { + pub fn set_u0_grad(&self, yy: &[f64], dyy: &mut [f64], data: &[f64], ddata: &mut [f64]) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); self.check_data_len(data, "data"); @@ -385,9 +453,9 @@ impl Compiler { self.with_threading(|i, dim| { unsafe { (self.jit_grad_functions.set_u0_grad)( - yy.as_ptr() as *mut f64, + yy.as_ptr(), dyy.as_ptr() as *mut f64, - data.as_ptr() as *mut f64, + data.as_ptr(), ddata.as_ptr() as *mut f64, i, dim, @@ -435,19 +503,19 @@ impl Compiler { self.has_mass } - pub fn mass(&self, t: f64, yp: &[f64], data: &mut [f64], rr: &mut [f64]) { + pub fn mass(&self, t: f64, v: &[f64], data: &mut [f64], mv: &mut [f64]) { if !self.has_mass { panic!("Model does not have a mass function"); } - self.check_state_len(yp, "yp"); - self.check_state_len(rr, "rr"); + self.check_state_len(v, "v"); + self.check_state_len(mv, "mv"); self.check_data_len(data, "data"); self.with_threading(|i, dim| unsafe { (self.jit_functions.mass)( t, - yp.as_ptr(), + v.as_ptr(), data.as_ptr() as *mut f64, - rr.as_ptr() as *mut f64, + mv.as_ptr() as *mut f64, i, dim, ) @@ -468,9 +536,9 @@ impl Compiler { t: f64, yy: &[f64], dyy: &[f64], - data: &mut [f64], + data: &[f64], ddata: &mut [f64], - rr: &mut [f64], + rr: &[f64], drr: &mut [f64], ) { self.check_state_len(yy, "yy"); @@ -484,9 +552,9 @@ impl Compiler { t, yy.as_ptr(), dyy.as_ptr(), - data.as_ptr() as *mut f64, + data.as_ptr(), ddata.as_ptr() as *mut f64, - rr.as_ptr() as *mut f64, + rr.as_ptr(), drr.as_ptr() as *mut f64, i, dim, @@ -519,10 +587,10 @@ impl Compiler { .rhs_rgrad)( t, yy.as_ptr(), - dyy.as_ptr(), - data.as_ptr() as *mut f64, + dyy.as_ptr() as *mut f64, + data.as_ptr(), ddata.as_ptr() as *mut f64, - rr.as_ptr() as *mut f64, + rr.as_ptr(), drr.as_ptr() as *mut f64, i, dim, @@ -530,22 +598,110 @@ impl Compiler { }); } - pub fn calc_out(&self, t: f64, yy: &[f64], data: &mut [f64]) { + pub fn mass_rgrad( + &self, + t: f64, + dv: &mut [f64], + data: &[f64], + ddata: &mut [f64], + dmv: &mut [f64], + ) { + self.check_state_len(dv, "dv"); + self.check_state_len(dmv, "dmv"); + self.check_data_len(data, "data"); + self.check_data_len(ddata, "ddata"); + self.with_threading(|i, dim| unsafe { + (self + .jit_grad_r_functions + .as_ref() + .expect("module does not support reverse autograd") + .mass_rgrad)( + t, + std::ptr::null(), + dv.as_ptr() as *mut f64, + data.as_ptr(), + ddata.as_ptr() as *mut f64, + std::ptr::null(), + dmv.as_ptr() as *mut f64, + i, + dim, + ) + }); + } + + pub fn rhs_sgrad( + &self, + t: f64, + yy: &[f64], + data: &[f64], + ddata: &mut [f64], + rr: &[f64], + drr: &mut [f64], + ) { self.check_state_len(yy, "yy"); + self.check_state_len(rr, "rr"); + self.check_state_len(drr, "drr"); self.check_data_len(data, "data"); + self.check_data_len(ddata, "ddata"); self.with_threading(|i, dim| unsafe { - (self.jit_functions.calc_out)(t, yy.as_ptr(), data.as_ptr() as *mut f64, i, dim) + (self + .jit_sens_grad_functions + .as_ref() + .expect("module does not support sens autograd") + .rhs_sgrad)( + t, + yy.as_ptr(), + data.as_ptr(), + ddata.as_ptr() as *mut f64, + rr.as_ptr(), + drr.as_ptr() as *mut f64, + i, + dim, + ) }); } - pub fn calc_out_grad( + pub fn rhs_srgrad( &self, t: f64, yy: &[f64], - dyy: &[f64], - data: &mut [f64], + data: &[f64], ddata: &mut [f64], + rr: &[f64], + drr: &mut [f64], ) { + self.check_state_len(yy, "yy"); + self.check_state_len(rr, "rr"); + self.check_state_len(drr, "drr"); + self.check_data_len(data, "data"); + self.check_data_len(ddata, "ddata"); + self.with_threading(|i, dim| unsafe { + (self + .jit_sens_rev_grad_functions + .as_ref() + .expect("module does not support sens autograd") + .rhs_rgrad)( + t, + yy.as_ptr(), + data.as_ptr(), + ddata.as_ptr() as *mut f64, + rr.as_ptr(), + drr.as_ptr() as *mut f64, + i, + dim, + ) + }); + } + + pub fn calc_out(&self, t: f64, yy: &[f64], data: &mut [f64]) { + self.check_state_len(yy, "yy"); + self.check_data_len(data, "data"); + self.with_threading(|i, dim| unsafe { + (self.jit_functions.calc_out)(t, yy.as_ptr(), data.as_ptr() as *mut f64, i, dim) + }); + } + + pub fn calc_out_grad(&self, t: f64, yy: &[f64], dyy: &[f64], data: &[f64], ddata: &mut [f64]) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); self.check_data_len(data, "data"); @@ -555,7 +711,7 @@ impl Compiler { t, yy.as_ptr(), dyy.as_ptr(), - data.as_ptr() as *mut f64, + data.as_ptr(), ddata.as_ptr() as *mut f64, i, dim, @@ -592,6 +748,46 @@ impl Compiler { }); } + pub fn calc_out_sgrad(&self, t: f64, yy: &[f64], data: &[f64], ddata: &mut [f64]) { + self.check_state_len(yy, "yy"); + self.check_data_len(data, "data"); + self.check_data_len(ddata, "ddata"); + self.with_threading(|i, dim| unsafe { + (self + .jit_sens_grad_functions + .as_ref() + .expect("module does not support sens autograd") + .calc_out_sgrad)( + t, + yy.as_ptr(), + data.as_ptr(), + ddata.as_ptr() as *mut f64, + i, + dim, + ) + }); + } + + pub fn calc_out_srgrad(&self, t: f64, yy: &[f64], data: &[f64], ddata: &mut [f64]) { + self.check_state_len(yy, "yy"); + self.check_data_len(data, "data"); + self.check_data_len(ddata, "ddata"); + self.with_threading(|i, dim| unsafe { + (self + .jit_sens_rev_grad_functions + .as_ref() + .expect("module does not support sens autograd") + .calc_out_rgrad)( + t, + yy.as_ptr(), + data.as_ptr(), + ddata.as_ptr() as *mut f64, + i, + dim, + ) + }); + } + /// Get various dimensions of the model /// /// # Returns @@ -630,6 +826,12 @@ impl Compiler { unsafe { (self.jit_functions.set_inputs)(inputs.as_ptr(), data.as_mut_ptr()) }; } + pub fn get_inputs(&self, inputs: &mut [f64], data: &[f64]) { + self.check_inputs_len(inputs, "inputs"); + self.check_data_len(data, "data"); + unsafe { (self.jit_functions.get_inputs)(inputs.as_mut_ptr(), data.as_ptr()) }; + } + pub fn set_inputs_grad( &self, inputs: &[f64], @@ -669,8 +871,8 @@ impl Compiler { .expect("module does not support reverse autograd") .set_inputs_rgrad)( inputs.as_ptr(), - dinputs.as_ptr(), - data.as_ptr() as *mut f64, + dinputs.as_mut_ptr(), + data.as_ptr(), ddata.as_mut_ptr(), ) }; @@ -726,50 +928,43 @@ mod tests { #[test] fn test_from_discrete_str_llvm() { use crate::execution::llvm::codegen::LlvmModule; - let text = " - u { y = 1 } - F { -y } - out { y } - "; - let compiler = Compiler::::from_discrete_str(text, Default::default()).unwrap(); - let (n_states, n_inputs, n_outputs, _n_data, n_stop, has_mass) = compiler.get_dims(); - assert_eq!(n_states, 1); - assert_eq!(n_inputs, 0); - assert_eq!(n_outputs, 1); - assert_eq!(n_stop, 0); - assert!(!has_mass); - let mut u0 = vec![0.]; - let mut res = vec![0.]; - let mut data = compiler.get_new_data(); - compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); - assert_relative_eq!(u0.as_slice(), vec![1.].as_slice()); - compiler.rhs(0., u0.as_slice(), data.as_mut_slice(), res.as_mut_slice()); - assert_relative_eq!(res.as_slice(), vec![-1.].as_slice()); + test_from_discrete_str_common::(); } #[test] fn test_from_discrete_str_cranelift() { - let text = " + test_from_discrete_str_common::(); + } + + fn test_from_discrete_str_common() { + let text1 = " u { y = 1 } F { -y } out { y } "; - let compiler = - Compiler::::from_discrete_str(text, Default::default()).unwrap(); - let (n_states, n_inputs, n_outputs, _n_data, n_stop, has_mass) = compiler.get_dims(); - assert_eq!(n_states, 1); - assert_eq!(n_inputs, 0); - assert_eq!(n_outputs, 1); - assert_eq!(n_stop, 0); - assert!(!has_mass); - - let mut u0 = vec![0.]; - let mut res = vec![0.]; - let mut data = compiler.get_new_data(); - compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); - assert_relative_eq!(u0.as_slice(), vec![1.].as_slice()); - compiler.rhs(0., u0.as_slice(), data.as_mut_slice(), res.as_mut_slice()); - assert_relative_eq!(res.as_slice(), vec![-1.].as_slice()); + let text2 = " + p { 1 } + u { p } + F { -u } + out { u } + "; + for text in [text2, text1] { + let compiler = Compiler::::from_discrete_str(text, Default::default()).unwrap(); + let (n_states, n_inputs, n_outputs, _n_data, n_stop, has_mass) = compiler.get_dims(); + assert_eq!(n_states, 1); + assert_eq!(n_inputs, 0); + assert_eq!(n_outputs, 1); + assert_eq!(n_stop, 0); + assert!(!has_mass); + + let mut u0 = vec![0.]; + let mut res = vec![0.]; + let mut data = compiler.get_new_data(); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + assert_relative_eq!(u0.as_slice(), vec![1.].as_slice()); + compiler.rhs(0., u0.as_slice(), data.as_mut_slice(), res.as_mut_slice()); + assert_relative_eq!(res.as_slice(), vec![-1.].as_slice()); + } } #[test] @@ -932,24 +1127,23 @@ mod tests { let mut ddata = compiler.get_new_data(); let mut du0 = vec![0.; n_states]; let mut dres = vec![0.; n_states]; - let mut grad_data = compiler.get_new_data(); compiler.set_inputs_grad( inputs.as_slice(), dinputs.as_slice(), - grad_data.as_mut_slice(), + data.as_mut_slice(), ddata.as_mut_slice(), ); compiler.set_u0_grad( u0.as_mut_slice(), du0.as_mut_slice(), - grad_data.as_mut_slice(), + data.as_mut_slice(), ddata.as_mut_slice(), ); compiler.rhs_grad( 0., u0.as_slice(), du0.as_slice(), - grad_data.as_mut_slice(), + data.as_mut_slice(), ddata.as_mut_slice(), res.as_mut_slice(), dres.as_mut_slice(), @@ -958,7 +1152,7 @@ mod tests { 0., u0.as_slice(), du0.as_slice(), - grad_data.as_mut_slice(), + data.as_mut_slice(), ddata.as_mut_slice(), ); results.push( @@ -1002,13 +1196,79 @@ mod tests { data.as_slice(), ddata.as_mut_slice(), ); + compiler.get_inputs(dinputs.as_mut_slice(), ddata.as_slice()); + results.push(dinputs.to_vec()); + + // forward mode sens (rhs) + let mut ddata = compiler.get_new_data(); + let mut dres = vec![0.; n_states]; + let dinputs = vec![1.; n_inputs]; + compiler.set_inputs(dinputs.as_slice(), ddata.as_mut_slice()); + compiler.rhs_sgrad( + 0., + u0.as_slice(), + data.as_slice(), + ddata.as_mut_slice(), + res.as_slice(), + dres.as_mut_slice(), + ); + results.push( + compiler + .get_tensor_data(tensor_name, ddata.as_slice()) + .unwrap() + .to_vec(), + ); + + // forward mode sens (calc_out) + let mut ddata = compiler.get_new_data(); + let dinputs = vec![1.; n_inputs]; + compiler.set_inputs(dinputs.as_slice(), ddata.as_mut_slice()); + compiler.calc_out_sgrad(0., u0.as_slice(), data.as_slice(), ddata.as_mut_slice()); + results.push( + compiler + .get_tensor_data(tensor_name, ddata.as_slice()) + .unwrap() + .to_vec(), + ); + + // reverse mode sens (rhs) + let mut ddata = compiler.get_new_data(); + let dtensor = compiler + .get_tensor_data_mut(tensor_name, ddata.as_mut_slice()) + .unwrap(); + dtensor.fill(1.); + let mut dres = vec![0.; n_states]; + let mut dinputs = vec![0.; n_inputs]; + compiler.rhs_srgrad( + 0., + u0.as_slice(), + data.as_slice(), + ddata.as_mut_slice(), + res.as_slice(), + dres.as_mut_slice(), + ); compiler.set_inputs_rgrad( inputs.as_slice(), dinputs.as_mut_slice(), data.as_slice(), ddata.as_mut_slice(), ); + results.push(dinputs.to_vec()); + // reverse mode sens (calc_out) + let mut ddata = compiler.get_new_data(); + let dtensor = compiler + .get_tensor_data_mut(tensor_name, ddata.as_mut_slice()) + .unwrap(); + dtensor.fill(1.); + let mut dinputs = vec![0.; n_inputs]; + compiler.calc_out_srgrad(0., u0.as_slice(), data.as_slice(), ddata.as_mut_slice()); + compiler.set_inputs_rgrad( + inputs.as_slice(), + dinputs.as_mut_slice(), + data.as_slice(), + ddata.as_mut_slice(), + ); results.push(dinputs.to_vec()); } results @@ -1142,7 +1402,7 @@ mod tests { } macro_rules! tensor_grad_test { - ($($name:ident: $text:literal expect $tensor_name:literal $expected_grad:expr ; $expected_rgrad:expr,)*) => { + ($($name:ident: $text:literal expect $tensor_name:literal $expected_grad:expr ; $expected_rgrad:expr; $expected_sgrad:expr; $expected_srgrad:expr,)*) => { $( #[test] fn $name() { @@ -1177,6 +1437,10 @@ mod tests { let results = tensor_test_common::(full_text.as_str(), $tensor_name, CompilerMode::MultiThreaded(None)); assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice()); assert_relative_eq!(results[2].as_slice(), $expected_rgrad.as_slice()); + assert_relative_eq!(results[3].as_slice(), $expected_sgrad.as_slice()); + assert_relative_eq!(results[4].as_slice(), $expected_sgrad.as_slice()); + assert_relative_eq!(results[5].as_slice(), $expected_srgrad.as_slice()); + assert_relative_eq!(results[6].as_slice(), $expected_srgrad.as_slice()); } let results = tensor_test_common::(full_text.as_str(), $tensor_name, CompilerMode::MultiThreaded(None)); @@ -1190,6 +1454,10 @@ mod tests { let results = tensor_test_common::(full_text.as_str(), $tensor_name, CompilerMode::SingleThreaded); assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice()); assert_relative_eq!(results[2].as_slice(), $expected_rgrad.as_slice()); + assert_relative_eq!(results[3].as_slice(), $expected_sgrad.as_slice()); + assert_relative_eq!(results[4].as_slice(), $expected_sgrad.as_slice()); + assert_relative_eq!(results[5].as_slice(), $expected_srgrad.as_slice()); + assert_relative_eq!(results[6].as_slice(), $expected_srgrad.as_slice()); } let results = tensor_test_common::(full_text.as_str(), $tensor_name, CompilerMode::SingleThreaded); @@ -1201,19 +1469,21 @@ mod tests { } tensor_grad_test! { - const_grad: "r { 3 }" expect "r" vec![0.] ; vec![0.], - const_vec_grad: "r_i { 3, 4 }" expect "r" vec![0., 0.] ; vec![0.], - input_grad: "r { 2 * p * p }" expect "r" vec![4.] ; vec![4.], - input_vec_grad: "r_i { 2 * p * p, 3 * p }" expect "r" vec![4., 3.] ; vec![7.], - state_grad: "r { 2 * y }" expect "r" vec![2.] ; vec![2.], - input_and_state_grad: "r { 2 * y * p }" expect "r" vec![4.] ; vec![4.], - state_and_const_grad1: "r_i { 2 * y, 3 }" expect "r" vec![2., 0.] ; vec![2.], - state_and_const_grad2: "r_i { 3 * y, 2 * y }" expect "r" vec![3., 2.] ; vec![5.], + const_grad: "r { 3 }" expect "r" vec![0.] ; vec![0.] ; vec![0.] ; vec![0.], + const_vec_grad: "r_i { 3, 4 }" expect "r" vec![0., 0.] ; vec![0.] ; vec![0., 0.] ; vec![0.], + input_grad: "r { 2 * p * p }" expect "r" vec![4.] ; vec![4.] ; vec![4.] ; vec![4.], + input_vec_grad: "r_i { 2 * p * p, 3 * p }" expect "r" vec![4., 3.] ; vec![7.] ; vec![4., 3.] ; vec![7.], + state_grad: "r { 2 * y }" expect "r" vec![2.] ; vec![2.] ; vec![0.] ; vec![0.], + input_and_state_grad: "r { 2 * y * p }" expect "r" vec![4.] ; vec![4.] ; vec![2.] ; vec![2.], + state_and_const_grad1: "r_i { 2 * y, 3 }" expect "r" vec![2., 0.] ; vec![2.] ; vec![0., 0.] ; vec![0.], + state_and_const_grad2: "r_i { 3 * y, 2 * y }" expect "r" vec![3., 2.] ; vec![5.] ; vec![0., 0.] ; vec![0.], + state_and_const_grad3: "r_i { 2 * p, 3 }" expect "r" vec![2., 0.] ; vec![2.] ; vec![2., 0.] ; vec![2.], + state_and_const_grad4: "r_i { 3 * p, 2 * p }" expect "r" vec![3., 2.] ; vec![5.] ; vec![3., 2.] ; vec![5.], } macro_rules! tensor_test_big_state { - ($($name:ident: $text:literal expect $tensor_name:literal $expected_value:expr ; $expected_grad:expr ; $expected_rgrad:expr,)*) => { + ($($name:ident: $text:literal expect $tensor_name:literal $expected_value:expr ; $expected_grad:expr ; $expected_rgrad:expr ; $expected_sgrad:expr ; $expected_srgrad:expr,)*) => { $( #[test] fn $name() { @@ -1235,6 +1505,10 @@ mod tests { assert_relative_eq!(results[0].as_slice(), $expected_value.as_slice()); assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice()); assert_relative_eq!(results[2].as_slice(), $expected_rgrad.as_slice()); + assert_relative_eq!(results[3].as_slice(), $expected_sgrad.as_slice()); + assert_relative_eq!(results[4].as_slice(), $expected_sgrad.as_slice()); + assert_relative_eq!(results[5].as_slice(), $expected_srgrad.as_slice()); + assert_relative_eq!(results[6].as_slice(), $expected_srgrad.as_slice()); } let results = tensor_test_common::(full_text.as_str(), $tensor_name, CompilerMode::SingleThreaded); @@ -1256,6 +1530,10 @@ mod tests { assert_relative_eq!(results[0].as_slice(), $expected_value.as_slice()); assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice()); assert_relative_eq!(results[2].as_slice(), $expected_rgrad.as_slice()); + assert_relative_eq!(results[3].as_slice(), $expected_sgrad.as_slice()); + assert_relative_eq!(results[4].as_slice(), $expected_sgrad.as_slice()); + assert_relative_eq!(results[5].as_slice(), $expected_srgrad.as_slice()); + assert_relative_eq!(results[6].as_slice(), $expected_srgrad.as_slice()); } } } @@ -1264,11 +1542,12 @@ mod tests { } tensor_test_big_state! { - big_state_expr: "r_i { x_i + y_i }" expect "r" vec![2.; 50] ; vec![2.; 50] ; vec![100.], - big_state_multi: "r_i { x_i + y_i } b_i { x_i, r_i - y_i }" expect "b" vec![1.; 100] ; vec![1.; 100] ; vec![100.], - big_state_multi_w_scalar: "r { 1.0 + 1.0 } b_i { x_i, r - y_i }" expect "b" vec![1.; 100] ; vec![1.; 50].into_iter().chain(vec![-1.; 50].into_iter()).collect::>() ; vec![0.], - big_state_diag: "b_ij { (0..100, 0..100): 3.0 } r_i { b_ij * u_j }" expect "r" vec![3.; 100] ; vec![3.; 100] ; vec![300.], - big_state_tridiag: "b_ij { (0..100, 0..100): 3.0, (0..99, 1..100): 2.0, (1..100, 0..99): 1.0, (0, 99): 1.0, (99, 0): 2.0 } r_i { b_ij * u_j }" expect "r" vec![6.; 100]; vec![6.; 100]; vec![600.], + big_state_expr: "r_i { x_i + y_i }" expect "r" vec![2.; 50] ; vec![2.; 50] ; vec![100.] ; vec![0.; 50] ; vec![0.], + big_state_multi: "r_i { x_i + y_i } b_i { x_i, r_i - y_i }" expect "b" vec![1.; 100] ; vec![1.; 100] ; vec![100.] ; vec![0.; 100] ; vec![0.], + big_state_multi_w_scalar: "r { 1.0 + 1.0 } b_i { x_i, r - y_i }" expect "b" vec![1.; 100] ; vec![1.; 50].into_iter().chain(vec![-1.; 50].into_iter()).collect::>() ; vec![0.] ; vec![0.; 100] ; vec![0.], + big_state_diag: "b_ij { (0..100, 0..100): 3.0 } r_i { b_ij * u_j }" expect "r" vec![3.; 100] ; vec![3.; 100] ; vec![300.] ; vec![0.; 100] ; vec![0.], + big_state_tridiag: "b_ij { (0..100, 0..100): 3.0, (0..99, 1..100): 2.0, (1..100, 0..99): 1.0, (0, 99): 1.0, (99, 0): 2.0 } r_i { b_ij * u_j }" expect "r" vec![6.; 100]; vec![6.; 100]; vec![600.] ; vec![0.; 100]; vec![0.], + big_state_tridiag2: "b_ij { (0..100, 0..100): p + 2.0, (0..99, 1..100): 2.0, (1..100, 0..99): 1.0, (0, 99): 1.0, (99, 0): 2.0 } r_i { b_ij * u_j }" expect "r" vec![6.; 100]; vec![7.; 100]; vec![700.] ; vec![1.; 100]; vec![100.], } #[cfg(feature = "llvm")] @@ -1390,9 +1669,12 @@ mod tests { let mut data = compiler.get_new_data(); let mut ddata = compiler.get_new_data(); let (_n_states, n_inputs, _n_outputs, _n_data, _n_stop, _has_mass) = compiler.get_dims(); + let inputs = vec![2.; n_inputs]; + compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.rhs(0., u0.as_slice(), data.as_mut_slice(), res.as_mut_slice()); for _i in 0..3 { - let inputs = vec![2.; n_inputs]; let dinputs = vec![1.; n_inputs]; compiler.set_inputs_grad( inputs.as_slice(), @@ -1529,4 +1811,38 @@ mod tests { } } } + + #[cfg(feature = "llvm")] + #[test] + fn test_mass_llvm() { + let full_text = " + dudt_i { dxdt = 1, dydt = 1, dzdt = 1 } + u_i { x = 1, y = 2, z = 3 } + F_i { x, y, z } + M_i { dxdt + dydt, dydt, dzdt } + "; + let model = parse_ds_string(full_text).unwrap(); + let discrete_model = DiscreteModel::build("test_mass", &model).unwrap(); + + let compiler = + Compiler::::from_discrete_model(&discrete_model, Default::default()) + .unwrap(); + let mut data = compiler.get_new_data(); + let mut u0 = vec![0.0, 0.0, 0.0]; + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + let mut mv = vec![0.0, 0.0, 0.0]; + let mut v = vec![1.0, 1.0, 1.0]; + compiler.mass(0.0, v.as_slice(), data.as_mut_slice(), mv.as_mut_slice()); + assert_relative_eq!(mv.as_slice(), vec![2.0, 1.0, 1.0].as_slice()); + mv = vec![1.0, 1.0, 1.0]; + let mut ddata = compiler.get_new_data(); + compiler.mass_rgrad( + 0.0, + v.as_mut_slice(), + data.as_mut_slice(), + ddata.as_mut_slice(), + mv.as_mut_slice(), + ); + assert_relative_eq!(v.as_slice(), vec![2.0, 3.0, 2.0].as_slice()); + } } diff --git a/src/execution/cranelift/codegen.rs b/src/execution/cranelift/codegen.rs index 76fe5d9..8a80970 100644 --- a/src/execution/cranelift/codegen.rs +++ b/src/execution/cranelift/codegen.rs @@ -206,6 +206,14 @@ unsafe impl Sync for CraneliftModule {} impl CodegenModule for CraneliftModule { type FuncId = FuncId; + fn compile_mass_rgrad( + &mut self, + _func_id: &Self::FuncId, + _model: &DiscreteModel, + ) -> Result { + Err(anyhow!("not implemented")) + } + fn compile_calc_out_rgrad( &mut self, _func_id: &Self::FuncId, @@ -238,6 +246,46 @@ impl CodegenModule for CraneliftModule { Err(anyhow!("not implemented")) } + fn compile_calc_out_full(&mut self, _model: &DiscreteModel) -> Result { + Err(anyhow!("not implemented")) + } + + fn compile_calc_out_sgrad( + &mut self, + _func_id: &Self::FuncId, + _model: &DiscreteModel, + ) -> Result { + Err(anyhow!("not implemented")) + } + + fn compile_calc_out_srgrad( + &mut self, + _func_id: &Self::FuncId, + _model: &DiscreteModel, + ) -> Result { + Err(anyhow!("not implemented")) + } + + fn compile_rhs_full(&mut self, _model: &DiscreteModel) -> Result { + Err(anyhow!("not implemented")) + } + + fn compile_rhs_sgrad( + &mut self, + _func_id: &Self::FuncId, + _model: &DiscreteModel, + ) -> Result { + Err(anyhow!("not implemented")) + } + + fn compile_rhs_srgrad( + &mut self, + _func_id: &Self::FuncId, + _model: &DiscreteModel, + ) -> Result { + Err(anyhow!("not implemented")) + } + fn supports_reverse_autodiff(&self) -> bool { false } @@ -259,8 +307,6 @@ impl CodegenModule for CraneliftModule { let arg_names = &["t", "u", "du", "data", "ddata", "threadId", "threadDim"]; let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); - codegen.jit_compile_tensor(model.out(), None, false)?; - codegen.jit_compile_call_barrier(0); codegen.jit_compile_tensor(model.out(), None, true)?; codegen.builder.ins().return_(&[]); codegen.builder.finalize(); @@ -300,9 +346,6 @@ impl CodegenModule for CraneliftModule { // calculate time dependant definitions let mut nbarrier = 0; for tensor in model.time_dep_defns() { - codegen.jit_compile_tensor(tensor, None, false)?; - codegen.jit_compile_call_barrier(nbarrier); - nbarrier += 1; codegen.jit_compile_tensor(tensor, None, true)?; codegen.jit_compile_call_barrier(nbarrier); nbarrier += 1; @@ -310,18 +353,12 @@ impl CodegenModule for CraneliftModule { // TODO: could split state dep defns into before and after F for a in model.state_dep_defns() { - codegen.jit_compile_tensor(a, None, false)?; - codegen.jit_compile_call_barrier(nbarrier); - nbarrier += 1; codegen.jit_compile_tensor(a, None, true)?; codegen.jit_compile_call_barrier(nbarrier); nbarrier += 1; } // F - let res = *codegen.variables.get("rr").unwrap(); - codegen.jit_compile_tensor(model.rhs(), Some(res), false)?; - codegen.jit_compile_call_barrier(nbarrier); let res = *codegen.variables.get("drr").unwrap(); codegen.jit_compile_tensor(model.rhs(), Some(res), true)?; @@ -344,13 +381,9 @@ impl CodegenModule for CraneliftModule { let arg_names = &["inputs", "dinputs", "data", "ddata"]; let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); - let base_data_ptr = codegen.variables.get("data").unwrap(); - let base_data_ptr = codegen.builder.use_var(*base_data_ptr); - codegen.jit_compile_set_inputs(model, base_data_ptr, false); - let base_data_ptr = codegen.variables.get("ddata").unwrap(); let base_data_ptr = codegen.builder.use_var(*base_data_ptr); - codegen.jit_compile_set_inputs(model, base_data_ptr, true); + codegen.jit_compile_inputs(model, base_data_ptr, true, false); codegen.builder.ins().return_(&[]); codegen.builder.finalize(); @@ -374,21 +407,13 @@ impl CodegenModule for CraneliftModule { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); let mut nbarrier = 0; + #[allow(clippy::explicit_counter_loop)] for a in model.time_indep_defns() { - codegen.jit_compile_tensor(a, None, false)?; - codegen.jit_compile_call_barrier(nbarrier); - nbarrier += 1; codegen.jit_compile_tensor(a, None, true)?; codegen.jit_compile_call_barrier(nbarrier); nbarrier += 1; } - codegen.jit_compile_tensor( - model.state(), - Some(*codegen.variables.get("u0").unwrap()), - false, - )?; - codegen.jit_compile_call_barrier(nbarrier); codegen.jit_compile_tensor( model.state(), Some(*codegen.variables.get("du0").unwrap()), @@ -782,13 +807,27 @@ impl CodegenModule for CraneliftModule { let base_data_ptr = codegen.variables.get("data").unwrap(); let base_data_ptr = codegen.builder.use_var(*base_data_ptr); - codegen.jit_compile_set_inputs(model, base_data_ptr, false); + codegen.jit_compile_inputs(model, base_data_ptr, false, false); codegen.builder.ins().return_(&[]); codegen.builder.finalize(); self.declare_function("set_inputs") } + fn compile_get_inputs(&mut self, model: &DiscreteModel) -> Result { + let arg_types = &[self.real_ptr_type, self.real_ptr_type]; + let arg_names = &["inputs", "data"]; + let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); + + let base_data_ptr = codegen.variables.get("data").unwrap(); + let base_data_ptr = codegen.builder.use_var(*base_data_ptr); + codegen.jit_compile_inputs(model, base_data_ptr, false, true); + + codegen.builder.ins().return_(&[]); + codegen.builder.finalize(); + self.declare_function("get_inputs") + } + fn compile_set_id(&mut self, model: &DiscreteModel) -> Result { let arg_types = &[self.real_ptr_type]; let arg_names = &["id"]; @@ -2054,11 +2093,12 @@ impl<'ctx> CraneliftCodeGen<'ctx> { codegen } - fn jit_compile_set_inputs( + fn jit_compile_inputs( &mut self, model: &DiscreteModel, base_data_ptr: Value, is_tangent: bool, + is_get: bool, ) { let mut inputs_index = 0; for input in model.inputs() { @@ -2096,13 +2136,23 @@ impl<'ctx> CraneliftCodeGen<'ctx> { let indexed_input_ptr = self.ptr_add_offset(self.real_type, input_ptr, curr_input_index_plus_start_index); let indexed_data_ptr = self.ptr_add_offset(self.real_type, data_ptr, curr_input_index); - let input_value = + if is_get { + let input_value = + self.builder + .ins() + .load(self.real_type, self.mem_flags, indexed_data_ptr, 0); self.builder .ins() - .load(self.real_type, self.mem_flags, indexed_input_ptr, 0); - self.builder - .ins() - .store(self.mem_flags, input_value, indexed_data_ptr, 0); + .store(self.mem_flags, input_value, indexed_input_ptr, 0); + } else { + let input_value = + self.builder + .ins() + .load(self.real_type, self.mem_flags, indexed_input_ptr, 0); + self.builder + .ins() + .store(self.mem_flags, input_value, indexed_data_ptr, 0); + } // increment loop index let one = self.builder.ins().iconst(self.int_type, 1); diff --git a/src/execution/interface.rs b/src/execution/interface.rs index 3a1499c..e658e44 100644 --- a/src/execution/interface.rs +++ b/src/execution/interface.rs @@ -19,13 +19,44 @@ pub type RhsFunc = unsafe extern "C" fn( threadId: UIntType, threadDim: UIntType, ); -pub type RhsGradientFunc = unsafe extern "C" fn( +pub type RhsGradFunc = unsafe extern "C" fn( time: RealType, u: *const RealType, du: *const RealType, - data: *mut RealType, + data: *const RealType, ddata: *mut RealType, - rr: *mut RealType, + rr: *const RealType, + drr: *mut RealType, + threadId: UIntType, + threadDim: UIntType, +); +pub type RhsRevGradFunc = unsafe extern "C" fn( + time: RealType, + u: *const RealType, + du: *mut RealType, + data: *const RealType, + ddata: *mut RealType, + rr: *const RealType, + drr: *mut RealType, + threadId: UIntType, + threadDim: UIntType, +); +pub type RhsSensGradFunc = unsafe extern "C" fn( + time: RealType, + u: *const RealType, + data: *const RealType, + ddata: *mut RealType, + rr: *const RealType, + drr: *mut RealType, + threadId: UIntType, + threadDim: UIntType, +); +pub type RhsSensRevGradFunc = unsafe extern "C" fn( + time: RealType, + u: *const RealType, + data: *const RealType, + ddata: *mut RealType, + rr: *const RealType, drr: *mut RealType, threadId: UIntType, threadDim: UIntType, @@ -38,16 +69,35 @@ pub type MassFunc = unsafe extern "C" fn( threadId: UIntType, threadDim: UIntType, ); +pub type MassRevGradFunc = unsafe extern "C" fn( + time: RealType, + v: *const RealType, + dv: *mut RealType, + data: *const RealType, + ddata: *mut RealType, + mv: *const RealType, + dmv: *mut RealType, + threadId: UIntType, + threadDim: UIntType, +); pub type U0Func = unsafe extern "C" fn( u: *mut RealType, data: *mut RealType, threadId: UIntType, threadDim: UIntType, ); -pub type U0GradientFunc = unsafe extern "C" fn( - u: *mut RealType, +pub type U0GradFunc = unsafe extern "C" fn( + u: *const RealType, du: *mut RealType, - data: *mut RealType, + data: *const RealType, + ddata: *mut RealType, + threadId: UIntType, + threadDim: UIntType, +); +pub type U0RevGradFunc = unsafe extern "C" fn( + u: *const RealType, + du: *mut RealType, + data: *const RealType, ddata: *mut RealType, threadId: UIntType, threadDim: UIntType, @@ -59,16 +109,16 @@ pub type CalcOutFunc = unsafe extern "C" fn( threadId: UIntType, threadDim: UIntType, ); -pub type CalcOutGradientFunc = unsafe extern "C" fn( +pub type CalcOutGradFunc = unsafe extern "C" fn( time: RealType, u: *const RealType, du: *const RealType, - data: *mut RealType, + data: *const RealType, ddata: *mut RealType, threadId: UIntType, threadDim: UIntType, ); -pub type CalcOutReverseGradientFunc = unsafe extern "C" fn( +pub type CalcOutRevGradFunc = unsafe extern "C" fn( time: RealType, u: *const RealType, du: *mut RealType, @@ -77,6 +127,22 @@ pub type CalcOutReverseGradientFunc = unsafe extern "C" fn( threadId: UIntType, threadDim: UIntType, ); +pub type CalcOutSensGradFunc = unsafe extern "C" fn( + time: RealType, + u: *const RealType, + data: *const RealType, + ddata: *mut RealType, + threadId: UIntType, + threadDim: UIntType, +); +pub type CalcOutSensRevGradFunc = unsafe extern "C" fn( + time: RealType, + u: *const RealType, + data: *const RealType, + ddata: *mut RealType, + threadId: UIntType, + threadDim: UIntType, +); pub type GetDimsFunc = unsafe extern "C" fn( states: *mut UIntType, inputs: *mut UIntType, @@ -86,10 +152,17 @@ pub type GetDimsFunc = unsafe extern "C" fn( has_mass: *mut UIntType, ); pub type SetInputsFunc = unsafe extern "C" fn(inputs: *const RealType, data: *mut RealType); -pub type SetInputsGradientFunc = unsafe extern "C" fn( +pub type GetInputsFunc = unsafe extern "C" fn(inputs: *mut RealType, data: *const RealType); +pub type SetInputsGradFunc = unsafe extern "C" fn( inputs: *const RealType, dinputs: *const RealType, - data: *mut RealType, + data: *const RealType, + ddata: *mut RealType, +); +pub type SetInputsRevGradFunc = unsafe extern "C" fn( + inputs: *const RealType, + dinputs: *mut RealType, + data: *const RealType, ddata: *mut RealType, ); pub type SetIdFunc = unsafe extern "C" fn(id: *mut RealType); diff --git a/src/execution/llvm/codegen.rs b/src/execution/llvm/codegen.rs index d2ddb11..429f83d 100644 --- a/src/execution/llvm/codegen.rs +++ b/src/execution/llvm/codegen.rs @@ -281,7 +281,11 @@ impl CodegenModule for LlvmModule { } fn compile_calc_out(&mut self, model: &DiscreteModel) -> Result { - self.codegen_mut().compile_calc_out(model) + self.codegen_mut().compile_calc_out(model, false) + } + + fn compile_calc_out_full(&mut self, model: &DiscreteModel) -> Result { + self.codegen_mut().compile_calc_out(model, true) } fn compile_calc_stop(&mut self, model: &DiscreteModel) -> Result { @@ -289,7 +293,12 @@ impl CodegenModule for LlvmModule { } fn compile_rhs(&mut self, model: &DiscreteModel) -> Result { - let ret = self.codegen_mut().compile_rhs(model); + let ret = self.codegen_mut().compile_rhs(model, false); + ret + } + + fn compile_rhs_full(&mut self, model: &DiscreteModel) -> Result { + let ret = self.codegen_mut().compile_rhs(model, true); ret } @@ -306,7 +315,11 @@ impl CodegenModule for LlvmModule { } fn compile_set_inputs(&mut self, model: &DiscreteModel) -> Result { - self.codegen_mut().compile_set_inputs(model) + self.codegen_mut().compile_inputs(model, false) + } + + fn compile_get_inputs(&mut self, model: &DiscreteModel) -> Result { + self.codegen_mut().compile_inputs(model, true) } fn compile_set_id(&mut self, model: &DiscreteModel) -> Result { @@ -321,8 +334,8 @@ impl CodegenModule for LlvmModule { self.codegen_mut().compile_gradient( *func_id, &[ - CompileGradientArgType::Dup, - CompileGradientArgType::Dup, + CompileGradientArgType::DupNoNeed, + CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, ], @@ -360,8 +373,8 @@ impl CodegenModule for LlvmModule { *func_id, &[ CompileGradientArgType::Const, - CompileGradientArgType::Dup, - CompileGradientArgType::Dup, + CompileGradientArgType::DupNoNeed, + CompileGradientArgType::DupNoNeed, CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, @@ -370,6 +383,25 @@ impl CodegenModule for LlvmModule { ) } + fn compile_mass_rgrad( + &mut self, + func_id: &Self::FuncId, + _model: &DiscreteModel, + ) -> Result { + self.codegen_mut().compile_gradient( + *func_id, + &[ + CompileGradientArgType::Const, + CompileGradientArgType::DupNoNeed, + CompileGradientArgType::DupNoNeed, + CompileGradientArgType::DupNoNeed, + CompileGradientArgType::Const, + CompileGradientArgType::Const, + ], + CompileMode::Reverse, + ) + } + fn compile_rhs_rgrad( &mut self, func_id: &Self::FuncId, @@ -398,8 +430,8 @@ impl CodegenModule for LlvmModule { *func_id, &[ CompileGradientArgType::Const, - CompileGradientArgType::Dup, - CompileGradientArgType::Dup, + CompileGradientArgType::DupNoNeed, + CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, ], @@ -432,7 +464,10 @@ impl CodegenModule for LlvmModule { ) -> Result { self.codegen_mut().compile_gradient( *func_id, - &[CompileGradientArgType::Dup, CompileGradientArgType::Dup], + &[ + CompileGradientArgType::DupNoNeed, + CompileGradientArgType::DupNoNeed, + ], CompileMode::Forward, ) } @@ -452,6 +487,80 @@ impl CodegenModule for LlvmModule { ) } + fn compile_rhs_sgrad( + &mut self, + func_id: &Self::FuncId, + _model: &DiscreteModel, + ) -> Result { + self.codegen_mut().compile_gradient( + *func_id, + &[ + CompileGradientArgType::Const, + CompileGradientArgType::Const, + CompileGradientArgType::DupNoNeed, + CompileGradientArgType::DupNoNeed, + CompileGradientArgType::Const, + CompileGradientArgType::Const, + ], + CompileMode::Forward, + ) + } + + fn compile_calc_out_sgrad( + &mut self, + func_id: &Self::FuncId, + _model: &DiscreteModel, + ) -> Result { + self.codegen_mut().compile_gradient( + *func_id, + &[ + CompileGradientArgType::Const, + CompileGradientArgType::Const, + CompileGradientArgType::DupNoNeed, + CompileGradientArgType::Const, + CompileGradientArgType::Const, + ], + CompileMode::Forward, + ) + } + + fn compile_calc_out_srgrad( + &mut self, + func_id: &Self::FuncId, + _model: &DiscreteModel, + ) -> Result { + self.codegen_mut().compile_gradient( + *func_id, + &[ + CompileGradientArgType::Const, + CompileGradientArgType::Const, + CompileGradientArgType::DupNoNeed, + CompileGradientArgType::Const, + CompileGradientArgType::Const, + ], + CompileMode::Reverse, + ) + } + + fn compile_rhs_srgrad( + &mut self, + func_id: &Self::FuncId, + _model: &DiscreteModel, + ) -> Result { + self.codegen_mut().compile_gradient( + *func_id, + &[ + CompileGradientArgType::Const, + CompileGradientArgType::Const, + CompileGradientArgType::DupNoNeed, + CompileGradientArgType::DupNoNeed, + CompileGradientArgType::Const, + CompileGradientArgType::Const, + ], + CompileMode::Reverse, + ) + } + fn pre_autodiff_optimisation(&mut self) -> Result<()> { //let pass_manager_builder = PassManagerBuilder::create(); //pass_manager_builder.set_optimization_level(inkwell::OptimizationLevel::Default); @@ -503,6 +612,10 @@ impl CodegenModule for LlvmModule { let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline"); barrier_func.remove_enum_attribute(AttributeLoc::Function, nolinline_kind_id); } + //self.codegen() + // .module() + // .print_to_file("post_autodiff_optimisation.ll") + // .unwrap(); let initialization_config = &InitializationConfig::default(); Target::initialize_all(initialization_config); @@ -525,11 +638,6 @@ impl CodegenModule for LlvmModule { .run_passes(passes, &machine, PassBuilderOptions::create()) .map_err(|e| anyhow!("Failed to run passes: {:?}", e))?; - //self.codegen() - // .module() - // .print_to_file("post_autodiff_optimisation.ll") - // .unwrap(); - Ok(()) } } @@ -1288,7 +1396,8 @@ impl<'ctx> CodeGen<'ctx> { .take(arg_len) .map(|f| f.into()) .collect::>(); - intrinsic.get_declaration(&self.module, args_types.as_slice()) + // if we get an intrinsic, we don't need to add to the list of functions and can return early + return intrinsic.get_declaration(&self.module, args_types.as_slice()); } // some custom functions "sigmoid" => { @@ -2557,6 +2666,7 @@ impl<'ctx> CodeGen<'ctx> { pub fn compile_calc_out<'m>( &mut self, model: &'m DiscreteModel, + include_constants: bool, ) -> Result> { self.clear(); let void_type = self.context.void_type(); @@ -2571,7 +2681,12 @@ impl<'ctx> CodeGen<'ctx> { false, ); let fn_arg_names = &["t", "u", "data", "thread_id", "thread_dim"]; - let function = self.module.add_function("calc_out", fn_type, None); + let function_name = if include_constants { + "calc_out_full" + } else { + "calc_out" + }; + let function = self.module.add_function(function_name, fn_type, None); // add noalias let alias_id = Attribute::get_named_enum_kind_id("noalias"); @@ -2600,10 +2715,20 @@ impl<'ctx> CodeGen<'ctx> { //self.compile_print_value("thread_id", PrintValue::Int(thread_id.into_int_value()))?; //self.compile_print_value("thread_dim", PrintValue::Int(thread_dim.into_int_value()))?; - // calculate time dependant definitions let mut nbarriers = 0; - let total_barriers = + let mut total_barriers = (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64; + if include_constants { + total_barriers += model.time_indep_defns().len() as u64; + // calculate time independant definitions + for tensor in model.time_indep_defns() { + self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?; + self.jit_compile_call_barrier(nbarriers, total_barriers); + nbarriers += 1; + } + } + + // calculate time dependant definitions for tensor in model.time_dep_defns() { self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?; self.jit_compile_call_barrier(nbarriers, total_barriers); @@ -2709,7 +2834,11 @@ impl<'ctx> CodeGen<'ctx> { } } - pub fn compile_rhs<'m>(&mut self, model: &'m DiscreteModel) -> Result> { + pub fn compile_rhs<'m>( + &mut self, + model: &'m DiscreteModel, + include_constants: bool, + ) -> Result> { self.clear(); let void_type = self.context.void_type(); let fn_type = void_type.fn_type( @@ -2724,7 +2853,8 @@ impl<'ctx> CodeGen<'ctx> { false, ); let fn_arg_names = &["t", "u", "data", "rr", "thread_id", "thread_dim"]; - let function = self.module.add_function("rhs", fn_type, None); + let function_name = if include_constants { "rhs_full" } else { "rhs" }; + let function = self.module.add_function(function_name, fn_type, None); // add noalias let alias_id = Attribute::get_named_enum_kind_id("noalias"); @@ -2747,10 +2877,20 @@ impl<'ctx> CodeGen<'ctx> { self.insert_data(model); self.insert_indices(); - // calculate time dependant definitions let mut nbarriers = 0; - let total_barriers = + let mut total_barriers = (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64; + if include_constants { + total_barriers += model.time_indep_defns().len() as u64; + // calculate constant definitions + for tensor in model.time_indep_defns() { + self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?; + self.jit_compile_call_barrier(nbarriers, total_barriers); + nbarriers += 1; + } + } + + // calculate time dependant definitions for tensor in model.time_dep_defns() { self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?; self.jit_compile_call_barrier(nbarriers, total_barriers); @@ -3254,14 +3394,19 @@ impl<'ctx> CodeGen<'ctx> { } } - pub fn compile_set_inputs(&mut self, model: &DiscreteModel) -> Result> { + pub fn compile_inputs( + &mut self, + model: &DiscreteModel, + is_get: bool, + ) -> Result> { self.clear(); let void_type = self.context.void_type(); let fn_type = void_type.fn_type( &[self.real_ptr_type.into(), self.real_ptr_type.into()], false, ); - let function = self.module.add_function("set_inputs", fn_type, None); + let function_name = if is_get { "get_inputs" } else { "set_inputs" }; + let function = self.module.add_function(function_name, fn_type, None); let mut block = self.context.append_basic_block(function, "entry"); self.fn_value_opt = Some(function); @@ -3279,7 +3424,7 @@ impl<'ctx> CodeGen<'ctx> { let name = format!("input_{}", input.name()); self.insert_tensor(input); let ptr = self.get_var(input); - // loop thru the elements of this input and set them using the inputs ptr + // loop thru the elements of this input and set/get them using the inputs ptr let inputs_start_index = self.int_type.const_int(inputs_index as u64, false); let start_index = self.int_type.const_int(0, false); let end_index = self @@ -3311,10 +3456,17 @@ impl<'ctx> CodeGen<'ctx> { curr_inputs_index, name.as_str(), ); - let input_value = self - .build_load(self.real_type, inputs_ptr, name.as_str())? - .into_float_value(); - self.builder.build_store(input_ptr, input_value)?; + if is_get { + let input_value = self + .build_load(self.real_type, input_ptr, name.as_str())? + .into_float_value(); + self.builder.build_store(inputs_ptr, input_value)?; + } else { + let input_value = self + .build_load(self.real_type, inputs_ptr, name.as_str())? + .into_float_value(); + self.builder.build_store(input_ptr, input_value)?; + } // increment loop index let one = self.int_type.const_int(1, false); diff --git a/src/execution/module.rs b/src/execution/module.rs index 4e43f41..7186ff1 100644 --- a/src/execution/module.rs +++ b/src/execution/module.rs @@ -11,24 +11,35 @@ pub trait CodegenModule: Sized + Sync { fn new(triple: Triple, model: &DiscreteModel, threaded: bool) -> Result; fn compile_set_u0(&mut self, model: &DiscreteModel) -> Result; fn compile_calc_out(&mut self, model: &DiscreteModel) -> Result; + fn compile_calc_out_full(&mut self, model: &DiscreteModel) -> Result; fn compile_calc_stop(&mut self, model: &DiscreteModel) -> Result; fn compile_rhs(&mut self, model: &DiscreteModel) -> Result; + fn compile_rhs_full(&mut self, model: &DiscreteModel) -> Result; fn compile_mass(&mut self, model: &DiscreteModel) -> Result; fn compile_get_dims(&mut self, model: &DiscreteModel) -> Result; fn compile_get_tensor(&mut self, model: &DiscreteModel, name: &str) -> Result; fn compile_set_inputs(&mut self, model: &DiscreteModel) -> Result; + fn compile_get_inputs(&mut self, model: &DiscreteModel) -> Result; fn compile_set_id(&mut self, model: &DiscreteModel) -> Result; + fn compile_mass_rgrad( + &mut self, + func_id: &Self::FuncId, + model: &DiscreteModel, + ) -> Result; + fn compile_set_u0_grad( &mut self, func_id: &Self::FuncId, model: &DiscreteModel, ) -> Result; + fn compile_rhs_grad( &mut self, func_id: &Self::FuncId, model: &DiscreteModel, ) -> Result; + fn compile_calc_out_grad( &mut self, func_id: &Self::FuncId, @@ -64,6 +75,30 @@ pub trait CodegenModule: Sized + Sync { model: &DiscreteModel, ) -> Result; + fn compile_rhs_sgrad( + &mut self, + func_id: &Self::FuncId, + model: &DiscreteModel, + ) -> Result; + + fn compile_rhs_srgrad( + &mut self, + func_id: &Self::FuncId, + model: &DiscreteModel, + ) -> Result; + + fn compile_calc_out_sgrad( + &mut self, + func_id: &Self::FuncId, + model: &DiscreteModel, + ) -> Result; + + fn compile_calc_out_srgrad( + &mut self, + func_id: &Self::FuncId, + model: &DiscreteModel, + ) -> Result; + fn supports_reverse_autodiff(&self) -> bool; fn jit(&mut self, func_id: Self::FuncId) -> Result<*const u8>;