Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added fusion_surrogates/__init__.py
Empty file.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from flax import typing as flax_typing
import flax.linen as nn
from fusion_surrogates import networks
from fusion_surrogates import transforms
import immutabledict
import jax
import jax.numpy as jnp
Expand All @@ -41,20 +42,6 @@
VERSION: Final[str] = '11D'


def normalize(
data: jax.Array, *, mean: jax.Array, stddev: jax.Array
) -> jax.Array:
"""Normalizes data to have mean 0 and stddev 1."""
return (data - mean) / jnp.where(stddev == 0, 1, stddev)


def unnormalize(
data: jax.Array, *, mean: jax.Array, stddev: jax.Array
) -> jax.Array:
"""Unnormalizes data to the orginal distribution."""
return data * jnp.where(stddev == 0, 1, stddev) + mean


@dataclasses.dataclass
class QLKNNStatsData:
"""Stats data for normalization in QLKNNModel."""
Expand Down Expand Up @@ -258,7 +245,7 @@ def predict_with_params(
the raw model prediction
"""
if self._config.normalize_inputs and self._config.stats_data is not None:
inputs = normalize(
inputs = transforms.normalize(
inputs,
mean=self._config.stats_data.input_mean,
stddev=self._config.stats_data.input_stddev,
Expand All @@ -281,7 +268,7 @@ def predict_targets(self, inputs: jax.Array) -> jax.Array:
jax.tree_util.tree_map(lambda x: x[0], self._params), inputs
)
if self._config.normalize_targets and self._config.stats_data is not None:
outputs = unnormalize(
outputs = transforms.unnormalize(
outputs,
mean=self._config.stats_data.target_mean,
stddev=self._config.stats_data.target_stddev,
Expand Down
14 changes: 14 additions & 0 deletions fusion_surrogates/transforms.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move the tests associated to these transforms outside from qlknn_model_test.py into transforms_test.py

Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Tools for transforming input and output tensors."""

import jax
import jax.numpy as jnp


def normalize(data: jax.Array, *, mean: jax.Array, stddev: jax.Array) -> jax.Array:
"""Normalizes data to have mean 0 and stddev 1."""
return (data - mean) / jnp.where(stddev == 0, 1, stddev)


def unnormalize(data: jax.Array, *, mean: jax.Array, stddev: jax.Array) -> jax.Array:
"""Unnormalizes data to the orginal distribution."""
return data * jnp.where(stddev == 0, 1, stddev) + mean
4 changes: 4 additions & 0 deletions fusion_surrogates/ukaea_tglfnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""UKAEA-TGLFNN.

This package contains the implementation of UKAEA's TGLF surrogate for STEP.
"""
83 changes: 83 additions & 0 deletions fusion_surrogates/ukaea_tglfnn/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Base code for UKAEA's TGLFNN model.

UKAEA-TGLFNN is a neural network surrogate model for the gyrokinetics code TGLF, developed by Lorenzo Zanisi at UKAEA.
The model is trained on a dataset generated from JETTO TGLF runs in the STEP design space.
Hence, it is best suited to modelling transport in spherical tokamaks.
"""

import dataclasses
import json
import typing
from typing import Literal


import jax
import jax.numpy as jnp
import optax
import yaml

from fusion_surrogates import networks
from fusion_surrogates import transforms

OutputLabel = Literal["efe_gb", "efi_gb", "pfi_gb"]
OUTPUT_LABELS = typing.get_args(OutputLabel)

InputLabel = Literal[
"RLNS_1",
"RLTS_1",
"RLTS_2",
"TAUS_2",
"RMIN_LOC",
"DRMAJDX_LOC",
"Q_LOC",
"SHAT",
"XNUE",
"KAPPA_LOC",
"S_KAPPA_LOC",
"DELTA_LOC",
"S_DELTA_LOC",
"BETAE",
"ZEFF",
]
INPUT_LABELS = typing.get_args(InputLabel)


@dataclasses.dataclass
class TGLFNNModelConfig:
n_ensemble: int
num_hiddens: int
dropout: float
normalize: bool = True
unnormalize: bool = True
hidden_size: int = 512

@classmethod
def load(cls, config_path: str) -> "TGLFNNModelConfig":
with open(config_path, "r") as f:
config = yaml.safe_load(f)

return cls(
n_ensemble=config["num_estimators"],
num_hiddens=config["model_size"],
dropout=config["dropout"],
)


@dataclasses.dataclass
class TGLFNNModelStats:
input_mean: jax.Array
input_std: jax.Array
output_mean: jax.Array
output_std: jax.Array

@classmethod
def load(cls, stats_path: str) -> "TGLFNNModelStats":
with open(stats_path, "r") as f:
stats = json.load(f)

return cls(
input_mean=jnp.array([stats[label]["mean"] for label in INPUT_LABELS]),
input_std=jnp.array([stats[label]["std"] for label in INPUT_LABELS]),
output_mean=jnp.array([stats[label]["mean"] for label in OUTPUT_LABELS]),
output_std=jnp.array([stats[label]["std"] for label in OUTPUT_LABELS]),
)
53 changes: 53 additions & 0 deletions fusion_surrogates/ukaea_tglfnn/onnx_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Implementation of UKAEA-TGLFNN as loaded from ONNX."""

from typing import Literal

import jax
import jax.numpy as jnp
import jaxonnxruntime
from jaxonnxruntime import backend as jaxort_backend
import onnx

from fusion_surrogates.ukaea_tglfnn import config as ukaea_tglfnn_config

jaxonnxruntime.config.update("jaxort_only_allow_initializers_as_static_args", False)


class ONNXTGLFNNModel:
def __init__(
self,
efe_onnx_path: str,
efi_onnx_path: str,
pfi_onnx_path: str,
) -> "TGLFNNModel":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this type annotation.


self.models = {}

efe_model = onnx.load_model(efe_onnx_path)
self.models["efe_gb"] = jaxort_backend.Backend.prepare(efe_model)

efi_model = onnx.load_model(efi_onnx_path)
self.models["efi_gb"] = jaxort_backend.Backend.prepare(efi_model)

pfi_model = onnx.load_model(pfi_onnx_path)
self.models["pfi_gb"] = jaxort_backend.Backend.prepare(pfi_model)

self._input_dtype = jnp.float32
self._input_node_label = "input"

def _predict_single_flux(
self, flux: ukaea_tglfnn_config.OutputLabel, inputs: jax.Array
) -> jax.Array:
output = self.models[flux].run(
{self._input_node_label: inputs.astype(self._input_dtype)}
)
return jnp.stack([jnp.squeeze(output[0]), jnp.squeeze(output[1])], axis=-1)

def predict(self, inputs: jax.Array):
return jnp.stack(
[
self._predict_single_flux(f, inputs)
for f in ukaea_tglfnn_config.OUTPUT_LABELS
],
axis=-2,
)
108 changes: 108 additions & 0 deletions fusion_surrogates/ukaea_tglfnn/pytorch_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Implementation of UKAEA-TGLFNN as loaded from Pytorch checkpoint."""

import jax
import jax.numpy as jnp
import optax
import torch

from fusion_surrogates import networks
from fusion_surrogates import transforms
from fusion_surrogates.ukaea_tglfnn import config as ukaea_tglfnn_config


def _convert_pytorch_state_dict(
pytorch_state_dict: dict, config: ukaea_tglfnn_config.TGLFNNModelConfig
) -> optax.Params:
params = {}
for i in range(config.n_ensemble):
model_dict = {}
for j in range(config.num_hiddens):
layer_dict = {
"kernel": jnp.array(
pytorch_state_dict[f"models.{i}.model.{j*3}.weight"]
).T,
"bias": jnp.array(pytorch_state_dict[f"models.{i}.model.{j*3}.bias"]).T,
}
model_dict[f"Dense_{j}"] = layer_dict
params[f"GaussianMLP_{i}"] = model_dict
return {"params": params}


class PytorchTGLFNNModel:
def __init__(
self,
config_path: str,
stats_path: str,
efe_gb_checkpoint_path: str,
efi_gb_checkpoint_path: str,
pfi_gb_checkpoint_path: str,
map_location: str = "cpu",
):
self.config = ukaea_tglfnn_config.TGLFNNModelConfig.load(config_path)
self.stats = ukaea_tglfnn_config.TGLFNNModelStats.load(stats_path)

with open(efe_gb_checkpoint_path, "rb") as f:
efe_gb_params = _convert_pytorch_state_dict(
torch.load(f, map_location=map_location), self.config
)
with open(efi_gb_checkpoint_path, "rb") as f:
efi_gb_params = _convert_pytorch_state_dict(
torch.load(f, map_location=map_location), self.config
)
with open(pfi_gb_checkpoint_path, "rb") as f:
pfi_gb_params = _convert_pytorch_state_dict(
torch.load(f, map_location=map_location), self.config
)

self.params = {
"efe_gb": efe_gb_params,
"efi_gb": efi_gb_params,
"pfi_gb": pfi_gb_params,
}

self.network = networks.GaussianMLPEnsemble(
n_ensemble=self.config.n_ensemble,
hidden_size=self.config.hidden_size,
num_hiddens=self.config.num_hiddens,
dropout=self.config.dropout,
activation="relu",
)

def predict(
self,
inputs: jax.Array,
) -> jax.Array:
"""Compute the model prediction for the given inputs.

Args:
inputs: The input data to the model. Must be shape (..., 15).

Returns:
A jax.Array of shape (..., 3, 2), where output[..., i, 0]
and output[..., i, 1] are the mean and variance for the ith flux output.
Outputs are in the order of OUTPUT_LABELS, i.e. efe_gb, efi_gb, pfi_gb.
"""
if self.config.normalize:
inputs = transforms.normalize(
inputs, mean=self.stats.input_mean, stddev=self.stats.input_std
)

output = jnp.stack(
[
self.network.apply(self.params[label], inputs, deterministic=True)
for label in ukaea_tglfnn_config.OUTPUT_LABELS
],
axis=-2,
)

if self.config.unnormalize:
mean = output[..., 0]
var = output[..., 1]

unnormalized_mean = transforms.unnormalize(
mean, mean=self.stats.output_mean, stddev=self.stats.output_std
)

output = jnp.stack([unnormalized_mean, var], axis=-1)

return output
61 changes: 61 additions & 0 deletions fusion_surrogates/ukaea_tglfnn/ukaea_tglfnn_test.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that the loaders are in different files, we should split these tests as well.

Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax.numpy as jnp
from absl.testing import absltest
from numpy import testing

from fusion_surrogates.ukaea_tglfnn import pytorch_model
from fusion_surrogates.ukaea_tglfnn import onnx_model


class PyTorchTGLFNNModelTest(absltest.TestCase):
model = pytorch_model.PytorchTGLFNNModel(
config_path="models/1.0.1/config.yaml",
stats_path="models/1.0.1/stats.json",
efe_gb_checkpoint_path="models/1.0.1/regressor_efe_gb.pt",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add the model data and test data in this PR as well. My understanding is that you are waiting for the approval on your side before pushing the data. Is that correct?

efi_gb_checkpoint_path="models/1.0.1/regressor_efi_gb.pt",
pfi_gb_checkpoint_path="models/1.0.1/regressor_pfi_gb.pt",
map_location="cpu",
)

reference_inputs = jnp.load("test_data/input.npy")
reference_outputs = jnp.load("test_data/output.npy")

def test_matches_reference(self):
predicted_outputs = self.model.predict(self.reference_inputs)
testing.assert_allclose(
self.reference_outputs[..., 0], predicted_outputs, rtol=1e-3
)


class ONNXTGLFNNModelTest(absltest.TestCase):
model = onnx_model.ONNXTGLFNNModel(
efe_onnx_path="models/1.0.1/regressor_efe_gb_onnx.onnx",
efi_onnx_path="models/1.0.1/regressor_efi_gb_onnx.onnx",
pfi_onnx_path="models/1.0.1/regressor_pfi_gb_onnx.onnx",
)

reference_inputs = jnp.load("test_data/input.npy")
reference_outputs = jnp.load("test_data/output.npy")

def test_matches_reference(self):
predicted_outputs = self.model.predict(self.reference_inputs)
testing.assert_allclose(
self.reference_outputs, predicted_outputs[..., 0], rtol=1e-3
)


if __name__ == "__main__":
absltest.main()