Skip to content

Commit fed49a9

Browse files
bug: calculate internal tensors in out (#26)
* add test * bug: calc internal tensors in out * llvm < 13 no longer supported by inkwell * fix clippy
1 parent 534b651 commit fed49a9

File tree

7 files changed

+72
-23
lines changed

7 files changed

+72
-23
lines changed

src/ast/mod.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ pub struct DsModel<'a> {
1010
pub tensors: Vec<Box<Ast<'a>>>,
1111
}
1212

13-
impl<'a> fmt::Display for DsModel<'a> {
13+
impl fmt::Display for DsModel<'_> {
1414
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1515
if self.inputs.len() > 1 {
1616
write!(f, "in = [")?;
@@ -56,7 +56,7 @@ pub struct Equation<'a> {
5656
pub rhs: Box<Ast<'a>>,
5757
}
5858

59-
impl<'a> fmt::Display for Equation<'a> {
59+
impl fmt::Display for Equation<'_> {
6060
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
6161
write!(f, "{} = {}", self.lhs, self.rhs,)
6262
}
@@ -75,7 +75,7 @@ pub struct Name<'a> {
7575
pub is_tangent: bool,
7676
}
7777

78-
impl<'a> fmt::Display for Name<'a> {
78+
impl fmt::Display for Name<'_> {
7979
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
8080
if self.is_tangent {
8181
write!(f, "d_")?;
@@ -109,7 +109,7 @@ pub struct Domain<'a> {
109109
pub dim: usize,
110110
}
111111

112-
impl<'a> fmt::Display for Domain<'a> {
112+
impl fmt::Display for Domain<'_> {
113113
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
114114
write!(f, "{}", self.range).and_then(|_| {
115115
if self.dim == 1 {
@@ -189,7 +189,7 @@ impl<'a> Tensor<'a> {
189189
}
190190
}
191191

192-
impl<'a> fmt::Display for Tensor<'a> {
192+
impl fmt::Display for Tensor<'_> {
193193
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
194194
write!(f, "{}", self.name)?;
195195
if !self.indices.is_empty() {
@@ -213,7 +213,7 @@ pub struct Indice<'a> {
213213
pub sep: Option<&'a str>,
214214
}
215215

216-
impl<'a> fmt::Display for Indice<'a> {
216+
impl fmt::Display for Indice<'_> {
217217
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
218218
write!(f, "{}", self.first)?;
219219
if let Some(ref last) = self.last {
@@ -230,7 +230,7 @@ pub struct Vector<'a> {
230230
pub data: Vec<Box<Ast<'a>>>,
231231
}
232232

233-
impl<'a> fmt::Display for Vector<'a> {
233+
impl fmt::Display for Vector<'_> {
234234
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
235235
write!(f, "(")?;
236236
for (i, elmt) in self.data.iter().enumerate() {
@@ -279,7 +279,7 @@ pub struct NamedGradient<'a> {
279279
pub gradient_wrt: Box<Ast<'a>>,
280280
}
281281

282-
impl<'a> fmt::Display for NamedGradient<'a> {
282+
impl fmt::Display for NamedGradient<'_> {
283283
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
284284
write!(f, "d{}d{}", self.gradient_of, self.gradient_wrt)
285285
}
@@ -913,7 +913,7 @@ impl<'a> Ast<'a> {
913913
}
914914
}
915915

916-
impl<'a> fmt::Display for Ast<'a> {
916+
impl fmt::Display for Ast<'_> {
917917
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
918918
match &self.kind {
919919
AstKind::Model(model) => {

src/continuous/builder.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ pub struct Variable<'s> {
5959
pub init_conditions: Vec<BoundaryCondition<'s>>,
6060
}
6161

62-
impl<'a> fmt::Display for Variable<'a> {
62+
impl fmt::Display for Variable<'_> {
6363
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
6464
let deps_disp: Vec<_> = self
6565
.dependents

src/discretise/discrete_model.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ pub struct DiscreteModel<'s> {
4242
stop: Option<Tensor<'s>>,
4343
}
4444

45-
impl<'s> fmt::Display for DiscreteModel<'s> {
45+
impl fmt::Display for DiscreteModel<'_> {
4646
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
4747
if !self.inputs.is_empty() {
4848
write!(f, "in = [")?;

src/discretise/tensor.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ impl<'s> TensorBlock<'s> {
108108
}
109109
}
110110

111-
impl<'s> fmt::Display for TensorBlock<'s> {
111+
impl fmt::Display for TensorBlock<'_> {
112112
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
113113
write!(f, "(")?;
114114
for i in 0..self.rank() {
@@ -212,7 +212,7 @@ impl<'s> Tensor<'s> {
212212
}
213213
}
214214

215-
impl<'s> fmt::Display for Tensor<'s> {
215+
impl fmt::Display for Tensor<'_> {
216216
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
217217
if !self.indices.is_empty() {
218218
write!(f, "{}_", self.name)?;

src/execution/compiler.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,40 @@ mod tests {
612612
assert_eq!(stop.len(), 1);
613613
}
614614

615+
fn test_out_depends_on_internal_tensor<T: CodegenModule>() {
616+
let full_text = "
617+
u_i { y = 1 }
618+
twoy_i { 2 * y }
619+
F_i { y * (1 - y), }
620+
out_i { twoy_i }
621+
";
622+
let model = parse_ds_string(full_text).unwrap();
623+
let discrete_model = DiscreteModel::build("$name", &model).unwrap();
624+
let compiler = Compiler::<T>::from_discrete_model(&discrete_model).unwrap();
625+
let mut u0 = vec![1.];
626+
let mut data = compiler.get_new_data();
627+
// need this to set the constants
628+
compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice());
629+
compiler.calc_out(0., u0.as_slice(), data.as_mut_slice());
630+
let out = compiler.get_out(data.as_slice());
631+
assert_relative_eq!(out[0], 2.);
632+
u0[0] = 2.;
633+
compiler.calc_out(0., u0.as_slice(), data.as_mut_slice());
634+
let out = compiler.get_out(data.as_slice());
635+
assert_relative_eq!(out[0], 4.);
636+
}
637+
638+
#[test]
639+
fn test_out_depends_on_internal_tensor_cranelift() {
640+
test_out_depends_on_internal_tensor::<CraneliftModule>();
641+
}
642+
643+
#[cfg(feature = "llvm")]
644+
#[test]
645+
fn test_out_depends_on_internal_tensor_llvm() {
646+
test_out_depends_on_internal_tensor::<crate::LlvmModule>();
647+
}
648+
615649
#[test]
616650
fn test_vector_add_scalar_cranelift() {
617651
let n = 1;

src/execution/cranelift/codegen.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,6 @@ pub struct CraneliftModule {
2222
/// context per thread, though this isn't in the simple demo here.
2323
ctx: codegen::Context,
2424

25-
/// The data description, which is to data objects what `ctx` is to functions.
26-
//data_description: DataDescription,
27-
2825
/// The module, with the jit backend, which manages the JIT'd
2926
/// functions.
3027
module: JITModule,
@@ -330,6 +327,16 @@ impl CodegenModule for CraneliftModule {
330327
let arg_names = &["t", "u", "data"];
331328
let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types);
332329

330+
// calculate time dependant definitions
331+
for tensor in model.time_dep_defns() {
332+
codegen.jit_compile_tensor(tensor, None, false)?;
333+
}
334+
335+
// TODO: could split state dep defns into before and after F
336+
for a in model.state_dep_defns() {
337+
codegen.jit_compile_tensor(a, None, false)?;
338+
}
339+
333340
codegen.jit_compile_tensor(model.out(), None, false)?;
334341
codegen.builder.ins().return_(&[]);
335342
codegen.builder.finalize();

src/execution/llvm/codegen.rs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ impl<'ctx> CodeGen<'ctx> {
498498
self.insert_tensor(model.rhs());
499499
}
500500

501-
#[llvm_versions(4.0..=14.0)]
501+
#[llvm_versions(13.0..=14.0)]
502502
fn pointer_type(_context: &'ctx Context, ty: BasicTypeEnum<'ctx>) -> PointerType<'ctx> {
503503
ty.ptr_type(AddressSpace::default())
504504
}
@@ -508,7 +508,7 @@ impl<'ctx> CodeGen<'ctx> {
508508
context.ptr_type(AddressSpace::default())
509509
}
510510

511-
#[llvm_versions(4.0..=14.0)]
511+
#[llvm_versions(13.0..=14.0)]
512512
fn fn_pointer_type(_context: &'ctx Context, ty: FunctionType<'ctx>) -> PointerType<'ctx> {
513513
ty.ptr_type(AddressSpace::default())
514514
}
@@ -518,7 +518,7 @@ impl<'ctx> CodeGen<'ctx> {
518518
context.ptr_type(AddressSpace::default())
519519
}
520520

521-
#[llvm_versions(4.0..=14.0)]
521+
#[llvm_versions(13.0..=14.0)]
522522
fn insert_indices(&mut self) {
523523
if let Some(indices) = self.globals.indices.as_ref() {
524524
let zero = self.context.i32_type().const_int(0, false);
@@ -549,7 +549,7 @@ impl<'ctx> CodeGen<'ctx> {
549549
self.variables.insert(name.to_owned(), value);
550550
}
551551

552-
#[llvm_versions(4.0..=14.0)]
552+
#[llvm_versions(13.0..=14.0)]
553553
fn build_gep<T: BasicType<'ctx>>(
554554
&self,
555555
_ty: T,
@@ -579,7 +579,7 @@ impl<'ctx> CodeGen<'ctx> {
579579
}
580580
}
581581

582-
#[llvm_versions(4.0..=14.0)]
582+
#[llvm_versions(13.0..=14.0)]
583583
fn build_load<T: BasicType<'ctx>>(
584584
&self,
585585
_ty: T,
@@ -599,7 +599,7 @@ impl<'ctx> CodeGen<'ctx> {
599599
self.builder.build_load(ty, ptr, name).map_err(|e| e.into())
600600
}
601601

602-
#[llvm_versions(4.0..=14.0)]
602+
#[llvm_versions(13.0..=14.0)]
603603
fn get_ptr_to_index<T: BasicType<'ctx>>(
604604
builder: &Builder<'ctx>,
605605
_ty: T,
@@ -1887,7 +1887,15 @@ impl<'ctx> CodeGen<'ctx> {
18871887
self.insert_data(model);
18881888
self.insert_indices();
18891889

1890-
// TODO: could split state dep defns into before and after F and G
1890+
// calculate time dependant definitions
1891+
for tensor in model.time_dep_defns() {
1892+
self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
1893+
}
1894+
1895+
// TODO: could split state dep defns into before and after F
1896+
for a in model.state_dep_defns() {
1897+
self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
1898+
}
18911899

18921900
self.jit_compile_tensor(model.out(), Some(*self.get_var(model.out())))?;
18931901
self.builder.build_return(None)?;

0 commit comments

Comments
 (0)