-
Notifications
You must be signed in to change notification settings - Fork 2
Add UKAEA-TGLFNN network #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7f657d3
b00e087
0f9a73e
da996c6
c796d43
d508de3
ce3eb67
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| """Tools for transforming input and output tensors.""" | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
|
|
||
|
|
||
| def normalize(data: jax.Array, *, mean: jax.Array, stddev: jax.Array) -> jax.Array: | ||
| """Normalizes data to have mean 0 and stddev 1.""" | ||
| return (data - mean) / jnp.where(stddev == 0, 1, stddev) | ||
|
|
||
|
|
||
| def unnormalize(data: jax.Array, *, mean: jax.Array, stddev: jax.Array) -> jax.Array: | ||
| """Unnormalizes data to the orginal distribution.""" | ||
| return data * jnp.where(stddev == 0, 1, stddev) + mean |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| """UKAEA-TGLFNN. | ||
|
|
||
| This package contains the implementation of UKAEA's TGLF surrogate for STEP. | ||
| """ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| """Base code for UKAEA's TGLFNN model. | ||
|
|
||
| UKAEA-TGLFNN is a neural network surrogate model for the gyrokinetics code TGLF, developed by Lorenzo Zanisi at UKAEA. | ||
| The model is trained on a dataset generated from JETTO TGLF runs in the STEP design space. | ||
| Hence, it is best suited to modelling transport in spherical tokamaks. | ||
| """ | ||
|
|
||
| import dataclasses | ||
| import json | ||
| import typing | ||
| from typing import Literal | ||
|
|
||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| import optax | ||
| import yaml | ||
|
|
||
| from fusion_surrogates import networks | ||
| from fusion_surrogates import transforms | ||
|
|
||
| OutputLabel = Literal["efe_gb", "efi_gb", "pfi_gb"] | ||
| OUTPUT_LABELS = typing.get_args(OutputLabel) | ||
|
|
||
| InputLabel = Literal[ | ||
| "RLNS_1", | ||
| "RLTS_1", | ||
| "RLTS_2", | ||
| "TAUS_2", | ||
| "RMIN_LOC", | ||
| "DRMAJDX_LOC", | ||
| "Q_LOC", | ||
| "SHAT", | ||
| "XNUE", | ||
| "KAPPA_LOC", | ||
| "S_KAPPA_LOC", | ||
| "DELTA_LOC", | ||
| "S_DELTA_LOC", | ||
| "BETAE", | ||
| "ZEFF", | ||
| ] | ||
| INPUT_LABELS = typing.get_args(InputLabel) | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class TGLFNNModelConfig: | ||
| n_ensemble: int | ||
| num_hiddens: int | ||
| dropout: float | ||
| normalize: bool = True | ||
| unnormalize: bool = True | ||
| hidden_size: int = 512 | ||
|
|
||
| @classmethod | ||
| def load(cls, config_path: str) -> "TGLFNNModelConfig": | ||
| with open(config_path, "r") as f: | ||
| config = yaml.safe_load(f) | ||
|
|
||
| return cls( | ||
| n_ensemble=config["num_estimators"], | ||
| num_hiddens=config["model_size"], | ||
| dropout=config["dropout"], | ||
| ) | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class TGLFNNModelStats: | ||
| input_mean: jax.Array | ||
| input_std: jax.Array | ||
| output_mean: jax.Array | ||
| output_std: jax.Array | ||
|
|
||
| @classmethod | ||
| def load(cls, stats_path: str) -> "TGLFNNModelStats": | ||
| with open(stats_path, "r") as f: | ||
| stats = json.load(f) | ||
|
|
||
| return cls( | ||
| input_mean=jnp.array([stats[label]["mean"] for label in INPUT_LABELS]), | ||
| input_std=jnp.array([stats[label]["std"] for label in INPUT_LABELS]), | ||
| output_mean=jnp.array([stats[label]["mean"] for label in OUTPUT_LABELS]), | ||
| output_std=jnp.array([stats[label]["std"] for label in OUTPUT_LABELS]), | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| """Implementation of UKAEA-TGLFNN as loaded from ONNX.""" | ||
|
|
||
| from typing import Literal | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| import jaxonnxruntime | ||
| from jaxonnxruntime import backend as jaxort_backend | ||
| import onnx | ||
|
|
||
| from fusion_surrogates.ukaea_tglfnn import config as ukaea_tglfnn_config | ||
|
|
||
| jaxonnxruntime.config.update("jaxort_only_allow_initializers_as_static_args", False) | ||
|
|
||
|
|
||
| class ONNXTGLFNNModel: | ||
| def __init__( | ||
| self, | ||
| efe_onnx_path: str, | ||
| efi_onnx_path: str, | ||
| pfi_onnx_path: str, | ||
| ) -> "TGLFNNModel": | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this type annotation. |
||
|
|
||
| self.models = {} | ||
|
|
||
| efe_model = onnx.load_model(efe_onnx_path) | ||
| self.models["efe_gb"] = jaxort_backend.Backend.prepare(efe_model) | ||
|
|
||
| efi_model = onnx.load_model(efi_onnx_path) | ||
| self.models["efi_gb"] = jaxort_backend.Backend.prepare(efi_model) | ||
|
|
||
| pfi_model = onnx.load_model(pfi_onnx_path) | ||
| self.models["pfi_gb"] = jaxort_backend.Backend.prepare(pfi_model) | ||
|
|
||
| self._input_dtype = jnp.float32 | ||
| self._input_node_label = "input" | ||
|
|
||
| def _predict_single_flux( | ||
| self, flux: ukaea_tglfnn_config.OutputLabel, inputs: jax.Array | ||
| ) -> jax.Array: | ||
| output = self.models[flux].run( | ||
| {self._input_node_label: inputs.astype(self._input_dtype)} | ||
| ) | ||
| return jnp.stack([jnp.squeeze(output[0]), jnp.squeeze(output[1])], axis=-1) | ||
|
|
||
| def predict(self, inputs: jax.Array): | ||
| return jnp.stack( | ||
| [ | ||
| self._predict_single_flux(f, inputs) | ||
| for f in ukaea_tglfnn_config.OUTPUT_LABELS | ||
| ], | ||
| axis=-2, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,108 @@ | ||
| """Implementation of UKAEA-TGLFNN as loaded from Pytorch checkpoint.""" | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| import optax | ||
| import torch | ||
|
|
||
| from fusion_surrogates import networks | ||
| from fusion_surrogates import transforms | ||
| from fusion_surrogates.ukaea_tglfnn import config as ukaea_tglfnn_config | ||
|
|
||
|
|
||
| def _convert_pytorch_state_dict( | ||
| pytorch_state_dict: dict, config: ukaea_tglfnn_config.TGLFNNModelConfig | ||
| ) -> optax.Params: | ||
| params = {} | ||
| for i in range(config.n_ensemble): | ||
| model_dict = {} | ||
| for j in range(config.num_hiddens): | ||
| layer_dict = { | ||
| "kernel": jnp.array( | ||
| pytorch_state_dict[f"models.{i}.model.{j*3}.weight"] | ||
| ).T, | ||
| "bias": jnp.array(pytorch_state_dict[f"models.{i}.model.{j*3}.bias"]).T, | ||
| } | ||
| model_dict[f"Dense_{j}"] = layer_dict | ||
| params[f"GaussianMLP_{i}"] = model_dict | ||
| return {"params": params} | ||
|
|
||
|
|
||
| class PytorchTGLFNNModel: | ||
| def __init__( | ||
| self, | ||
| config_path: str, | ||
| stats_path: str, | ||
| efe_gb_checkpoint_path: str, | ||
| efi_gb_checkpoint_path: str, | ||
| pfi_gb_checkpoint_path: str, | ||
| map_location: str = "cpu", | ||
| ): | ||
| self.config = ukaea_tglfnn_config.TGLFNNModelConfig.load(config_path) | ||
| self.stats = ukaea_tglfnn_config.TGLFNNModelStats.load(stats_path) | ||
|
|
||
| with open(efe_gb_checkpoint_path, "rb") as f: | ||
| efe_gb_params = _convert_pytorch_state_dict( | ||
| torch.load(f, map_location=map_location), self.config | ||
| ) | ||
| with open(efi_gb_checkpoint_path, "rb") as f: | ||
| efi_gb_params = _convert_pytorch_state_dict( | ||
| torch.load(f, map_location=map_location), self.config | ||
| ) | ||
| with open(pfi_gb_checkpoint_path, "rb") as f: | ||
| pfi_gb_params = _convert_pytorch_state_dict( | ||
| torch.load(f, map_location=map_location), self.config | ||
| ) | ||
|
|
||
| self.params = { | ||
| "efe_gb": efe_gb_params, | ||
| "efi_gb": efi_gb_params, | ||
| "pfi_gb": pfi_gb_params, | ||
| } | ||
|
|
||
| self.network = networks.GaussianMLPEnsemble( | ||
| n_ensemble=self.config.n_ensemble, | ||
| hidden_size=self.config.hidden_size, | ||
| num_hiddens=self.config.num_hiddens, | ||
| dropout=self.config.dropout, | ||
| activation="relu", | ||
| ) | ||
|
|
||
| def predict( | ||
| self, | ||
| inputs: jax.Array, | ||
| ) -> jax.Array: | ||
| """Compute the model prediction for the given inputs. | ||
|
|
||
| Args: | ||
| inputs: The input data to the model. Must be shape (..., 15). | ||
|
|
||
| Returns: | ||
| A jax.Array of shape (..., 3, 2), where output[..., i, 0] | ||
| and output[..., i, 1] are the mean and variance for the ith flux output. | ||
| Outputs are in the order of OUTPUT_LABELS, i.e. efe_gb, efi_gb, pfi_gb. | ||
| """ | ||
| if self.config.normalize: | ||
| inputs = transforms.normalize( | ||
| inputs, mean=self.stats.input_mean, stddev=self.stats.input_std | ||
| ) | ||
|
|
||
| output = jnp.stack( | ||
| [ | ||
| self.network.apply(self.params[label], inputs, deterministic=True) | ||
| for label in ukaea_tglfnn_config.OUTPUT_LABELS | ||
| ], | ||
| axis=-2, | ||
| ) | ||
|
|
||
| if self.config.unnormalize: | ||
| mean = output[..., 0] | ||
| var = output[..., 1] | ||
|
|
||
| unnormalized_mean = transforms.unnormalize( | ||
| mean, mean=self.stats.output_mean, stddev=self.stats.output_std | ||
| ) | ||
|
|
||
| output = jnp.stack([unnormalized_mean, var], axis=-1) | ||
|
|
||
| return output |
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now that the loaders are in different files, we should split these tests as well. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| # Copyright 2024 DeepMind Technologies Limited. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import jax.numpy as jnp | ||
| from absl.testing import absltest | ||
| from numpy import testing | ||
|
|
||
| from fusion_surrogates.ukaea_tglfnn import pytorch_model | ||
| from fusion_surrogates.ukaea_tglfnn import onnx_model | ||
|
|
||
|
|
||
| class PyTorchTGLFNNModelTest(absltest.TestCase): | ||
| model = pytorch_model.PytorchTGLFNNModel( | ||
| config_path="models/1.0.1/config.yaml", | ||
| stats_path="models/1.0.1/stats.json", | ||
| efe_gb_checkpoint_path="models/1.0.1/regressor_efe_gb.pt", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should add the model data and test data in this PR as well. My understanding is that you are waiting for the approval on your side before pushing the data. Is that correct? |
||
| efi_gb_checkpoint_path="models/1.0.1/regressor_efi_gb.pt", | ||
| pfi_gb_checkpoint_path="models/1.0.1/regressor_pfi_gb.pt", | ||
| map_location="cpu", | ||
| ) | ||
|
|
||
| reference_inputs = jnp.load("test_data/input.npy") | ||
| reference_outputs = jnp.load("test_data/output.npy") | ||
|
|
||
| def test_matches_reference(self): | ||
| predicted_outputs = self.model.predict(self.reference_inputs) | ||
| testing.assert_allclose( | ||
| self.reference_outputs[..., 0], predicted_outputs, rtol=1e-3 | ||
| ) | ||
|
|
||
|
|
||
| class ONNXTGLFNNModelTest(absltest.TestCase): | ||
| model = onnx_model.ONNXTGLFNNModel( | ||
| efe_onnx_path="models/1.0.1/regressor_efe_gb_onnx.onnx", | ||
| efi_onnx_path="models/1.0.1/regressor_efi_gb_onnx.onnx", | ||
| pfi_onnx_path="models/1.0.1/regressor_pfi_gb_onnx.onnx", | ||
| ) | ||
|
|
||
| reference_inputs = jnp.load("test_data/input.npy") | ||
| reference_outputs = jnp.load("test_data/output.npy") | ||
|
|
||
| def test_matches_reference(self): | ||
| predicted_outputs = self.model.predict(self.reference_inputs) | ||
| testing.assert_allclose( | ||
| self.reference_outputs, predicted_outputs[..., 0], rtol=1e-3 | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| absltest.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please move the tests associated to these transforms outside from qlknn_model_test.py into transforms_test.py