-
Notifications
You must be signed in to change notification settings - Fork 1
Add UKAEA-TGLFNN network #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
FYI, we have just renamed the library to fusion_surrogates, and this will cause import issues with your changes.
Apologies for the trouble. |
618cb81
to
b00e087
Compare
This is complete pending the release of the model weights and/or data for testing. I could also plausibly add more unit tests. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies for the late review. I didn't set my Github notifications properly apparently. It should be fixed now, so I hope I can respond more quickly in the future.
Thanks for the review @hamelphi. It was a bit of a half-baked PR because things were changing internally and we weren't sure about the methods we'd use for exporting/sharing these models. Things should be a bit more consistent now, following your comments! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing my comments.
What is missing:
- Settling on the API for TGLFNNModel: Do we want a common class or protocol for onnx and pytorch? Up to you. I think it would make sense personally.
- Clean up tests
- Add model and test data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please move the tests associated to these transforms outside from qlknn_model_test.py into transforms_test.py
efe_onnx_path: str, | ||
efi_onnx_path: str, | ||
pfi_onnx_path: str, | ||
) -> "TGLFNNModel": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this type annotation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that the loaders are in different files, we should split these tests as well.
model = pytorch_model.PytorchTGLFNNModel( | ||
config_path="models/1.0.1/config.yaml", | ||
stats_path="models/1.0.1/stats.json", | ||
efe_gb_checkpoint_path="models/1.0.1/regressor_efe_gb.pt", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add the model data and test data in this PR as well. My understanding is that you are waiting for the approval on your side before pushing the data. Is that correct?
FYI, I refactored the file structure to follow what you suggested in this PR. I added the |
Adding a refactored version of the UKAEA-TGLFNN from google-deepmind/torax#477. I've tried to match the format of the model file to the qlknn model - feedback welcome!
Tasks: