Skip to content

Commit

Permalink
bug: calculate internal tensors in out (#26)
Browse files Browse the repository at this point in the history
* add test

* bug: calc internal tensors in out

* llvm < 13 no longer supported by inkwell

* fix clippy
  • Loading branch information
martinjrobins authored Dec 7, 2024
1 parent 534b651 commit fed49a9
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 23 deletions.
18 changes: 9 additions & 9 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub struct DsModel<'a> {
pub tensors: Vec<Box<Ast<'a>>>,
}

impl<'a> fmt::Display for DsModel<'a> {
impl fmt::Display for DsModel<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if self.inputs.len() > 1 {
write!(f, "in = [")?;
Expand Down Expand Up @@ -56,7 +56,7 @@ pub struct Equation<'a> {
pub rhs: Box<Ast<'a>>,
}

impl<'a> fmt::Display for Equation<'a> {
impl fmt::Display for Equation<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{} = {}", self.lhs, self.rhs,)
}
Expand All @@ -75,7 +75,7 @@ pub struct Name<'a> {
pub is_tangent: bool,
}

impl<'a> fmt::Display for Name<'a> {
impl fmt::Display for Name<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if self.is_tangent {
write!(f, "d_")?;
Expand Down Expand Up @@ -109,7 +109,7 @@ pub struct Domain<'a> {
pub dim: usize,
}

impl<'a> fmt::Display for Domain<'a> {
impl fmt::Display for Domain<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.range).and_then(|_| {
if self.dim == 1 {
Expand Down Expand Up @@ -189,7 +189,7 @@ impl<'a> Tensor<'a> {
}
}

impl<'a> fmt::Display for Tensor<'a> {
impl fmt::Display for Tensor<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.name)?;
if !self.indices.is_empty() {
Expand All @@ -213,7 +213,7 @@ pub struct Indice<'a> {
pub sep: Option<&'a str>,
}

impl<'a> fmt::Display for Indice<'a> {
impl fmt::Display for Indice<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.first)?;
if let Some(ref last) = self.last {
Expand All @@ -230,7 +230,7 @@ pub struct Vector<'a> {
pub data: Vec<Box<Ast<'a>>>,
}

impl<'a> fmt::Display for Vector<'a> {
impl fmt::Display for Vector<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "(")?;
for (i, elmt) in self.data.iter().enumerate() {
Expand Down Expand Up @@ -279,7 +279,7 @@ pub struct NamedGradient<'a> {
pub gradient_wrt: Box<Ast<'a>>,
}

impl<'a> fmt::Display for NamedGradient<'a> {
impl fmt::Display for NamedGradient<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "d{}d{}", self.gradient_of, self.gradient_wrt)
}
Expand Down Expand Up @@ -913,7 +913,7 @@ impl<'a> Ast<'a> {
}
}

impl<'a> fmt::Display for Ast<'a> {
impl fmt::Display for Ast<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self.kind {
AstKind::Model(model) => {
Expand Down
2 changes: 1 addition & 1 deletion src/continuous/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pub struct Variable<'s> {
pub init_conditions: Vec<BoundaryCondition<'s>>,
}

impl<'a> fmt::Display for Variable<'a> {
impl fmt::Display for Variable<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let deps_disp: Vec<_> = self
.dependents
Expand Down
2 changes: 1 addition & 1 deletion src/discretise/discrete_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub struct DiscreteModel<'s> {
stop: Option<Tensor<'s>>,
}

impl<'s> fmt::Display for DiscreteModel<'s> {
impl fmt::Display for DiscreteModel<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if !self.inputs.is_empty() {
write!(f, "in = [")?;
Expand Down
4 changes: 2 additions & 2 deletions src/discretise/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ impl<'s> TensorBlock<'s> {
}
}

impl<'s> fmt::Display for TensorBlock<'s> {
impl fmt::Display for TensorBlock<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "(")?;
for i in 0..self.rank() {
Expand Down Expand Up @@ -212,7 +212,7 @@ impl<'s> Tensor<'s> {
}
}

impl<'s> fmt::Display for Tensor<'s> {
impl fmt::Display for Tensor<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if !self.indices.is_empty() {
write!(f, "{}_", self.name)?;
Expand Down
34 changes: 34 additions & 0 deletions src/execution/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,40 @@ mod tests {
assert_eq!(stop.len(), 1);
}

fn test_out_depends_on_internal_tensor<T: CodegenModule>() {
let full_text = "
u_i { y = 1 }
twoy_i { 2 * y }
F_i { y * (1 - y), }
out_i { twoy_i }
";
let model = parse_ds_string(full_text).unwrap();
let discrete_model = DiscreteModel::build("$name", &model).unwrap();
let compiler = Compiler::<T>::from_discrete_model(&discrete_model).unwrap();
let mut u0 = vec![1.];
let mut data = compiler.get_new_data();
// need this to set the constants
compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice());
compiler.calc_out(0., u0.as_slice(), data.as_mut_slice());
let out = compiler.get_out(data.as_slice());
assert_relative_eq!(out[0], 2.);
u0[0] = 2.;
compiler.calc_out(0., u0.as_slice(), data.as_mut_slice());
let out = compiler.get_out(data.as_slice());
assert_relative_eq!(out[0], 4.);
}

#[test]
fn test_out_depends_on_internal_tensor_cranelift() {
test_out_depends_on_internal_tensor::<CraneliftModule>();
}

#[cfg(feature = "llvm")]
#[test]
fn test_out_depends_on_internal_tensor_llvm() {
test_out_depends_on_internal_tensor::<crate::LlvmModule>();
}

#[test]
fn test_vector_add_scalar_cranelift() {
let n = 1;
Expand Down
13 changes: 10 additions & 3 deletions src/execution/cranelift/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ pub struct CraneliftModule {
/// context per thread, though this isn't in the simple demo here.
ctx: codegen::Context,

/// The data description, which is to data objects what `ctx` is to functions.
//data_description: DataDescription,

/// The module, with the jit backend, which manages the JIT'd
/// functions.
module: JITModule,
Expand Down Expand Up @@ -330,6 +327,16 @@ impl CodegenModule for CraneliftModule {
let arg_names = &["t", "u", "data"];
let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types);

// calculate time dependant definitions
for tensor in model.time_dep_defns() {
codegen.jit_compile_tensor(tensor, None, false)?;
}

// TODO: could split state dep defns into before and after F
for a in model.state_dep_defns() {
codegen.jit_compile_tensor(a, None, false)?;
}

codegen.jit_compile_tensor(model.out(), None, false)?;
codegen.builder.ins().return_(&[]);
codegen.builder.finalize();
Expand Down
22 changes: 15 additions & 7 deletions src/execution/llvm/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ impl<'ctx> CodeGen<'ctx> {
self.insert_tensor(model.rhs());
}

#[llvm_versions(4.0..=14.0)]
#[llvm_versions(13.0..=14.0)]
fn pointer_type(_context: &'ctx Context, ty: BasicTypeEnum<'ctx>) -> PointerType<'ctx> {
ty.ptr_type(AddressSpace::default())
}
Expand All @@ -508,7 +508,7 @@ impl<'ctx> CodeGen<'ctx> {
context.ptr_type(AddressSpace::default())
}

#[llvm_versions(4.0..=14.0)]
#[llvm_versions(13.0..=14.0)]
fn fn_pointer_type(_context: &'ctx Context, ty: FunctionType<'ctx>) -> PointerType<'ctx> {
ty.ptr_type(AddressSpace::default())
}
Expand All @@ -518,7 +518,7 @@ impl<'ctx> CodeGen<'ctx> {
context.ptr_type(AddressSpace::default())
}

#[llvm_versions(4.0..=14.0)]
#[llvm_versions(13.0..=14.0)]
fn insert_indices(&mut self) {
if let Some(indices) = self.globals.indices.as_ref() {
let zero = self.context.i32_type().const_int(0, false);
Expand Down Expand Up @@ -549,7 +549,7 @@ impl<'ctx> CodeGen<'ctx> {
self.variables.insert(name.to_owned(), value);
}

#[llvm_versions(4.0..=14.0)]
#[llvm_versions(13.0..=14.0)]
fn build_gep<T: BasicType<'ctx>>(
&self,
_ty: T,
Expand Down Expand Up @@ -579,7 +579,7 @@ impl<'ctx> CodeGen<'ctx> {
}
}

#[llvm_versions(4.0..=14.0)]
#[llvm_versions(13.0..=14.0)]
fn build_load<T: BasicType<'ctx>>(
&self,
_ty: T,
Expand All @@ -599,7 +599,7 @@ impl<'ctx> CodeGen<'ctx> {
self.builder.build_load(ty, ptr, name).map_err(|e| e.into())
}

#[llvm_versions(4.0..=14.0)]
#[llvm_versions(13.0..=14.0)]
fn get_ptr_to_index<T: BasicType<'ctx>>(
builder: &Builder<'ctx>,
_ty: T,
Expand Down Expand Up @@ -1887,7 +1887,15 @@ impl<'ctx> CodeGen<'ctx> {
self.insert_data(model);
self.insert_indices();

// TODO: could split state dep defns into before and after F and G
// calculate time dependant definitions
for tensor in model.time_dep_defns() {
self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
}

// TODO: could split state dep defns into before and after F
for a in model.state_dep_defns() {
self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
}

self.jit_compile_tensor(model.out(), Some(*self.get_var(model.out())))?;
self.builder.build_return(None)?;
Expand Down

0 comments on commit fed49a9

Please sign in to comment.