|
1 | 1 | import json |
2 | 2 | from copy import deepcopy |
3 | 3 | from pathlib import Path |
4 | | - |
| 4 | +import dataclasses |
5 | 5 | import chex |
6 | 6 | import jax.numpy as jnp |
7 | 7 | from flax import linen as nn |
|
12 | 12 | from torax.pedestal_model import pedestal_model as pedestal_model_lib |
13 | 13 | from torax.transport_model import tglf_based_transport_model |
14 | 14 | from warnings import warn |
15 | | -from torax.transport_model.tglf_based_transport_model import TGLFInputs |
| 15 | +from torax.transport_model.tglf_based_transport_model import TGLFInputs, RuntimeParams |
| 16 | +from torax.transport_model import transport_model |
| 17 | +from typing import Callable |
16 | 18 |
|
17 | 19 |
|
18 | 20 | class TGLFNNSurrogate(nn.Module): |
@@ -118,9 +120,7 @@ def get_params_from_pytorch_state_dict(self, pytorch_state_dict: dict): |
118 | 120 | return params |
119 | 121 |
|
120 | 122 |
|
121 | | -class TGLFSurrogateTransportModel( |
122 | | - tglf_based_transport_model.TGLFBasedTransportModel |
123 | | -): |
| 123 | +class TGLFNNTransportModel(tglf_based_transport_model.TGLFBasedTransportModel): |
124 | 124 | """Calculate turbulent transport coefficients using a TGLF surrogate model.""" |
125 | 125 |
|
126 | 126 | def __init__( |
@@ -239,3 +239,24 @@ def _call_implementation( |
239 | 239 | gradient_reference_length=geo.Rmin, # Device minor radius at LCFS |
240 | 240 | gyrobohm_flux_reference_length=geo.Rmin, # TODO: Check |
241 | 241 | ) |
| 242 | + |
| 243 | + |
| 244 | +@dataclasses.dataclass(kw_only=True) |
| 245 | +class TGLFNNTransportModelBuilder(transport_model.TransportModelBuilder): |
| 246 | + """When called, instantiates a TGLFSurrogateTransportModel.""" |
| 247 | + |
| 248 | + runtime_params: RuntimeParams = dataclasses.field( |
| 249 | + default_factory=RuntimeParams |
| 250 | + ) |
| 251 | + weights_path: str = ( |
| 252 | + "/home/theo/documents/ukaea/torax/tglfnn/1.0.0/tglfnn_checkpoint.json" |
| 253 | + ) |
| 254 | + scaling_path: str = "/home/theo/documents/ukaea/torax/tglfnn/1.0.0/stats.json" |
| 255 | + |
| 256 | + def __call__( |
| 257 | + self, |
| 258 | + ) -> TGLFNNTransportModel: |
| 259 | + return TGLFNNTransportModel( |
| 260 | + path_to_model_weights_json=(self.weights_path), |
| 261 | + path_to_model_scaling_json=(self.scaling_path), |
| 262 | + ) |
0 commit comments