From 20bab19ade4548ab122581ee5e3c28cf972b8c8c Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Tue, 28 Jan 2025 16:47:53 +0000 Subject: [PATCH 1/5] feat: time_indep_defn split to constant_dfns and input_dep_defns --- src/discretise/discrete_model.rs | 90 +++++++++++++++++----- src/discretise/env.rs | 23 ++++-- src/execution/compiler.rs | 24 +++++- src/execution/cranelift/codegen.rs | 16 +++- src/execution/data_layout.rs | 64 +++++++++++----- src/execution/interface.rs | 2 + src/execution/llvm/codegen.rs | 115 ++++++++++++++++++++++++----- src/execution/module.rs | 1 + src/execution/translation.rs | 2 +- 9 files changed, 269 insertions(+), 68 deletions(-) diff --git a/src/discretise/discrete_model.rs b/src/discretise/discrete_model.rs index 93b4575..dd5c8f4 100644 --- a/src/discretise/discrete_model.rs +++ b/src/discretise/discrete_model.rs @@ -31,7 +31,8 @@ pub struct DiscreteModel<'s> { lhs: Option>, rhs: Tensor<'s>, out: Option>, - time_indep_defns: Vec>, + constant_defns: Vec>, + input_dep_defns: Vec>, time_dep_defns: Vec>, state_dep_defns: Vec>, dstate_dep_defns: Vec>, @@ -54,7 +55,10 @@ impl fmt::Display for DiscreteModel<'_> { writeln!(f, "{}", input)?; } } - for defn in &self.time_indep_defns { + for defn in &self.constant_defns { + writeln!(f, "{}", defn)?; + } + for defn in &self.input_dep_defns { writeln!(f, "{}", defn)?; } for defn in &self.time_dep_defns { @@ -90,7 +94,8 @@ impl<'s> DiscreteModel<'s> { lhs: None, rhs: Tensor::new_empty("F"), out: None, - time_indep_defns: Vec::new(), + constant_defns: Vec::new(), + input_dep_defns: Vec::new(), time_dep_defns: Vec::new(), state_dep_defns: Vec::new(), dstate_dep_defns: Vec::new(), @@ -258,7 +263,7 @@ impl<'s> DiscreteModel<'s> { } pub fn build(name: &'s str, model: &'s ast::DsModel) -> Result { - let mut env = Env::default(); + let mut env = Env::new(model.inputs.as_slice()); let mut ret = Self::new(name); let mut read_state = false; let mut span_f = None; @@ -374,6 +379,7 @@ impl<'s> DiscreteModel<'s> { let dependent_on_state = env_entry.is_state_dependent(); let dependent_on_time = env_entry.is_time_dependent(); let dependent_on_dudt = env_entry.is_dstatedt_dependent(); + let dependent_on_input = env_entry.is_input_dependent(); if is_input { // inputs must be constants if dependent_on_time || dependent_on_state { @@ -383,12 +389,11 @@ impl<'s> DiscreteModel<'s> { )); } ret.inputs.push(built); + } else if !dependent_on_time && !dependent_on_input { + ret.constant_defns.push(built); } else if !dependent_on_time { - ret.time_indep_defns.push(built); - } else if dependent_on_time - && !dependent_on_state - && !dependent_on_dudt - { + ret.input_dep_defns.push(built); + } else if !dependent_on_state && !dependent_on_dudt { ret.time_dep_defns.push(built); } else if dependent_on_state { ret.state_dep_defns.push(built); @@ -670,7 +675,7 @@ impl<'s> DiscreteModel<'s> { .iter() .map(DiscreteModel::dfn_to_array) .collect(); - let time_indep_defns = const_defns + let constant_defns = const_defns .iter() .map(DiscreteModel::dfn_to_array) .collect(); @@ -687,7 +692,8 @@ impl<'s> DiscreteModel<'s> { state, state_dot: Some(state_dot), out: Some(out_array), - time_indep_defns, + constant_defns, + input_dep_defns: Vec::new(), // todo: need to implement time_dep_defns, state_dep_defns, dstate_dep_defns, @@ -700,9 +706,14 @@ impl<'s> DiscreteModel<'s> { self.inputs.as_ref() } - pub fn time_indep_defns(&self) -> &[Tensor] { - self.time_indep_defns.as_ref() + pub fn constant_defns(&self) -> &[Tensor] { + self.constant_defns.as_ref() } + + pub fn input_dep_defns(&self) -> &[Tensor] { + self.input_dep_defns.as_ref() + } + pub fn time_dep_defns(&self) -> &[Tensor] { self.time_dep_defns.as_ref() } @@ -772,7 +783,8 @@ mod tests { let model_info = ModelInfo::build("circuit", &models).unwrap(); assert_eq!(model_info.errors.len(), 0); let discrete = DiscreteModel::from(&model_info); - assert_eq!(discrete.time_indep_defns.len(), 0); + assert_eq!(discrete.input_dep_defns().len(), 0); + assert_eq!(discrete.constant_defns().len(), 0); assert_eq!(discrete.time_dep_defns.len(), 1); assert_eq!(discrete.time_dep_defns[0].name(), "inputVoltage"); assert_eq!(discrete.state_dep_defns.len(), 1); @@ -849,7 +861,15 @@ mod tests { ); assert_eq!( model - .time_indep_defns() + .constant_defns() + .iter() + .map(|t| t.name()) + .collect::>(), + Vec::<&str>::new() + ); + assert_eq!( + model + .input_dep_defns() .iter() .map(|t| t.name()) .collect::>(), @@ -1163,7 +1183,7 @@ mod tests { let model = parse_ds_string(model_text.as_str()).unwrap(); match DiscreteModel::build("$name", &model) { Ok(model) => { - let tensor = model.time_indep_defns.iter().chain(model.time_dep_defns.iter()).find(|t| t.name() == $tensor_name).unwrap(); + let tensor = model.constant_defns().iter().chain(model.time_dep_defns.iter()).find(|t| t.name() == $tensor_name).unwrap(); let tensor_string = format!("{}", tensor).chars().filter(|c| !c.is_whitespace()).collect::(); let tensor_check_string = $tensor_string.chars().filter(|c| !c.is_whitespace()).collect::(); assert_eq!(tensor_string, tensor_check_string); @@ -1248,6 +1268,40 @@ mod tests { assert!(model.out().is_none()); } + #[test] + fn test_constants_and_input_dep() { + let text = " + in = [r] + r { 1, } + k { 1, } + r2 { 2 * r } + u_i { + y = k, + } + F_i { + r * y, + } + "; + let model = parse_ds_string(text).unwrap(); + let model = DiscreteModel::build("$name", &model).unwrap(); + assert_eq!( + model + .constant_defns() + .iter() + .map(|t| t.name()) + .collect::>(), + ["k"] + ); + assert_eq!( + model + .input_dep_defns() + .iter() + .map(|t| t.name()) + .collect::>(), + ["r2"] + ); + } + #[test] fn test_sparse_layout() { let text = " @@ -1272,12 +1326,12 @@ mod tests { let model = parse_ds_string(text).unwrap(); let model = DiscreteModel::build("$name", &model).unwrap(); let r = model - .time_indep_defns() + .constant_defns() .iter() .find(|t| t.name() == "r") .unwrap(); let b = model - .time_indep_defns() + .constant_defns() .iter() .find(|t| t.name() == "b") .unwrap(); diff --git a/src/discretise/env.rs b/src/discretise/env.rs index fc0dffd..38e4abd 100644 --- a/src/discretise/env.rs +++ b/src/discretise/env.rs @@ -14,6 +14,7 @@ pub struct EnvVar { is_time_dependent: bool, is_state_dependent: bool, is_dstatedt_dependent: bool, + is_input_dependent: bool, is_algebraic: bool, } @@ -34,6 +35,10 @@ impl EnvVar { self.is_algebraic } + pub fn is_input_dependent(&self) -> bool { + self.is_input_dependent + } + pub fn layout(&self) -> &Layout { self.layout.as_ref() } @@ -43,10 +48,11 @@ pub struct Env { current_span: Option, errs: ValidationErrors, vars: HashMap, + inputs: Vec, } -impl Default for Env { - fn default() -> Self { +impl Env { + pub fn new(inputs: &[&str]) -> Self { let mut vars = HashMap::new(); vars.insert( "t".to_string(), @@ -55,6 +61,7 @@ impl Default for Env { is_time_dependent: true, is_state_dependent: false, is_dstatedt_dependent: false, + is_input_dependent: false, is_algebraic: true, }, ); @@ -62,11 +69,9 @@ impl Default for Env { errs: ValidationErrors::default(), vars, current_span: None, + inputs: inputs.iter().map(|s| s.to_string()).collect(), } } -} - -impl Env { pub fn is_tensor_time_dependent(&self, tensor: &Tensor) -> bool { if tensor.name() == "u" || tensor.name() == "dudt" { return true; @@ -83,6 +88,12 @@ impl Env { self.is_tensor_dependent_on(tensor, "u") } + pub fn is_tensor_input_dependent(&self, tensor: &Tensor) -> bool { + self.inputs + .iter() + .any(|input| self.is_tensor_dependent_on(tensor, input)) + } + pub fn is_tensor_dstatedt_dependent(&self, tensor: &Tensor) -> bool { self.is_tensor_dependent_on(tensor, "dudt") } @@ -112,6 +123,7 @@ impl Env { is_time_dependent: self.is_tensor_time_dependent(var), is_state_dependent: self.is_tensor_state_dependent(var), is_dstatedt_dependent: self.is_tensor_dstatedt_dependent(var), + is_input_dependent: self.is_tensor_input_dependent(var), }, ); } @@ -125,6 +137,7 @@ impl Env { is_time_dependent: self.is_tensor_time_dependent(var), is_state_dependent: self.is_tensor_state_dependent(var), is_dstatedt_dependent: self.is_tensor_dstatedt_dependent(var), + is_input_dependent: self.is_tensor_input_dependent(var), }, ); } diff --git a/src/execution/compiler.rs b/src/execution/compiler.rs index 399f2da..a03412c 100644 --- a/src/execution/compiler.rs +++ b/src/execution/compiler.rs @@ -10,8 +10,8 @@ use super::{ interface::{ BarrierInitFunc, CalcOutGradFunc, CalcOutRevGradFunc, CalcOutSensGradFunc, CalcOutSensRevGradFunc, GetInputsFunc, MassRevGradFunc, RhsGradFunc, RhsRevGradFunc, - RhsSensGradFunc, RhsSensRevGradFunc, SetInputsGradFunc, SetInputsRevGradFunc, U0GradFunc, - U0RevGradFunc, + RhsSensGradFunc, RhsSensRevGradFunc, SetConstantsFunc, SetInputsGradFunc, + SetInputsRevGradFunc, U0GradFunc, U0RevGradFunc, }, module::CodegenModule, }; @@ -48,6 +48,7 @@ struct JitFunctions { set_inputs: SetInputsFunc, get_inputs: GetInputsFunc, barrier_init: Option, + set_constants: SetConstantsFunc, } struct JitGradFunctions { @@ -176,6 +177,7 @@ impl Compiler { let get_dims = module.compile_get_dims(model)?; let set_inputs = module.compile_set_inputs(model)?; let get_inputs = module.compile_get_inputs(model)?; + let set_constants = module.compile_set_constants(model)?; module.pre_autodiff_optimisation()?; @@ -217,6 +219,9 @@ impl Compiler { } else { None }; + let set_constants = unsafe { + std::mem::transmute::<*const u8, SetConstantsFunc>(module.jit(set_constants)?) + }; let set_u0 = unsafe { std::mem::transmute::<*const u8, U0Func>(module.jit(set_u0)?) }; let rhs = unsafe { std::mem::transmute::<*const u8, RhsFunc>(module.jit(rhs)?) }; let mass = unsafe { std::mem::transmute::<*const u8, MassFunc>(module.jit(mass)?) }; @@ -308,12 +313,13 @@ impl Compiler { None }; - Ok(Self { + let mut ret = Self { module, jit_functions: JitFunctions { set_u0, rhs, mass, + set_constants, calc_out, get_inputs, calc_stop, @@ -338,7 +344,11 @@ impl Compiler { has_mass, thread_pool, thread_lock, - }) + }; + + // all done, can set constants now + ret.set_constants(); + Ok(ret) } pub fn get_tensor_data<'a>(&self, name: &str, data: &'a [f64]) -> Option<&'a [f64]> { @@ -417,6 +427,12 @@ impl Compiler { } } + fn set_constants(&mut self) { + self.with_threading(|i, dim| unsafe { + (self.jit_functions.set_constants)(i, dim); + }); + } + pub fn set_u0(&self, yy: &mut [f64], data: &mut [f64]) { self.check_state_len(yy, "yy"); self.with_threading(|i, dim| unsafe { diff --git a/src/execution/cranelift/codegen.rs b/src/execution/cranelift/codegen.rs index 9ba299f..85668ff 100644 --- a/src/execution/cranelift/codegen.rs +++ b/src/execution/cranelift/codegen.rs @@ -29,6 +29,7 @@ pub struct CraneliftModule { layout: DataLayout, indices_id: DataId, + constants_id: DataId, thread_counter: Option, //triple: Triple, @@ -422,7 +423,7 @@ impl CodegenModule for CraneliftModule { let mut nbarrier = 0; #[allow(clippy::explicit_counter_loop)] - for a in model.time_indep_defns() { + for a in model.input_dep_defns() { codegen.jit_compile_tensor(a, None, true)?; codegen.jit_compile_call_barrier(nbarrier); nbarrier += 1; @@ -511,6 +512,12 @@ impl CodegenModule for CraneliftModule { let layout = DataLayout::new(model); + // define constant global data + let mut data_description = DataDescription::new(); + data_description.define_zeroinit(layout.constants().len()); + let constants_id = module.declare_data("constants", Linkage::Local, false, false)?; + module.define_data(constants_id, &data_description)?; + // write indices data as a global data object // convect the indices to bytes //let int_type = ptr_type; @@ -556,6 +563,7 @@ impl CodegenModule for CraneliftModule { ctx: module.make_context(), module, indices_id, + constants_id, int_type, real_type, real_ptr_type: ptr_type, @@ -583,7 +591,7 @@ impl CodegenModule for CraneliftModule { let mut nbarrier = 0; #[allow(clippy::explicit_counter_loop)] - for a in model.time_indep_defns() { + for a in model.input_dep_defns() { codegen.jit_compile_tensor(a, None, false)?; codegen.jit_compile_call_barrier(nbarrier); nbarrier += 1; @@ -2093,9 +2101,11 @@ impl<'ctx> CraneliftCodeGen<'ctx> { } } + // todo: insert constant tensors + // insert all tensors in data if it exists in args let tensors = model.inputs().iter(); - let tensors = tensors.chain(model.time_indep_defns().iter()); + let tensors = tensors.chain(model.input_dep_defns().iter()); let tensors = tensors.chain(model.time_dep_defns().iter()); let tensors = tensors.chain(model.state_dep_defns().iter()); diff --git a/src/execution/data_layout.rs b/src/execution/data_layout.rs index 1603839..2e4cc53 100644 --- a/src/execution/data_layout.rs +++ b/src/execution/data_layout.rs @@ -5,8 +5,8 @@ use crate::discretise::{DiscreteModel, Layout, RcLayout, Tensor}; use super::Translation; // there are three different layouts: -// 1. the data layout is a mapping from tensors to the index of the first element in the data array. -// Each tensor in the data layout is a contiguous array of nnz elements +// 1. the data layout is a mapping from tensors to the index of the first element in the data or constants array. +// Each tensor in the data or constants layout is a contiguous array of nnz elements // 2. the layout layout is a mapping from Layout to the index of the first element in the indices array. // Only sparse layouts are stored, and each sparse layout is a contiguous array of nnz*rank elements // 3. the translation layout is a mapping from layout from-to pairs to the index of the first element in the indices array. @@ -19,6 +19,7 @@ pub struct DataLayout { layout_index_map: HashMap, translate_index_map: HashMap<(RcLayout, RcLayout), usize>, data: Vec, + constants: Vec, indices: Vec, layout_map: HashMap, } @@ -30,21 +31,21 @@ impl DataLayout { let mut layout_index_map = HashMap::new(); let mut translate_index_map = HashMap::new(); let mut data = Vec::new(); + let mut constants = Vec::new(); let mut indices = Vec::new(); let mut layout_map = HashMap::new(); - let mut add_tensor = |tensor: &Tensor| { - let is_not_in_data = tensor.name() == "u" - || tensor.name() == "dudt" - || tensor.name() == "rhs" - || tensor.name() == "lhs" - || tensor.name() == "out"; + let mut add_tensor = |tensor: &Tensor, in_data: bool, in_constants: bool| { // insert the data (non-zeros) for each tensor layout_map.insert(tensor.name().to_string(), tensor.layout_ptr().clone()); - if !is_not_in_data { + if in_data { data_index_map.insert(tensor.name().to_string(), data.len()); data_length_map.insert(tensor.name().to_string(), tensor.nnz()); data.extend(vec![0.0; tensor.nnz()]); + } else if in_constants { + data_index_map.insert(tensor.name().to_string(), constants.len()); + data_length_map.insert(tensor.name().to_string(), tensor.nnz()); + constants.extend(vec![0.0; tensor.nnz()]); } // add the translation info for each block-tensor pair @@ -73,26 +74,44 @@ impl DataLayout { } }; - model.inputs().iter().for_each(&mut add_tensor); - model.time_indep_defns().iter().for_each(&mut add_tensor); - model.time_dep_defns().iter().for_each(&mut add_tensor); - add_tensor(model.state()); + model + .constant_defns() + .iter() + .for_each(|c| add_tensor(c, false, true)); + model + .inputs() + .iter() + .for_each(|i| add_tensor(i, true, false)); + model + .input_dep_defns() + .iter() + .for_each(|i| add_tensor(i, true, false)); + model + .time_dep_defns() + .iter() + .for_each(|i| add_tensor(i, true, false)); + add_tensor(model.state(), false, false); if let Some(state_dot) = model.state_dot() { - add_tensor(state_dot); + add_tensor(state_dot, false, false); } - model.state_dep_defns().iter().for_each(&mut add_tensor); + model + .state_dep_defns() + .iter() + .for_each(|i| add_tensor(i, true, false)); if let Some(lhs) = model.lhs() { - add_tensor(lhs); + add_tensor(lhs, false, false); } - add_tensor(model.rhs()); + add_tensor(model.rhs(), false, false); if let Some(out) = model.out() { - add_tensor(out); + add_tensor(out, false, false); } // add layout info for "t" let t_layout = RcLayout::new(Layout::new_scalar()); layout_map.insert("t".to_string(), t_layout); + // todo: could we just calculate constants now? + Self { data_index_map, layout_index_map, @@ -101,6 +120,7 @@ impl DataLayout { translate_index_map, layout_map, data_length_map, + constants, } } @@ -160,6 +180,14 @@ impl DataLayout { self.data.as_mut_slice() } + pub fn constants(&self) -> &[f64] { + self.constants.as_ref() + } + + pub fn constants_mut(&mut self) -> &mut [f64] { + self.constants.as_mut_slice() + } + pub fn indices(&self) -> &[i32] { self.indices.as_ref() } diff --git a/src/execution/interface.rs b/src/execution/interface.rs index 6f00280..20ef58d 100644 --- a/src/execution/interface.rs +++ b/src/execution/interface.rs @@ -3,6 +3,8 @@ type UIntType = u32; pub type BarrierInitFunc = unsafe extern "C" fn(); +pub type SetConstantsFunc = unsafe extern "C" fn(threadId: UIntType, threadDim: UIntType); + pub type StopFunc = unsafe extern "C" fn( time: RealType, u: *const RealType, diff --git a/src/execution/llvm/codegen.rs b/src/execution/llvm/codegen.rs index 10d7cdf..9969846 100644 --- a/src/execution/llvm/codegen.rs +++ b/src/execution/llvm/codegen.rs @@ -280,6 +280,10 @@ impl CodegenModule for LlvmModule { self.codegen_mut().compile_set_u0(model) } + fn compile_set_constants(&mut self, model: &DiscreteModel) -> Result { + self.codegen_mut().compile_set_constants(model) + } + fn compile_calc_out(&mut self, model: &DiscreteModel) -> Result { self.codegen_mut().compile_calc_out(model, false) } @@ -648,17 +652,18 @@ impl CodegenModule for LlvmModule { struct Globals<'ctx> { indices: Option>, + constants: Option>, thread_counter: Option>, } impl<'ctx> Globals<'ctx> { fn new( layout: &DataLayout, - context: &'ctx inkwell::context::Context, module: &Module<'ctx>, + int_type: IntType<'ctx>, + real_type: FloatType<'ctx>, threaded: bool, ) -> Self { - let int_type = context.i32_type(); let thread_counter = if threaded { let tc = module.add_global( int_type, @@ -677,6 +682,19 @@ impl<'ctx> Globals<'ctx> { } else { None }; + let constants = if layout.constants().is_empty() { + None + } else { + let constants_array_type = + real_type.array_type(u32::try_from(layout.constants().len()).unwrap()); + let constants = module.add_global( + constants_array_type, + Some(AddressSpace::default()), + "enzyme_const_constants", + ); + constants.set_constant(true); + Some(constants) + }; let indices = if layout.indices().is_empty() { None } else { @@ -688,14 +706,19 @@ impl<'ctx> Globals<'ctx> { .map(|&i| int_type.const_int(i.try_into().unwrap(), false)) .collect::>(); let indices_value = int_type.const_array(indices_array_values.as_slice()); - let indices = - module.add_global(indices_array_type, Some(AddressSpace::default()), "indices"); + let indices = module.add_global( + indices_array_type, + Some(AddressSpace::default()), + "enzyme_const_indices", + ); + indices.set_constant(true); indices.set_initializer(&indices_value); Some(indices) }; Self { indices, thread_counter, + constants, } } } @@ -791,7 +814,7 @@ impl<'ctx> CodeGen<'ctx> { let builder = context.create_builder(); let layout = DataLayout::new(model); let module = context.create_module(model.name()); - let globals = Globals::new(&layout, context, &module, threaded); + let globals = Globals::new(&layout, &module, int_type, real_type, threaded); let ee = module .create_jit_execution_engine(OptimizationLevel::Aggressive) .map_err(|e| anyhow::anyhow!("Error creating execution engine: {:?}", e))?; @@ -874,6 +897,48 @@ impl<'ctx> CodeGen<'ctx> { .map_err(|e| anyhow!("Error building call to printf: {}", e)) } + fn compile_set_constants(&mut self, model: &DiscreteModel) -> Result> { + self.clear(); + let void_type = self.context.void_type(); + let fn_type = void_type.fn_type(&[self.int_type.into(), self.int_type.into()], false); + let fn_arg_names = &["thread_id", "thread_dim"]; + let function = self.module.add_function("set_constants", fn_type, None); + + let basic_block = self.context.append_basic_block(function, "entry"); + self.fn_value_opt = Some(function); + self.builder.position_at_end(basic_block); + + for (i, arg) in function.get_param_iter().enumerate() { + let name = fn_arg_names[i]; + let alloca = self.function_arg_alloca(name, arg); + self.insert_param(name, alloca); + } + + self.insert_indices(); + self.insert_constants(model); + + let mut nbarriers = 0; + let total_barriers = (model.constant_defns().len()) as u64; + #[allow(clippy::explicit_counter_loop)] + for a in model.constant_defns() { + self.jit_compile_tensor(a, Some(*self.get_var(a)))?; + self.jit_compile_call_barrier(nbarriers, total_barriers); + nbarriers += 1; + } + + self.builder.build_return(None)?; + + if function.verify(true) { + Ok(function) + } else { + function.print_to_stderr(); + unsafe { + function.delete(); + } + Err(anyhow!("Invalid generated function.")) + } + } + fn compile_barrier_init(&mut self) -> Result> { self.clear(); let void_type = self.context.void_type(); @@ -1209,18 +1274,29 @@ impl<'ctx> CodeGen<'ctx> { self.module.write_bitcode_to_path(path); } + fn insert_constants(&mut self, model: &DiscreteModel) { + if let Some(constants) = self.globals.constants.as_ref() { + self.insert_param("constants", constants.as_pointer_value()); + for tensor in model.constant_defns() { + self.insert_tensor(tensor, true); + } + } + } + fn insert_data(&mut self, model: &DiscreteModel) { + self.insert_constants(model); + for tensor in model.inputs() { - self.insert_tensor(tensor); + self.insert_tensor(tensor, false); } - for tensor in model.time_indep_defns() { - self.insert_tensor(tensor); + for tensor in model.input_dep_defns() { + self.insert_tensor(tensor, false); } for tensor in model.time_dep_defns() { - self.insert_tensor(tensor); + self.insert_tensor(tensor, false); } for tensor in model.state_dep_defns() { - self.insert_tensor(tensor); + self.insert_tensor(tensor, false); } } @@ -1328,8 +1404,9 @@ impl<'ctx> CodeGen<'ctx> { data_index += blk.nnz(); } } - fn insert_tensor(&mut self, tensor: &Tensor) { - let ptr = *self.variables.get("data").unwrap(); + fn insert_tensor(&mut self, tensor: &Tensor, is_constant: bool) { + let var_name = if is_constant { "constants" } else { "data" }; + let ptr = *self.variables.get(var_name).unwrap(); let mut data_index = self.layout.get_data_index(tensor.name()).unwrap(); let i = self .context @@ -2638,9 +2715,9 @@ impl<'ctx> CodeGen<'ctx> { self.insert_indices(); let mut nbarriers = 0; - let total_barriers = (model.time_indep_defns().len() + 1) as u64; + let total_barriers = (model.input_dep_defns().len() + 1) as u64; #[allow(clippy::explicit_counter_loop)] - for a in model.time_indep_defns() { + for a in model.input_dep_defns() { self.jit_compile_tensor(a, Some(*self.get_var(a)))?; self.jit_compile_call_barrier(nbarriers, total_barriers); nbarriers += 1; @@ -2719,9 +2796,9 @@ impl<'ctx> CodeGen<'ctx> { let mut total_barriers = (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64; if include_constants { - total_barriers += model.time_indep_defns().len() as u64; + total_barriers += model.input_dep_defns().len() as u64; // calculate time independant definitions - for tensor in model.time_indep_defns() { + for tensor in model.input_dep_defns() { self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?; self.jit_compile_call_barrier(nbarriers, total_barriers); nbarriers += 1; @@ -2882,9 +2959,9 @@ impl<'ctx> CodeGen<'ctx> { let mut total_barriers = (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64; if include_constants { - total_barriers += model.time_indep_defns().len() as u64; + total_barriers += model.input_dep_defns().len() as u64; // calculate constant definitions - for tensor in model.time_indep_defns() { + for tensor in model.input_dep_defns() { self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?; self.jit_compile_call_barrier(nbarriers, total_barriers); nbarriers += 1; @@ -3426,7 +3503,7 @@ impl<'ctx> CodeGen<'ctx> { let mut inputs_index = 0usize; for input in model.inputs() { let name = format!("input_{}", input.name()); - self.insert_tensor(input); + self.insert_tensor(input, false); let ptr = self.get_var(input); // loop thru the elements of this input and set/get them using the inputs ptr let inputs_start_index = self.int_type.const_int(inputs_index as u64, false); diff --git a/src/execution/module.rs b/src/execution/module.rs index 7186ff1..3b25509 100644 --- a/src/execution/module.rs +++ b/src/execution/module.rs @@ -21,6 +21,7 @@ pub trait CodegenModule: Sized + Sync { fn compile_set_inputs(&mut self, model: &DiscreteModel) -> Result; fn compile_get_inputs(&mut self, model: &DiscreteModel) -> Result; fn compile_set_id(&mut self, model: &DiscreteModel) -> Result; + fn compile_set_constants(&mut self, model: &DiscreteModel) -> Result; fn compile_mass_rgrad( &mut self, diff --git a/src/execution/translation.rs b/src/execution/translation.rs index 7475d79..de8d711 100644 --- a/src/execution/translation.rs +++ b/src/execution/translation.rs @@ -326,7 +326,7 @@ mod tests { panic!("{}", e.as_error_message(full_text.as_str())); } }; - let tensor = discrete_model.time_indep_defns().iter().find(|t| t.elmts().iter().find(|blk| blk.name() == Some($blk_name)).is_some()).unwrap(); + let tensor = discrete_model.constant_defns().iter().find(|t| t.elmts().iter().find(|blk| blk.name() == Some($blk_name)).is_some()).unwrap(); let blk = tensor.elmts().iter().find(|blk| blk.name() == Some($blk_name)).unwrap(); let translation = Translation::new(blk.expr_layout(), blk.layout(), &blk.start(), tensor.layout_ptr()); assert_eq!(translation.to_string(), $expected_value); From 7de9bca9209e3d68dea4d6c367072555ea6220bd Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Tue, 28 Jan 2025 23:17:50 +0000 Subject: [PATCH 2/5] working through bugs --- src/execution/compiler.rs | 59 ++++++++++++++++++++++++++++++ src/execution/cranelift/codegen.rs | 57 +++++++++++++++++++++++++++-- src/execution/data_layout.rs | 12 +++++- src/execution/llvm/codegen.rs | 19 +++++++--- src/execution/module.rs | 2 + 5 files changed, 139 insertions(+), 10 deletions(-) diff --git a/src/execution/compiler.rs b/src/execution/compiler.rs index a03412c..b330431 100644 --- a/src/execution/compiler.rs +++ b/src/execution/compiler.rs @@ -352,11 +352,23 @@ impl Compiler { } pub fn get_tensor_data<'a>(&self, name: &str, data: &'a [f64]) -> Option<&'a [f64]> { + if self.module.layout().is_constant(name) { + return None; + } let index = self.module.layout().get_data_index(name)?; let nnz = self.module.layout().get_data_length(name)?; Some(&data[index..index + nnz]) } + pub fn get_constants_data(&self, name: &str) -> Option<&[f64]> { + if !self.module.layout().is_constant(name) { + return None; + } + let index = self.module.layout().get_data_index(name)?; + let nnz = self.module.layout().get_data_length(name)?; + Some(&self.module.get_constants()[index..index + nnz]) + } + pub fn get_tensor_data_mut<'a>( &self, name: &str, @@ -976,6 +988,53 @@ mod tests { use approx::assert_relative_eq; use super::*; + + fn test_constants() { + let full_text = " + in = [a] + a { 1 } + b { 2 } + a2 { a * a } + b2 { b * b } + u_i { y = 1 } + F_i { y * b } + "; + let model = parse_ds_string(full_text).unwrap(); + let discrete_model = DiscreteModel::build("$name", &model).unwrap(); + let compiler = Compiler::::from_discrete_model(&discrete_model, Default::default()).unwrap(); + // b and b2 should already be set + let mut data = compiler.get_new_data(); + let b = compiler.get_constants_data("b").unwrap(); + let b2 = compiler.get_constants_data("b2").unwrap(); + assert_relative_eq!(b[0], 2.); + assert_relative_eq!(b2[0], 4.); + // a and a2 should not be set (be 0) + let a = compiler.get_tensor_data("a", &data).unwrap(); + let a2 = compiler.get_tensor_data("a2", &data).unwrap(); + assert_relative_eq!(a[0], 0.); + assert_relative_eq!(a2[0], 0.); + // set the inputs and u0 + let inputs = vec![1.]; + compiler.set_inputs(&inputs, data.as_mut_slice()); + let mut u0 = vec![0.]; + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + // now a and a2 should be set + let a = compiler.get_tensor_data("a", &data).unwrap(); + let a2 = compiler.get_tensor_data("a2", &data).unwrap(); + assert_relative_eq!(a[0], 1.); + assert_relative_eq!(a2[0], 1.); + } + + #[test] + fn test_constants_cranelift() { + test_constants::(); + } + + #[cfg(feature = "llvm")] + #[test] + fn test_constants_llvm() { + test_constants::(); + } #[cfg(feature = "llvm")] #[test] diff --git a/src/execution/cranelift/codegen.rs b/src/execution/cranelift/codegen.rs index 85668ff..07a08dc 100644 --- a/src/execution/cranelift/codegen.rs +++ b/src/execution/cranelift/codegen.rs @@ -5,6 +5,7 @@ use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{DataDescription, DataId, FuncId, FuncOrDataId, Linkage, Module}; use std::collections::HashMap; use std::iter::zip; +use std::slice::from_raw_parts; use target_lexicon::{Endianness, PointerWidth, Triple}; use crate::ast::{Ast, AstKind}; @@ -206,6 +207,11 @@ unsafe impl Sync for CraneliftModule {} impl CodegenModule for CraneliftModule { type FuncId = FuncId; + + fn get_constants(&self) -> &[f64] { + let data = self.module.get_finalized_data(self.constants_id); + unsafe { from_raw_parts(data.0 as *const f64, data.1 / 8) } + } fn compile_mass_rgrad( &mut self, @@ -404,6 +410,28 @@ impl CodegenModule for CraneliftModule { codegen.builder.finalize(); self.declare_function("set_inputs_grad") } + + fn compile_set_constants(&mut self, model: &DiscreteModel) -> Result { + let arg_types = &[ + self.int_type, + self.int_type, + ]; + let arg_names = &["threadId", "threadDim"]; + let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); + + let mut nbarrier = 0; + #[allow(clippy::explicit_counter_loop)] + for a in model.constant_defns() { + codegen.jit_compile_tensor(a, None, false)?; + codegen.jit_compile_call_barrier(nbarrier); + nbarrier += 1; + } + // Emit the return instruction. + codegen.builder.ins().return_(&[]); + codegen.builder.finalize(); + + self.declare_function("set_constants") + } fn compile_set_u0_grad( &mut self, @@ -513,16 +541,16 @@ impl CodegenModule for CraneliftModule { let layout = DataLayout::new(model); // define constant global data + let int_type = types::I32; + let real_type = types::F64; let mut data_description = DataDescription::new(); - data_description.define_zeroinit(layout.constants().len()); - let constants_id = module.declare_data("constants", Linkage::Local, false, false)?; + data_description.define_zeroinit(layout.constants().len() * (real_type.bytes() as usize)); + let constants_id = module.declare_data("constants", Linkage::Local, true, false)?; module.define_data(constants_id, &data_description)?; // write indices data as a global data object // convect the indices to bytes //let int_type = ptr_type; - let int_type = types::I32; - let real_type = types::F64; let mut vec8: Vec = vec![]; for elem in layout.indices() { // convert indices to i64 @@ -935,6 +963,7 @@ struct CraneliftCodeGen<'a> { functions: HashMap, layout: &'a DataLayout, indices: GlobalValue, + constants: GlobalValue, threaded: bool, } @@ -1030,6 +1059,10 @@ impl<'ctx> CraneliftCodeGen<'ctx> { AstKind::Number(value) => Ok(self.fconst(*value)), AstKind::Name(iname) => { let ptr = if iname.is_tangent { + // tangent of a constant is zero + if self.layout.is_constant(iname.name) { + return Ok(self.fconst(0.0)); + } let name = self.get_tangent_tensor_name(iname.name); self.builder .use_var(*self.variables.get(name.as_str()).unwrap()) @@ -2026,6 +2059,10 @@ impl<'ctx> CraneliftCodeGen<'ctx> { let indices = module .module .declare_data_in_func(module.indices_id, builder.func); + + let constants = module + .module + .declare_data_in_func(module.constants_id, builder.func); // Create the entry block, to start emitting code in. let entry_block = builder.create_block(); @@ -2053,6 +2090,7 @@ impl<'ctx> CraneliftCodeGen<'ctx> { module: &mut module.module, tensor_ptr: None, indices, + constants, variables: HashMap::new(), mem_flags: MemFlags::new(), functions: HashMap::new(), @@ -2102,6 +2140,17 @@ impl<'ctx> CraneliftCodeGen<'ctx> { } // todo: insert constant tensors + + + let constants = codegen + .builder + .ins() + .global_value(codegen.real_ptr_type, codegen.constants); + for tensor in model.constant_defns() { + let data_index = + i64::try_from(codegen.layout.get_data_index(tensor.name()).unwrap()).unwrap(); + codegen.insert_tensor(tensor, constants, data_index, false); + } // insert all tensors in data if it exists in args let tensors = model.inputs().iter(); diff --git a/src/execution/data_layout.rs b/src/execution/data_layout.rs index 2e4cc53..f666da4 100644 --- a/src/execution/data_layout.rs +++ b/src/execution/data_layout.rs @@ -14,6 +14,7 @@ use super::Translation; // We also store a mapping from tensor names to their layout, so that we can easily look up the layout of a tensor #[derive(Debug)] pub struct DataLayout { + is_constant_map: HashMap, data_index_map: HashMap, data_length_map: HashMap, layout_index_map: HashMap, @@ -26,6 +27,7 @@ pub struct DataLayout { impl DataLayout { pub fn new(model: &DiscreteModel) -> Self { + let mut is_constant_map = HashMap::new(); let mut data_index_map = HashMap::new(); let mut data_length_map = HashMap::new(); let mut layout_index_map = HashMap::new(); @@ -42,17 +44,20 @@ impl DataLayout { data_index_map.insert(tensor.name().to_string(), data.len()); data_length_map.insert(tensor.name().to_string(), tensor.nnz()); data.extend(vec![0.0; tensor.nnz()]); + is_constant_map.insert(tensor.name().to_string(), false); } else if in_constants { data_index_map.insert(tensor.name().to_string(), constants.len()); data_length_map.insert(tensor.name().to_string(), tensor.nnz()); constants.extend(vec![0.0; tensor.nnz()]); + is_constant_map.insert(tensor.name().to_string(), true); } // add the translation info for each block-tensor pair for blk in tensor.elmts() { - // need layouts of all named tensor blocks + // need layouts and is_constant of all named tensor blocks if let Some(name) = blk.name() { layout_map.insert(name.to_string(), blk.layout().clone()); + is_constant_map.insert(name.to_string(), in_constants); } // insert the layout info for each tensor expression @@ -113,6 +118,7 @@ impl DataLayout { // todo: could we just calculate constants now? Self { + is_constant_map, data_index_map, layout_index_map, data, @@ -128,6 +134,10 @@ impl DataLayout { pub fn get_layout(&self, name: &str) -> Option<&RcLayout> { self.layout_map.get(name) } + + pub fn is_constant(&self, name: &str) -> bool { + *self.is_constant_map.get(name).unwrap() + } // get the index of the data array for the given tensor name pub fn get_data_index(&self, name: &str) -> Option { diff --git a/src/execution/llvm/codegen.rs b/src/execution/llvm/codegen.rs index 9969846..52e3d73 100644 --- a/src/execution/llvm/codegen.rs +++ b/src/execution/llvm/codegen.rs @@ -23,6 +23,7 @@ use llvm_sys::core::{ LLVMBuildCall2, LLVMGetArgOperand, LLVMGetBasicBlockParent, LLVMGetGlobalParent, LLVMGetInstructionParent, LLVMGetNamedFunction, LLVMGlobalGetValueType, }; +use llvm_sys::execution_engine::LLVMGetGlobalValueAddress; use llvm_sys::prelude::{LLVMBuilderRef, LLVMValueRef}; use std::collections::HashMap; use std::ffi::CString; @@ -276,6 +277,13 @@ impl CodegenModule for LlvmModule { } } + fn get_constants(&self) -> &[f64] { + let constants_name = CString::new("constants").unwrap(); + let constants_ptr = unsafe { LLVMGetGlobalValueAddress(self.codegen().ee.as_mut_ptr(), constants_name.into_raw()) as *const f64 }; + let constants_size = self.layout().constants().len(); + unsafe { std::slice::from_raw_parts(constants_ptr, constants_size) } + } + fn compile_set_u0(&mut self, model: &DiscreteModel) -> Result { self.codegen_mut().compile_set_u0(model) } @@ -620,10 +628,10 @@ impl CodegenModule for LlvmModule { let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline"); barrier_func.remove_enum_attribute(AttributeLoc::Function, nolinline_kind_id); } - //self.codegen() - // .module() - // .print_to_file("post_autodiff_optimisation.ll") - // .unwrap(); + self.codegen() + .module() + .print_to_file("post_autodiff_optimisation.ll") + .unwrap(); let initialization_config = &InitializationConfig::default(); Target::initialize_all(initialization_config); @@ -692,7 +700,8 @@ impl<'ctx> Globals<'ctx> { Some(AddressSpace::default()), "enzyme_const_constants", ); - constants.set_constant(true); + constants.set_constant(false); + constants.set_linkage(Linkage::AvailableExternally); Some(constants) }; let indices = if layout.indices().is_empty() { diff --git a/src/execution/module.rs b/src/execution/module.rs index 3b25509..43aed4c 100644 --- a/src/execution/module.rs +++ b/src/execution/module.rs @@ -104,6 +104,8 @@ pub trait CodegenModule: Sized + Sync { fn jit(&mut self, func_id: Self::FuncId) -> Result<*const u8>; fn jit_barrier_init(&mut self) -> Result<*const u8>; + + fn get_constants(&self) -> &[f64]; fn pre_autodiff_optimisation(&mut self) -> Result<()>; fn post_autodiff_optimisation(&mut self) -> Result<()>; From afddc28b09fd6794723d49ce0f7277cb437cb968 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Wed, 29 Jan 2025 14:56:43 +0000 Subject: [PATCH 3/5] fix bugs, turn atomic add back on --- src/execution/compiler.rs | 57 +++++++++++++++++++----------- src/execution/cranelift/codegen.rs | 12 +++---- src/execution/data_layout.rs | 4 +-- src/execution/llvm/codegen.rs | 20 ++++++----- src/execution/module.rs | 2 +- 5 files changed, 56 insertions(+), 39 deletions(-) diff --git a/src/execution/compiler.rs b/src/execution/compiler.rs index b330431..367b5ef 100644 --- a/src/execution/compiler.rs +++ b/src/execution/compiler.rs @@ -988,7 +988,7 @@ mod tests { use approx::assert_relative_eq; use super::*; - + fn test_constants() { let full_text = " in = [a] @@ -1001,7 +1001,8 @@ mod tests { "; let model = parse_ds_string(full_text).unwrap(); let discrete_model = DiscreteModel::build("$name", &model).unwrap(); - let compiler = Compiler::::from_discrete_model(&discrete_model, Default::default()).unwrap(); + let compiler = + Compiler::::from_discrete_model(&discrete_model, Default::default()).unwrap(); // b and b2 should already be set let mut data = compiler.get_new_data(); let b = compiler.get_constants_data("b").unwrap(); @@ -1024,7 +1025,7 @@ mod tests { assert_relative_eq!(a[0], 1.); assert_relative_eq!(a2[0], 1.); } - + #[test] fn test_constants_cranelift() { test_constants::(); @@ -1227,12 +1228,22 @@ mod tests { compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); compiler.rhs(0., u0.as_slice(), data.as_mut_slice(), res.as_mut_slice()); compiler.calc_out(0., u0.as_slice(), data.as_mut_slice(), out.as_mut_slice()); - results.push( - compiler - .get_tensor_data(tensor_name, data.as_slice()) - .unwrap() - .to_vec(), - ); + let tensor_is_constant = compiler.module().layout().is_constant(tensor_name); + let tensor_len = compiler + .module() + .layout() + .get_data_length(tensor_name) + .unwrap(); + if tensor_is_constant { + results.push(compiler.get_constants_data(tensor_name).unwrap().to_vec()); + } else { + results.push( + compiler + .get_tensor_data(tensor_name, data.as_slice()) + .unwrap() + .to_vec(), + ); + } // forward mode let mut dinputs = vec![0.; n_inputs]; dinputs.fill(1.); @@ -1270,14 +1281,18 @@ mod tests { out.as_slice(), dout.as_mut_slice(), ); - results.push( - compiler - .get_tensor_data(tensor_name, ddata.as_slice()) - .unwrap() - .to_vec(), - ); + if tensor_is_constant { + results.push(vec![0.; tensor_len]); + } else { + results.push( + compiler + .get_tensor_data(tensor_name, ddata.as_slice()) + .unwrap() + .to_vec(), + ); + } // reverse-mode - if compiler.module().supports_reverse_autodiff() { + if compiler.module().supports_reverse_autodiff() && !tensor_is_constant { let mut ddata = compiler.get_new_data(); let dtensor = compiler .get_tensor_data_mut(tensor_name, ddata.as_mut_slice()) @@ -1402,6 +1417,12 @@ mod tests { ddata.as_mut_slice(), ); results.push(dinputs.to_vec()); + } else { + results.push(vec![0.; n_inputs]); + results.push(vec![0.; tensor_len]); + results.push(vec![0.; tensor_len]); + results.push(vec![0.; n_inputs]); + results.push(vec![0.; n_inputs]); } results } @@ -1577,7 +1598,6 @@ mod tests { let results = tensor_test_common::(full_text.as_str(), $tensor_name, CompilerMode::MultiThreaded(None)); assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice()); - assert_eq!(results.len(), 2); } #[cfg(feature = "llvm")] @@ -1594,7 +1614,6 @@ mod tests { let results = tensor_test_common::(full_text.as_str(), $tensor_name, CompilerMode::SingleThreaded); assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice()); - assert_eq!(results.len(), 2); } )* } @@ -1646,14 +1665,12 @@ mod tests { let results = tensor_test_common::(full_text.as_str(), $tensor_name, CompilerMode::SingleThreaded); assert_relative_eq!(results[0].as_slice(), $expected_value.as_slice()); assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice()); - assert_eq!(results.len(), 2); #[cfg(feature = "rayon")] { let results = tensor_test_common::(full_text.as_str(), $tensor_name, CompilerMode::MultiThreaded(None)); assert_relative_eq!(results[0].as_slice(), $expected_value.as_slice()); assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice()); - assert_eq!(results.len(), 2); #[cfg(feature = "llvm")] { diff --git a/src/execution/cranelift/codegen.rs b/src/execution/cranelift/codegen.rs index 07a08dc..4f4b939 100644 --- a/src/execution/cranelift/codegen.rs +++ b/src/execution/cranelift/codegen.rs @@ -207,7 +207,7 @@ unsafe impl Sync for CraneliftModule {} impl CodegenModule for CraneliftModule { type FuncId = FuncId; - + fn get_constants(&self) -> &[f64] { let data = self.module.get_finalized_data(self.constants_id); unsafe { from_raw_parts(data.0 as *const f64, data.1 / 8) } @@ -410,12 +410,9 @@ impl CodegenModule for CraneliftModule { codegen.builder.finalize(); self.declare_function("set_inputs_grad") } - + fn compile_set_constants(&mut self, model: &DiscreteModel) -> Result { - let arg_types = &[ - self.int_type, - self.int_type, - ]; + let arg_types = &[self.int_type, self.int_type]; let arg_names = &["threadId", "threadDim"]; let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); @@ -2059,7 +2056,7 @@ impl<'ctx> CraneliftCodeGen<'ctx> { let indices = module .module .declare_data_in_func(module.indices_id, builder.func); - + let constants = module .module .declare_data_in_func(module.constants_id, builder.func); @@ -2140,7 +2137,6 @@ impl<'ctx> CraneliftCodeGen<'ctx> { } // todo: insert constant tensors - let constants = codegen .builder diff --git a/src/execution/data_layout.rs b/src/execution/data_layout.rs index f666da4..24ddecc 100644 --- a/src/execution/data_layout.rs +++ b/src/execution/data_layout.rs @@ -49,8 +49,8 @@ impl DataLayout { data_index_map.insert(tensor.name().to_string(), constants.len()); data_length_map.insert(tensor.name().to_string(), tensor.nnz()); constants.extend(vec![0.0; tensor.nnz()]); - is_constant_map.insert(tensor.name().to_string(), true); } + is_constant_map.insert(tensor.name().to_string(), in_constants); // add the translation info for each block-tensor pair for blk in tensor.elmts() { @@ -134,7 +134,7 @@ impl DataLayout { pub fn get_layout(&self, name: &str) -> Option<&RcLayout> { self.layout_map.get(name) } - + pub fn is_constant(&self, name: &str) -> bool { *self.is_constant_map.get(name).unwrap() } diff --git a/src/execution/llvm/codegen.rs b/src/execution/llvm/codegen.rs index 52e3d73..4149bbd 100644 --- a/src/execution/llvm/codegen.rs +++ b/src/execution/llvm/codegen.rs @@ -278,8 +278,12 @@ impl CodegenModule for LlvmModule { } fn get_constants(&self) -> &[f64] { - let constants_name = CString::new("constants").unwrap(); - let constants_ptr = unsafe { LLVMGetGlobalValueAddress(self.codegen().ee.as_mut_ptr(), constants_name.into_raw()) as *const f64 }; + let constants_name = CString::new("enzyme_const_constants").unwrap(); + let constants_ptr = unsafe { + LLVMGetGlobalValueAddress(self.codegen().ee.as_mut_ptr(), constants_name.into_raw()) + as *const f64 + }; + assert!(!constants_ptr.is_null()); let constants_size = self.layout().constants().len(); unsafe { std::slice::from_raw_parts(constants_ptr, constants_size) } } @@ -628,10 +632,10 @@ impl CodegenModule for LlvmModule { let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline"); barrier_func.remove_enum_attribute(AttributeLoc::Function, nolinline_kind_id); } - self.codegen() - .module() - .print_to_file("post_autodiff_optimisation.ll") - .unwrap(); + //self.codegen() + // .module() + // .print_to_file("post_autodiff_optimisation.ll") + // .unwrap(); let initialization_config = &InitializationConfig::default(); Target::initialize_all(initialization_config); @@ -701,7 +705,7 @@ impl<'ctx> Globals<'ctx> { "enzyme_const_constants", ); constants.set_constant(false); - constants.set_linkage(Linkage::AvailableExternally); + constants.set_initializer(&constants_array_type.const_zero()); Some(constants) }; let indices = if layout.indices().is_empty() { @@ -3299,7 +3303,7 @@ impl<'ctx> CodeGen<'ctx> { args_uncacheable.as_mut_ptr(), args_uncacheable.len(), std::ptr::null_mut(), - 0, + 1, ) }; if self.threaded { diff --git a/src/execution/module.rs b/src/execution/module.rs index 43aed4c..a347ef4 100644 --- a/src/execution/module.rs +++ b/src/execution/module.rs @@ -104,7 +104,7 @@ pub trait CodegenModule: Sized + Sync { fn jit(&mut self, func_id: Self::FuncId) -> Result<*const u8>; fn jit_barrier_init(&mut self) -> Result<*const u8>; - + fn get_constants(&self) -> &[f64]; fn pre_autodiff_optimisation(&mut self) -> Result<()>; From b16890b1d5ad965fe67edb030a4f0dd49d39fb6e Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Wed, 29 Jan 2025 14:58:49 +0000 Subject: [PATCH 4/5] update enzyme --- Enzyme | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Enzyme b/Enzyme index 5c632cc..df197be 160000 --- a/Enzyme +++ b/Enzyme @@ -1 +1 @@ -Subproject commit 5c632cca6325eadb728df0783c0c41be9eb6cd00 +Subproject commit df197be4f1909067ac3bc68b746ce9f3e406476f From 39803c08629bec215944732a08a22bfac3ca6bde Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Wed, 29 Jan 2025 19:28:20 +0000 Subject: [PATCH 5/5] turn off llvm multithreaded tests for macos --- src/execution/compiler.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/execution/compiler.rs b/src/execution/compiler.rs index 367b5ef..e8b294e 100644 --- a/src/execution/compiler.rs +++ b/src/execution/compiler.rs @@ -1666,12 +1666,15 @@ mod tests { assert_relative_eq!(results[0].as_slice(), $expected_value.as_slice()); assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice()); + #[cfg(feature = "rayon")] { let results = tensor_test_common::(full_text.as_str(), $tensor_name, CompilerMode::MultiThreaded(None)); assert_relative_eq!(results[0].as_slice(), $expected_value.as_slice()); assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice()); + // todo: multi-threaded llvm not working on macos + #[cfg(not(target_os = "macos"))] #[cfg(feature = "llvm")] { use crate::execution::llvm::codegen::LlvmModule;