Skip to content

Commit c1c833d

Browse files
feat: add gradient and adjoints wrt inputs (#45)
* added sens function to llvm and updating interface * add tests for sgrad and srgrad * tests for sgrad and srgrad seem to work * add mass_rgrad * fix some interface sigs, check for reserved names * fix llvm intrinsic bug * cranelift grad impl dupnoneed * cargo fmt
1 parent 4bc7255 commit c1c833d

File tree

6 files changed

+838
-181
lines changed

6 files changed

+838
-181
lines changed

src/discretise/discrete_model.rs

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,27 @@ impl<'s> DiscreteModel<'s> {
101101

102102
fn build_array(array: &ast::Tensor<'s>, env: &mut Env) -> Option<Tensor<'s>> {
103103
let rank = array.indices().len();
104+
let reserved_names = [
105+
"u0",
106+
"t",
107+
"data",
108+
"root",
109+
"thread_id",
110+
"thread_dim",
111+
"rr",
112+
"states",
113+
"inputs",
114+
"outputs",
115+
"hass_mass",
116+
];
117+
if reserved_names.contains(&array.name()) {
118+
let span = env.current_span().to_owned();
119+
env.errs_mut().push(ValidationError::new(
120+
format!("{} is a reserved name", array.name()),
121+
span,
122+
));
123+
return None;
124+
}
104125
let mut elmts = Vec::new();
105126
let mut start = Index::zeros(rank);
106127
let nerrs = env.errs().len();
@@ -148,6 +169,16 @@ impl<'s> DiscreteModel<'s> {
148169
i64::try_from(elmt_layout.shape()[0]).unwrap()
149170
};
150171

172+
if reserved_names
173+
.contains(&name.as_ref().unwrap_or(&"".to_string()).as_str())
174+
{
175+
let span = env.current_span().to_owned();
176+
env.errs_mut().push(ValidationError::new(
177+
format!("{} is a reserved name", name.as_ref().unwrap()),
178+
span,
179+
));
180+
}
181+
151182
elmts.push(TensorBlock::new(
152183
name,
153184
start.clone(),
@@ -446,7 +477,7 @@ impl<'s> DiscreteModel<'s> {
446477
None,
447478
)],
448479
);
449-
ret.out = Self::build_array(&out_tensor, &mut env).unwrap();
480+
ret.out = Self::build_array(&out_tensor, &mut env).unwrap_or(Tensor::new_empty("out"));
450481
}
451482
if let Some(span) = span_f {
452483
Self::check_match(&ret.rhs, &ret.state, span, &mut env);

0 commit comments

Comments
 (0)