Skip to content

Commit 00bb388

Browse files
committed
Unit tests for TGLFInputs and TGLFBasedTransportModel
1 parent 283cb22 commit 00bb388

File tree

3 files changed

+249
-31
lines changed

3 files changed

+249
-31
lines changed

torax/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class Constants:
5151
epsilon0: chex.Numeric
5252
mu0: chex.Numeric
5353
eps: chex.Numeric
54+
c: chex.Numeric
5455

5556

5657
CONSTANTS: Final[Constants] = Constants(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright 2024 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for torax.transport_model.tglf_based_transport_model."""
16+
from absl.testing import absltest
17+
from absl.testing import parameterized
18+
import chex
19+
import jax.numpy as jnp
20+
from torax import core_profile_setters
21+
from torax import state
22+
from torax.config import runtime_params as general_runtime_params
23+
from torax.config import runtime_params_slice
24+
from torax.geometry import circular_geometry
25+
from torax.geometry import geometry
26+
from torax.pedestal_model import pedestal_model as pedestal_model_lib
27+
from torax.pedestal_model import set_tped_nped
28+
from torax.sources import source_models as source_models_lib
29+
from torax.transport_model import tglf_based_transport_model
30+
from torax.transport_model import quasilinear_transport_model
31+
from torax.transport_model import runtime_params as runtime_params_lib
32+
33+
34+
def _get_model_inputs(transport: tglf_based_transport_model.RuntimeParams):
35+
"""Returns the model inputs for testing."""
36+
runtime_params = general_runtime_params.GeneralRuntimeParams()
37+
geo = circular_geometry.build_circular_geometry()
38+
source_models_builder = source_models_lib.SourceModelsBuilder()
39+
source_models = source_models_builder()
40+
pedestal_model_builder = (
41+
set_tped_nped.SetTemperatureDensityPedestalModelBuilder()
42+
)
43+
dynamic_runtime_params_slice = (
44+
runtime_params_slice.DynamicRuntimeParamsSliceProvider(
45+
runtime_params=runtime_params,
46+
transport=transport,
47+
sources=source_models_builder.runtime_params,
48+
pedestal=pedestal_model_builder.runtime_params,
49+
torax_mesh=geo.torax_mesh,
50+
)(
51+
t=runtime_params.numerics.t_initial,
52+
)
53+
)
54+
static_slice = runtime_params_slice.build_static_runtime_params_slice(
55+
runtime_params=runtime_params,
56+
source_runtime_params=source_models_builder.runtime_params,
57+
torax_mesh=geo.torax_mesh,
58+
)
59+
core_profiles = core_profile_setters.initial_core_profiles(
60+
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
61+
static_runtime_params_slice=static_slice,
62+
geo=geo,
63+
source_models=source_models,
64+
)
65+
return dynamic_runtime_params_slice, geo, core_profiles
66+
67+
68+
class TGLFBasedTransportModelTest(parameterized.TestCase):
69+
"""Unit tests for the `torax.transport_model.tglf_based_transport_model` module."""
70+
71+
def test_tglf_based_transport_model_output_shapes(self):
72+
"""Tests that the core transport output has the right shapes."""
73+
transport = tglf_based_transport_model.RuntimeParams(
74+
**runtime_params_lib.RuntimeParams()
75+
)
76+
transport_model = FakeTGLFBasedTransportModel()
77+
dynamic_runtime_params_slice, geo, core_profiles = _get_model_inputs(
78+
transport
79+
)
80+
pedestal_model = set_tped_nped.SetTemperatureDensityPedestalModel()
81+
pedestal_model_outputs = pedestal_model(
82+
dynamic_runtime_params_slice, geo, core_profiles
83+
)
84+
85+
core_transport = transport_model(
86+
dynamic_runtime_params_slice, geo, core_profiles, pedestal_model_outputs
87+
)
88+
expected_shape = geo.rho_face_norm.shape
89+
self.assertEqual(core_transport.chi_face_ion.shape, expected_shape)
90+
self.assertEqual(core_transport.chi_face_el.shape, expected_shape)
91+
self.assertEqual(core_transport.d_face_el.shape, expected_shape)
92+
self.assertEqual(core_transport.v_face_el.shape, expected_shape)
93+
94+
def test_tglf_based_transport_model_prepare_tglf_inputs_shapes(self):
95+
"""Tests that the tglf inputs have the expected shapes."""
96+
transport = tglf_based_transport_model.RuntimeParams(
97+
**runtime_params_lib.RuntimeParams()
98+
)
99+
dynamic_runtime_params_slice, geo, core_profiles = _get_model_inputs(
100+
transport
101+
)
102+
transport_model = FakeTGLFBasedTransportModel()
103+
tglf_inputs = transport_model._prepare_tglf_inputs(
104+
Zeff_face=dynamic_runtime_params_slice.plasma_composition.Zeff_face,
105+
q_correction_factor=dynamic_runtime_params_slice.numerics.q_correction_factor,
106+
geo=geo,
107+
core_profiles=core_profiles,
108+
)
109+
110+
# Inputs that are 1D
111+
vector_keys = [
112+
'chiGB',
113+
'lref_over_lti',
114+
'lref_over_lte',
115+
'lref_over_lne',
116+
'lref_over_lni0',
117+
'lref_over_lni1',
118+
'Ti_over_Te',
119+
'drmaj',
120+
'q',
121+
's_hat',
122+
'nu_ee',
123+
'kappa',
124+
'kappa_shear',
125+
'delta',
126+
'delta_shear',
127+
'beta_e',
128+
'Zeff',
129+
]
130+
# Inputs that are 0D
131+
scalar_keys = ['Rmaj', 'Rmin']
132+
133+
expected_vector_length = geo.rho_face_norm.shape[0]
134+
for key in vector_keys:
135+
try:
136+
self.assertEqual(
137+
getattr(tglf_inputs, key).shape, (expected_vector_length,)
138+
)
139+
except Exception as e:
140+
print(key, getattr(tglf_inputs, key))
141+
raise e
142+
for key in scalar_keys:
143+
self.assertEqual(getattr(tglf_inputs, key).shape, ())
144+
145+
146+
class FakeTGLFBasedTransportModel(
147+
tglf_based_transport_model.TGLFBasedTransportModel
148+
):
149+
"""Fake TGLFBasedTransportModel for testing purposes."""
150+
151+
def __init__(self):
152+
super().__init__()
153+
self._frozen = True
154+
155+
# pylint: disable=invalid-name
156+
def prepare_tglf_inputs(
157+
self,
158+
Zeff_face: chex.Array,
159+
q_correction_factor: chex.Numeric,
160+
geo: geometry.Geometry,
161+
core_profiles: state.CoreProfiles,
162+
) -> tglf_based_transport_model.TGLFInputs:
163+
"""Exposing prepare_tglf_inputs for testing."""
164+
return self._prepare_tglf_inputs(
165+
Zeff_face=Zeff_face,
166+
q_correction_factor=q_correction_factor,
167+
geo=geo,
168+
core_profiles=core_profiles,
169+
)
170+
171+
# pylint: enable=invalid-name
172+
173+
def _call_implementation(
174+
self,
175+
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
176+
geo: geometry.Geometry,
177+
core_profiles: state.CoreProfiles,
178+
pedestal_model_output: pedestal_model_lib.PedestalModelOutput,
179+
) -> state.CoreTransport:
180+
tglf_inputs = self._prepare_tglf_inputs(
181+
Zeff_face=dynamic_runtime_params_slice.plasma_composition.Zeff_face,
182+
q_correction_factor=dynamic_runtime_params_slice.numerics.q_correction_factor,
183+
geo=geo,
184+
core_profiles=core_profiles,
185+
)
186+
187+
transport = dynamic_runtime_params_slice.transport
188+
# Assert required for pytype.
189+
assert isinstance(
190+
transport,
191+
tglf_based_transport_model.DynamicRuntimeParams,
192+
)
193+
194+
return self._make_core_transport(
195+
qi=jnp.ones(geo.rho_face_norm.shape) * 0.4,
196+
qe=jnp.ones(geo.rho_face_norm.shape) * 0.5,
197+
pfe=jnp.ones(geo.rho_face_norm.shape) * 1.6,
198+
quasilinear_inputs=tglf_inputs,
199+
transport=transport,
200+
geo=geo,
201+
core_profiles=core_profiles,
202+
gradient_reference_length=geo.Rmaj, # TODO
203+
gyrobohm_flux_reference_length=geo.Rmin, # TODO
204+
)
205+
206+
207+
if __name__ == '__main__':
208+
absltest.main()

torax/transport_model/tglf_based_transport_model.py

+40-31
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import chex
1717
from jax import numpy as jnp
1818

19-
from torax import geometry
19+
from torax.geometry import geometry
2020
from torax import physics
2121
from torax import state
2222
from torax.constants import CONSTANTS
@@ -60,8 +60,8 @@ class TGLFInputs(quasilinear_transport_model.QuasilinearInputs):
6060

6161
# Ti/Te
6262
Ti_over_Te: chex.Array
63-
# dRmaj/dr
64-
dRmaj: chex.Array
63+
# drmaj/dr (flux surface centroid major radius gradient)
64+
drmaj: chex.Array
6565
# q
6666
q: chex.Array
6767
# r/q dq/dr
@@ -88,17 +88,18 @@ class TGLFBasedTransportModel(
8888
"""Base class for TGLF-based transport models."""
8989

9090
def _prepare_tglf_inputs(
91+
self,
9192
Zeff_face: chex.Array,
9293
q_correction_factor: chex.Numeric,
9394
geo: geometry.Geometry,
9495
core_profiles: state.CoreProfiles,
9596
) -> TGLFInputs:
96-
## Shorthand 'standard' variables
97+
# Shorthand 'standard' variables
9798
Te_keV = core_profiles.temp_el.face_value()
9899
Te_eV = Te_keV * 1e3
99100
Te_J = Te_keV * CONSTANTS.keV2J
100101
Ti_keV = core_profiles.temp_ion.face_value()
101-
ne = core_profiles.ne * core_profiles.nref
102+
ne = core_profiles.ne.face_value() * core_profiles.nref
102103
# q must be recalculated since in the nonlinear solver psi has intermediate
103104
# states in the iterative solve
104105
q, _ = physics.calc_q_from_psi(
@@ -107,29 +108,33 @@ def _prepare_tglf_inputs(
107108
q_correction_factor=q_correction_factor,
108109
)
109110

110-
## Reference values used for TGLF-specific normalisation
111-
# https://gafusion.github.io/doc/cgyro/outputs.html#output-normalization
112-
# https://gafusion.github.io/doc/geometry.html#effective-field
113-
# B_unit = 1/r d(psi_tor)/dr = q/r dpsi/dr
114-
# Note: TGLF uses geo.rmid = (Rmax - Rmin)/2 as the radial coordinate
115-
# This means all gradients are calculated w.r.t. rmid
116-
m_D_amu = 2.014 # Mass of deuterium
111+
# Reference values used for TGLF-specific normalisation
112+
# - 'a' in TGLF means the minor radius at the LCFS
113+
# - 'r' in TGLF means the flux surface centroid minor radius. Gradients are
114+
# taken w.r.t. r
115+
# https://gafusion.github.io/doc/tglf/tglf_list.html#rmin-loc
116+
# - B_unit = 1/r d(psi_tor)/dr = q/r dpsi/dr
117+
# https://gafusion.github.io/doc/geometry.html#effective-field
118+
# - c_s (ion sound speed)
119+
# https://gafusion.github.io/doc/cgyro/outputs.html#output-normalization
120+
m_D_amu = 2.014 # Mass of deuterium - TODO: load from lookup table
117121
m_D = m_D_amu * CONSTANTS.mp # Mass of deuterium
118122
c_s = (Te_J / m_D) ** 0.5
119-
a = geo.Rmin[-1] # Minor radius at LCFS
120-
B_unit = q / (geo.rmid) * jnp.gradient(core_profiles.psi, geo.rmid)
123+
a = geo.Rmin # Device minor radius at LCFS
124+
r = geo.rmid_face # Flux surface centroid minor radius
125+
B_unit = q / r * jnp.gradient(core_profiles.psi.face_value(), r)
121126

122-
## Dimensionless gradients, eg lref_over_lti where lref=amin, lti = -ti / (dti/dr)
127+
# Dimensionless gradients
123128
normalized_log_gradients = quasilinear_transport_model.NormalizedLogarithmicGradients.from_profiles(
124129
core_profiles=core_profiles,
125-
radial_coordinate=geo.rmid,
130+
radial_coordinate=geo.rmid, # TODO: Why does this have to be a variable on the cell grid?
126131
reference_length=a,
127132
)
128133

129-
## Dimensionless temperature ratio
134+
# Dimensionless temperature ratio
130135
Ti_over_Te = Ti_keV / Te_keV
131136

132-
## Dimensionless electron-electron collision frequency = nu_ee / (c_s/a)
137+
# Dimensionless electron-electron collision frequency = nu_ee / (c_s/a)
133138
# https://gafusion.github.io/doc/tglf/tglf_list.html#xnue
134139
# https://gafusion.github.io/doc/cgyro/cgyro_list.html#cgyro-nu-ee
135140
# Note: In the TGLF docs, XNUE is mislabelled as electron-ion collision frequency.
@@ -143,35 +148,39 @@ def _prepare_tglf_inputs(
143148
)
144149
nu_ee = normalised_nu_ee / (c_s / a)
145150

146-
## Safety factor, q
151+
# Safety factor, q
147152
# https://gafusion.github.io/doc/tglf/tglf_list.html#q-sa
148153
# defined before
149154

150-
## Safety factor shear, s_hat = r/q dq/dr
155+
# Safety factor shear, s_hat = r/q dq/dr
151156
# https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-shat-sa
152157
# Note: calc_s_from_psi_rmid gives rq dq/dr, so we divide by q**2
158+
# r_mid = r
153159
s_hat = physics.calc_s_from_psi_rmid(geo, core_profiles.psi) / q**2
154160

155-
## Electron beta
161+
# Electron beta
156162
# https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-betae
157163
# Note: Te in eV
158164
beta_e = 8 * jnp.pi * ne * Te_eV / B_unit**2
159165

160-
## Major radius shear = dRmaj/dr
166+
# Major radius shear = drmaj/drmin, where 'rmaj' is the flux surface centroid
167+
# major radius and 'rmin' the flux surface centroid minor radius
161168
# https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-drmajdx-loc
162-
dRmaj = jnp.gradient(geo.Rmaj, geo.rmid)
169+
rmaj = (
170+
geo.Rin_face + geo.Rout_face
171+
) / 2 # Flux surface centroid maj radius
172+
drmaj = jnp.gradient(rmaj, r)
163173

164-
## Elongation shear = r/kappa dkappa/dr
174+
# Elongation shear = r/kappa dkappa/dr
165175
# https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-s-kappa-loc
166176
kappa = geo.elongation_face
167-
kappa_shear = geo.rmid_face / kappa * jnp.gradient(kappa, geo.rmid_face)
177+
kappa_shear = geo.rmid_face / kappa * jnp.gradient(kappa, r)
168178

169-
## Triangularity shear = r ddelta/dr
179+
# Triangularity shear = r ddelta/dr
170180
# https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-s-delta-loc
171-
delta = geo.delta_face
172-
delta_shear = geo.rmid_face * jnp.gradient(delta, geo.rmid_face)
181+
delta_shear = r * jnp.gradient(geo.delta_face, r)
173182

174-
## Gyrobohm diffusivity
183+
# Gyrobohm diffusivity
175184
# https://gafusion.github.io/doc/tglf/tglf_table.html#id7
176185
# https://gafusion.github.io/doc/cgyro/outputs.html#output-normalization
177186
# Note: TGLF uses the same normalisation as CGYRO
@@ -198,13 +207,13 @@ def _prepare_tglf_inputs(
198207
lref_over_lni1=normalized_log_gradients.lref_over_lni1,
199208
# From TGLFInputs
200209
Ti_over_Te=Ti_over_Te,
201-
dRmaj=dRmaj,
210+
drmaj=drmaj,
202211
q=q,
203212
s_hat=s_hat,
204213
nu_ee=nu_ee,
205214
kappa=kappa,
206215
kappa_shear=kappa_shear,
207-
delta=delta,
216+
delta=geo.delta_face,
208217
delta_shear=delta_shear,
209218
beta_e=beta_e,
210219
Zeff=Zeff_face,

0 commit comments

Comments
 (0)