Skip to content

Commit cc39071

Browse files
committed
Add UKAEA-TGLFNN
1 parent 5ba12e6 commit cc39071

File tree

1 file changed

+173
-0
lines changed

1 file changed

+173
-0
lines changed
+173
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import dataclasses
2+
import json
3+
from typing import Final
4+
5+
import jax
6+
import jax.numpy as jnp
7+
import optax
8+
import yaml
9+
10+
from fusion_transport_surrogates.networks import GaussianMLPEnsemble
11+
from fusion_transport_surrogates.utils import normalize, unnormalize
12+
13+
INPUT_LABELS: Final[list[str]] = [
14+
"RLNS_1",
15+
"RLTS_1",
16+
"RLTS_2",
17+
"TAUS_2",
18+
"RMIN_LOC",
19+
"DRMAJDX_LOC",
20+
"Q_LOC",
21+
"SHAT",
22+
"XNUE",
23+
"KAPPA_LOC",
24+
"S_KAPPA_LOC",
25+
"DELTA_LOC",
26+
"S_DELTA_LOC",
27+
"BETAE",
28+
"ZEFF",
29+
]
30+
OUTPUT_LABELS: Final[list[str]] = ["efe_gb", "efi_gb", "pfi_gb"]
31+
32+
33+
@dataclasses.dataclass
34+
class TGLFNNModelConfig:
35+
n_ensemble: int
36+
hidden_size: int
37+
num_hiddens: int
38+
dropout: float
39+
normalize: bool
40+
unnormalize: bool
41+
hidden_size: int = 512
42+
43+
@classmethod
44+
def load(cls, config_path: str) -> "TGLFNNModelConfig":
45+
with open(config_path, "r") as f:
46+
config = yaml.safe_load(f)
47+
48+
return cls(
49+
n_ensemble=config["num_estimators"],
50+
num_hiddens=config["model_size"],
51+
dropout=config["dropout"],
52+
normalize=config["scale"],
53+
unnormalize=config["denormalise"],
54+
)
55+
56+
57+
@dataclasses.dataclass
58+
class TGLFNNModelStats:
59+
input_mean: jax.Array
60+
input_std: jax.Array
61+
output_mean: jax.Array
62+
output_std: jax.Array
63+
64+
@classmethod
65+
def load(cls, stats_path: str) -> "TGLFNNModelStats":
66+
with open(stats_path, "r") as f:
67+
stats = json.load(f)
68+
69+
return cls(
70+
input_mean=jnp.array([stats[label]["mean"] for label in INPUT_LABELS]),
71+
input_std=jnp.array([stats[label]["std"] for label in INPUT_LABELS]),
72+
output_mean=jnp.array([stats[label]["mean"] for label in OUTPUT_LABELS]),
73+
output_std=jnp.array([stats[label]["std"] for label in OUTPUT_LABELS]),
74+
)
75+
76+
77+
class TGLFNNModel:
78+
79+
def __init__(
80+
self,
81+
config: TGLFNNModelConfig,
82+
stats: TGLFNNModelStats,
83+
params: optax.Params | None,
84+
):
85+
self.config = config
86+
self.stats = stats
87+
self.params = params
88+
self.network = GaussianMLPEnsemble(
89+
n_ensemble=config.n_ensemble,
90+
hidden_size=config.hidden_size,
91+
n_hidden_layers=config.n_hidden_layers,
92+
dropout=config.dropout,
93+
)
94+
95+
@classmethod
96+
def load_from_pytorch(
97+
cls,
98+
config_path: str,
99+
stats_path: str,
100+
efe_gb_checkpoint_path: str,
101+
efi_gb_checkpoint_path: str,
102+
pfi_gb_checkpoint_path: str,
103+
*args,
104+
**kwargs,
105+
) -> "TGLFNNModel":
106+
import torch
107+
108+
def _convert_pytorch_state_dict(
109+
pytorch_state_dict: dict, config: TGLFNNModelConfig
110+
) -> optax.Params:
111+
params = {}
112+
for i in range(config.n_ensemble):
113+
model_dict = {}
114+
for j in range(config.n_hidden_layers):
115+
layer_dict = {
116+
"kernel": jnp.array(
117+
pytorch_state_dict[f"models.{i}.model.{j*3}.weight"]
118+
).T,
119+
"bias": jnp.array(
120+
pytorch_state_dict[f"models.{i}.model.{j*3}.bias"]
121+
).T,
122+
}
123+
model_dict[f"Dense_{j}"] = layer_dict
124+
params[f"GaussianMLP_{i}"] = model_dict
125+
return {"params": params}
126+
127+
config = TGLFNNModelConfig.load(config_path)
128+
stats = TGLFNNModelStats.load(stats_path)
129+
130+
with open(efe_gb_checkpoint_path, "rb") as f:
131+
efe_gb_params = _convert_pytorch_state_dict(
132+
torch.load(f, *args, **kwargs), config
133+
)
134+
with open(efi_gb_checkpoint_path, "rb") as f:
135+
efi_gb_params = _convert_pytorch_state_dict(
136+
torch.load(f, *args, **kwargs), config
137+
)
138+
with open(pfi_gb_checkpoint_path, "rb") as f:
139+
pfi_gb_params = _convert_pytorch_state_dict(
140+
torch.load(f, *args, **kwargs), config
141+
)
142+
143+
params = {
144+
"efe_gb": efe_gb_params,
145+
"efi_gb": efi_gb_params,
146+
"pfi_gb": pfi_gb_params,
147+
}
148+
149+
return cls(config, stats, params)
150+
151+
def predict(
152+
self,
153+
inputs: jax.Array,
154+
) -> dict[str, jax.Array]:
155+
if self.config.normalize:
156+
inputs = normalize(
157+
inputs, mean=self.stats.input_mean, stddev=self.stats.input_std
158+
)
159+
160+
output = jnp.stack(
161+
[
162+
self.network.apply(self.params[label], inputs, deterministic=True)
163+
for label in OUTPUT_LABELS
164+
],
165+
axis=-1,
166+
)
167+
168+
if self.config.unnormalize:
169+
output = unnormalize(
170+
output, mean=self.stats.output_mean, stddev=self.stats.output_std
171+
)
172+
173+
return output

0 commit comments

Comments
 (0)