Skip to content

Commit

Permalink
feat: add gradient and adjoints wrt inputs (#45)
Browse files Browse the repository at this point in the history
* added sens function to llvm and updating interface

* add tests for sgrad and srgrad

* tests for sgrad and srgrad seem to work

* add mass_rgrad

* fix some interface sigs, check for reserved names

* fix llvm intrinsic bug

* cranelift grad impl dupnoneed

* cargo fmt
  • Loading branch information
martinjrobins authored Jan 27, 2025
1 parent 4bc7255 commit c1c833d
Show file tree
Hide file tree
Showing 6 changed files with 838 additions and 181 deletions.
33 changes: 32 additions & 1 deletion src/discretise/discrete_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,27 @@ impl<'s> DiscreteModel<'s> {

fn build_array(array: &ast::Tensor<'s>, env: &mut Env) -> Option<Tensor<'s>> {
let rank = array.indices().len();
let reserved_names = [
"u0",
"t",
"data",
"root",
"thread_id",
"thread_dim",
"rr",
"states",
"inputs",
"outputs",
"hass_mass",
];
if reserved_names.contains(&array.name()) {
let span = env.current_span().to_owned();
env.errs_mut().push(ValidationError::new(
format!("{} is a reserved name", array.name()),
span,
));
return None;
}
let mut elmts = Vec::new();
let mut start = Index::zeros(rank);
let nerrs = env.errs().len();
Expand Down Expand Up @@ -148,6 +169,16 @@ impl<'s> DiscreteModel<'s> {
i64::try_from(elmt_layout.shape()[0]).unwrap()
};

if reserved_names
.contains(&name.as_ref().unwrap_or(&"".to_string()).as_str())
{
let span = env.current_span().to_owned();
env.errs_mut().push(ValidationError::new(
format!("{} is a reserved name", name.as_ref().unwrap()),
span,
));
}

elmts.push(TensorBlock::new(
name,
start.clone(),
Expand Down Expand Up @@ -446,7 +477,7 @@ impl<'s> DiscreteModel<'s> {
None,
)],
);
ret.out = Self::build_array(&out_tensor, &mut env).unwrap();
ret.out = Self::build_array(&out_tensor, &mut env).unwrap_or(Tensor::new_empty("out"));
}
if let Some(span) = span_f {
Self::check_match(&ret.rhs, &ret.state, span, &mut env);
Expand Down
Loading

0 comments on commit c1c833d

Please sign in to comment.