Skip to content

Commit a6b239a

Browse files
hamelphiTorax team
authored andcommitted
Fix for numerical instability with QLKNN rotation.
PiperOrigin-RevId: 862857911
1 parent 3e9f3c6 commit a6b239a

File tree

2 files changed

+34
-13
lines changed

2 files changed

+34
-13
lines changed

torax/_src/transport_model/qlknn_transport_model.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import jax
2323
from jax import numpy as jnp
2424
from torax._src import array_typing
25+
from torax._src import constants
2526
from torax._src import state
2627
from torax._src.config import runtime_params as runtime_params_lib
2728
from torax._src.geometry import geometry
@@ -185,17 +186,29 @@ def _calculate_rotation_rule_factor(
185186
)
186187

187188

188-
def _apply_rotation_rule(
189+
def _maybe_apply_rotation_rule(
189190
model_output: base_qlknn_model.ModelOutput,
190191
qualikiz_inputs: qualikiz_based_transport_model.QualikizInputs,
192+
rotation_mode: qualikiz_based_transport_model.RotationMode,
193+
geo: geometry.Geometry,
191194
) -> base_qlknn_model.ModelOutput:
192195
"""Apply the rotation scaling factor to the model output (Victor rule)."""
196+
if rotation_mode == qualikiz_based_transport_model.RotationMode.OFF:
197+
# Rotation is disabled. Do not apply the rotation rule.
198+
return model_output
199+
200+
gamma_E_GB = qualikiz_inputs.gamma_E_GB
201+
if rotation_mode == qualikiz_based_transport_model.RotationMode.HALF_RADIUS:
202+
# Only consider contribution from the outer half-radius (rho > 0.5).
203+
gamma_E_GB = gamma_E_GB * jnp.where(geo.rho_face_norm > 0.5, 1.0, 0.0)
204+
193205
f_rot_rule = _calculate_rotation_rule_factor(qualikiz_inputs)
194-
gamma_max = model_output['gamma_max'].squeeze()
206+
195207
lower_bound = 1e-4
196-
gamma_max = jnp.maximum(gamma_max, lower_bound)
208+
gamma_max = jnp.maximum(model_output['gamma_max'].squeeze(), lower_bound)
209+
197210
scaling_factor = jnp.clip(
198-
1 + f_rot_rule * jnp.abs(qualikiz_inputs.gamma_E_GB) / gamma_max
211+
1.0 + f_rot_rule * jnp.abs(gamma_E_GB) / gamma_max, 0.0
199212
)
200213
# Add an extra dimension to match model outputs.
201214
scaling_factor = scaling_factor[..., jnp.newaxis]
@@ -206,6 +219,15 @@ def _apply_rotation_rule(
206219
if flux.endswith('_itg') or flux.endswith('_tem'):
207220
updated_model_output[flux] = scaling_factor * updated_model_output[flux]
208221

222+
# Make tiny flux values exactly zero.
223+
# This prevents numerical instabilities in the solver. The cause of
224+
# these instabilities is not yet well understood.
225+
updated_model_output[flux] = jnp.where(
226+
jnp.abs(updated_model_output[flux]) < constants.CONSTANTS.eps,
227+
0.0,
228+
updated_model_output[flux],
229+
)
230+
209231
return updated_model_output
210232

211233

@@ -309,11 +331,13 @@ def _combined(
309331
lambda: feature_scan, # Called when False
310332
)
311333
model_output = model.predict(feature_scan)
312-
if (
313-
runtime_config_inputs.transport.rotation_mode
314-
!= qualikiz_based_transport_model.RotationMode.OFF
315-
):
316-
model_output = _apply_rotation_rule(model_output, qualikiz_inputs)
334+
335+
model_output = _maybe_apply_rotation_rule(
336+
model_output,
337+
qualikiz_inputs,
338+
runtime_config_inputs.transport.rotation_mode,
339+
geo,
340+
)
317341

318342
model_output = _filter_model_output(
319343
model_output=model_output,

torax/_src/transport_model/qualikiz_based_transport_model.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,6 @@ def _get_v_ExB():
224224
geo=geo,
225225
poloidal_velocity_multiplier=poloidal_velocity_multiplier,
226226
)
227-
v_ExB = transport.rotation_multiplier * v_ExB
228-
if transport.rotation_mode == RotationMode.HALF_RADIUS:
229-
# Only consider contribution from the outer half-radius (rho > 0.5).
230-
v_ExB = v_ExB * jnp.where(geo.rho_face_norm > 0.5, 1, 0)
231227
return v_ExB
232228

233229
# gamma_E_SI = r / q * d(v_ExB * q / r)/dr
@@ -245,6 +241,7 @@ def _get_v_ExB():
245241
gamma_E_SI = rmid_face / q * cv.face_grad(
246242
x=rmid, x_left=rmid_face[0], x_right=rmid_face[-1]
247243
)
244+
gamma_E_SI = gamma_E_SI * transport.rotation_multiplier
248245

249246
# We need different normalizations for QuaLiKiz and QLKNN models.
250247
c_ref = jnp.sqrt(constants.keV_to_J / constants.m_amu)

0 commit comments

Comments
 (0)