Skip to content

Commit

Permalink
feat: out tensor is optional and provided externally via args (#47)
Browse files Browse the repository at this point in the history
* make out optional and remove from data layout

* fix optional out bugs
  • Loading branch information
martinjrobins authored Jan 28, 2025
1 parent d77ec3d commit f14884c
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 145 deletions.
44 changes: 16 additions & 28 deletions src/discretise/discrete_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub struct DiscreteModel<'s> {
name: &'s str,
lhs: Option<Tensor<'s>>,
rhs: Tensor<'s>,
out: Tensor<'s>,
out: Option<Tensor<'s>>,
time_indep_defns: Vec<Tensor<'s>>,
time_dep_defns: Vec<Tensor<'s>>,
state_dep_defns: Vec<Tensor<'s>>,
Expand Down Expand Up @@ -74,7 +74,10 @@ impl fmt::Display for DiscreteModel<'_> {
if let Some(stop) = &self.stop {
writeln!(f, "{}", stop)?;
}
writeln!(f, "{}", self.out)
if let Some(out) = &self.out {
writeln!(f, "{}", out)?;
}
Ok(())
}
}

Expand All @@ -86,7 +89,7 @@ impl<'s> DiscreteModel<'s> {
name,
lhs: None,
rhs: Tensor::new_empty("F"),
out: Tensor::new_empty("out"),
out: None,
time_indep_defns: Vec::new(),
time_dep_defns: Vec::new(),
state_dep_defns: Vec::new(),
Expand Down Expand Up @@ -258,7 +261,6 @@ impl<'s> DiscreteModel<'s> {
let mut env = Env::default();
let mut ret = Self::new(name);
let mut read_state = false;
let mut read_out = false;
let mut span_f = None;
let mut span_m = None;
for tensor_ast in model.tensors.iter() {
Expand Down Expand Up @@ -346,15 +348,14 @@ impl<'s> DiscreteModel<'s> {
}
}
"out" => {
read_out = true;
if let Some(built) = Self::build_array(tensor, &mut env) {
if built.rank() > 1 {
env.errs_mut().push(ValidationError::new(
"output shape must be a scalar or 1D vector".to_string(),
tensor_ast.span,
));
}
ret.out = built;
ret.out = Some(built);
}
// check that out is not dependent on dudt
if let Some(out) = env.get("out") {
Expand Down Expand Up @@ -467,18 +468,6 @@ impl<'s> DiscreteModel<'s> {
span_all,
));
}
// add default out if not defined
if !read_out && read_state {
let out_tensor = ast::Tensor::new(
"out",
ret.state.indices().to_vec(),
vec![Ast::new_tensor_elmt(
Ast::new_name("u", ret.state.indices().to_vec(), false),
None,
)],
);
ret.out = Self::build_array(&out_tensor, &mut env).unwrap_or(Tensor::new_empty("out"));
}
if let Some(span) = span_f {
Self::check_match(&ret.rhs, &ret.state, span, &mut env);
}
Expand Down Expand Up @@ -697,7 +686,7 @@ impl<'s> DiscreteModel<'s> {
inputs,
state,
state_dot: Some(state_dot),
out: out_array,
out: Some(out_array),
time_indep_defns,
time_dep_defns,
state_dep_defns,
Expand Down Expand Up @@ -733,8 +722,8 @@ impl<'s> DiscreteModel<'s> {
self.state_dot.as_ref()
}

pub fn out(&self) -> &Tensor<'s> {
&self.out
pub fn out(&self) -> Option<&Tensor<'s>> {
self.out.as_ref()
}

pub fn lhs(&self) -> Option<&Tensor<'s>> {
Expand Down Expand Up @@ -792,7 +781,7 @@ mod tests {
assert_eq!(discrete.rhs.name(), "F");
assert_eq!(discrete.state.shape()[0], 1);
assert_eq!(discrete.state.elmts().len(), 1);
assert_eq!(discrete.out.elmts().len(), 3);
assert_eq!(discrete.out().unwrap().elmts().len(), 3);
println!("{}", discrete);
}
#[test]
Expand All @@ -808,9 +797,9 @@ mod tests {
let model_info = ModelInfo::build("logistic_growth", &models).unwrap();
assert_eq!(model_info.errors.len(), 0);
let discrete = DiscreteModel::from(&model_info);
assert_eq!(discrete.out.elmts()[0].expr().to_string(), "y");
assert_eq!(discrete.out.elmts()[1].expr().to_string(), "t");
assert_eq!(discrete.out.elmts()[2].expr().to_string(), "z");
assert_eq!(discrete.out().unwrap().elmts()[0].expr().to_string(), "y");
assert_eq!(discrete.out().unwrap().elmts()[1].expr().to_string(), "t");
assert_eq!(discrete.out().unwrap().elmts()[2].expr().to_string(), "z");
println!("{}", discrete);
}

Expand Down Expand Up @@ -1245,7 +1234,7 @@ mod tests {
}

#[test]
fn test_default_out() {
fn test_no_out() {
let text = "
u_i {
y = 1,
Expand All @@ -1256,8 +1245,7 @@ mod tests {
";
let model = parse_ds_string(text).unwrap();
let model = DiscreteModel::build("$name", &model).unwrap();
assert_eq!(model.out().elmts().len(), 1);
assert_eq!(model.out().elmts()[0].expr().to_string(), "u_i");
assert!(model.out().is_none());
}

#[test]
Expand Down
Loading

0 comments on commit f14884c

Please sign in to comment.