Skip to content

Commit 5df00bc

Browse files
committed
Add to build_sim
1 parent c2a7ba0 commit 5df00bc

File tree

3 files changed

+37
-5
lines changed

3 files changed

+37
-5
lines changed

torax/config/build_sim.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from torax.transport_model import constant as constant_transport
4242
from torax.transport_model import critical_gradient as critical_gradient_transport
4343
from torax.transport_model import qlknn_transport_model
44+
from torax.transport_model import tglfnn_transport_model
4445
# pylint: disable=g-import-not-at-top
4546
try:
4647
from torax.transport_model import qualikiz_transport_model
@@ -583,6 +584,16 @@ def build_transport_model_builder_from_config(
583584
**qualikiz_params,
584585
)
585586
)
587+
elif transport_model == 'tglfnn':
588+
tglfnn_params = dict(transport_config.pop('tglfnn_params', {}))
589+
tglfnn_params.update(transport_config)
590+
return tglfnn_transport_model.TGLFNNTransportModelBuilder(
591+
runtime_params=config_args.recursive_replace(
592+
tglfnn_transport_model.RuntimeParams(),
593+
**tglfnn_params,
594+
)
595+
)
596+
586597
# pylint: enable=undefined-variable
587598
raise ValueError(f'Unknown transport model: {transport_model}')
588599

torax/examples/torax_tglfnn.py

Whitespace-only changes.

torax/transport_model/tglf_surrogate_transport_model.py renamed to torax/transport_model/tglfnn_transport_model.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
from copy import deepcopy
33
from pathlib import Path
4-
4+
import dataclasses
55
import chex
66
import jax.numpy as jnp
77
from flax import linen as nn
@@ -12,7 +12,9 @@
1212
from torax.pedestal_model import pedestal_model as pedestal_model_lib
1313
from torax.transport_model import tglf_based_transport_model
1414
from warnings import warn
15-
from torax.transport_model.tglf_based_transport_model import TGLFInputs
15+
from torax.transport_model.tglf_based_transport_model import TGLFInputs, RuntimeParams
16+
from torax.transport_model import transport_model
17+
from typing import Callable
1618

1719

1820
class TGLFNNSurrogate(nn.Module):
@@ -118,9 +120,7 @@ def get_params_from_pytorch_state_dict(self, pytorch_state_dict: dict):
118120
return params
119121

120122

121-
class TGLFSurrogateTransportModel(
122-
tglf_based_transport_model.TGLFBasedTransportModel
123-
):
123+
class TGLFNNTransportModel(tglf_based_transport_model.TGLFBasedTransportModel):
124124
"""Calculate turbulent transport coefficients using a TGLF surrogate model."""
125125

126126
def __init__(
@@ -239,3 +239,24 @@ def _call_implementation(
239239
gradient_reference_length=geo.Rmin, # Device minor radius at LCFS
240240
gyrobohm_flux_reference_length=geo.Rmin, # TODO: Check
241241
)
242+
243+
244+
@dataclasses.dataclass(kw_only=True)
245+
class TGLFNNTransportModelBuilder(transport_model.TransportModelBuilder):
246+
"""When called, instantiates a TGLFSurrogateTransportModel."""
247+
248+
runtime_params: RuntimeParams = dataclasses.field(
249+
default_factory=RuntimeParams
250+
)
251+
weights_path: str = (
252+
"/home/theo/documents/ukaea/torax/tglfnn/1.0.0/tglfnn_checkpoint.json"
253+
)
254+
scaling_path: str = "/home/theo/documents/ukaea/torax/tglfnn/1.0.0/stats.json"
255+
256+
def __call__(
257+
self,
258+
) -> TGLFNNTransportModel:
259+
return TGLFNNTransportModel(
260+
path_to_model_weights_json=(self.weights_path),
261+
path_to_model_scaling_json=(self.scaling_path),
262+
)

0 commit comments

Comments
 (0)