Skip to content

Commit b2337a7

Browse files
Nush395Torax team
authored and
Torax team
committed
Make all sources discriminated on model function name.
PiperOrigin-RevId: 742220027
1 parent 721e184 commit b2337a7

21 files changed

+232
-148
lines changed

torax/sources/base.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@
2323
class SourceModelBase(torax_pydantic.BaseModelFrozen, abc.ABC):
2424
"""Base model holding parameters common to all source models.
2525
26+
Subclasses should define the `model_function_name` attribute as a `Literal`
27+
string. This string should match the name of the function that calculates the
28+
source profile. This is used as an identifier for the model function in the
29+
source config for Pydantic to "discriminate" against. This should be given a
30+
unique value for each source model function implementation.
31+
2632
Attributes:
2733
mode: Defines how the source values are computed (from a model, from a file,
2834
etc.)
@@ -39,7 +45,6 @@ class SourceModelBase(torax_pydantic.BaseModelFrozen, abc.ABC):
3945
default here is a vector of all zeros along for all rho and time, and the
4046
output vector is along the cell grid.
4147
"""
42-
4348
mode: runtime_params.Mode = runtime_params.Mode.ZERO
4449
is_explicit: bool = False
4550
prescribed_values: torax_pydantic.TimeVaryingArray = (

torax/sources/bootstrap_current_source.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""bootstrap current source profile."""
1616
import dataclasses
17-
from typing import ClassVar, Literal
17+
from typing import ClassVar
1818

1919
import chex
2020
import jax
@@ -51,7 +51,6 @@ class BootstrapCurrentSource(source.Source):
5151
"""
5252

5353
SOURCE_NAME: ClassVar[str] = 'j_bootstrap'
54-
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_neoclassical'
5554

5655
@property
5756
def source_name(self) -> str:
@@ -136,8 +135,6 @@ class BootstrapCurrentSourceConfig(base.SourceModelBase):
136135
Attributes:
137136
bootstrap_mult: Multiplication factor for bootstrap current.
138137
"""
139-
140-
source_name: Literal['j_bootstrap'] = 'j_bootstrap'
141138
bootstrap_mult: float = 1.0
142139
mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED
143140

torax/sources/bremsstrahlung_heat_sink.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
"""Bremsstrahlung heat sink for electron heat equation.."""
1818
import dataclasses
19-
from typing import ClassVar, Literal
19+
from typing import ClassVar, Final, Literal
2020

2121
import chex
2222
import jax
@@ -31,6 +31,12 @@
3131
from torax.sources import source_profiles
3232

3333

34+
# Default value for the model function to be used for the Bremsstrahlung heat
35+
# sink. This is also used as an identifier for the model function in the default
36+
# source config for Pydantic to "discriminate" against.
37+
DEFAULT_MODEL_FUNCTION_NAME: Final[str] = 'bremsstrahlung_model_func'
38+
39+
3440
@chex.dataclass(frozen=True)
3541
class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams):
3642
use_relativistic_correction: bool
@@ -121,7 +127,6 @@ class BremsstrahlungHeatSink(source.Source):
121127
"""Brehmsstrahlung heat sink for electron heat equation."""
122128

123129
SOURCE_NAME: ClassVar[str] = 'bremsstrahlung_heat_sink'
124-
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'bremsstrahlung_model_func'
125130
model_func: source.SourceProfileFunction = bremsstrahlung_model_func
126131

127132
@property
@@ -139,8 +144,9 @@ class BremsstrahlungHeatSinkConfig(base.SourceModelBase):
139144
Attributes:
140145
use_relativistic_correction: Whether to use relativistic correction.
141146
"""
142-
143-
source_name: Literal['bremsstrahlung_heat_sink'] = 'bremsstrahlung_heat_sink'
147+
model_function_name: Literal['bremsstrahlung_model_func'] = (
148+
'bremsstrahlung_model_func'
149+
)
144150
use_relativistic_correction: bool = False
145151
mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED
146152

torax/sources/cyclotron_radiation_heat_sink.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@
3535
import typing_extensions
3636

3737

38+
# Default value for the model function to be used for the Cyclotron radiation
39+
# heat sink source. This is also used as an identifier for the model function in
40+
# the default source config for Pydantic to "discriminate" against.
41+
DEFAULT_MODEL_FUNCTION_NAME: str = 'cyclotron_radiation_albajar'
42+
43+
3844
@chex.dataclass(frozen=True)
3945
class StaticRuntimeParams(runtime_params_lib.StaticRuntimeParams):
4046
beta_min: float
@@ -359,7 +365,6 @@ class CyclotronRadiationHeatSink(source.Source):
359365
"""Cyclotron radiation heat sink for electron heat equation."""
360366

361367
SOURCE_NAME: ClassVar[str] = 'cyclotron_radiation_heat_sink'
362-
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'cyclotron_radiation_albajar'
363368
model_func: source.SourceProfileFunction = cyclotron_radiation_albajar
364369

365370
@property
@@ -386,9 +391,8 @@ class CyclotronRadiationHeatSinkConfig(base.SourceModelBase):
386391
beta_grid_size: The number of points to use in the grid search for the best
387392
fit of the temperature function.
388393
"""
389-
390-
source_name: Literal['cyclotron_radiation_heat_sink'] = (
391-
'cyclotron_radiation_heat_sink'
394+
model_function_name: Literal['cyclotron_radiation_albajar'] = (
395+
'cyclotron_radiation_albajar'
392396
)
393397
mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED
394398
wall_reflection_coeff: float = 0.9

torax/sources/electron_cyclotron_source.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@
3030
from torax.sources import source_profiles
3131
from torax.torax_pydantic import torax_pydantic
3232

33-
InterpolatedVarTimeRhoInput = (
34-
runtime_params_lib.interpolated_param.InterpolatedVarTimeRhoInput
35-
)
33+
34+
# Default value for the model function to be used for the electron cyclotron
35+
# source. This is also used as an identifier for the model function in
36+
# the default source config for Pydantic to "discriminate" against.
37+
DEFAULT_MODEL_FUNCTION_NAME: str = "calc_heating_and_current"
3638

3739

3840
@chex.dataclass(frozen=True)
@@ -122,7 +124,6 @@ class ElectronCyclotronSource(source.Source):
122124
"""Electron cyclotron source for the Te and Psi equations."""
123125

124126
SOURCE_NAME: ClassVar[str] = "electron_cyclotron_source"
125-
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = "calc_heating_and_current"
126127
model_func: source.SourceProfileFunction = calc_heating_and_current
127128

128129
@property
@@ -146,9 +147,8 @@ class ElectronCyclotronSourceConfig(base.SourceModelBase):
146147
location
147148
gaussian_ec_total_power: Gaussian EC total power
148149
"""
149-
150-
source_name: Literal["electron_cyclotron_source"] = (
151-
"electron_cyclotron_source"
150+
model_function_name: Literal["calc_heating_and_current"] = (
151+
"calc_heating_and_current"
152152
)
153153
cd_efficiency: torax_pydantic.TimeVaryingArray = (
154154
torax_pydantic.ValidatedDefault({0.0: {0.0: 0.2, 1.0: 0.2}})

torax/sources/fusion_heat_source.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@
3131
from torax.sources import source_profiles
3232

3333

34+
# Default value for the model function to be used for the fusion heat
35+
# source. This is also used as an identifier for the model function in
36+
# the default source config for Pydantic to "discriminate" against.
37+
DEFAULT_MODEL_FUNCTION_NAME: str = 'fusion_heat_model_func'
38+
39+
3440
def calc_fusion(
3541
geo: geometry.Geometry,
3642
core_profiles: state.CoreProfiles,
@@ -162,7 +168,6 @@ class FusionHeatSource(source.Source):
162168
"""Fusion heat source for both ion and electron heat."""
163169

164170
SOURCE_NAME: ClassVar[str] = 'fusion_heat_source'
165-
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'fusion_heat_model_func'
166171
model_func: source.SourceProfileFunction = fusion_heat_model_func
167172

168173
@property
@@ -179,7 +184,9 @@ def affected_core_profiles(self) -> tuple[source.AffectedCoreProfile, ...]:
179184

180185
class FusionHeatSourceConfig(base.SourceModelBase):
181186
"""Configuration for the FusionHeatSource."""
182-
source_name: Literal['fusion_heat_source'] = 'fusion_heat_source'
187+
model_function_name: Literal['fusion_heat_model_func'] = (
188+
'fusion_heat_model_func'
189+
)
183190
mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED
184191

185192
@property

torax/sources/gas_puff_source.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@
2929
from torax.torax_pydantic import torax_pydantic
3030

3131

32+
# Default value for the model function to be used for the gas puff
33+
# source. This is also used as an identifier for the model function in
34+
# the default source config for Pydantic to "discriminate" against.
35+
DEFAULT_MODEL_FUNCTION_NAME: str = 'calc_puff_source'
36+
37+
3238
# pylint: disable=invalid-name
3339
@chex.dataclass(frozen=True)
3440
class DynamicGasPuffRuntimeParams(runtime_params_lib.DynamicRuntimeParams):
@@ -66,7 +72,6 @@ class GasPuffSource(source.Source):
6672
"""Gas puff source for the ne equation."""
6773

6874
SOURCE_NAME: ClassVar[str] = 'gas_puff_source'
69-
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_puff_source'
7075
model_func: source.SourceProfileFunction = calc_puff_source
7176

7277
@property
@@ -87,8 +92,7 @@ class GasPuffSourceConfig(base.SourceModelBase):
8792
[normalized radial coord]
8893
S_puff_tot: total gas puff particles/s
8994
"""
90-
91-
source_name: Literal['gas_puff_source'] = 'gas_puff_source'
95+
model_function_name: Literal['calc_puff_source'] = 'calc_puff_source'
9296
puff_decay_length: torax_pydantic.TimeVaryingScalar = (
9397
torax_pydantic.ValidatedDefault(0.05)
9498
)

torax/sources/generic_current_source.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@
3131
from torax.torax_pydantic import torax_pydantic
3232

3333

34+
# Default value for the model function to be used for the generic current
35+
# source. This is also used as an identifier for the model function in
36+
# the default source config for Pydantic to "discriminate" against.
37+
DEFAULT_MODEL_FUNCTION_NAME: str = 'calc_generic_current'
38+
39+
3440
# pylint: disable=invalid-name
3541
@chex.dataclass(frozen=True)
3642
class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams):
@@ -106,7 +112,6 @@ class GenericCurrentSource(source.Source):
106112
"""A generic current density source profile."""
107113

108114
SOURCE_NAME: ClassVar[str] = 'generic_current_source'
109-
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_generic_current'
110115
model_func: source.SourceProfileFunction = calculate_generic_current
111116

112117
@property
@@ -129,8 +134,7 @@ class GenericCurrentSourceConfig(source_base.SourceModelBase):
129134
use_absolute_current: Toggles if external current is provided absolutely or
130135
as a fraction of Ip.
131136
"""
132-
133-
source_name: Literal['generic_current_source'] = 'generic_current_source'
137+
model_function_name: Literal['calc_generic_current'] = 'calc_generic_current'
134138
Iext: torax_pydantic.TimeVaryingScalar = torax_pydantic.ValidatedDefault(3.0)
135139
fext: torax_pydantic.TimeVaryingScalar = torax_pydantic.ValidatedDefault(0.2)
136140
wext: torax_pydantic.TimeVaryingScalar = torax_pydantic.ValidatedDefault(0.05)

torax/sources/generic_ion_el_heat_source.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@
2929
from torax.torax_pydantic import torax_pydantic
3030

3131

32+
# Default value for the model function to be used for the electron cyclotron
33+
# source. This is also used as an identifier for the model function in
34+
# the default source config for Pydantic to "discriminate" against.
35+
DEFAULT_MODEL_FUNCTION_NAME: str = 'default_formula'
36+
37+
3238
# pylint: disable=invalid-name
3339
@chex.dataclass(frozen=True)
3440
class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams):
@@ -103,7 +109,6 @@ class GenericIonElectronHeatSource(source.Source):
103109
"""Generic heat source for both ion and electron heat."""
104110

105111
SOURCE_NAME: ClassVar[str] = 'generic_ion_el_heat_source'
106-
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'default_formula'
107112
model_func: source.SourceProfileFunction = default_formula
108113

109114
@property
@@ -128,10 +133,7 @@ class GenericIonElHeatSourceConfig(base.SourceModelBase):
128133
el_heat_fraction: Electron heating fraction
129134
absorption_fraction: Fraction of absorbed power
130135
"""
131-
132-
source_name: Literal['generic_ion_el_heat_source'] = (
133-
'generic_ion_el_heat_source'
134-
)
136+
model_function_name: Literal['default_formula'] = 'default_formula'
135137
w: torax_pydantic.TimeVaryingScalar = torax_pydantic.ValidatedDefault(0.25)
136138
rsource: torax_pydantic.TimeVaryingScalar = torax_pydantic.ValidatedDefault(
137139
0.0

torax/sources/generic_particle_source.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@
2828
from torax.torax_pydantic import torax_pydantic
2929

3030

31+
# Default value for the model function to be used for the generic particle
32+
# source. This is also used as an identifier for the model function in
33+
# the default source config for Pydantic to "discriminate" against.
34+
DEFAULT_MODEL_FUNCTION_NAME: str = 'calc_generic_particle_source'
35+
36+
3137
# pylint: disable=invalid-name
3238
def calc_generic_particle_source(
3339
unused_static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
@@ -63,7 +69,6 @@ class GenericParticleSource(source.Source):
6369
"""Neutral-beam injection source for the ne equation."""
6470

6571
SOURCE_NAME: ClassVar[str] = 'generic_particle_source'
66-
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_generic_particle_source'
6772
model_func: source.SourceProfileFunction = calc_generic_particle_source
6873

6974
@property
@@ -93,8 +98,9 @@ class GenericParticleSourceConfig(base.SourceModelBase):
9398
mode: Defines how the source values are computed (from a model, from a file,
9499
etc.)
95100
"""
96-
97-
source_name: Literal['generic_particle_source'] = 'generic_particle_source'
101+
model_function_name: Literal['calc_generic_particle_source'] = (
102+
'calc_generic_particle_source'
103+
)
98104
particle_width: torax_pydantic.TimeVaryingScalar = (
99105
torax_pydantic.ValidatedDefault(0.25)
100106
)

torax/sources/impurity_radiation_heat_sink/impurity_radiation_constant_fraction.py

-5
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@
2828
from torax.sources.impurity_radiation_heat_sink import impurity_radiation_heat_sink
2929
from torax.torax_pydantic import torax_pydantic
3030

31-
MODEL_FUNCTION_NAME = 'radially_constant_fraction_of_Pin'
32-
3331

3432
def radially_constant_fraction_of_Pin( # pylint: disable=invalid-name
3533
unused_static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
@@ -101,9 +99,6 @@ class ImpurityRadiationHeatSinkConstantFractionConfig(base.SourceModelBase):
10199
fraction_of_total_power_density: Fraction of total power density to be
102100
absorbed by the impurity.
103101
"""
104-
source_name: Literal['impurity_radiation_heat_sink'] = (
105-
'impurity_radiation_heat_sink'
106-
)
107102
model_function_name: Literal['radially_constant_fraction_of_Pin'] = (
108103
'radially_constant_fraction_of_Pin'
109104
)

torax/sources/impurity_radiation_heat_sink/impurity_radiation_mavrin_fit.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@
3535
from torax.sources.impurity_radiation_heat_sink import impurity_radiation_heat_sink
3636

3737

38-
MODEL_FUNCTION_NAME = 'impurity_radiation_mavrin_fit'
38+
# Default value for the model function to be used for the impurity radiation
39+
# source. This is also used as an identifier for the model function in
40+
# the source config for Pydantic to "discriminate" against.
41+
DEFAULT_MODEL_FUNCTION_NAME: str = 'impurity_radiation_mavrin_fit'
3942

4043
# Polynomial fit coefficients from A. A. Mavrin (2018):
4144
# Improved fits of coronal radiative cooling rates for high-temperature plasmas,
@@ -224,10 +227,6 @@ class ImpurityRadiationHeatSinkMavrinFitConfig(base.SourceModelBase):
224227
Attributes:
225228
radiation_multiplier: Multiplier for the impurity radiation profile.
226229
"""
227-
228-
source_name: Literal['impurity_radiation_heat_sink'] = (
229-
'impurity_radiation_heat_sink'
230-
)
231230
model_function_name: Literal['impurity_radiation_mavrin_fit'] = (
232231
'impurity_radiation_mavrin_fit'
233232
)

torax/sources/ion_cyclotron_source.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@
4141
# Internal import.
4242

4343

44+
# Default value for the model function to be used for the ion cyclotron
45+
# source. This is also used as an identifier for the model function in
46+
# the default source config for Pydantic to "discriminate" against.
47+
DEFAULT_MODEL_FUNCTION_NAME: str = 'icrh_model_func'
48+
49+
4450
# Environment variable for the TORIC NN model. Used if the model path
4551
# is not set in the config.
4652
_MODEL_PATH_ENV_VAR: Final[str] = 'TORIC_NN_MODEL_PATH'
@@ -429,7 +435,6 @@ class IonCyclotronSource(source.Source):
429435
"""Ion cyclotron source with surrogate model."""
430436

431437
SOURCE_NAME: ClassVar[str] = 'ion_cyclotron_source'
432-
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'icrh_model_func'
433438

434439
@property
435440
def source_name(self) -> str:
@@ -457,8 +462,7 @@ class IonCyclotronSourceConfig(base.SourceModelBase):
457462
Ptot: Total heating power [W].
458463
absorption_fraction: Fraction of absorbed power.
459464
"""
460-
461-
source_name: Literal['ion_cyclotron_source'] = 'ion_cyclotron_source'
465+
model_function_name: Literal['icrh_model_func'] = 'icrh_model_func'
462466
wall_inner: torax_pydantic.Meter = 1.24
463467
wall_outer: torax_pydantic.Meter = 2.43
464468
frequency: torax_pydantic.TimeVaryingScalar = torax_pydantic.ValidatedDefault(

0 commit comments

Comments
 (0)