Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions torax/_src/edge/divertor_sol_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,9 +421,10 @@ def calc_alpha_t(
nu_ee = jnp.exp(log_nu_ee)

# Z_eff correction to transform electron-electron collisions to ion-electron
# collisions. Equation B2 in Eich 2020
# collisions. Equation B2 in Eich 2020. Adding a small addition to Z_eff to
# avoid numerical issues with the gradient at Z_eff=1 (no impurities).
Z_eff_correction = (1.0 - 0.569) * jnp.exp(
-(((Z_eff_separatrix - 1.0) / 3.25) ** 0.85)
-(((Z_eff_separatrix - 1.0 + constants.CONSTANTS.eps) / 3.25) ** 0.85)
) + 0.569

nu_ei = nu_ee * Z_eff_correction * Z_eff_separatrix
Expand Down
30 changes: 19 additions & 11 deletions torax/_src/edge/extended_lengyel_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,11 @@ def body_fun(_, carry):
# Solve for the impurity concentration required to achieve the target
# temperature for a given q_parallel. This also updates the divertor and
# separatrix Z_eff values in sol_model, used downstream.
current_sol_model.state.c_z_prefactor, physics_outcome = (
_solve_for_c_z_prefactor(sol_model=current_sol_model)
c_z_prefactor, physics_outcome = _solve_for_c_z_prefactor(
sol_model=current_sol_model
)
# Clip to physical values (non-negative impurity concentration).
current_sol_model.state.c_z_prefactor = jnp.maximum(c_z_prefactor, 0.0)

# Update alpha_t for the next loop iteration.
current_sol_model.state.alpha_t = divertor_sol_1d_lib.calc_alpha_t(
Expand Down Expand Up @@ -342,11 +344,15 @@ def inverse_mode_newton_solver(
)

# 4. Construct final model.
# Clip c_z_prefactor to 0.0 if it is negative (unphysical solution).
# Negative values are allowed during the solve process (to ensure smooth
# gradients for the solver), but the final physical state must have
# non-negative concentrations.
final_state = divertor_sol_1d_lib.ExtendedLengyelState(
q_parallel=jnp.exp(x_root[0]),
alpha_t=jax.nn.softplus(x_root[1]),
kappa_e=jnp.exp(x_root[2]),
c_z_prefactor=x_root[3],
c_z_prefactor=jnp.maximum(x_root[3], 0.0),
T_e_target=fixed_Tt,
)

Expand All @@ -355,6 +361,7 @@ def inverse_mode_newton_solver(
)

# 5. Re-calculate physics outcome at final state to return the physics_outcome
# This uses the clipped (physical) c_z_prefactor.
_, physics_outcome = _solve_for_c_z_prefactor(sol_model=final_sol_model)

solver_status = ExtendedLengyelSolverStatus(
Expand Down Expand Up @@ -468,11 +475,15 @@ def _inverse_residual(
) -> jax.Array:
"""Calculates the residual vector for Inverse Mode F(x) = 0."""
# 1. Construct physical state from vector guess.
# Note: c_z_prefactor is clipped to be non-negative for physics calculations
# to avoid NaNs (e.g. in Z_eff -> alpha_t). The solver is allowed to explore
# negative values in x_vec[3] to properly find the root (even if unphysical),
# but the state used for consistent physics checks must be valid.
current_state = divertor_sol_1d_lib.ExtendedLengyelState(
q_parallel=jnp.exp(x_vec[0]),
alpha_t=jax.nn.softplus(x_vec[1]),
kappa_e=jnp.exp(x_vec[2]),
c_z_prefactor=x_vec[3],
c_z_prefactor=jnp.maximum(x_vec[3], 0.0),
T_e_target=fixed_Tt,
)

Expand Down Expand Up @@ -510,7 +521,10 @@ def _inverse_residual(
r_qp = jnp.log(qp_calc_safe) - x_vec[0]
r_at = at_calc_safe - current_state.alpha_t
r_ke = jnp.log(ke_calc_safe) - x_vec[2]
r_cz = cz_calc - current_state.c_z_prefactor
# Residual for c_z compares the calculated required c_z against the
# *raw* solver guess x_vec[3], not the clipped state value.
# This provides a gradient signal even when x_vec[3] is negative.
r_cz = cz_calc - x_vec[3]

return jnp.stack([r_qp, r_at, r_ke, r_cz])

Expand Down Expand Up @@ -629,12 +643,6 @@ def _solve_for_c_z_prefactor(
PhysicsOutcome.SUCCESS,
)

# c_z is related to impurity density which physically cannot be negative.
# The natural floor of c_z_prefactor is zero.
c_z_prefactor = jnp.where(
status == PhysicsOutcome.SUCCESS, c_z_prefactor, 0.0
)

return c_z_prefactor, status


Expand Down
6 changes: 1 addition & 5 deletions torax/_src/edge/tests/extended_lengyel_solver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,11 @@ def test_unsuccessful_solve_for_c_z(self):
calculated_c_z, status = extended_lengyel_solvers._solve_for_c_z_prefactor(
sol_model=sol_model,
)
expected_c_z = 0.0

self.assertEqual(
status, extended_lengyel_solvers.PhysicsOutcome.C_Z_PREFACTOR_NEGATIVE
)
np.testing.assert_allclose(
calculated_c_z,
expected_c_z,
)
self.assertLess(calculated_c_z, 0.0)

def test_inverse_unsuccessful_newton_solve_but_successful_hybrid_solve(self):
# The initial guess state is deliberately set far from the solution, by
Expand Down
65 changes: 64 additions & 1 deletion torax/_src/edge/tests/extended_lengyel_standalone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from torax._src.edge import extended_lengyel_defaults
from torax._src.edge import extended_lengyel_enums
Expand All @@ -23,7 +24,7 @@
# pylint: disable=invalid-name


class ExtendedLengyelTest(absltest.TestCase):
class ExtendedLengyelTest(parameterized.TestCase):

def test_run_extended_lengyel_model_inverse_mode_fixed_point(self):
"""Integration test for the full extended_lengyel model in inverse mode."""
Expand Down Expand Up @@ -549,6 +550,68 @@ def test_validate_inputs_for_computation_mode(self):
seed_impurity_weights={},
)

@parameterized.named_parameters(
('low_ip', {'plasma_current': 2.0e6, 'power_crossing_separatrix': 10e6}),
(
'low_power',
{'plasma_current': 15.0e6, 'power_crossing_separatrix': 1.0e6},
),
)
def test_underpowered_scenario(self, inputs_update):
"""Test scenario where input power is too low to reach target temperature.

This uses unrealistically low inputs for ITER-like scenarios (Ip or P_SOL)
which results in required impurity concentration being negative
(physically impossible).
The solver should report this via physics_outcome and a non-zero residual.

Args:
inputs_update: Dictionary of input parameters to override defaults.
"""
inputs = {
'T_e_target': 5.0,
'power_crossing_separatrix': 10e6,
'separatrix_electron_density': 3e19,
'main_ion_charge': 1.0,
'mean_ion_charge_state': 1.0,
'fixed_impurity_concentrations': {},
'magnetic_field_on_axis': 5.3,
'plasma_current': 15.0e6,
'connection_length_target': 50.0,
'connection_length_divertor': 10.0,
'major_radius': 6.2,
'minor_radius': 2.0,
'elongation_psi95': 1.7,
'triangularity_psi95': 0.33,
'average_ion_mass': 2.0,
'computation_mode': extended_lengyel_enums.ComputationMode.INVERSE,
'solver_mode': extended_lengyel_enums.SolverMode.HYBRID,
'seed_impurity_weights': {'Ne': 1.0},
}
inputs.update(inputs_update)

outputs = extended_lengyel_standalone.run_extended_lengyel_standalone(
**inputs
)

numerics = outputs.solver_status.numerics_outcome

# 1. Assert no NaNs in output.
self.assertFalse(np.any(np.isnan(numerics.residual)))
self.assertFalse(np.isnan(outputs.Z_eff_separatrix))
self.assertFalse(np.isnan(outputs.alpha_t))

# 2. Physics outcome should flag the issue.
self.assertEqual(
outputs.solver_status.physics_outcome.item(),
extended_lengyel_solvers.PhysicsOutcome.C_Z_PREFACTOR_NEGATIVE,
)

# 3. Impurities should be clamped to 0.
np.testing.assert_allclose(
outputs.seed_impurity_concentrations['Ne'], 0.0, atol=1e-5
)


if __name__ == '__main__':
absltest.main()
2 changes: 2 additions & 0 deletions torax/tests/sim_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ class SimTest(sim_test_case.SimTestCase):
(
'test_iterhybrid_predictor_corrector_mavrin_n_e_ratios_lengyel',
'test_iterhybrid_predictor_corrector_mavrin_n_e_ratios_lengyel.py',
_ALL_PROFILES,
1e-8,
),
# Predictor-corrector with Mavrin and n_e_ratios_Z_eff impurity mode.
(
Expand Down
Loading