2222import jax
2323from jax import numpy as jnp
2424from torax ._src import array_typing
25+ from torax ._src import constants
2526from torax ._src import state
2627from torax ._src .config import runtime_params as runtime_params_lib
2728from torax ._src .geometry import geometry
@@ -185,17 +186,29 @@ def _calculate_rotation_rule_factor(
185186 )
186187
187188
188- def _apply_rotation_rule (
189+ def _maybe_apply_rotation_rule (
189190 model_output : base_qlknn_model .ModelOutput ,
190191 qualikiz_inputs : qualikiz_based_transport_model .QualikizInputs ,
192+ rotation_mode : qualikiz_based_transport_model .RotationMode ,
193+ geo : geometry .Geometry ,
191194) -> base_qlknn_model .ModelOutput :
192195 """Apply the rotation scaling factor to the model output (Victor rule)."""
196+ if rotation_mode == qualikiz_based_transport_model .RotationMode .OFF :
197+ # Rotation is disabled. Do not apply the rotation rule.
198+ return model_output
199+
200+ gamma_E_GB = qualikiz_inputs .gamma_E_GB
201+ if rotation_mode == qualikiz_based_transport_model .RotationMode .HALF_RADIUS :
202+ # Only consider contribution from the outer half-radius (rho > 0.5).
203+ gamma_E_GB = gamma_E_GB * jnp .where (geo .rho_face_norm > 0.5 , 1.0 , 0.0 )
204+
193205 f_rot_rule = _calculate_rotation_rule_factor (qualikiz_inputs )
194- gamma_max = model_output [ 'gamma_max' ]. squeeze ()
206+
195207 lower_bound = 1e-4
196- gamma_max = jnp .maximum (gamma_max , lower_bound )
208+ gamma_max = jnp .maximum (model_output ['gamma_max' ].squeeze (), lower_bound )
209+
197210 scaling_factor = jnp .clip (
198- 1 + f_rot_rule * jnp .abs (qualikiz_inputs . gamma_E_GB ) / gamma_max
211+ 1.0 + f_rot_rule * jnp .abs (gamma_E_GB ) / gamma_max , 0.0
199212 )
200213 # Add an extra dimension to match model outputs.
201214 scaling_factor = scaling_factor [..., jnp .newaxis ]
@@ -206,6 +219,15 @@ def _apply_rotation_rule(
206219 if flux .endswith ('_itg' ) or flux .endswith ('_tem' ):
207220 updated_model_output [flux ] = scaling_factor * updated_model_output [flux ]
208221
222+ # Make tiny flux values exactly zero.
223+ # This prevents numerical instabilities in the solver. The cause of
224+ # these instabilities is not yet well understood.
225+ updated_model_output [flux ] = jnp .where (
226+ jnp .abs (updated_model_output [flux ]) < constants .CONSTANTS .eps ,
227+ 0.0 ,
228+ updated_model_output [flux ],
229+ )
230+
209231 return updated_model_output
210232
211233
@@ -309,11 +331,13 @@ def _combined(
309331 lambda : feature_scan , # Called when False
310332 )
311333 model_output = model .predict (feature_scan )
312- if (
313- runtime_config_inputs .transport .rotation_mode
314- != qualikiz_based_transport_model .RotationMode .OFF
315- ):
316- model_output = _apply_rotation_rule (model_output , qualikiz_inputs )
334+
335+ model_output = _maybe_apply_rotation_rule (
336+ model_output ,
337+ qualikiz_inputs ,
338+ runtime_config_inputs .transport .rotation_mode ,
339+ geo ,
340+ )
317341
318342 model_output = _filter_model_output (
319343 model_output = model_output ,
0 commit comments