From 65a3d27f054cba9b0fc0b818d6b0403825487ea4 Mon Sep 17 00:00:00 2001 From: Martin Robinson Date: Sat, 7 Dec 2024 09:15:03 +0000 Subject: [PATCH] bug: stop calculates time and state dep tensors (#27) --- src/execution/compiler.rs | 7 +++++++ src/execution/cranelift/codegen.rs | 12 +++++++++++- src/execution/llvm/codegen.rs | 12 +++++++++++- 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/src/execution/compiler.rs b/src/execution/compiler.rs index de32824..e014f6c 100644 --- a/src/execution/compiler.rs +++ b/src/execution/compiler.rs @@ -618,6 +618,7 @@ mod tests { twoy_i { 2 * y } F_i { y * (1 - y), } out_i { twoy_i } + stop_i { twoy_i - 0.5 } "; let model = parse_ds_string(full_text).unwrap(); let discrete_model = DiscreteModel::build("$name", &model).unwrap(); @@ -633,6 +634,12 @@ mod tests { compiler.calc_out(0., u0.as_slice(), data.as_mut_slice()); let out = compiler.get_out(data.as_slice()); assert_relative_eq!(out[0], 4.); + let mut stop = vec![0.]; + compiler.calc_stop(0., u0.as_slice(), data.as_mut_slice(), stop.as_mut_slice()); + assert_relative_eq!(stop[0], 3.5); + u0[0] = 0.5; + compiler.calc_stop(0., u0.as_slice(), data.as_mut_slice(), stop.as_mut_slice()); + assert_relative_eq!(stop[0], 0.5); } #[test] diff --git a/src/execution/cranelift/codegen.rs b/src/execution/cranelift/codegen.rs index 7b18d1a..d9c7ca1 100644 --- a/src/execution/cranelift/codegen.rs +++ b/src/execution/cranelift/codegen.rs @@ -332,7 +332,7 @@ impl CodegenModule for CraneliftModule { codegen.jit_compile_tensor(tensor, None, false)?; } - // TODO: could split state dep defns into before and after F + // calculate state dependant definitions for a in model.state_dep_defns() { codegen.jit_compile_tensor(a, None, false)?; } @@ -355,6 +355,16 @@ impl CodegenModule for CraneliftModule { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); if let Some(stop) = model.stop() { + // calculate time dependant definitions + for tensor in model.time_dep_defns() { + codegen.jit_compile_tensor(tensor, None, false)?; + } + + // calculate state dependant definitions + for a in model.state_dep_defns() { + codegen.jit_compile_tensor(a, None, false)?; + } + let root = *codegen.variables.get("root").unwrap(); codegen.jit_compile_tensor(stop, Some(root), false)?; } diff --git a/src/execution/llvm/codegen.rs b/src/execution/llvm/codegen.rs index 75b00a9..e88b453 100644 --- a/src/execution/llvm/codegen.rs +++ b/src/execution/llvm/codegen.rs @@ -1892,7 +1892,7 @@ impl<'ctx> CodeGen<'ctx> { self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?; } - // TODO: could split state dep defns into before and after F + // calculate state dependant definitions for a in model.state_dep_defns() { self.jit_compile_tensor(a, Some(*self.get_var(a)))?; } @@ -1951,6 +1951,16 @@ impl<'ctx> CodeGen<'ctx> { self.insert_indices(); if let Some(stop) = model.stop() { + // calculate time dependant definitions + for tensor in model.time_dep_defns() { + self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?; + } + + // calculate state dependant definitions + for a in model.state_dep_defns() { + self.jit_compile_tensor(a, Some(*self.get_var(a)))?; + } + let res_ptr = self.get_param("root"); self.jit_compile_tensor(stop, Some(*res_ptr))?; }