From 3f40ae7f19e06a3f1204b2c5d7c59ea5a7dba99c Mon Sep 17 00:00:00 2001 From: Martin Robinson Date: Thu, 3 Oct 2024 14:59:52 +0100 Subject: [PATCH] bug: fix for case if inputs come in different order (#23) --- Cargo.toml | 1 + src/discretise/discrete_model.rs | 9 ++++ src/execution/compiler.rs | 66 ++++++++++++++++++++++++++++++ src/execution/cranelift/codegen.rs | 60 ++++++++++++++++----------- 4 files changed, 111 insertions(+), 25 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 50fb07c..aa66b30 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,6 +55,7 @@ cmake = { version = "0.1.50", optional = true } [dev-dependencies] divan = "0.1.14" +env_logger = "0.11.5" [[bench]] name = "evaluation" diff --git a/src/discretise/discrete_model.rs b/src/discretise/discrete_model.rs index 1c05331..7987680 100644 --- a/src/discretise/discrete_model.rs +++ b/src/discretise/discrete_model.rs @@ -373,6 +373,15 @@ impl<'s> DiscreteModel<'s> { } } + // reorder inputs to match the order defined in "in = [ ... ]" + ret.inputs.sort_by_key(|t| { + model + .inputs + .iter() + .position(|&name| name == t.name()) + .unwrap() + }); + // set is_algebraic for every state based on equations if ret.state_dot.is_some() && ret.lhs.is_some() { let state_dot = ret.state_dot.as_ref().unwrap(); diff --git a/src/execution/compiler.rs b/src/execution/compiler.rs index 2ef3e47..45d8d4f 100644 --- a/src/execution/compiler.rs +++ b/src/execution/compiler.rs @@ -593,6 +593,34 @@ mod tests { assert_eq!(stop.len(), 1); } + #[test] + fn test_vector_add_scalar_cranelift() { + let n = 1; + let u = vec![1.0; n]; + let full_text = format!( + " + u_i {{ + {} + }} + F_i {{ + u_i + 1.0, + }} + out_i {{ + u_i + }} + ", + (0..n) + .map(|i| format!("x{} = {},", i, u[i])) + .collect::>() + .join("\n"), + ); + let model = parse_ds_string(&full_text).unwrap(); + let name = "$name"; + let discrete_model = DiscreteModel::build(name, &model).unwrap(); + env_logger::builder().is_test(true).try_init().unwrap(); + let _compiler = Compiler::::from_discrete_model(&discrete_model).unwrap(); + } + fn tensor_test_common(text: &str, tensor_name: &str) -> Vec> { let full_text = format!( " @@ -981,4 +1009,42 @@ mod tests { let out = compiler.get_out(data.as_slice()); assert_relative_eq!(out, vec![1., 2., 4.].as_slice()); } + + #[test] + fn test_inputs() { + let full_text = " + in = [c, a, b] + a { 1 } b { 2 } c { 3 } + u { y = 0 } + F { y } + out { y } + "; + let model = parse_ds_string(full_text).unwrap(); + let discrete_model = DiscreteModel::build("test_inputs", &model).unwrap(); + + let compiler = Compiler::::from_discrete_model(&discrete_model).unwrap(); + let mut data = compiler.get_new_data(); + let inputs = vec![1.0, 2.0, 3.0]; + compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()); + + for (name, expected_value) in vec![("a", vec![2.0]), ("b", vec![3.0]), ("c", vec![1.0])] { + let inputs = compiler.get_tensor_data(name, data.as_slice()).unwrap(); + assert_relative_eq!(inputs, expected_value.as_slice()); + } + + #[cfg(feature = "llvm")] + { + let compiler = + Compiler::::from_discrete_model(&discrete_model).unwrap(); + let mut data = compiler.get_new_data(); + let inputs = vec![1.0, 2.0, 3.0]; + compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()); + + for (name, expected_value) in vec![("a", vec![2.0]), ("b", vec![3.0]), ("c", vec![1.0])] + { + let inputs = compiler.get_tensor_data(name, data.as_slice()).unwrap(); + assert_relative_eq!(inputs, expected_value.as_slice()); + } + } + } } diff --git a/src/execution/cranelift/codegen.rs b/src/execution/cranelift/codegen.rs index 62e1807..0676869 100644 --- a/src/execution/cranelift/codegen.rs +++ b/src/execution/cranelift/codegen.rs @@ -52,7 +52,7 @@ impl CraneliftModule { .module .declare_function(name, Linkage::Export, &self.ctx.func.signature)?; - //println!("Declared function: {}", name); + //println!("Declared function: {} -------------------------------------------------------------------------------------", name); //println!("IR:\n{}", self.ctx.func); // Define the function to jit. This finishes compilation, although @@ -263,7 +263,7 @@ impl CodegenModule for CraneliftModule { // write indices data as a global data object // convect the indices to bytes - let int_type = types::I32; + let int_type = ptr_type; let real_type = types::F64; let mut vec8: Vec = vec![]; for elem in layout.indices() { @@ -875,9 +875,6 @@ impl<'ctx> CraneliftCodeGen<'ctx> { let one = self.builder.ins().iconst(int_type, 1); let zero = self.builder.ins().iconst(int_type, 0); - let expr_index_var = self.decl_stack_slot(self.int_type, Some(zero)); - let elmt_index_var = self.decl_stack_slot(self.int_type, Some(zero)); - // setup indices, loop through the nested loops let mut indices = Vec::new(); let mut blocks = Vec::new(); @@ -896,6 +893,13 @@ impl<'ctx> CraneliftCodeGen<'ctx> { (None, 0) }; + //let expr_index_var = self.decl_stack_slot(self.int_type, Some(zero)); + let elmt_index_var = if contract_sum.is_some() { + Some(self.decl_stack_slot(self.int_type, Some(zero))) + } else { + None + }; + for i in 0..expr_rank { let block = self.builder.create_block(); let curr_index = self.builder.append_block_param(block, self.int_type); @@ -914,28 +918,22 @@ impl<'ctx> CraneliftCodeGen<'ctx> { preblock = block; } - let elmt_index = self - .builder - .ins() - .stack_load(self.int_type, elmt_index_var, 0); - // load and increment the expression index - let expr_index = self - .builder - .ins() - .stack_load(self.int_type, expr_index_var, 0); - let next_expr_index = self.builder.ins().iadd(expr_index, one); - self.builder - .ins() - .stack_store(next_expr_index, expr_index_var, 0); + //let expr_index = self + // .builder + // .ins() + // .stack_load(self.int_type, expr_index_var, 0); + //let next_expr_index = self.builder.ins().iadd(expr_index, one); + //self.builder + // .ins() + // .stack_store(next_expr_index, expr_index_var, 0); let expr = if is_tangent { elmt.tangent_expr() } else { elmt.expr() }; - let float_value = - self.jit_compile_expr(name, expr, indices.as_slice(), elmt, Some(expr_index))?; + let float_value = self.jit_compile_expr(name, expr, indices.as_slice(), elmt, None)?; if contract_sum.is_some() { let contract_sum_value = @@ -947,6 +945,14 @@ impl<'ctx> CraneliftCodeGen<'ctx> { .ins() .stack_store(new_contract_sum_value, contract_sum.unwrap(), 0); } else { + let expr_index = if indices.is_empty() { + zero + } else { + indices + .iter() + .skip(1) + .fold(indices[0], |acc, x| self.builder.ins().imul(acc, *x)) + }; self.jit_compile_broadcast_and_store( name, elmt, @@ -955,20 +961,24 @@ impl<'ctx> CraneliftCodeGen<'ctx> { translation, preblock, )?; - let next_elmt_index = self.builder.ins().iadd(elmt_index, one); - self.builder - .ins() - .stack_store(next_elmt_index, elmt_index_var, 0); + //let next_elmt_index = self.builder.ins().iadd(elmt_index, one); + //self.builder + // .ins() + // .stack_store(next_elmt_index, elmt_index_var, 0); } // unwind the nested loops for i in (0..expr_rank).rev() { // update and store contract sum if i == expr_rank - contract_by - 1 && contract_sum.is_some() { + let elmt_index = + self.builder + .ins() + .stack_load(self.int_type, elmt_index_var.unwrap(), 0); let next_elmt_index = self.builder.ins().iadd(elmt_index, one); self.builder .ins() - .stack_store(next_elmt_index, elmt_index_var, 0); + .stack_store(next_elmt_index, elmt_index_var.unwrap(), 0); let contract_sum_value = self.builder