-
Notifications
You must be signed in to change notification settings - Fork 780
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix external variable initialization #1775
base: master
Are you sure you want to change the base?
Conversation
Could you point out an example for using this code? |
Here : https://github.com/bonneted/sbinn/blob/main/sbinn/sbinn_jax.py The implementation of sbinn using JAX.
For this first train, we want to use the default |
The code modification seems necessary. But there is another example https://github.com/lululxvi/deepxde/blob/master/examples/pinn_inverse/Lorenz_inverse.py , which works well (at least worked earlier). |
This one was already working well because there is no pertaining without the external variables.
The problem occurs when we compile without the external trainable variables, which is when we want the PDE to use the default |
The code seems OK. But the underlying logic becomes extremely complicated now. In fact, you can simply add |
That's true in that case, but it can be interesting to start training the model with frozen parameters (for example https://doi.org/10.1126/sciadv.abk0644) |
Please resolve the conflicts. |
I've resolved the conflict based on your improved logic. |
external_trainable_variables_val = [ | ||
var.value for var in self.external_trainable_variables | ||
] | ||
self.params = [self.net.params, external_trainable_variables_val] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if self.external_trainable_variables is []?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's working fine, var.value is never evaluated and the output is []
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case,
self.params = [self.net.params, []]
which is strange.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was already like that before the introduction external_trainable_variables_val
in a previous commit.
It used to be: (see bonneted@3810a98)
self.params = [self.net.params, self.external_trainable_variables]
with self.external_trainable_variables = []
by default
We can see how it works in the original commit by @ZongrenZou : bonneted@09c74ab
self.params
is split in nn_params
and ext_params
in outputs_losses
function:
def outputs_losses(params, training, inputs, targets, losses_fn):
nn_params, ext_params = params
# TODO: Add auxiliary vars
def outputs_fn(inputs):
return self.net.apply(nn_params, inputs, training=training)
outputs_ = self.net.apply(nn_params, inputs, training=training)
# Data losses
# We use aux so that self.data.losses is a pure function.
aux = [outputs_fn, ext_params] if ext_params else [outputs_fn]
losses = losses_fn(targets, outputs_, loss_fn, inputs, self, aux=aux)
I've faced two bugs when trying to implement :
https://github.com/lu-group/sbinn/blob/b2c1c94d6564732189722f6e6772af0f63cb0d8c/sbinn/sbinn_tf.py#L8
in
model.py
it's because the external variables were not initialized on the second compile as the parameters of the net were alreadyin
pde.py
if I don't compile with external variables I still want the code to work with the default values of unknownsI think this code can be safely modified only for jax because the line after was already only for jax