From 5fe7a47341404e68c59efe658cad45b493309a28 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Thu, 14 Mar 2024 09:52:29 +0000 Subject: [PATCH] now have M and F, tests pass --- examples/logistic.ds | 4 +- src/discretise/discrete_model.rs | 127 ++++++++++++++------- src/execution/codegen.rs | 109 +++++++++++++----- src/execution/compiler.rs | 186 ++++++++++++++----------------- src/execution/translation.rs | 4 +- src/lib.rs | 13 ++- 6 files changed, 261 insertions(+), 182 deletions(-) diff --git a/examples/logistic.ds b/examples/logistic.ds index 27b6ecd..48b3220 100644 --- a/examples/logistic.ds +++ b/examples/logistic.ds @@ -9,11 +9,11 @@ dudt_i { dydt = 0, dzdt = 0, } -F_i { +M_i { dydt, 0, } -G_i { +F_i { (r * y) * (1 - (y / k)), (2 * y) - z, } diff --git a/src/discretise/discrete_model.rs b/src/discretise/discrete_model.rs index e3e2c56..24bee0e 100644 --- a/src/discretise/discrete_model.rs +++ b/src/discretise/discrete_model.rs @@ -35,6 +35,7 @@ pub struct DiscreteModel<'s> { time_indep_defns: Vec>, time_dep_defns: Vec>, state_dep_defns: Vec>, + dstate_dep_defns: Vec>, inputs: Vec>, state: Tensor<'s>, state_dot: Tensor<'s>, @@ -87,6 +88,7 @@ impl<'s> DiscreteModel<'s> { time_indep_defns: Vec::new(), time_dep_defns: Vec::new(), state_dep_defns: Vec::new(), + dstate_dep_defns: Vec::new(), inputs: Vec::new(), state: Tensor::new_empty("u"), state_dot: Tensor::new_empty("u_dot"), @@ -95,41 +97,6 @@ impl<'s> DiscreteModel<'s> { } } - // residual = F(t, u, u_dot) - G(t, u) - // return a tensor equal to the residual - pub fn residual(&self) -> Tensor<'s> { - let mut residual = self.lhs.clone(); - residual.set_name("residual"); - let indices = self.lhs.indices().to_vec(); - let lhs = Ast { - kind: AstKind::new_indexed_name("F", indices.clone()), - span: None, - }; - let rhs = Ast { - kind: AstKind::new_indexed_name("G", indices), - span: None, - }; - let name = "residual"; - let indices = self.lhs.indices().to_vec(); - let layout = self.lhs.layout_ptr().clone(); - let elmts = vec![ - TensorBlock::new( - None, - Index::from_vec(vec![0]), - indices.clone(), - self.lhs.layout_ptr().clone(), - self.lhs.layout_ptr().clone(), - Ast { - kind: AstKind::new_binop('-', lhs, rhs), - span: None, - }, - ) - ]; - Tensor::new(name, elmts, layout, indices) - } - - - fn build_array(array: &ast::Tensor<'s>, env: &mut Env) -> Option> { let rank = array.indices().len(); let mut elmts = Vec::new(); @@ -317,6 +284,14 @@ impl<'s> DiscreteModel<'s> { if let Some(built) = Self::build_array(tensor, &mut env) { ret.stop = Some(built); } + // check that stop is not dependent on dudt + let stop = env.get("stop").unwrap(); + if stop.is_dstatedt_dependent() { + env.errs_mut().push(ValidationError::new( + "stop must not be dependent on dudt".to_string(), + tensor_ast.span, + )); + } } "out" => { read_out = true; @@ -329,6 +304,14 @@ impl<'s> DiscreteModel<'s> { } ret.out = built; } + // check that out is not dependent on dudt + let out = env.get("out").unwrap(); + if out.is_dstatedt_dependent() { + env.errs_mut().push(ValidationError::new( + "out must not be dependent on dudt".to_string(), + tensor_ast.span, + )); + } } _name => { if let Some(built) = Self::build_array(tensor, &mut env) { @@ -336,6 +319,7 @@ impl<'s> DiscreteModel<'s> { if let Some(env_entry) = env.get(built.name()) { let dependent_on_state = env_entry.is_state_dependent(); let dependent_on_time = env_entry.is_time_dependent(); + let dependent_on_dudt = env_entry.is_dstatedt_dependent(); if is_input { // inputs must be constants if dependent_on_time || dependent_on_state { @@ -347,10 +331,14 @@ impl<'s> DiscreteModel<'s> { ret.inputs.push(built); } else if !dependent_on_time { ret.time_indep_defns.push(built); - } else if dependent_on_time && !dependent_on_state { + } else if dependent_on_time && !dependent_on_state && !dependent_on_dudt { ret.time_dep_defns.push(built); - } else { + } else if dependent_on_state { ret.state_dep_defns.push(built); + } else if dependent_on_dudt { + ret.dstate_dep_defns.push(built); + } else { + panic!("all the cases should be covered") } } } @@ -635,6 +623,7 @@ impl<'s> DiscreteModel<'s> { let rhs = Tensor::new_no_layout("F", f_elmts, vec!['i']); let name = model.name; let stop = None; + let dstate_dep_defns = Vec::new(); DiscreteModel { name, lhs, @@ -646,6 +635,7 @@ impl<'s> DiscreteModel<'s> { time_indep_defns, time_dep_defns, state_dep_defns, + dstate_dep_defns, is_algebraic, stop, } @@ -664,6 +654,10 @@ impl<'s> DiscreteModel<'s> { pub fn state_dep_defns(&self) -> &[Tensor] { self.state_dep_defns.as_ref() } + + pub fn dstate_dep_defns(&self) -> &[Tensor] { + self.dstate_dep_defns.as_ref() + } pub fn state(&self) -> &Tensor<'s> { &self.state @@ -749,6 +743,54 @@ mod tests { assert_eq!(discrete.out.elmts()[2].expr().to_string(), "z"); println!("{}", discrete); } + + #[test] + fn tensor_classification() { + let text = " + in = [r, k, ] + r { 1, } + k { 1, } + z { 2 * r } + g { 2 * t } + u_i { + y = 1, + z = 0, + } + u2_i { + 2 * y, + 2 * z, + } + dudt_i { + dydt = 0, + dzdt = 0, + } + dudt2_i { + 2 * dydt, + 0, + } + M_i { + dydt, + 0, + } + F_i { + (r * y) * (1 - (y / k)), + (2 * y) - z, + } + out_i { + y, + t, + z, + } + "; + let model = parse_ds_string(text).unwrap(); + let model = DiscreteModel::build("$name", &model).unwrap(); + assert_eq!(model.inputs().iter().map(|t| t.name()).collect::>(), ["r", "k"]); + assert_eq!(model.time_indep_defns().iter().map(|t| t.name()).collect::>(), ["z"]); + assert_eq!(model.time_dep_defns().iter().map(|t| t.name()).collect::>(), ["g"]); + assert_eq!(model.state_dep_defns().iter().map(|t| t.name()).collect::>(), ["u2"]); + assert_eq!(model.dstate_dep_defns().iter().map(|t| t.name()).collect::>(), ["dudt2"]); + assert_eq!(model.inputs().iter().map(|t| t.name()).collect::>(), ["r", "k"]); + } macro_rules! count { () => (0usize); @@ -924,7 +966,7 @@ mod tests { 1, } " ["F and u must have the same shape",], - error_f_dep_on_dudt: " + error_dep_on_dudt: " u_i { y = 1, } @@ -934,10 +976,13 @@ mod tests { F_i { dydt, } + stop_i { + dydt, + } out_i { - y, + dydt, } - " ["G and u must have the same shape",], + " ["F must not be dependent on dudt", "stop must not be dependent on dudt", "out must not be dependent on dudt",], error_m_dep_on_u: " u_i { y = 1, @@ -957,7 +1002,7 @@ mod tests { out_i { y, } - " ["G and u must have the same shape",], + " ["M must not be dependent on u",], ); diff --git a/src/execution/codegen.rs b/src/execution/codegen.rs index 4a6965d..646bf63 100644 --- a/src/execution/codegen.rs +++ b/src/execution/codegen.rs @@ -21,13 +21,14 @@ use crate::execution::{Translation, TranslationFrom, TranslationTo, DataLayout}; /// /// Calling this is innately `unsafe` because there's no guarantee it doesn't /// do `unsafe` operations internally. -pub type StopFunc = unsafe extern "C" fn(time: RealType, u: *const RealType, up: *const RealType, data: *mut RealType, root: *mut RealType); -pub type ResidualFunc = unsafe extern "C" fn(time: RealType, u: *const RealType, up: *const RealType, data: *mut RealType, rr: *mut RealType); -pub type ResidualGradientFunc = unsafe extern "C" fn(time: RealType, u: *const RealType, du: *const RealType, up: *const RealType, dup: *const RealType, data: *mut RealType, ddata: *mut RealType, rr: *mut RealType, drr: *mut RealType); -pub type U0Func = unsafe extern "C" fn(data: *mut RealType, u: *mut RealType, up: *mut RealType); -pub type U0GradientFunc = unsafe extern "C" fn(data: *mut RealType, ddata: *mut RealType, u: *mut RealType, du: *mut RealType, up: *mut RealType, dup: *mut RealType); -pub type CalcOutFunc = unsafe extern "C" fn(time: RealType, u: *const RealType, up: *const RealType, data: *mut RealType); -pub type CalcOutGradientFunc = unsafe extern "C" fn(time: RealType, u: *const RealType, du: *const RealType, up: *const RealType, dup: *const RealType, data: *mut RealType, ddata: *mut RealType); +pub type StopFunc = unsafe extern "C" fn(time: RealType, u: *const RealType, data: *mut RealType, root: *mut RealType); +pub type RhsFunc = unsafe extern "C" fn(time: RealType, u: *const RealType, data: *mut RealType, rr: *mut RealType); +pub type RhsGradientFunc = unsafe extern "C" fn(time: RealType, u: *const RealType, du: *const RealType, data: *mut RealType, ddata: *mut RealType, rr: *mut RealType, drr: *mut RealType); +pub type MassFunc = unsafe extern "C" fn(time: RealType, v: *const RealType, data: *mut RealType, mv: *mut RealType); +pub type U0Func = unsafe extern "C" fn(data: *mut RealType, u: *mut RealType); +pub type U0GradientFunc = unsafe extern "C" fn(data: *mut RealType, ddata: *mut RealType, u: *mut RealType, du: *mut RealType); +pub type CalcOutFunc = unsafe extern "C" fn(time: RealType, u: *const RealType, data: *mut RealType); +pub type CalcOutGradientFunc = unsafe extern "C" fn(time: RealType, u: *const RealType, du: *const RealType, data: *mut RealType, ddata: *mut RealType); pub type GetDimsFunc = unsafe extern "C" fn(states: *mut u32, inputs: *mut u32, outputs: *mut u32, data: *mut u32, stop: *mut u32); pub type SetInputsFunc = unsafe extern "C" fn(inputs: *const RealType, data: *mut RealType); pub type SetInputsGradientFunc = unsafe extern "C" fn(inputs: *const RealType, dinputs: *const RealType, data: *mut RealType, ddata: *mut RealType); @@ -147,7 +148,7 @@ impl<'ctx> CodeGen<'ctx> { self.variables.insert(name.to_owned(), value); } - fn insert_state(&mut self, u: &Tensor, dudt: &Tensor) { + fn insert_state(&mut self, u: &Tensor) { let mut data_index = 0; for blk in u.elmts() { if let Some(name) = blk.name() { @@ -158,7 +159,10 @@ impl<'ctx> CodeGen<'ctx> { } data_index += blk.nnz(); } - data_index = 0; + + } + fn insert_dot_state(&mut self, dudt: &Tensor) { + let mut data_index = 0; for blk in dudt.elmts() { if let Some(name) = blk.name() { let ptr = self.variables.get("dudt").unwrap(); @@ -896,10 +900,10 @@ impl<'ctx> CodeGen<'ctx> { let real_ptr_type = self.real_type.ptr_type(AddressSpace::default()); let void_type = self.context.void_type(); let fn_type = void_type.fn_type( - &[real_ptr_type.into(), real_ptr_type.into(), real_ptr_type.into()] + &[real_ptr_type.into(), real_ptr_type.into()] , false ); - let fn_arg_names = &[ "data", "u0", "dudt0"]; + let fn_arg_names = &[ "data", "u0"]; let function = self.module.add_function("set_u0", fn_type, None); let basic_block = self.context.append_basic_block(function, "entry"); self.fn_value_opt = Some(function); @@ -919,7 +923,6 @@ impl<'ctx> CodeGen<'ctx> { } self.jit_compile_tensor(model.state(), Some(*self.get_param("u0")))?; - self.jit_compile_tensor(model.state_dot(), Some(*self.get_param("dudt0")))?; self.builder.build_return(None)?; @@ -941,10 +944,10 @@ impl<'ctx> CodeGen<'ctx> { let real_ptr_type = self.real_type.ptr_type(AddressSpace::default()); let void_type = self.context.void_type(); let fn_type = void_type.fn_type( - &[self.real_type.into(), real_ptr_type.into(), real_ptr_type.into(), real_ptr_type.into()] + &[self.real_type.into(), real_ptr_type.into(), real_ptr_type.into()] , false ); - let fn_arg_names = &["t", "u", "dudt", "data"]; + let fn_arg_names = &["t", "u", "data"]; let function = self.module.add_function("calc_out", fn_type, None); let basic_block = self.context.append_basic_block(function, "entry"); self.fn_value_opt = Some(function); @@ -956,7 +959,7 @@ impl<'ctx> CodeGen<'ctx> { self.insert_param(name, alloca); } - self.insert_state(model.state(), model.state_dot()); + self.insert_state(model.state()); self.insert_data(model); self.insert_indices(); @@ -983,10 +986,10 @@ impl<'ctx> CodeGen<'ctx> { let real_ptr_type = self.real_type.ptr_type(AddressSpace::default()); let void_type = self.context.void_type(); let fn_type = void_type.fn_type( - &[self.real_type.into(), real_ptr_type.into(), real_ptr_type.into(), real_ptr_type.into(), real_ptr_type.into()] + &[self.real_type.into(), real_ptr_type.into(), real_ptr_type.into(), real_ptr_type.into()] , false ); - let fn_arg_names = &["t", "u", "dudt", "data", "root"]; + let fn_arg_names = &["t", "u", "data", "root"]; let function = self.module.add_function("calc_stop", fn_type, None); let basic_block = self.context.append_basic_block(function, "entry"); self.fn_value_opt = Some(function); @@ -998,7 +1001,7 @@ impl<'ctx> CodeGen<'ctx> { self.insert_param(name, alloca); } - self.insert_state(model.state(), model.state_dot()); + self.insert_state(model.state()); self.insert_data(model); self.insert_indices(); @@ -1022,16 +1025,16 @@ impl<'ctx> CodeGen<'ctx> { } - pub fn compile_residual<'m>(& mut self, model: &'m DiscreteModel) -> Result> { + pub fn compile_rhs<'m>(& mut self, model: &'m DiscreteModel) -> Result> { self.clear(); let real_ptr_type = self.real_type.ptr_type(AddressSpace::default()); let void_type = self.context.void_type(); let fn_type = void_type.fn_type( - &[self.real_type.into(), real_ptr_type.into(), real_ptr_type.into(), real_ptr_type.into(), real_ptr_type.into()] + &[self.real_type.into(), real_ptr_type.into(), real_ptr_type.into(), real_ptr_type.into()] , false ); - let fn_arg_names = &["t", "u", "dudt", "data", "rr"]; - let function = self.module.add_function("residual", fn_type, None); + let fn_arg_names = &["t", "u", "data", "rr"]; + let function = self.module.add_function("rhs", fn_type, None); let basic_block = self.context.append_basic_block(function, "entry"); self.fn_value_opt = Some(function); self.builder.position_at_end(basic_block); @@ -1042,7 +1045,7 @@ impl<'ctx> CodeGen<'ctx> { self.insert_param(name, alloca); } - self.insert_state(model.state(), model.state_dot()); + self.insert_state(model.state()); self.insert_data(model); self.insert_indices(); @@ -1051,20 +1054,66 @@ impl<'ctx> CodeGen<'ctx> { self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?; } - // TODO: could split state dep defns into before and after F and G + // TODO: could split state dep defns into before and after F for a in model.state_dep_defns() { self.jit_compile_tensor(a, Some(*self.get_var(a)))?; } - // F and G - self.jit_compile_tensor(model.lhs(), Some(*self.get_var(model.lhs())))?; - self.jit_compile_tensor(model.rhs(), Some(*self.get_var(model.rhs())))?; + // F + let res_ptr = self.get_param("rr"); + self.jit_compile_tensor(model.rhs(), Some(*res_ptr))?; - // compute residual here as dummy array - let residual = model.residual(); + self.builder.build_return(None)?; + + if function.verify(true) { + self.fpm.run_on(&function); + Ok(function) + } else { + function.print_to_stderr(); + unsafe { + function.delete(); + } + Err(anyhow!("Invalid generated function.")) + } + } + pub fn compile_mass<'m>(& mut self, model: &'m DiscreteModel) -> Result> { + self.clear(); + let real_ptr_type = self.real_type.ptr_type(AddressSpace::default()); + let void_type = self.context.void_type(); + let fn_type = void_type.fn_type( + &[self.real_type.into(), real_ptr_type.into(), real_ptr_type.into(), real_ptr_type.into()] + , false + ); + let fn_arg_names = &["t", "dudt", "data", "rr"]; + let function = self.module.add_function("mass", fn_type, None); + let basic_block = self.context.append_basic_block(function, "entry"); + self.fn_value_opt = Some(function); + self.builder.position_at_end(basic_block); + + for (i, arg) in function.get_param_iter().enumerate() { + let name = fn_arg_names[i]; + let alloca = self.function_arg_alloca(name, arg); + self.insert_param(name, alloca); + } + + self.insert_dot_state(model.state_dot()); + self.insert_data(model); + self.insert_indices(); + + // calculate time dependant definitions + for tensor in model.time_dep_defns() { + self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?; + } + + for a in model.dstate_dep_defns() { + self.jit_compile_tensor(a, Some(*self.get_var(a)))?; + } + + // mass let res_ptr = self.get_param("rr"); - let _res_ptr = self.jit_compile_tensor(&residual, Some(*res_ptr))?; + self.jit_compile_tensor(model.lhs(), Some(*res_ptr))?; + self.builder.build_return(None)?; if function.verify(true) { diff --git a/src/execution/compiler.rs b/src/execution/compiler.rs index 7c24515..c4f6fe5 100644 --- a/src/execution/compiler.rs +++ b/src/execution/compiler.rs @@ -16,23 +16,23 @@ use crate::utils::find_runtime_path; use std::process::Command; -use super::codegen::CalcOutGradientFunc; +use super::codegen::{CalcOutGradientFunc, MassFunc, RhsFunc, RhsGradientFunc}; use super::codegen::CompileGradientArgType; use super::codegen::GetDimsFunc; use super::codegen::GetOutFunc; -use super::codegen::ResidualGradientFunc; use super::codegen::SetIdFunc; use super::codegen::SetInputsFunc; use super::codegen::SetInputsGradientFunc; use super::codegen::U0GradientFunc; -use super::{CodeGen, codegen::{U0Func, ResidualFunc, CalcOutFunc, StopFunc}, data_layout::DataLayout}; +use super::{CodeGen, codegen::{U0Func, CalcOutFunc, StopFunc}, data_layout::DataLayout}; struct JitFunctions<'ctx> { set_u0: JitFunction<'ctx, U0Func>, - residual: JitFunction<'ctx, ResidualFunc>, + rhs: JitFunction<'ctx, RhsFunc>, + mass: JitFunction<'ctx, MassFunc>, calc_out: JitFunction<'ctx, CalcOutFunc>, calc_stop: JitFunction<'ctx, StopFunc>, set_id: JitFunction<'ctx, SetIdFunc>, @@ -43,7 +43,7 @@ struct JitFunctions<'ctx> { struct JitGradFunctions<'ctx> { set_u0_grad: JitFunction<'ctx, U0GradientFunc>, - residual_grad: JitFunction<'ctx, ResidualGradientFunc>, + rhs_grad: JitFunction<'ctx, RhsGradientFunc>, calc_out_grad: JitFunction<'ctx, CalcOutGradientFunc>, set_inputs_grad: JitFunction<'ctx, SetInputsGradientFunc>, } @@ -116,12 +116,13 @@ impl Compiler { let mut codegen = CodeGen::new(model, context, module, real_type, real_type_str); let _set_u0 = codegen.compile_set_u0(model)?; - let _set_u0_grad = codegen.compile_gradient(_set_u0, &[CompileGradientArgType::Dup, CompileGradientArgType::Dup, CompileGradientArgType::Dup])?; + let _set_u0_grad = codegen.compile_gradient(_set_u0, &[CompileGradientArgType::Dup, CompileGradientArgType::Dup])?; let _calc_stop = codegen.compile_calc_stop(model)?; - let _residual = codegen.compile_residual(model)?; - let _residual_grad = codegen.compile_gradient(_residual, &[CompileGradientArgType::Const, CompileGradientArgType::Dup, CompileGradientArgType::Dup, CompileGradientArgType::Dup, CompileGradientArgType::DupNoNeed])?; + let _rhs= codegen.compile_rhs(model)?; + let _mass = codegen.compile_mass(model)?; + let _rhs_grad = codegen.compile_gradient(_rhs, &[CompileGradientArgType::Const, CompileGradientArgType::Dup, CompileGradientArgType::Dup, CompileGradientArgType::DupNoNeed])?; let _calc_out = codegen.compile_calc_out(model)?; - let _calc_out_grad = codegen.compile_gradient(_calc_out, &[CompileGradientArgType::Const, CompileGradientArgType::Dup, CompileGradientArgType::Dup, CompileGradientArgType::Dup])?; + let _calc_out_grad = codegen.compile_gradient(_calc_out, &[CompileGradientArgType::Const, CompileGradientArgType::Dup, CompileGradientArgType::Dup])?; let _set_id = codegen.compile_set_id(model)?; let _get_dims= codegen.compile_get_dims(model)?; let _set_inputs = codegen.compile_set_inputs(model)?; @@ -156,7 +157,8 @@ impl Compiler { let ee = module.create_jit_execution_engine(OptimizationLevel::None).map_err(|e| anyhow::anyhow!("Error creating execution engine: {:?}", e))?; let set_u0 = Compiler::jit("set_u0", &ee)?; - let residual = Compiler::jit("residual", &ee)?; + let rhs = Compiler::jit("rhs", &ee)?; + let mass = Compiler::jit("mass", &ee)?; let calc_stop = Compiler::jit("calc_stop", &ee)?; let calc_out = Compiler::jit("calc_out", &ee)?; let set_id = Compiler::jit("set_id", &ee)?; @@ -166,14 +168,15 @@ impl Compiler { let set_inputs_grad = Compiler::jit("set_inputs_grad", &ee)?; let calc_out_grad = Compiler::jit("calc_out_grad", &ee)?; - let residual_grad = Compiler::jit("residual_grad", &ee)?; + let rhs_grad = Compiler::jit("rhs_grad", &ee)?; let set_u0_grad = Compiler::jit("set_u0_grad", &ee)?; let data = CompilerData { codegen, jit_functions: JitFunctions { set_u0, - residual, + rhs, + mass, calc_out, set_id, get_dims, @@ -183,7 +186,7 @@ impl Compiler { }, jit_grad_functions: JitGradFunctions { set_u0_grad, - residual_grad, + rhs_grad, calc_out_grad, set_inputs_grad, }, @@ -330,37 +333,27 @@ impl Compiler { Some(&data[index..index+nnz]) } - pub fn set_u0(&self, yy: &mut [f64], yp: &mut [f64], data: &mut [f64]) -> Result<()> { + pub fn set_u0(&self, yy: &mut [f64], data: &mut [f64]) -> Result<()> { let number_of_states = *self.borrow_number_of_states(); if yy.len() != number_of_states { return Err(anyhow!("Expected {} states, got {}", number_of_states, yy.len())); } - if yp.len() != number_of_states { - return Err(anyhow!("Expected {} state derivatives, got {}", number_of_states, yp.len())); - } self.with_data(|compiler| { let yy_ptr = yy.as_mut_ptr(); - let yp_ptr = yp.as_mut_ptr(); let data_ptr = data.as_mut_ptr(); - unsafe { compiler.jit_functions.set_u0.call(data_ptr, yy_ptr, yp_ptr); } + unsafe { compiler.jit_functions.set_u0.call(data_ptr, yy_ptr); } }); Ok(()) } - pub fn set_u0_grad(&self, yy: &mut [f64], dyy: &mut [f64], yp: &mut [f64], dyp: &mut [f64], data: &mut [f64], ddata: &mut [f64]) -> Result<()> { + pub fn set_u0_grad(&self, yy: &mut [f64], dyy: &mut [f64], data: &mut [f64], ddata: &mut [f64]) -> Result<()> { let number_of_states = *self.borrow_number_of_states(); if yy.len() != number_of_states { return Err(anyhow!("Expected {} states, got {}", number_of_states, yy.len())); } - if yp.len() != number_of_states { - return Err(anyhow!("Expected {} state derivatives, got {}", number_of_states, yp.len())); - } if dyy.len() != number_of_states { return Err(anyhow!("Expected {} states for dyy, got {}", number_of_states, dyy.len())); } - if dyp.len() != number_of_states { - return Err(anyhow!("Expected {} state derivatives for dyp, got {}", number_of_states, dyp.len())); - } if data.len() != self.data_len() { return Err(anyhow!("Expected {} data, got {}", self.data_len(), data.len())); } @@ -369,25 +362,20 @@ impl Compiler { } self.with_data(|compiler| { let yy_ptr = yy.as_mut_ptr(); - let yp_ptr = yp.as_mut_ptr(); let data_ptr = data.as_mut_ptr(); let dyy_ptr = dyy.as_mut_ptr(); - let dyp_ptr = dyp.as_mut_ptr(); let ddata_ptr = ddata.as_mut_ptr(); - unsafe { compiler.jit_grad_functions.set_u0_grad.call(data_ptr, ddata_ptr, yy_ptr, dyy_ptr, yp_ptr, dyp_ptr); } + unsafe { compiler.jit_grad_functions.set_u0_grad.call(data_ptr, ddata_ptr, yy_ptr, dyy_ptr); } }); Ok(()) } - pub fn calc_stop(&self, t: f64, yy: &[f64], yp: &[f64], data: &mut [f64], stop: &mut [f64]) -> Result<()> { + pub fn calc_stop(&self, t: f64, yy: &[f64], data: &mut [f64], stop: &mut [f64]) -> Result<()> { let (n_states, _, _, n_data, n_stop) = self.get_dims(); if yy.len() != n_states { return Err(anyhow!("Expected {} states, got {}", n_states, yy.len())); } - if yp.len() != n_states { - return Err(anyhow!("Expected {} state derivatives, got {}", n_states, yp.len())); - } if data.len() != n_data { return Err(anyhow!("Expected {} data, got {}", n_data, data.len())); } @@ -396,22 +384,38 @@ impl Compiler { } self.with_data(|compiler| { let yy_ptr = yy.as_ptr(); - let yp_ptr = yp.as_ptr(); let data_ptr = data.as_mut_ptr(); let stop_ptr = stop.as_mut_ptr(); - unsafe { compiler.jit_functions.calc_stop.call(t, yy_ptr, yp_ptr, data_ptr, stop_ptr); } + unsafe { compiler.jit_functions.calc_stop.call(t, yy_ptr, data_ptr, stop_ptr); } }); Ok(()) } - pub fn residual(&self, t: f64, yy: &[f64], yp: &[f64], data: &mut [f64], rr: &mut [f64]) -> Result<()> { + pub fn rhs(&self, t: f64, yy: &[f64], data: &mut [f64], rr: &mut [f64]) -> Result<()> { let number_of_states = *self.borrow_number_of_states(); if yy.len() != number_of_states { return Err(anyhow!("Expected {} states, got {}", number_of_states, yy.len())); } + if rr.len() != number_of_states { + return Err(anyhow!("Expected {} residual states, got {}", number_of_states, rr.len())); + } + if data.len() != self.data_len() { + return Err(anyhow!("Expected {} data, got {}", self.data_len(), data.len())); + } + self.with_data(|compiler| { + let yy_ptr = yy.as_ptr(); + let rr_ptr = rr.as_mut_ptr(); + let data_ptr = data.as_mut_ptr(); + unsafe { compiler.jit_functions.rhs.call(t, yy_ptr, data_ptr, rr_ptr); } + }); + Ok(()) + } + + pub fn mass(&self, t: f64, yp: &[f64], data: &mut [f64], rr: &mut [f64]) -> Result<()> { + let number_of_states = *self.borrow_number_of_states(); if yp.len() != number_of_states { - return Err(anyhow!("Expected {} state derivatives, got {}", number_of_states, yp.len())); + return Err(anyhow!("Expected {} states, got {}", number_of_states, yp.len())); } if rr.len() != number_of_states { return Err(anyhow!("Expected {} residual states, got {}", number_of_states, rr.len())); @@ -420,11 +424,10 @@ impl Compiler { return Err(anyhow!("Expected {} data, got {}", self.data_len(), data.len())); } self.with_data(|compiler| { - let yy_ptr = yy.as_ptr(); let yp_ptr = yp.as_ptr(); let rr_ptr = rr.as_mut_ptr(); let data_ptr = data.as_mut_ptr(); - unsafe { compiler.jit_functions.residual.call(t, yy_ptr, yp_ptr, data_ptr, rr_ptr); } + unsafe { compiler.jit_functions.mass.call(t, yp_ptr, data_ptr, rr_ptr); } }); Ok(()) } @@ -440,23 +443,17 @@ impl Compiler { } #[allow(clippy::too_many_arguments)] - pub fn residual_grad(&self, t: f64, yy: &[f64], dyy: &[f64], yp: &[f64], dyp: &[f64], data: &mut [f64], ddata: &mut [f64], rr: &mut [f64], drr: &mut [f64]) -> Result<()> { + pub fn rhs_grad(&self, t: f64, yy: &[f64], dyy: &[f64], data: &mut [f64], ddata: &mut [f64], rr: &mut [f64], drr: &mut [f64]) -> Result<()> { let number_of_states = *self.borrow_number_of_states(); if yy.len() != number_of_states { return Err(anyhow!("Expected {} states, got {}", number_of_states, yy.len())); } - if yp.len() != number_of_states { - return Err(anyhow!("Expected {} state derivatives, got {}", number_of_states, yp.len())); - } if rr.len() != number_of_states { return Err(anyhow!("Expected {} residual states, got {}", number_of_states, rr.len())); } if dyy.len() != number_of_states { return Err(anyhow!("Expected {} states for dyy, got {}", number_of_states, dyy.len())); } - if dyp.len() != number_of_states { - return Err(anyhow!("Expected {} state derivatives for dyp, got {}", number_of_states, dyp.len())); - } if drr.len() != number_of_states { return Err(anyhow!("Expected {} residual states for drr, got {}", number_of_states, drr.len())); } @@ -468,67 +465,52 @@ impl Compiler { } self.with_data(|compiler| { let yy_ptr = yy.as_ptr(); - let yp_ptr = yp.as_ptr(); let rr_ptr = rr.as_mut_ptr(); let dyy_ptr = dyy.as_ptr(); - let dyp_ptr = dyp.as_ptr(); let drr_ptr = drr.as_mut_ptr(); let data_ptr = data.as_mut_ptr(); let ddata_ptr = ddata.as_mut_ptr(); - unsafe { compiler.jit_grad_functions.residual_grad.call(t, yy_ptr, dyy_ptr, yp_ptr, dyp_ptr, data_ptr, ddata_ptr, rr_ptr, drr_ptr); } + unsafe { compiler.jit_grad_functions.rhs_grad.call(t, yy_ptr, dyy_ptr, data_ptr, ddata_ptr, rr_ptr, drr_ptr); } }); Ok(()) } - pub fn calc_out(&self, t: f64, yy: &[f64], yp: &[f64], data: &mut [f64]) -> Result<()> { + pub fn calc_out(&self, t: f64, yy: &[f64], data: &mut [f64]) -> Result<()> { let number_of_states = *self.borrow_number_of_states(); if yy.len() != *self.borrow_number_of_states() { return Err(anyhow!("Expected {} states, got {}", number_of_states, yy.len())); } - if yp.len() != *self.borrow_number_of_states() { - return Err(anyhow!("Expected {} state derivatives, got {}", number_of_states, yp.len())); - } if data.len() != self.data_len() { return Err(anyhow!("Expected {} data, got {}", self.data_len(), data.len())); } self.with_data(|compiler| { let yy_ptr = yy.as_ptr(); - let yp_ptr = yp.as_ptr(); let data_ptr = data.as_mut_ptr(); - unsafe { compiler.jit_functions.calc_out.call(t, yy_ptr, yp_ptr, data_ptr ); } + unsafe { compiler.jit_functions.calc_out.call(t, yy_ptr, data_ptr ); } }); Ok(()) } - #[allow(clippy::too_many_arguments)] - pub fn calc_out_grad(&self, t: f64, yy: &[f64], dyy: &[f64], yp: &[f64], dyp: &[f64], data: &mut [f64], ddata: &mut [f64]) -> Result<()> { + pub fn calc_out_grad(&self, t: f64, yy: &[f64], dyy: &[f64], data: &mut [f64], ddata: &mut [f64]) -> Result<()> { let number_of_states = *self.borrow_number_of_states(); if yy.len() != *self.borrow_number_of_states() { return Err(anyhow!("Expected {} states, got {}", number_of_states, yy.len())); } - if yp.len() != *self.borrow_number_of_states() { - return Err(anyhow!("Expected {} state derivatives, got {}", number_of_states, yp.len())); - } if data.len() != self.data_len() { return Err(anyhow!("Expected {} data, got {}", self.data_len(), data.len())); } if dyy.len() != *self.borrow_number_of_states() { return Err(anyhow!("Expected {} states for dyy, got {}", number_of_states, dyy.len())); } - if dyp.len() != *self.borrow_number_of_states() { - return Err(anyhow!("Expected {} state derivatives for dyp, got {}", number_of_states, dyp.len())); - } if ddata.len() != self.data_len() { return Err(anyhow!("Expected {} data for ddata, got {}", self.data_len(), ddata.len())); } self.with_data(|compiler| { let yy_ptr = yy.as_ptr(); - let yp_ptr = yp.as_ptr(); let data_ptr = data.as_mut_ptr(); let dyy_ptr = dyy.as_ptr(); - let dyp_ptr = dyp.as_ptr(); let ddata_ptr = ddata.as_mut_ptr(); - unsafe { compiler.jit_grad_functions.calc_out_grad.call(t, yy_ptr, dyy_ptr, yp_ptr, dyp_ptr, data_ptr, ddata_ptr); } + unsafe { compiler.jit_grad_functions.calc_out_grad.call(t, yy_ptr, dyy_ptr, data_ptr, ddata_ptr); } }); Ok(()) } @@ -722,19 +704,20 @@ mod tests { fn test_from_discrete_str() { let text = " u { y = 1 } - G { -y } + F { -y } out { y } "; let compiler = Compiler::from_discrete_str(text).unwrap(); let mut u0 = vec![0.]; - let mut up0 = vec![1.]; + let up0 = vec![2.]; let mut res = vec![0.]; let mut data = compiler.get_new_data(); - compiler.set_u0(u0.as_mut_slice(), up0.as_mut_slice(), data.as_mut_slice()).unwrap(); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()).unwrap(); assert_relative_eq!(u0.as_slice(), vec![1.].as_slice()); - assert_relative_eq!(up0.as_slice(), vec![0.].as_slice()); - compiler.residual(0., u0.as_slice(), up0.as_slice(), data.as_mut_slice(), res.as_mut_slice()).unwrap(); - assert_relative_eq!(res.as_slice(), vec![1.].as_slice()); + compiler.rhs(0., u0.as_slice(), data.as_mut_slice(), res.as_mut_slice()).unwrap(); + assert_relative_eq!(res.as_slice(), vec![-1.].as_slice()); + compiler.mass(0., up0.as_slice(), data.as_mut_slice(), res.as_mut_slice()).unwrap(); + assert_relative_eq!(res.as_slice(), vec![2.].as_slice()); } @@ -747,10 +730,10 @@ mod tests { dudt_i { dydt = 0, } - F_i { + M_i { dydt, } - G_i { + F_i { y * (1 - y), } stop_i { @@ -764,13 +747,12 @@ mod tests { let discrete_model = DiscreteModel::build("$name", &model).unwrap(); let compiler = Compiler::from_discrete_model(&discrete_model, "test_output/compiler_test_stop").unwrap(); let mut u0 = vec![1.]; - let mut up0 = vec![1.]; let mut res = vec![0.]; let mut stop = vec![0.]; let mut data = compiler.get_new_data(); - compiler.set_u0(u0.as_mut_slice(), up0.as_mut_slice(), data.as_mut_slice()).unwrap(); - compiler.residual(0., u0.as_slice(), up0.as_slice(), data.as_mut_slice(), res.as_mut_slice()).unwrap(); - compiler.calc_stop(0., u0.as_slice(), up0.as_slice(), data.as_mut_slice(), stop.as_mut_slice()).unwrap(); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()).unwrap(); + compiler.rhs(0., u0.as_slice(), data.as_mut_slice(), res.as_mut_slice()).unwrap(); + compiler.calc_stop(0., u0.as_slice(), data.as_mut_slice(), stop.as_mut_slice()).unwrap(); assert_relative_eq!(stop[0], 0.5); assert_eq!(stop.len(), 1); } @@ -790,7 +772,6 @@ mod tests { }; let compiler = Compiler::from_discrete_model(&discrete_model, tmp_loc).unwrap(); let mut u0 = vec![1.]; - let mut up0 = vec![1.]; let mut res = vec![0.]; let mut data = compiler.get_new_data(); let mut grad_data = Vec::new(); @@ -801,21 +782,20 @@ mod tests { let mut results = Vec::new(); let inputs = vec![1.; n_inputs]; compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()).unwrap(); - compiler.set_u0(u0.as_mut_slice(), up0.as_mut_slice(), data.as_mut_slice()).unwrap(); - compiler.residual(0., u0.as_slice(), up0.as_slice(), data.as_mut_slice(), res.as_mut_slice()).unwrap(); - compiler.calc_out(0., u0.as_slice(), up0.as_slice(), data.as_mut_slice()).unwrap(); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()).unwrap(); + compiler.rhs(0., u0.as_slice(), data.as_mut_slice(), res.as_mut_slice()).unwrap(); + compiler.calc_out(0., u0.as_slice(), data.as_mut_slice()).unwrap(); results.push(compiler.get_tensor_data(tensor_name, data.as_slice()).unwrap().to_vec()); for i in 0..n_inputs { let mut dinputs = vec![0.; n_inputs]; dinputs[i] = 1.0; let mut ddata = compiler.get_new_data(); let mut du0 = vec![0.]; - let mut dup0 = vec![0.]; let mut dres = vec![0.]; compiler.set_inputs_grad(inputs.as_slice(), dinputs.as_slice(), grad_data[i].as_mut_slice(), ddata.as_mut_slice()).unwrap(); - compiler.set_u0_grad(u0.as_mut_slice(), du0.as_mut_slice(), up0.as_mut_slice(), dup0.as_mut_slice(), grad_data[i].as_mut_slice(), ddata.as_mut_slice()).unwrap(); - compiler.residual_grad(0., u0.as_slice(), du0.as_slice(), up0.as_slice(), dup0.as_slice(), grad_data[i].as_mut_slice(), ddata.as_mut_slice(), res.as_mut_slice(), dres.as_mut_slice()).unwrap(); - compiler.calc_out_grad(0., u0.as_slice(), du0.as_slice(), up0.as_slice(), dup0.as_slice(), grad_data[i].as_mut_slice(), ddata.as_mut_slice()).unwrap(); + compiler.set_u0_grad(u0.as_mut_slice(), du0.as_mut_slice(), grad_data[i].as_mut_slice(), ddata.as_mut_slice()).unwrap(); + compiler.rhs_grad(0., u0.as_slice(), du0.as_slice(), grad_data[i].as_mut_slice(), ddata.as_mut_slice(), res.as_mut_slice(), dres.as_mut_slice()).unwrap(); + compiler.calc_out_grad(0., u0.as_slice(), du0.as_slice(), grad_data[i].as_mut_slice(), ddata.as_mut_slice()).unwrap(); results.push(compiler.get_tensor_data(tensor_name, ddata.as_slice()).unwrap().to_vec()); } results @@ -941,10 +921,10 @@ mod tests { dydt = p, }} {} - F_i {{ + M_i {{ dydt, }} - G_i {{ + F_i {{ y, }} out_i {{ @@ -985,10 +965,10 @@ mod tests { r { 2 * y * p, } - F_i { + M_i { dydt, } - G_i { + F_i { r, } out_i { @@ -1006,9 +986,7 @@ mod tests { }; let compiler = Compiler::from_discrete_model(&discrete_model, "test_output/compiler_test_repeated_grad").unwrap(); let mut u0 = vec![1.]; - let mut up0 = vec![1.]; let mut du0 = vec![1.]; - let mut dup0 = vec![1.]; let mut res = vec![0.]; let mut dres = vec![0.]; let mut data = compiler.get_new_data(); @@ -1019,9 +997,9 @@ mod tests { let inputs = vec![2.; n_inputs]; let dinputs = vec![1.; n_inputs]; compiler.set_inputs_grad(inputs.as_slice(), dinputs.as_slice(), data.as_mut_slice(), ddata.as_mut_slice()).unwrap(); - compiler.set_u0_grad(u0.as_mut_slice(), du0.as_mut_slice(), up0.as_mut_slice(), dup0.as_mut_slice(), data.as_mut_slice(), ddata.as_mut_slice()).unwrap(); - compiler.residual_grad(0., u0.as_slice(), du0.as_slice(), up0.as_slice(), dup0.as_slice(), data.as_mut_slice(), ddata.as_mut_slice(), res.as_mut_slice(), dres.as_mut_slice()).unwrap(); - assert_relative_eq!(dres.as_slice(), vec![-8.].as_slice()); + compiler.set_u0_grad(u0.as_mut_slice(), du0.as_mut_slice(), data.as_mut_slice(), ddata.as_mut_slice()).unwrap(); + compiler.rhs_grad(0., u0.as_slice(), du0.as_slice(), data.as_mut_slice(), ddata.as_mut_slice(), res.as_mut_slice(), dres.as_mut_slice()).unwrap(); + assert_relative_eq!(dres.as_slice(), vec![8.].as_slice()); } } @@ -1040,11 +1018,11 @@ mod tests { dydt = 0, 0, } - F_i { + M_i { dydt, 0, } - G_i { + F_i { y - 1, x - 2, } @@ -1075,17 +1053,21 @@ mod tests { assert_eq!(id, vec![1.0, 0.0]); let mut u = vec![0., 0.]; - let mut up = vec![0., 0.]; - compiler.set_u0(u.as_mut_slice(), up.as_mut_slice(), data.as_mut_slice()).unwrap(); + compiler.set_u0(u.as_mut_slice(), data.as_mut_slice()).unwrap(); assert_relative_eq!(u.as_slice(), vec![1., 2.].as_slice()); - assert_relative_eq!(up.as_slice(), vec![0., 0.].as_slice()); let mut rr = vec![1., 1.]; - compiler.residual(0., u.as_slice(), up.as_slice(), data.as_mut_slice(), rr.as_mut_slice()).unwrap(); + compiler.rhs(0., u.as_slice(), data.as_mut_slice(), rr.as_mut_slice()).unwrap(); assert_relative_eq!(rr.as_slice(), vec![0., 0.].as_slice()); + + let up = vec![2., 3.]; + rr = vec![1., 1.]; + compiler.mass(0., up.as_slice(), data.as_mut_slice(), rr.as_mut_slice()).unwrap(); + assert_relative_eq!(rr.as_slice(), vec![2., 0.].as_slice()); - compiler.calc_out(0., u.as_slice(), up.as_slice(), data.as_mut_slice()).unwrap(); + compiler.calc_out(0., u.as_slice(), data.as_mut_slice()).unwrap(); let out = compiler.get_out(data.as_slice()); assert_relative_eq!(out, vec![1., 2., 4.].as_slice()); + } } \ No newline at end of file diff --git a/src/execution/translation.rs b/src/execution/translation.rs index 5ebc0d0..d52a6d6 100644 --- a/src/execution/translation.rs +++ b/src/execution/translation.rs @@ -224,10 +224,10 @@ mod tests { dudt_i {{ dydt = 0, }} - F_i {{ + M_i {{ dydt, }} - G_i {{ + F_i {{ y, }} out_i {{ diff --git a/src/lib.rs b/src/lib.rs index 864c455..8912cbe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -120,17 +120,20 @@ mod tests { let dzdt = 2. * dydt; let inputs = vec![r, k]; let mut u0 = vec![y, z]; - let mut up0 = vec![dydt, dzdt]; let mut data = compiler.get_new_data(); compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()).unwrap(); - compiler.set_u0(u0.as_mut_slice(), up0.as_mut_slice(), data.as_mut_slice()).unwrap(); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()).unwrap(); u0 = vec![y, z]; - up0 = vec![dydt, dzdt]; + let up0 = vec![dydt, dzdt]; let mut res = vec![1., 1.]; - compiler.residual(0., u0.as_slice(), up0.as_slice(), data.as_mut_slice(), res.as_mut_slice()).unwrap(); - let expected_value = vec![0., 0.]; + compiler.rhs(0., u0.as_slice(), data.as_mut_slice(), res.as_mut_slice()).unwrap(); + let expected_value = vec![dydt, 2.0 * y - z]; + assert_relative_eq!(res.as_slice(), expected_value.as_slice()); + + compiler.mass(0., up0.as_slice(), data.as_mut_slice(), res.as_mut_slice()).unwrap(); + let expected_value = vec![dydt, 0.]; assert_relative_eq!(res.as_slice(), expected_value.as_slice()); }