-
Notifications
You must be signed in to change notification settings - Fork 2
Add UKAEA-TGLFNN network #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
7f657d3
b00e087
0f9a73e
da996c6
c796d43
d508de3
ce3eb67
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,224 @@ | ||
| import dataclasses | ||
theo-brown marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| import json | ||
| from typing import Final | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| import optax | ||
| import yaml | ||
|
|
||
| from fusion_surrogates.networks import GaussianMLPEnsemble | ||
theo-brown marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| from fusion_surrogates.utils import normalize, unnormalize | ||
theo-brown marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| INPUT_LABELS: Final[list[str]] = [ | ||
| "RLNS_1", | ||
| "RLTS_1", | ||
| "RLTS_2", | ||
| "TAUS_2", | ||
| "RMIN_LOC", | ||
| "DRMAJDX_LOC", | ||
| "Q_LOC", | ||
| "SHAT", | ||
| "XNUE", | ||
| "KAPPA_LOC", | ||
| "S_KAPPA_LOC", | ||
| "DELTA_LOC", | ||
| "S_DELTA_LOC", | ||
| "BETAE", | ||
| "ZEFF", | ||
| ] | ||
| OUTPUT_LABELS: Final[list[str]] = ["efe_gb", "efi_gb", "pfi_gb"] | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class TGLFNNModelConfig: | ||
| n_ensemble: int | ||
| num_hiddens: int | ||
| dropout: float | ||
| normalize: bool = True | ||
| unnormalize: bool = True | ||
| hidden_size: int = 512 | ||
|
|
||
| @classmethod | ||
| def load(cls, config_path: str) -> "TGLFNNModelConfig": | ||
| with open(config_path, "r") as f: | ||
| config = yaml.safe_load(f) | ||
|
|
||
| return cls( | ||
| n_ensemble=config["num_estimators"], | ||
| num_hiddens=config["model_size"], | ||
| dropout=config["dropout"], | ||
| ) | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class TGLFNNModelStats: | ||
| input_mean: jax.Array | ||
| input_std: jax.Array | ||
| output_mean: jax.Array | ||
| output_std: jax.Array | ||
|
|
||
| @classmethod | ||
| def load(cls, stats_path: str) -> "TGLFNNModelStats": | ||
| with open(stats_path, "r") as f: | ||
| stats = json.load(f) | ||
|
|
||
| return cls( | ||
| input_mean=jnp.array([stats[label]["mean"] for label in INPUT_LABELS]), | ||
| input_std=jnp.array([stats[label]["std"] for label in INPUT_LABELS]), | ||
| output_mean=jnp.array([stats[label]["mean"] for label in OUTPUT_LABELS]), | ||
| output_std=jnp.array([stats[label]["std"] for label in OUTPUT_LABELS]), | ||
| ) | ||
|
|
||
|
|
||
| class TGLFNNModel: | ||
|
|
||
| def __init__( | ||
| self, | ||
| config: TGLFNNModelConfig, | ||
| stats: TGLFNNModelStats, | ||
| params: optax.Params | None, | ||
| ): | ||
| self.config = config | ||
| self.stats = stats | ||
| self.params = params | ||
| self.network = GaussianMLPEnsemble( | ||
| n_ensemble=config.n_ensemble, | ||
| hidden_size=config.hidden_size, | ||
| num_hiddens=config.num_hiddens, | ||
| dropout=config.dropout, | ||
| activation="relu", | ||
| ) | ||
|
|
||
| @classmethod | ||
| def load_from_pytorch( | ||
| cls, | ||
| config_path: str, | ||
| stats_path: str, | ||
| efe_gb_checkpoint_path: str, | ||
| efi_gb_checkpoint_path: str, | ||
| pfi_gb_checkpoint_path: str, | ||
| *args, | ||
| **kwargs, | ||
| ) -> "TGLFNNModel": | ||
| import torch | ||
theo-brown marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def _convert_pytorch_state_dict( | ||
| pytorch_state_dict: dict, config: TGLFNNModelConfig | ||
| ) -> optax.Params: | ||
| params = {} | ||
| for i in range(config.n_ensemble): | ||
| model_dict = {} | ||
| for j in range(config.num_hiddens): | ||
| layer_dict = { | ||
| "kernel": jnp.array( | ||
| pytorch_state_dict[f"models.{i}.model.{j*3}.weight"] | ||
| ).T, | ||
| "bias": jnp.array( | ||
| pytorch_state_dict[f"models.{i}.model.{j*3}.bias"] | ||
| ).T, | ||
| } | ||
| model_dict[f"Dense_{j}"] = layer_dict | ||
| params[f"GaussianMLP_{i}"] = model_dict | ||
| return {"params": params} | ||
|
|
||
| config = TGLFNNModelConfig.load(config_path) | ||
| stats = TGLFNNModelStats.load(stats_path) | ||
|
|
||
| with open(efe_gb_checkpoint_path, "rb") as f: | ||
| efe_gb_params = _convert_pytorch_state_dict( | ||
| torch.load(f, *args, **kwargs), config | ||
| ) | ||
| with open(efi_gb_checkpoint_path, "rb") as f: | ||
| efi_gb_params = _convert_pytorch_state_dict( | ||
| torch.load(f, *args, **kwargs), config | ||
| ) | ||
| with open(pfi_gb_checkpoint_path, "rb") as f: | ||
| pfi_gb_params = _convert_pytorch_state_dict( | ||
| torch.load(f, *args, **kwargs), config | ||
| ) | ||
|
|
||
| params = { | ||
| "efe_gb": efe_gb_params, | ||
| "efi_gb": efi_gb_params, | ||
| "pfi_gb": pfi_gb_params, | ||
| } | ||
|
|
||
| return cls(config, stats, params) | ||
|
|
||
| def predict( | ||
| self, | ||
| inputs: jax.Array, | ||
| ) -> dict[str, jax.Array]: | ||
theo-brown marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Compute the model prediction for the given inputs. | ||
| Args: | ||
| inputs: The input data to the model. Must be shape (..., 15). | ||
| Returns: | ||
| A jax.Array of shape (..., 3, 2), where output[..., i, 0] | ||
| and output[..., i, 1] are the mean and variance for the ith flux output. | ||
| Outputs are in the order of OUTPUT_LABELS, i.e. efe_gb, efi_gb, pfi_gb. | ||
| """ | ||
| if self.config.normalize: | ||
| inputs = normalize( | ||
| inputs, mean=self.stats.input_mean, stddev=self.stats.input_std | ||
| ) | ||
|
|
||
| output = jnp.stack( | ||
| [ | ||
| self.network.apply(self.params[label], inputs, deterministic=True) | ||
| for label in OUTPUT_LABELS | ||
| ], | ||
| axis=-2, | ||
| ) | ||
|
|
||
| if self.config.unnormalize: | ||
| mean = output[..., 0] | ||
| var = output[..., 1] | ||
|
|
||
| unnormalized_mean = unnormalize( | ||
| mean, mean=self.stats.output_mean, stddev=self.stats.output_std | ||
| ) | ||
|
|
||
| output = jnp.stack([unnormalized_mean, var], axis=-1) | ||
|
|
||
| return output | ||
|
|
||
|
|
||
| class ONNXTGLFNNModel: | ||
| def __init__( | ||
| self, | ||
| efe_onnx_path: str, | ||
| efi_onnx_path: str, | ||
| pfi_onnx_path: str, | ||
| ) -> "TGLFNNModel": | ||
theo-brown marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| import onnx | ||
| from jaxonnxruntime import config | ||
theo-brown marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| from jaxonnxruntime.backend import Backend as ONNXJaxBackend | ||
|
|
||
| config.update("jaxort_only_allow_initializers_as_static_args", False) | ||
|
|
||
| self.models = {} | ||
| efe_model = onnx.load_model(efe_onnx_path) | ||
| self.models["efe_gb"] = ONNXJaxBackend.prepare(efe_model) | ||
|
|
||
| efi_model = onnx.load_model(efi_onnx_path) | ||
| self.models["efi_gb"] = ONNXJaxBackend.prepare(efi_model) | ||
|
|
||
| pfi_model = onnx.load_model(pfi_onnx_path) | ||
| self.models["pfi_gb"] = ONNXJaxBackend.prepare(pfi_model) | ||
|
|
||
| self._input_dtype = jnp.float32 | ||
| self._input_node_label = "input" | ||
|
|
||
| def _predict_single_flux(self, flux: str, inputs: jax.Array) -> jax.Array: | ||
| output = self.models[flux].run( | ||
| {self._input_node_label: inputs.astype(self._input_dtype)} | ||
| ) | ||
| return jnp.stack([jnp.squeeze(output[0]), jnp.squeeze(output[1])], axis=-1) | ||
|
|
||
| def predict(self, inputs: jax.Array): | ||
| return jnp.stack( | ||
| [self._predict_single_flux(f, inputs) for f in OUTPUT_LABELS], axis=-2 | ||
| ) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now that the loaders are in different files, we should split these tests as well. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| # Copyright 2024 DeepMind Technologies Limited. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import jax.numpy as jnp | ||
| from absl.testing import absltest | ||
| from numpy import testing | ||
|
|
||
| from fusion_surrogates.ukaea_tglfnn.ukaea_tglfnn import ( | ||
theo-brown marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ONNXTGLFNNModel, | ||
| TGLFNNModel, | ||
| ) | ||
|
|
||
|
|
||
| class UKAEA_TGLFNNModelTest(absltest.TestCase): | ||
| model = TGLFNNModel.load_from_pytorch( | ||
| config_path="models/1.0.1/config.yaml", | ||
| stats_path="models/1.0.1/stats.json", | ||
| efe_gb_checkpoint_path="models/1.0.1/regressor_efe_gb.pt", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should add the model data and test data in this PR as well. My understanding is that you are waiting for the approval on your side before pushing the data. Is that correct? |
||
| efi_gb_checkpoint_path="models/1.0.1/regressor_efi_gb.pt", | ||
| pfi_gb_checkpoint_path="models/1.0.1/regressor_pfi_gb.pt", | ||
| map_location="cpu", | ||
| ) | ||
|
|
||
| reference_inputs = jnp.load("test_data/input.npy") | ||
| reference_outputs = jnp.load("test_data/output.npy") | ||
|
|
||
| def test_matches_reference(self): | ||
| predicted_outputs = self.model.predict(self.reference_inputs) | ||
| testing.assert_allclose( | ||
| self.reference_outputs[..., 0], predicted_outputs, rtol=1e-3 | ||
| ) | ||
|
|
||
|
|
||
| class UKAEA_ONNXTGLFNNModelTest(absltest.TestCase): | ||
| model = ONNXTGLFNNModel( | ||
| efe_onnx_path="models/1.0.1/regressor_efe_gb_onnx.onnx", | ||
| efi_onnx_path="models/1.0.1/regressor_efi_gb_onnx.onnx", | ||
| pfi_onnx_path="models/1.0.1/regressor_pfi_gb_onnx.onnx", | ||
| ) | ||
|
|
||
| reference_inputs = jnp.load("test_data/input.npy") | ||
| reference_outputs = jnp.load("test_data/output.npy") | ||
|
|
||
| def test_matches_reference(self): | ||
| predicted_outputs = self.model.predict(self.reference_inputs) | ||
| testing.assert_allclose( | ||
| self.reference_outputs, predicted_outputs[..., 0], rtol=1e-3 | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| absltest.main() | ||
theo-brown marked this conversation as resolved.
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| import jax | ||
theo-brown marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| import jax.numpy as jnp | ||
|
|
||
|
|
||
| def normalize(data: jax.Array, *, mean: jax.Array, stddev: jax.Array) -> jax.Array: | ||
| """Normalizes data to have mean 0 and stddev 1.""" | ||
| return (data - mean) / jnp.where(stddev == 0, 1, stddev) | ||
|
|
||
|
|
||
| def unnormalize(data: jax.Array, *, mean: jax.Array, stddev: jax.Array) -> jax.Array: | ||
| """Unnormalizes data to the orginal distribution.""" | ||
| return data * jnp.where(stddev == 0, 1, stddev) + mean | ||
Uh oh!
There was an error while loading. Please reload this page.