Skip to content

Commit cba2552

Browse files
committed
Unit tests for TGLFInputs and TGLFBasedTransportModel
1 parent 6b62f36 commit cba2552

File tree

3 files changed

+248
-31
lines changed

3 files changed

+248
-31
lines changed

torax/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class Constants:
5151
epsilon0: chex.Numeric
5252
mu0: chex.Numeric
5353
eps: chex.Numeric
54+
c: chex.Numeric
5455

5556

5657
CONSTANTS: Final[Constants] = Constants(
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
# Copyright 2024 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for torax.transport_model.tglf_based_transport_model."""
16+
from absl.testing import absltest
17+
from absl.testing import parameterized
18+
import chex
19+
import jax.numpy as jnp
20+
from torax import core_profile_setters
21+
from torax import state
22+
from torax.config import runtime_params as general_runtime_params
23+
from torax.config import runtime_params_slice
24+
from torax.geometry import geometry
25+
from torax.pedestal_model import pedestal_model as pedestal_model_lib
26+
from torax.pedestal_model import set_tped_nped
27+
from torax.sources import source_models as source_models_lib
28+
from torax.transport_model import tglf_based_transport_model
29+
from torax.transport_model import quasilinear_transport_model
30+
from torax.transport_model import runtime_params as runtime_params_lib
31+
32+
33+
def _get_model_inputs(transport: tglf_based_transport_model.RuntimeParams):
34+
"""Returns the model inputs for testing."""
35+
runtime_params = general_runtime_params.GeneralRuntimeParams()
36+
geo = geometry.build_circular_geometry()
37+
source_models_builder = source_models_lib.SourceModelsBuilder()
38+
source_models = source_models_builder()
39+
pedestal_model_builder = (
40+
set_tped_nped.SetTemperatureDensityPedestalModelBuilder()
41+
)
42+
dynamic_runtime_params_slice = (
43+
runtime_params_slice.DynamicRuntimeParamsSliceProvider(
44+
runtime_params=runtime_params,
45+
transport=transport,
46+
sources=source_models_builder.runtime_params,
47+
pedestal=pedestal_model_builder.runtime_params,
48+
torax_mesh=geo.torax_mesh,
49+
)(
50+
t=runtime_params.numerics.t_initial,
51+
)
52+
)
53+
static_slice = runtime_params_slice.build_static_runtime_params_slice(
54+
runtime_params=runtime_params,
55+
source_runtime_params=source_models_builder.runtime_params,
56+
torax_mesh=geo.torax_mesh,
57+
)
58+
core_profiles = core_profile_setters.initial_core_profiles(
59+
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
60+
static_runtime_params_slice=static_slice,
61+
geo=geo,
62+
source_models=source_models,
63+
)
64+
return dynamic_runtime_params_slice, geo, core_profiles
65+
66+
67+
class TGLFBasedTransportModelTest(parameterized.TestCase):
68+
"""Unit tests for the `torax.transport_model.tglf_based_transport_model` module."""
69+
70+
def test_tglf_based_transport_model_output_shapes(self):
71+
"""Tests that the core transport output has the right shapes."""
72+
transport = tglf_based_transport_model.RuntimeParams(
73+
**runtime_params_lib.RuntimeParams()
74+
)
75+
transport_model = FakeTGLFBasedTransportModel()
76+
dynamic_runtime_params_slice, geo, core_profiles = _get_model_inputs(
77+
transport
78+
)
79+
pedestal_model = set_tped_nped.SetTemperatureDensityPedestalModel()
80+
pedestal_model_outputs = pedestal_model(
81+
dynamic_runtime_params_slice, geo, core_profiles
82+
)
83+
84+
core_transport = transport_model(
85+
dynamic_runtime_params_slice, geo, core_profiles, pedestal_model_outputs
86+
)
87+
expected_shape = geo.rho_face_norm.shape
88+
self.assertEqual(core_transport.chi_face_ion.shape, expected_shape)
89+
self.assertEqual(core_transport.chi_face_el.shape, expected_shape)
90+
self.assertEqual(core_transport.d_face_el.shape, expected_shape)
91+
self.assertEqual(core_transport.v_face_el.shape, expected_shape)
92+
93+
def test_tglf_based_transport_model_prepare_tglf_inputs_shapes(self):
94+
"""Tests that the tglf inputs have the expected shapes."""
95+
transport = tglf_based_transport_model.RuntimeParams(
96+
**runtime_params_lib.RuntimeParams()
97+
)
98+
dynamic_runtime_params_slice, geo, core_profiles = _get_model_inputs(
99+
transport
100+
)
101+
transport_model = FakeTGLFBasedTransportModel()
102+
tglf_inputs = transport_model._prepare_tglf_inputs(
103+
Zeff_face=dynamic_runtime_params_slice.plasma_composition.Zeff_face,
104+
q_correction_factor=dynamic_runtime_params_slice.numerics.q_correction_factor,
105+
geo=geo,
106+
core_profiles=core_profiles,
107+
)
108+
109+
# Inputs that are 1D
110+
vector_keys = [
111+
'chiGB',
112+
'lref_over_lti',
113+
'lref_over_lte',
114+
'lref_over_lne',
115+
'lref_over_lni0',
116+
'lref_over_lni1',
117+
'Ti_over_Te',
118+
'drmaj',
119+
'q',
120+
's_hat',
121+
'nu_ee',
122+
'kappa',
123+
'kappa_shear',
124+
'delta',
125+
'delta_shear',
126+
'beta_e',
127+
'Zeff',
128+
]
129+
# Inputs that are 0D
130+
scalar_keys = ['Rmaj', 'Rmin']
131+
132+
expected_vector_length = geo.rho_face_norm.shape[0]
133+
for key in vector_keys:
134+
try:
135+
self.assertEqual(
136+
getattr(tglf_inputs, key).shape, (expected_vector_length,)
137+
)
138+
except Exception as e:
139+
print(key, getattr(tglf_inputs, key))
140+
raise e
141+
for key in scalar_keys:
142+
self.assertEqual(getattr(tglf_inputs, key).shape, ())
143+
144+
145+
class FakeTGLFBasedTransportModel(
146+
tglf_based_transport_model.TGLFBasedTransportModel
147+
):
148+
"""Fake TGLFBasedTransportModel for testing purposes."""
149+
150+
def __init__(self):
151+
super().__init__()
152+
self._frozen = True
153+
154+
# pylint: disable=invalid-name
155+
def prepare_tglf_inputs(
156+
self,
157+
Zeff_face: chex.Array,
158+
q_correction_factor: chex.Numeric,
159+
geo: geometry.Geometry,
160+
core_profiles: state.CoreProfiles,
161+
) -> tglf_based_transport_model.TGLFInputs:
162+
"""Exposing prepare_tglf_inputs for testing."""
163+
return self._prepare_tglf_inputs(
164+
Zeff_face=Zeff_face,
165+
q_correction_factor=q_correction_factor,
166+
geo=geo,
167+
core_profiles=core_profiles,
168+
)
169+
170+
# pylint: enable=invalid-name
171+
172+
def _call_implementation(
173+
self,
174+
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
175+
geo: geometry.Geometry,
176+
core_profiles: state.CoreProfiles,
177+
pedestal_model_output: pedestal_model_lib.PedestalModelOutput,
178+
) -> state.CoreTransport:
179+
tglf_inputs = self._prepare_tglf_inputs(
180+
Zeff_face=dynamic_runtime_params_slice.plasma_composition.Zeff_face,
181+
q_correction_factor=dynamic_runtime_params_slice.numerics.q_correction_factor,
182+
geo=geo,
183+
core_profiles=core_profiles,
184+
)
185+
186+
transport = dynamic_runtime_params_slice.transport
187+
# Assert required for pytype.
188+
assert isinstance(
189+
transport,
190+
tglf_based_transport_model.DynamicRuntimeParams,
191+
)
192+
193+
return self._make_core_transport(
194+
qi=jnp.ones(geo.rho_face_norm.shape) * 0.4,
195+
qe=jnp.ones(geo.rho_face_norm.shape) * 0.5,
196+
pfe=jnp.ones(geo.rho_face_norm.shape) * 1.6,
197+
quasilinear_inputs=tglf_inputs,
198+
transport=transport,
199+
geo=geo,
200+
core_profiles=core_profiles,
201+
gradient_reference_length=geo.Rmaj, # TODO
202+
gyrobohm_flux_reference_length=geo.Rmin, # TODO
203+
)
204+
205+
206+
if __name__ == '__main__':
207+
absltest.main()

torax/transport_model/tglf_based_transport_model.py

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import chex
1717
from jax import numpy as jnp
1818

19-
from torax import geometry
19+
from torax.geometry import geometry
2020
from torax import physics
2121
from torax import state
2222
from torax.constants import CONSTANTS
@@ -60,8 +60,8 @@ class TGLFInputs(quasilinear_transport_model.QuasilinearInputs):
6060

6161
# Ti/Te
6262
Ti_over_Te: chex.Array
63-
# dRmaj/dr
64-
dRmaj: chex.Array
63+
# drmaj/dr (flux surface centroid major radius gradient)
64+
drmaj: chex.Array
6565
# q
6666
q: chex.Array
6767
# r/q dq/dr
@@ -88,17 +88,18 @@ class TGLFBasedTransportModel(
8888
"""Base class for TGLF-based transport models."""
8989

9090
def _prepare_tglf_inputs(
91+
self,
9192
Zeff_face: chex.Array,
9293
q_correction_factor: chex.Numeric,
9394
geo: geometry.Geometry,
9495
core_profiles: state.CoreProfiles,
9596
) -> TGLFInputs:
96-
## Shorthand 'standard' variables
97+
# Shorthand 'standard' variables
9798
Te_keV = core_profiles.temp_el.face_value()
9899
Te_eV = Te_keV * 1e3
99100
Te_J = Te_keV * CONSTANTS.keV2J
100101
Ti_keV = core_profiles.temp_ion.face_value()
101-
ne = core_profiles.ne * core_profiles.nref
102+
ne = core_profiles.ne.face_value() * core_profiles.nref
102103
# q must be recalculated since in the nonlinear solver psi has intermediate
103104
# states in the iterative solve
104105
q, _ = physics.calc_q_from_psi(
@@ -107,29 +108,33 @@ def _prepare_tglf_inputs(
107108
q_correction_factor=q_correction_factor,
108109
)
109110

110-
## Reference values used for TGLF-specific normalisation
111-
# https://gafusion.github.io/doc/cgyro/outputs.html#output-normalization
112-
# https://gafusion.github.io/doc/geometry.html#effective-field
113-
# B_unit = 1/r d(psi_tor)/dr = q/r dpsi/dr
114-
# Note: TGLF uses geo.rmid = (Rmax - Rmin)/2 as the radial coordinate
115-
# This means all gradients are calculated w.r.t. rmid
116-
m_D_amu = 2.014 # Mass of deuterium
111+
# Reference values used for TGLF-specific normalisation
112+
# - 'a' in TGLF means the minor radius at the LCFS
113+
# - 'r' in TGLF means the flux surface centroid minor radius. Gradients are
114+
# taken w.r.t. r
115+
# https://gafusion.github.io/doc/tglf/tglf_list.html#rmin-loc
116+
# - B_unit = 1/r d(psi_tor)/dr = q/r dpsi/dr
117+
# https://gafusion.github.io/doc/geometry.html#effective-field
118+
# - c_s (ion sound speed)
119+
# https://gafusion.github.io/doc/cgyro/outputs.html#output-normalization
120+
m_D_amu = 2.014 # Mass of deuterium - TODO: load from lookup table
117121
m_D = m_D_amu * CONSTANTS.mp # Mass of deuterium
118122
c_s = (Te_J / m_D) ** 0.5
119-
a = geo.Rmin[-1] # Minor radius at LCFS
120-
B_unit = q / (geo.rmid) * jnp.gradient(core_profiles.psi, geo.rmid)
123+
a = geo.Rmin # Device minor radius at LCFS
124+
r = geo.rmid_face # Flux surface centroid minor radius
125+
B_unit = q / r * jnp.gradient(core_profiles.psi.face_value(), r)
121126

122-
## Dimensionless gradients, eg lref_over_lti where lref=amin, lti = -ti / (dti/dr)
127+
# Dimensionless gradients
123128
normalized_log_gradients = quasilinear_transport_model.NormalizedLogarithmicGradients.from_profiles(
124129
core_profiles=core_profiles,
125-
radial_coordinate=geo.rmid,
130+
radial_coordinate=geo.rmid, # TODO: Why does this have to be a variable on the cell grid?
126131
reference_length=a,
127132
)
128133

129-
## Dimensionless temperature ratio
134+
# Dimensionless temperature ratio
130135
Ti_over_Te = Ti_keV / Te_keV
131136

132-
## Dimensionless electron-electron collision frequency = nu_ee / (c_s/a)
137+
# Dimensionless electron-electron collision frequency = nu_ee / (c_s/a)
133138
# https://gafusion.github.io/doc/tglf/tglf_list.html#xnue
134139
# https://gafusion.github.io/doc/cgyro/cgyro_list.html#cgyro-nu-ee
135140
# Note: In the TGLF docs, XNUE is mislabelled as electron-ion collision frequency.
@@ -143,35 +148,39 @@ def _prepare_tglf_inputs(
143148
)
144149
nu_ee = normalised_nu_ee / (c_s / a)
145150

146-
## Safety factor, q
151+
# Safety factor, q
147152
# https://gafusion.github.io/doc/tglf/tglf_list.html#q-sa
148153
# defined before
149154

150-
## Safety factor shear, s_hat = r/q dq/dr
155+
# Safety factor shear, s_hat = r/q dq/dr
151156
# https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-shat-sa
152157
# Note: calc_s_from_psi_rmid gives rq dq/dr, so we divide by q**2
158+
# r_mid = r
153159
s_hat = physics.calc_s_from_psi_rmid(geo, core_profiles.psi) / q**2
154160

155-
## Electron beta
161+
# Electron beta
156162
# https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-betae
157163
# Note: Te in eV
158164
beta_e = 8 * jnp.pi * ne * Te_eV / B_unit**2
159165

160-
## Major radius shear = dRmaj/dr
166+
# Major radius shear = drmaj/drmin, where 'rmaj' is the flux surface centroid
167+
# major radius and 'rmin' the flux surface centroid minor radius
161168
# https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-drmajdx-loc
162-
dRmaj = jnp.gradient(geo.Rmaj, geo.rmid)
169+
rmaj = (
170+
geo.Rin_face + geo.Rout_face
171+
) / 2 # Flux surface centroid maj radius
172+
drmaj = jnp.gradient(rmaj, r)
163173

164-
## Elongation shear = r/kappa dkappa/dr
174+
# Elongation shear = r/kappa dkappa/dr
165175
# https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-s-kappa-loc
166176
kappa = geo.elongation_face
167-
kappa_shear = geo.rmid_face / kappa * jnp.gradient(kappa, geo.rmid_face)
177+
kappa_shear = geo.rmid_face / kappa * jnp.gradient(kappa, r)
168178

169-
## Triangularity shear = r ddelta/dr
179+
# Triangularity shear = r ddelta/dr
170180
# https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-s-delta-loc
171-
delta = geo.delta_face
172-
delta_shear = geo.rmid_face * jnp.gradient(delta, geo.rmid_face)
181+
delta_shear = r * jnp.gradient(geo.delta_face, r)
173182

174-
## Gyrobohm diffusivity
183+
# Gyrobohm diffusivity
175184
# https://gafusion.github.io/doc/tglf/tglf_table.html#id7
176185
# https://gafusion.github.io/doc/cgyro/outputs.html#output-normalization
177186
# Note: TGLF uses the same normalisation as CGYRO
@@ -198,13 +207,13 @@ def _prepare_tglf_inputs(
198207
lref_over_lni1=normalized_log_gradients.lref_over_lni1,
199208
# From TGLFInputs
200209
Ti_over_Te=Ti_over_Te,
201-
dRmaj=dRmaj,
210+
drmaj=drmaj,
202211
q=q,
203212
s_hat=s_hat,
204213
nu_ee=nu_ee,
205214
kappa=kappa,
206215
kappa_shear=kappa_shear,
207-
delta=delta,
216+
delta=geo.delta_face,
208217
delta_shear=delta_shear,
209218
beta_e=beta_e,
210219
Zeff=Zeff_face,

0 commit comments

Comments
 (0)