1919import jax
2020from jax import numpy as jnp
2121from torax ._src import constants
22+ from torax ._src import math_utils
2223from torax ._src .edge import collisional_radiative_models
2324from torax ._src .edge import divertor_sol_1d as divertor_sol_1d_lib
2425from torax ._src .edge import extended_lengyel_defaults
@@ -101,9 +102,11 @@ def body_fun(_, carry):
101102 # Solve for the impurity concentration required to achieve the target
102103 # temperature for a given q_parallel. This also updates the divertor and
103104 # separatrix Z_eff values in sol_model, used downstream.
104- current_sol_model . state . c_z_prefactor , physics_outcome = (
105- _solve_for_c_z_prefactor ( sol_model = current_sol_model )
105+ c_z_prefactor , physics_outcome = _solve_for_c_z_prefactor (
106+ sol_model = current_sol_model
106107 )
108+ # Clip to physical values (non-negative impurity concentration).
109+ current_sol_model .state .c_z_prefactor = jnp .maximum (c_z_prefactor , 0.0 )
107110
108111 # Update alpha_t for the next loop iteration.
109112 current_sol_model .state .alpha_t = divertor_sol_1d_lib .calc_alpha_t (
@@ -246,12 +249,12 @@ def forward_mode_newton_solver(
246249 """
247250 # 1. Create initial guess state vector.
248251 # Uses log space for strictly positive variables and to improve conditioning.
249- # alpha_t is left linear since should always remain O(1) and log steps
250- # can lead to numerical issues due to exponential amplification. Positivity is
251- # enforced via softplus when unpacking .
252+ # alpha_t is strictly positive, but is enforced via softplus in the residual.
253+ # Therefore we must inverse softplus the initial guess to maintain
254+ # consistency .
252255 x0 = jnp .stack ([
253256 jnp .log (initial_sol_model .state .q_parallel ),
254- initial_sol_model .state .alpha_t ,
257+ math_utils . inverse_softplus ( initial_sol_model .state .alpha_t ) ,
255258 jnp .log (initial_sol_model .state .kappa_e ),
256259 jnp .log (initial_sol_model .state .T_e_target ),
257260 ])
@@ -317,13 +320,13 @@ def inverse_mode_newton_solver(
317320 # 1. Create initial guess state vector.
318321
319322 # Uses log space for strictly positive variables and to improve conditioning.
320- # alpha_t is left linear since should always remain O(1) and log steps
321- # can lead to numerical issues due to exponential amplification. Positivity is
322- # enforced via softplus when unpacking .
323+ # alpha_t is strictly positive, but is enforced via softplus in the residual.
324+ # Therefore we must inverse softplus the initial guess to maintain
325+ # consistency .
323326
324327 x0 = jnp .stack ([
325328 jnp .log (initial_sol_model .state .q_parallel ),
326- initial_sol_model .state .alpha_t ,
329+ math_utils . inverse_softplus ( initial_sol_model .state .alpha_t ) ,
327330 jnp .log (initial_sol_model .state .kappa_e ),
328331 initial_sol_model .state .c_z_prefactor ,
329332 ])
@@ -342,11 +345,15 @@ def inverse_mode_newton_solver(
342345 )
343346
344347 # 4. Construct final model.
348+ # Clip c_z_prefactor to 0.0 if it is negative (unphysical solution).
349+ # Negative values are allowed during the solve process (to ensure smooth
350+ # gradients for the solver), but the final physical state must have
351+ # non-negative concentrations.
345352 final_state = divertor_sol_1d_lib .ExtendedLengyelState (
346353 q_parallel = jnp .exp (x_root [0 ]),
347354 alpha_t = jax .nn .softplus (x_root [1 ]),
348355 kappa_e = jnp .exp (x_root [2 ]),
349- c_z_prefactor = x_root [3 ],
356+ c_z_prefactor = jnp . maximum ( x_root [3 ], 0.0 ) ,
350357 T_e_target = fixed_Tt ,
351358 )
352359
@@ -355,6 +362,7 @@ def inverse_mode_newton_solver(
355362 )
356363
357364 # 5. Re-calculate physics outcome at final state to return the physics_outcome
365+ # This uses the clipped (physical) c_z_prefactor.
358366 _ , physics_outcome = _solve_for_c_z_prefactor (sol_model = final_sol_model )
359367
360368 solver_status = ExtendedLengyelSolverStatus (
@@ -454,7 +462,7 @@ def _forward_residual(
454462 at_calc_safe = jnp .maximum (at_calc , constants .CONSTANTS .eps )
455463
456464 r_qp = jnp .log (qp_calc_safe ) - x_vec [0 ]
457- r_at = at_calc_safe - current_state . alpha_t
465+ r_at = math_utils . inverse_softplus ( at_calc_safe ) - x_vec [ 1 ]
458466 r_ke = jnp .log (ke_calc_safe ) - x_vec [2 ]
459467 r_Tt = jnp .log (Tt_calc_safe ) - x_vec [3 ]
460468
@@ -468,11 +476,15 @@ def _inverse_residual(
468476) -> jax .Array :
469477 """Calculates the residual vector for Inverse Mode F(x) = 0."""
470478 # 1. Construct physical state from vector guess.
479+ # Note: c_z_prefactor is clipped to be non-negative for physics calculations
480+ # to avoid NaNs (e.g. in Z_eff -> alpha_t). The solver is allowed to explore
481+ # negative values in x_vec[3] to properly find the root (even if unphysical),
482+ # but the state used for consistent physics checks must be valid.
471483 current_state = divertor_sol_1d_lib .ExtendedLengyelState (
472484 q_parallel = jnp .exp (x_vec [0 ]),
473485 alpha_t = jax .nn .softplus (x_vec [1 ]),
474486 kappa_e = jnp .exp (x_vec [2 ]),
475- c_z_prefactor = x_vec [3 ],
487+ c_z_prefactor = jnp . maximum ( x_vec [3 ], 0.0 ) ,
476488 T_e_target = fixed_Tt ,
477489 )
478490
@@ -508,9 +520,12 @@ def _inverse_residual(
508520 at_calc_safe = jnp .maximum (at_calc , constants .CONSTANTS .eps )
509521
510522 r_qp = jnp .log (qp_calc_safe ) - x_vec [0 ]
511- r_at = at_calc_safe - current_state . alpha_t
523+ r_at = math_utils . inverse_softplus ( at_calc_safe ) - x_vec [ 1 ]
512524 r_ke = jnp .log (ke_calc_safe ) - x_vec [2 ]
513- r_cz = cz_calc - current_state .c_z_prefactor
525+ # Residual for c_z compares the calculated required c_z against the
526+ # *raw* solver guess x_vec[3], not the clipped state value.
527+ # This provides a gradient signal even when x_vec[3] is negative.
528+ r_cz = cz_calc - x_vec [3 ]
514529
515530 return jnp .stack ([r_qp , r_at , r_ke , r_cz ])
516531
@@ -629,12 +644,6 @@ def _solve_for_c_z_prefactor(
629644 PhysicsOutcome .SUCCESS ,
630645 )
631646
632- # c_z is related to impurity density which physically cannot be negative.
633- # The natural floor of c_z_prefactor is zero.
634- c_z_prefactor = jnp .where (
635- status == PhysicsOutcome .SUCCESS , c_z_prefactor , 0.0
636- )
637-
638647 return c_z_prefactor , status
639648
640649
0 commit comments