Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions torax/_src/transport_model/qlknn_transport_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import jax
from jax import numpy as jnp
from torax._src import array_typing
from torax._src import constants
from torax._src import state
from torax._src.config import runtime_params as runtime_params_lib
from torax._src.geometry import geometry
Expand Down Expand Up @@ -185,17 +186,29 @@ def _calculate_rotation_rule_factor(
)


def _apply_rotation_rule(
def _maybe_apply_rotation_rule(
model_output: base_qlknn_model.ModelOutput,
qualikiz_inputs: qualikiz_based_transport_model.QualikizInputs,
rotation_mode: qualikiz_based_transport_model.RotationMode,
geo: geometry.Geometry,
) -> base_qlknn_model.ModelOutput:
"""Apply the rotation scaling factor to the model output (Victor rule)."""
if rotation_mode == qualikiz_based_transport_model.RotationMode.OFF:
# Rotation is disabled. Do not apply the rotation rule.
return model_output

gamma_E_GB = qualikiz_inputs.gamma_E_GB
if rotation_mode == qualikiz_based_transport_model.RotationMode.HALF_RADIUS:
# Only consider contribution from the outer half-radius (rho > 0.5).
gamma_E_GB = gamma_E_GB * jnp.where(geo.rho_face_norm > 0.5, 1.0, 0.0)

f_rot_rule = _calculate_rotation_rule_factor(qualikiz_inputs)
gamma_max = model_output['gamma_max'].squeeze()

lower_bound = 1e-4
gamma_max = jnp.maximum(gamma_max, lower_bound)
gamma_max = jnp.maximum(model_output['gamma_max'].squeeze(), lower_bound)

scaling_factor = jnp.clip(
1 + f_rot_rule * jnp.abs(qualikiz_inputs.gamma_E_GB) / gamma_max
1.0 + f_rot_rule * jnp.abs(gamma_E_GB) / gamma_max, 0.0
)
# Add an extra dimension to match model outputs.
scaling_factor = scaling_factor[..., jnp.newaxis]
Expand All @@ -206,6 +219,15 @@ def _apply_rotation_rule(
if flux.endswith('_itg') or flux.endswith('_tem'):
updated_model_output[flux] = scaling_factor * updated_model_output[flux]

# Make tiny flux values exactly zero.
# This prevents numerical instabilities in the solver. The cause of
# these instabilities is not yet well understood.
updated_model_output[flux] = jnp.where(
jnp.abs(updated_model_output[flux]) < constants.CONSTANTS.eps,
0.0,
updated_model_output[flux],
)

return updated_model_output


Expand Down Expand Up @@ -309,11 +331,13 @@ def _combined(
lambda: feature_scan, # Called when False
)
model_output = model.predict(feature_scan)
if (
runtime_config_inputs.transport.rotation_mode
!= qualikiz_based_transport_model.RotationMode.OFF
):
model_output = _apply_rotation_rule(model_output, qualikiz_inputs)

model_output = _maybe_apply_rotation_rule(
model_output,
qualikiz_inputs,
runtime_config_inputs.transport.rotation_mode,
geo,
)

model_output = _filter_model_output(
model_output=model_output,
Expand Down
5 changes: 1 addition & 4 deletions torax/_src/transport_model/qualikiz_based_transport_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,6 @@ def _get_v_ExB():
geo=geo,
poloidal_velocity_multiplier=poloidal_velocity_multiplier,
)
v_ExB = transport.rotation_multiplier * v_ExB
if transport.rotation_mode == RotationMode.HALF_RADIUS:
# Only consider contribution from the outer half-radius (rho > 0.5).
v_ExB = v_ExB * jnp.where(geo.rho_face_norm > 0.5, 1, 0)
return v_ExB

# gamma_E_SI = r / q * d(v_ExB * q / r)/dr
Expand All @@ -245,6 +241,7 @@ def _get_v_ExB():
gamma_E_SI = rmid_face / q * cv.face_grad(
x=rmid, x_left=rmid_face[0], x_right=rmid_face[-1]
)
gamma_E_SI = gamma_E_SI * transport.rotation_multiplier

# We need different normalizations for QuaLiKiz and QLKNN models.
c_ref = jnp.sqrt(constants.keV_to_J / constants.m_amu)
Expand Down
Loading