Skip to content

Commit 35eb3da

Browse files
jcitrinTorax team
authored andcommitted
Bugfix: use inverse softplus when converting alpha_t in extended lengyel
Previously the conversion was unidirectional. The solver was still able to eventually find the solutions, but now it should be more robust. PiperOrigin-RevId: 860647236
1 parent 3e9f3c6 commit 35eb3da

File tree

7 files changed

+144
-30
lines changed

7 files changed

+144
-30
lines changed

torax/_src/edge/divertor_sol_1d.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -421,9 +421,10 @@ def calc_alpha_t(
421421
nu_ee = jnp.exp(log_nu_ee)
422422

423423
# Z_eff correction to transform electron-electron collisions to ion-electron
424-
# collisions. Equation B2 in Eich 2020
424+
# collisions. Equation B2 in Eich 2020. Adding a small addition to Z_eff to
425+
# avoid numerical issues with the gradient at Z_eff=1 (no impurities).
425426
Z_eff_correction = (1.0 - 0.569) * jnp.exp(
426-
-(((Z_eff_separatrix - 1.0) / 3.25) ** 0.85)
427+
-(((Z_eff_separatrix - 1.0 + constants.CONSTANTS.eps) / 3.25) ** 0.85)
427428
) + 0.569
428429

429430
nu_ei = nu_ee * Z_eff_correction * Z_eff_separatrix

torax/_src/edge/extended_lengyel_solvers.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import jax
2020
from jax import numpy as jnp
2121
from torax._src import constants
22+
from torax._src import math_utils
2223
from torax._src.edge import collisional_radiative_models
2324
from torax._src.edge import divertor_sol_1d as divertor_sol_1d_lib
2425
from torax._src.edge import extended_lengyel_defaults
@@ -101,9 +102,11 @@ def body_fun(_, carry):
101102
# Solve for the impurity concentration required to achieve the target
102103
# temperature for a given q_parallel. This also updates the divertor and
103104
# separatrix Z_eff values in sol_model, used downstream.
104-
current_sol_model.state.c_z_prefactor, physics_outcome = (
105-
_solve_for_c_z_prefactor(sol_model=current_sol_model)
105+
c_z_prefactor, physics_outcome = _solve_for_c_z_prefactor(
106+
sol_model=current_sol_model
106107
)
108+
# Clip to physical values (non-negative impurity concentration).
109+
current_sol_model.state.c_z_prefactor = jnp.maximum(c_z_prefactor, 0.0)
107110

108111
# Update alpha_t for the next loop iteration.
109112
current_sol_model.state.alpha_t = divertor_sol_1d_lib.calc_alpha_t(
@@ -246,12 +249,12 @@ def forward_mode_newton_solver(
246249
"""
247250
# 1. Create initial guess state vector.
248251
# Uses log space for strictly positive variables and to improve conditioning.
249-
# alpha_t is left linear since should always remain O(1) and log steps
250-
# can lead to numerical issues due to exponential amplification. Positivity is
251-
# enforced via softplus when unpacking.
252+
# alpha_t is strictly positive, but is enforced via softplus in the residual.
253+
# Therefore we must inverse softplus the initial guess to maintain
254+
# consistency.
252255
x0 = jnp.stack([
253256
jnp.log(initial_sol_model.state.q_parallel),
254-
initial_sol_model.state.alpha_t,
257+
math_utils.inverse_softplus(initial_sol_model.state.alpha_t),
255258
jnp.log(initial_sol_model.state.kappa_e),
256259
jnp.log(initial_sol_model.state.T_e_target),
257260
])
@@ -317,13 +320,13 @@ def inverse_mode_newton_solver(
317320
# 1. Create initial guess state vector.
318321

319322
# Uses log space for strictly positive variables and to improve conditioning.
320-
# alpha_t is left linear since should always remain O(1) and log steps
321-
# can lead to numerical issues due to exponential amplification. Positivity is
322-
# enforced via softplus when unpacking.
323+
# alpha_t is strictly positive, but is enforced via softplus in the residual.
324+
# Therefore we must inverse softplus the initial guess to maintain
325+
# consistency.
323326

324327
x0 = jnp.stack([
325328
jnp.log(initial_sol_model.state.q_parallel),
326-
initial_sol_model.state.alpha_t,
329+
math_utils.inverse_softplus(initial_sol_model.state.alpha_t),
327330
jnp.log(initial_sol_model.state.kappa_e),
328331
initial_sol_model.state.c_z_prefactor,
329332
])
@@ -342,11 +345,15 @@ def inverse_mode_newton_solver(
342345
)
343346

344347
# 4. Construct final model.
348+
# Clip c_z_prefactor to 0.0 if it is negative (unphysical solution).
349+
# Negative values are allowed during the solve process (to ensure smooth
350+
# gradients for the solver), but the final physical state must have
351+
# non-negative concentrations.
345352
final_state = divertor_sol_1d_lib.ExtendedLengyelState(
346353
q_parallel=jnp.exp(x_root[0]),
347354
alpha_t=jax.nn.softplus(x_root[1]),
348355
kappa_e=jnp.exp(x_root[2]),
349-
c_z_prefactor=x_root[3],
356+
c_z_prefactor=jnp.maximum(x_root[3], 0.0),
350357
T_e_target=fixed_Tt,
351358
)
352359

@@ -355,6 +362,7 @@ def inverse_mode_newton_solver(
355362
)
356363

357364
# 5. Re-calculate physics outcome at final state to return the physics_outcome
365+
# This uses the clipped (physical) c_z_prefactor.
358366
_, physics_outcome = _solve_for_c_z_prefactor(sol_model=final_sol_model)
359367

360368
solver_status = ExtendedLengyelSolverStatus(
@@ -454,7 +462,7 @@ def _forward_residual(
454462
at_calc_safe = jnp.maximum(at_calc, constants.CONSTANTS.eps)
455463

456464
r_qp = jnp.log(qp_calc_safe) - x_vec[0]
457-
r_at = at_calc_safe - current_state.alpha_t
465+
r_at = math_utils.inverse_softplus(at_calc_safe) - x_vec[1]
458466
r_ke = jnp.log(ke_calc_safe) - x_vec[2]
459467
r_Tt = jnp.log(Tt_calc_safe) - x_vec[3]
460468

@@ -468,11 +476,15 @@ def _inverse_residual(
468476
) -> jax.Array:
469477
"""Calculates the residual vector for Inverse Mode F(x) = 0."""
470478
# 1. Construct physical state from vector guess.
479+
# Note: c_z_prefactor is clipped to be non-negative for physics calculations
480+
# to avoid NaNs (e.g. in Z_eff -> alpha_t). The solver is allowed to explore
481+
# negative values in x_vec[3] to properly find the root (even if unphysical),
482+
# but the state used for consistent physics checks must be valid.
471483
current_state = divertor_sol_1d_lib.ExtendedLengyelState(
472484
q_parallel=jnp.exp(x_vec[0]),
473485
alpha_t=jax.nn.softplus(x_vec[1]),
474486
kappa_e=jnp.exp(x_vec[2]),
475-
c_z_prefactor=x_vec[3],
487+
c_z_prefactor=jnp.maximum(x_vec[3], 0.0),
476488
T_e_target=fixed_Tt,
477489
)
478490

@@ -508,9 +520,12 @@ def _inverse_residual(
508520
at_calc_safe = jnp.maximum(at_calc, constants.CONSTANTS.eps)
509521

510522
r_qp = jnp.log(qp_calc_safe) - x_vec[0]
511-
r_at = at_calc_safe - current_state.alpha_t
523+
r_at = math_utils.inverse_softplus(at_calc_safe) - x_vec[1]
512524
r_ke = jnp.log(ke_calc_safe) - x_vec[2]
513-
r_cz = cz_calc - current_state.c_z_prefactor
525+
# Residual for c_z compares the calculated required c_z against the
526+
# *raw* solver guess x_vec[3], not the clipped state value.
527+
# This provides a gradient signal even when x_vec[3] is negative.
528+
r_cz = cz_calc - x_vec[3]
514529

515530
return jnp.stack([r_qp, r_at, r_ke, r_cz])
516531

@@ -629,12 +644,6 @@ def _solve_for_c_z_prefactor(
629644
PhysicsOutcome.SUCCESS,
630645
)
631646

632-
# c_z is related to impurity density which physically cannot be negative.
633-
# The natural floor of c_z_prefactor is zero.
634-
c_z_prefactor = jnp.where(
635-
status == PhysicsOutcome.SUCCESS, c_z_prefactor, 0.0
636-
)
637-
638647
return c_z_prefactor, status
639648

640649

torax/_src/edge/tests/extended_lengyel_solver_test.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from absl.testing import absltest
16+
from absl.testing import parameterized
1617
import numpy as np
1718
from torax._src.edge import divertor_sol_1d
1819
from torax._src.edge import extended_lengyel_defaults
@@ -23,7 +24,7 @@
2324
# pylint: disable=invalid-name
2425

2526

26-
class ExtendedLengyelSolverInverseTest(absltest.TestCase):
27+
class ExtendedLengyelSolverInverseTest(parameterized.TestCase):
2728

2829
def setUp(self):
2930
super().setUp()
@@ -153,15 +154,11 @@ def test_unsuccessful_solve_for_c_z(self):
153154
calculated_c_z, status = extended_lengyel_solvers._solve_for_c_z_prefactor(
154155
sol_model=sol_model,
155156
)
156-
expected_c_z = 0.0
157157

158158
self.assertEqual(
159159
status, extended_lengyel_solvers.PhysicsOutcome.C_Z_PREFACTOR_NEGATIVE
160160
)
161-
np.testing.assert_allclose(
162-
calculated_c_z,
163-
expected_c_z,
164-
)
161+
self.assertLess(calculated_c_z, 0.0)
165162

166163
def test_inverse_unsuccessful_newton_solve_but_successful_hybrid_solve(self):
167164
# The initial guess state is deliberately set far from the solution, by

torax/_src/edge/tests/extended_lengyel_standalone_test.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from unittest import mock
1616
from absl.testing import absltest
17+
from absl.testing import parameterized
1718
import numpy as np
1819
from torax._src.edge import extended_lengyel_defaults
1920
from torax._src.edge import extended_lengyel_enums
@@ -23,7 +24,7 @@
2324
# pylint: disable=invalid-name
2425

2526

26-
class ExtendedLengyelTest(absltest.TestCase):
27+
class ExtendedLengyelTest(parameterized.TestCase):
2728

2829
def test_run_extended_lengyel_model_inverse_mode_fixed_point(self):
2930
"""Integration test for the full extended_lengyel model in inverse mode."""
@@ -549,6 +550,68 @@ def test_validate_inputs_for_computation_mode(self):
549550
seed_impurity_weights={},
550551
)
551552

553+
@parameterized.named_parameters(
554+
('low_ip', {'plasma_current': 2.0e6, 'power_crossing_separatrix': 10e6}),
555+
(
556+
'low_power',
557+
{'plasma_current': 15.0e6, 'power_crossing_separatrix': 1.0e6},
558+
),
559+
)
560+
def test_underpowered_scenario(self, inputs_update):
561+
"""Test scenario where input power is too low to reach target temperature.
562+
563+
This uses unrealistically low inputs for ITER-like scenarios (Ip or P_SOL)
564+
which results in required impurity concentration being negative
565+
(physically impossible).
566+
The solver should report this via physics_outcome and a non-zero residual.
567+
568+
Args:
569+
inputs_update: Dictionary of input parameters to override defaults.
570+
"""
571+
inputs = {
572+
'T_e_target': 5.0,
573+
'power_crossing_separatrix': 10e6,
574+
'separatrix_electron_density': 3e19,
575+
'main_ion_charge': 1.0,
576+
'mean_ion_charge_state': 1.0,
577+
'fixed_impurity_concentrations': {},
578+
'magnetic_field_on_axis': 5.3,
579+
'plasma_current': 15.0e6,
580+
'connection_length_target': 50.0,
581+
'connection_length_divertor': 10.0,
582+
'major_radius': 6.2,
583+
'minor_radius': 2.0,
584+
'elongation_psi95': 1.7,
585+
'triangularity_psi95': 0.33,
586+
'average_ion_mass': 2.0,
587+
'computation_mode': extended_lengyel_enums.ComputationMode.INVERSE,
588+
'solver_mode': extended_lengyel_enums.SolverMode.HYBRID,
589+
'seed_impurity_weights': {'Ne': 1.0},
590+
}
591+
inputs.update(inputs_update)
592+
593+
outputs = extended_lengyel_standalone.run_extended_lengyel_standalone(
594+
**inputs
595+
)
596+
597+
numerics = outputs.solver_status.numerics_outcome
598+
599+
# 1. Assert no NaNs in output.
600+
self.assertFalse(np.any(np.isnan(numerics.residual)))
601+
self.assertFalse(np.isnan(outputs.Z_eff_separatrix))
602+
self.assertFalse(np.isnan(outputs.alpha_t))
603+
604+
# 2. Physics outcome should flag the issue.
605+
self.assertEqual(
606+
outputs.solver_status.physics_outcome.item(),
607+
extended_lengyel_solvers.PhysicsOutcome.C_Z_PREFACTOR_NEGATIVE,
608+
)
609+
610+
# 3. Impurities should be clamped to 0.
611+
np.testing.assert_allclose(
612+
outputs.seed_impurity_concentrations['Ne'], 0.0, atol=1e-5
613+
)
614+
552615

553616
if __name__ == '__main__':
554617
absltest.main()

torax/_src/math_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,3 +323,16 @@ def cumulative_volume_integration(
323323

324324
def safe_divide(y: chex.Array, x: chex.Array) -> chex.Array:
325325
return y / (x + constants.CONSTANTS.eps)
326+
327+
328+
def inverse_softplus(x: jax.Array) -> jax.Array:
329+
"""Inverse of softplus function."""
330+
# Enforce minimum value to avoid log(0) or log(negative).
331+
# We want a function that maps x back to y such that softplus(y) = x.
332+
# y = log(exp(x) - 1).
333+
# If x -> 0, y -> -inf.
334+
# For avoiding overflow/underflow issues with float32:
335+
# exp(x) overflows if x > 88.
336+
# But for x > 30, softplus(x) ~ x.
337+
# For x < 1e-32, exp(x) = 1 and we get log(0). Avoid by clipping.
338+
return jnp.where(x > 30.0, x, jnp.log(jnp.expm1(jnp.maximum(x, 1e-20))))

torax/_src/tests/math_utils_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,35 @@ def test_cumulative_volume_integration(self, num_cell_grid_points: int):
315315
expected,
316316
)
317317

318+
@parameterized.parameters(1e-14, 1e-6, 1e-4, 0.1)
319+
def test_inverse_softplus_small_values(self, value):
320+
x_val = jnp.array(value)
321+
y_val = math_utils.inverse_softplus(x_val)
322+
x_rec = jax.nn.softplus(y_val)
323+
np.testing.assert_allclose(x_val, x_rec, rtol=1e-6)
324+
325+
@parameterized.parameters(1.0, 5.0, 10.0)
326+
def test_inverse_softplus_medium_values(self, value):
327+
x_val = jnp.array(value)
328+
y_val = math_utils.inverse_softplus(x_val)
329+
x_rec = jax.nn.softplus(y_val)
330+
np.testing.assert_allclose(x_val, x_rec, rtol=1e-6)
331+
332+
@parameterized.parameters(25.0, 50.0, 100.0)
333+
def test_inverse_softplus_large_values(self, value):
334+
x_val = jnp.array(value)
335+
y_val = math_utils.inverse_softplus(x_val)
336+
np.testing.assert_allclose(x_val, y_val, rtol=1e-6)
337+
x_rec = jax.nn.softplus(y_val)
338+
np.testing.assert_allclose(x_val, x_rec, rtol=1e-6)
339+
340+
@parameterized.parameters(-20, -10, -1, 1e-10, 1e-6, 0.1, 1.0, 10.0, 100.0)
341+
def test_softplus_round_trip(self, value):
342+
x = jnp.array(value)
343+
y = jax.nn.softplus(x)
344+
x_rec = math_utils.inverse_softplus(y)
345+
np.testing.assert_allclose(x, x_rec, rtol=1e-6)
346+
318347

319348
if __name__ == '__main__':
320349
absltest.main()

torax/tests/sim_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ class SimTest(sim_test_case.SimTestCase):
230230
(
231231
'test_iterhybrid_predictor_corrector_mavrin_n_e_ratios_lengyel',
232232
'test_iterhybrid_predictor_corrector_mavrin_n_e_ratios_lengyel.py',
233+
_ALL_PROFILES,
234+
1e-8,
233235
),
234236
# Predictor-corrector with Mavrin and n_e_ratios_Z_eff impurity mode.
235237
(

0 commit comments

Comments
 (0)