Skip to content

Commit 5f6ac11

Browse files
jcitrinTorax team
authored andcommitted
x-lengyel: Avoid unphysical negative c_z in inverse solver.
In underpowered scenarios, the "solution" for inverse mode could be negative seeded impurities, if even with zero impurities the target temperature is below the requested target temperature. There was a bug where the solver was allowed to explore these negative values and then propagate them into the state, leading to NaNs. Two separate fixes where done here: 1. Clip the seeded impurity ratios to zero in all places where it can be defined in Divertor1D states 2. Change the residual calculation to compare the calculated c_z to the state vector c_z (which is allowed to explore to negative values). This maintains non-zero residuals and a gradient signal. The solver converges to "negative" impurities which are clipped to zero 3. Fix associated bug where if Z_eff=1 (no impurities) then we got a division by zero in the Jacobian. Just needed to add small value to Zeff within a formula. PiperOrigin-RevId: 860647237
1 parent 3e9f3c6 commit 5f6ac11

File tree

5 files changed

+89
-19
lines changed

5 files changed

+89
-19
lines changed

torax/_src/edge/divertor_sol_1d.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -421,9 +421,10 @@ def calc_alpha_t(
421421
nu_ee = jnp.exp(log_nu_ee)
422422

423423
# Z_eff correction to transform electron-electron collisions to ion-electron
424-
# collisions. Equation B2 in Eich 2020
424+
# collisions. Equation B2 in Eich 2020. Adding a small addition to Z_eff to
425+
# avoid numerical issues with the gradient at Z_eff=1 (no impurities).
425426
Z_eff_correction = (1.0 - 0.569) * jnp.exp(
426-
-(((Z_eff_separatrix - 1.0) / 3.25) ** 0.85)
427+
-(((Z_eff_separatrix - 1.0 + constants.CONSTANTS.eps) / 3.25) ** 0.85)
427428
) + 0.569
428429

429430
nu_ei = nu_ee * Z_eff_correction * Z_eff_separatrix

torax/_src/edge/extended_lengyel_solvers.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,11 @@ def body_fun(_, carry):
101101
# Solve for the impurity concentration required to achieve the target
102102
# temperature for a given q_parallel. This also updates the divertor and
103103
# separatrix Z_eff values in sol_model, used downstream.
104-
current_sol_model.state.c_z_prefactor, physics_outcome = (
105-
_solve_for_c_z_prefactor(sol_model=current_sol_model)
104+
c_z_prefactor, physics_outcome = _solve_for_c_z_prefactor(
105+
sol_model=current_sol_model
106106
)
107+
# Clip to physical values (non-negative impurity concentration).
108+
current_sol_model.state.c_z_prefactor = jnp.maximum(c_z_prefactor, 0.0)
107109

108110
# Update alpha_t for the next loop iteration.
109111
current_sol_model.state.alpha_t = divertor_sol_1d_lib.calc_alpha_t(
@@ -342,11 +344,15 @@ def inverse_mode_newton_solver(
342344
)
343345

344346
# 4. Construct final model.
347+
# Clip c_z_prefactor to 0.0 if it is negative (unphysical solution).
348+
# Negative values are allowed during the solve process (to ensure smooth
349+
# gradients for the solver), but the final physical state must have
350+
# non-negative concentrations.
345351
final_state = divertor_sol_1d_lib.ExtendedLengyelState(
346352
q_parallel=jnp.exp(x_root[0]),
347353
alpha_t=jax.nn.softplus(x_root[1]),
348354
kappa_e=jnp.exp(x_root[2]),
349-
c_z_prefactor=x_root[3],
355+
c_z_prefactor=jnp.maximum(x_root[3], 0.0),
350356
T_e_target=fixed_Tt,
351357
)
352358

@@ -355,6 +361,7 @@ def inverse_mode_newton_solver(
355361
)
356362

357363
# 5. Re-calculate physics outcome at final state to return the physics_outcome
364+
# This uses the clipped (physical) c_z_prefactor.
358365
_, physics_outcome = _solve_for_c_z_prefactor(sol_model=final_sol_model)
359366

360367
solver_status = ExtendedLengyelSolverStatus(
@@ -468,11 +475,15 @@ def _inverse_residual(
468475
) -> jax.Array:
469476
"""Calculates the residual vector for Inverse Mode F(x) = 0."""
470477
# 1. Construct physical state from vector guess.
478+
# Note: c_z_prefactor is clipped to be non-negative for physics calculations
479+
# to avoid NaNs (e.g. in Z_eff -> alpha_t). The solver is allowed to explore
480+
# negative values in x_vec[3] to properly find the root (even if unphysical),
481+
# but the state used for consistent physics checks must be valid.
471482
current_state = divertor_sol_1d_lib.ExtendedLengyelState(
472483
q_parallel=jnp.exp(x_vec[0]),
473484
alpha_t=jax.nn.softplus(x_vec[1]),
474485
kappa_e=jnp.exp(x_vec[2]),
475-
c_z_prefactor=x_vec[3],
486+
c_z_prefactor=jnp.maximum(x_vec[3], 0.0),
476487
T_e_target=fixed_Tt,
477488
)
478489

@@ -510,7 +521,10 @@ def _inverse_residual(
510521
r_qp = jnp.log(qp_calc_safe) - x_vec[0]
511522
r_at = at_calc_safe - current_state.alpha_t
512523
r_ke = jnp.log(ke_calc_safe) - x_vec[2]
513-
r_cz = cz_calc - current_state.c_z_prefactor
524+
# Residual for c_z compares the calculated required c_z against the
525+
# *raw* solver guess x_vec[3], not the clipped state value.
526+
# This provides a gradient signal even when x_vec[3] is negative.
527+
r_cz = cz_calc - x_vec[3]
514528

515529
return jnp.stack([r_qp, r_at, r_ke, r_cz])
516530

@@ -629,12 +643,6 @@ def _solve_for_c_z_prefactor(
629643
PhysicsOutcome.SUCCESS,
630644
)
631645

632-
# c_z is related to impurity density which physically cannot be negative.
633-
# The natural floor of c_z_prefactor is zero.
634-
c_z_prefactor = jnp.where(
635-
status == PhysicsOutcome.SUCCESS, c_z_prefactor, 0.0
636-
)
637-
638646
return c_z_prefactor, status
639647

640648

torax/_src/edge/tests/extended_lengyel_solver_test.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,11 @@ def test_unsuccessful_solve_for_c_z(self):
153153
calculated_c_z, status = extended_lengyel_solvers._solve_for_c_z_prefactor(
154154
sol_model=sol_model,
155155
)
156-
expected_c_z = 0.0
157156

158157
self.assertEqual(
159158
status, extended_lengyel_solvers.PhysicsOutcome.C_Z_PREFACTOR_NEGATIVE
160159
)
161-
np.testing.assert_allclose(
162-
calculated_c_z,
163-
expected_c_z,
164-
)
160+
self.assertLess(calculated_c_z, 0.0)
165161

166162
def test_inverse_unsuccessful_newton_solve_but_successful_hybrid_solve(self):
167163
# The initial guess state is deliberately set far from the solution, by

torax/_src/edge/tests/extended_lengyel_standalone_test.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from unittest import mock
1616
from absl.testing import absltest
17+
from absl.testing import parameterized
1718
import numpy as np
1819
from torax._src.edge import extended_lengyel_defaults
1920
from torax._src.edge import extended_lengyel_enums
@@ -23,7 +24,7 @@
2324
# pylint: disable=invalid-name
2425

2526

26-
class ExtendedLengyelTest(absltest.TestCase):
27+
class ExtendedLengyelTest(parameterized.TestCase):
2728

2829
def test_run_extended_lengyel_model_inverse_mode_fixed_point(self):
2930
"""Integration test for the full extended_lengyel model in inverse mode."""
@@ -549,6 +550,68 @@ def test_validate_inputs_for_computation_mode(self):
549550
seed_impurity_weights={},
550551
)
551552

553+
@parameterized.named_parameters(
554+
('low_ip', {'plasma_current': 2.0e6, 'power_crossing_separatrix': 10e6}),
555+
(
556+
'low_power',
557+
{'plasma_current': 15.0e6, 'power_crossing_separatrix': 1.0e6},
558+
),
559+
)
560+
def test_underpowered_scenario(self, inputs_update):
561+
"""Test scenario where input power is too low to reach target temperature.
562+
563+
This uses unrealistically low inputs for ITER-like scenarios (Ip or P_SOL)
564+
which results in required impurity concentration being negative
565+
(physically impossible).
566+
The solver should report this via physics_outcome and a non-zero residual.
567+
568+
Args:
569+
inputs_update: Dictionary of input parameters to override defaults.
570+
"""
571+
inputs = {
572+
'T_e_target': 5.0,
573+
'power_crossing_separatrix': 10e6,
574+
'separatrix_electron_density': 3e19,
575+
'main_ion_charge': 1.0,
576+
'mean_ion_charge_state': 1.0,
577+
'fixed_impurity_concentrations': {},
578+
'magnetic_field_on_axis': 5.3,
579+
'plasma_current': 15.0e6,
580+
'connection_length_target': 50.0,
581+
'connection_length_divertor': 10.0,
582+
'major_radius': 6.2,
583+
'minor_radius': 2.0,
584+
'elongation_psi95': 1.7,
585+
'triangularity_psi95': 0.33,
586+
'average_ion_mass': 2.0,
587+
'computation_mode': extended_lengyel_enums.ComputationMode.INVERSE,
588+
'solver_mode': extended_lengyel_enums.SolverMode.HYBRID,
589+
'seed_impurity_weights': {'Ne': 1.0},
590+
}
591+
inputs.update(inputs_update)
592+
593+
outputs = extended_lengyel_standalone.run_extended_lengyel_standalone(
594+
**inputs
595+
)
596+
597+
numerics = outputs.solver_status.numerics_outcome
598+
599+
# 1. Assert no NaNs in output.
600+
self.assertFalse(np.any(np.isnan(numerics.residual)))
601+
self.assertFalse(np.isnan(outputs.Z_eff_separatrix))
602+
self.assertFalse(np.isnan(outputs.alpha_t))
603+
604+
# 2. Physics outcome should flag the issue.
605+
self.assertEqual(
606+
outputs.solver_status.physics_outcome.item(),
607+
extended_lengyel_solvers.PhysicsOutcome.C_Z_PREFACTOR_NEGATIVE,
608+
)
609+
610+
# 3. Impurities should be clamped to 0.
611+
np.testing.assert_allclose(
612+
outputs.seed_impurity_concentrations['Ne'], 0.0, atol=1e-5
613+
)
614+
552615

553616
if __name__ == '__main__':
554617
absltest.main()

torax/tests/sim_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ class SimTest(sim_test_case.SimTestCase):
230230
(
231231
'test_iterhybrid_predictor_corrector_mavrin_n_e_ratios_lengyel',
232232
'test_iterhybrid_predictor_corrector_mavrin_n_e_ratios_lengyel.py',
233+
_ALL_PROFILES,
234+
1e-8,
233235
),
234236
# Predictor-corrector with Mavrin and n_e_ratios_Z_eff impurity mode.
235237
(

0 commit comments

Comments
 (0)