From a6b239a60a0810113613104354f2c137aee615f9 Mon Sep 17 00:00:00 2001 From: Philippe Hamel Date: Thu, 29 Jan 2026 12:48:29 -0800 Subject: [PATCH] Fix for numerical instability with QLKNN rotation. PiperOrigin-RevId: 862857911 --- .../transport_model/qlknn_transport_model.py | 42 +++++++++++++++---- .../qualikiz_based_transport_model.py | 5 +-- 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/torax/_src/transport_model/qlknn_transport_model.py b/torax/_src/transport_model/qlknn_transport_model.py index a4542cdd1..eca68978b 100644 --- a/torax/_src/transport_model/qlknn_transport_model.py +++ b/torax/_src/transport_model/qlknn_transport_model.py @@ -22,6 +22,7 @@ import jax from jax import numpy as jnp from torax._src import array_typing +from torax._src import constants from torax._src import state from torax._src.config import runtime_params as runtime_params_lib from torax._src.geometry import geometry @@ -185,17 +186,29 @@ def _calculate_rotation_rule_factor( ) -def _apply_rotation_rule( +def _maybe_apply_rotation_rule( model_output: base_qlknn_model.ModelOutput, qualikiz_inputs: qualikiz_based_transport_model.QualikizInputs, + rotation_mode: qualikiz_based_transport_model.RotationMode, + geo: geometry.Geometry, ) -> base_qlknn_model.ModelOutput: """Apply the rotation scaling factor to the model output (Victor rule).""" + if rotation_mode == qualikiz_based_transport_model.RotationMode.OFF: + # Rotation is disabled. Do not apply the rotation rule. + return model_output + + gamma_E_GB = qualikiz_inputs.gamma_E_GB + if rotation_mode == qualikiz_based_transport_model.RotationMode.HALF_RADIUS: + # Only consider contribution from the outer half-radius (rho > 0.5). + gamma_E_GB = gamma_E_GB * jnp.where(geo.rho_face_norm > 0.5, 1.0, 0.0) + f_rot_rule = _calculate_rotation_rule_factor(qualikiz_inputs) - gamma_max = model_output['gamma_max'].squeeze() + lower_bound = 1e-4 - gamma_max = jnp.maximum(gamma_max, lower_bound) + gamma_max = jnp.maximum(model_output['gamma_max'].squeeze(), lower_bound) + scaling_factor = jnp.clip( - 1 + f_rot_rule * jnp.abs(qualikiz_inputs.gamma_E_GB) / gamma_max + 1.0 + f_rot_rule * jnp.abs(gamma_E_GB) / gamma_max, 0.0 ) # Add an extra dimension to match model outputs. scaling_factor = scaling_factor[..., jnp.newaxis] @@ -206,6 +219,15 @@ def _apply_rotation_rule( if flux.endswith('_itg') or flux.endswith('_tem'): updated_model_output[flux] = scaling_factor * updated_model_output[flux] + # Make tiny flux values exactly zero. + # This prevents numerical instabilities in the solver. The cause of + # these instabilities is not yet well understood. + updated_model_output[flux] = jnp.where( + jnp.abs(updated_model_output[flux]) < constants.CONSTANTS.eps, + 0.0, + updated_model_output[flux], + ) + return updated_model_output @@ -309,11 +331,13 @@ def _combined( lambda: feature_scan, # Called when False ) model_output = model.predict(feature_scan) - if ( - runtime_config_inputs.transport.rotation_mode - != qualikiz_based_transport_model.RotationMode.OFF - ): - model_output = _apply_rotation_rule(model_output, qualikiz_inputs) + + model_output = _maybe_apply_rotation_rule( + model_output, + qualikiz_inputs, + runtime_config_inputs.transport.rotation_mode, + geo, + ) model_output = _filter_model_output( model_output=model_output, diff --git a/torax/_src/transport_model/qualikiz_based_transport_model.py b/torax/_src/transport_model/qualikiz_based_transport_model.py index d1f32822e..07fa5bce7 100644 --- a/torax/_src/transport_model/qualikiz_based_transport_model.py +++ b/torax/_src/transport_model/qualikiz_based_transport_model.py @@ -224,10 +224,6 @@ def _get_v_ExB(): geo=geo, poloidal_velocity_multiplier=poloidal_velocity_multiplier, ) - v_ExB = transport.rotation_multiplier * v_ExB - if transport.rotation_mode == RotationMode.HALF_RADIUS: - # Only consider contribution from the outer half-radius (rho > 0.5). - v_ExB = v_ExB * jnp.where(geo.rho_face_norm > 0.5, 1, 0) return v_ExB # gamma_E_SI = r / q * d(v_ExB * q / r)/dr @@ -245,6 +241,7 @@ def _get_v_ExB(): gamma_E_SI = rmid_face / q * cv.face_grad( x=rmid, x_left=rmid_face[0], x_right=rmid_face[-1] ) + gamma_E_SI = gamma_E_SI * transport.rotation_multiplier # We need different normalizations for QuaLiKiz and QLKNN models. c_ref = jnp.sqrt(constants.keV_to_J / constants.m_amu)