Skip to content

Commit d7a7cdc

Browse files
Nush395Torax team
authored and
Torax team
committed
Create utility for registering new source model configs.
PiperOrigin-RevId: 742236669
1 parent b2337a7 commit d7a7cdc

File tree

2 files changed

+316
-0
lines changed

2 files changed

+316
-0
lines changed
+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright 2024 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Utilities for registering new pydantic configs."""
15+
from torax.sources import base
16+
from torax.sources import bremsstrahlung_heat_sink as bremsstrahlung_heat_sink_lib
17+
from torax.sources import cyclotron_radiation_heat_sink as cyclotron_radiation_heat_sink_lib
18+
from torax.sources import electron_cyclotron_source as electron_cyclotron_source_lib
19+
from torax.sources import fusion_heat_source as fusion_heat_source_lib
20+
from torax.sources import gas_puff_source as gas_puff_source_lib
21+
from torax.sources import generic_current_source as generic_current_source_lib
22+
from torax.sources import generic_ion_el_heat_source as generic_ion_el_heat_source_lib
23+
from torax.sources import generic_particle_source as generic_particle_source_lib
24+
from torax.sources import ion_cyclotron_source as ion_cyclotron_source_lib
25+
from torax.sources import ohmic_heat_source as ohmic_heat_source_lib
26+
from torax.sources import pellet_source as pellet_source_lib
27+
from torax.sources import pydantic_model as sources_pydantic_model
28+
from torax.sources.impurity_radiation_heat_sink import impurity_radiation_heat_sink as impurity_radiation_heat_sink_lib
29+
from torax.sources.impurity_radiation_heat_sink import impurity_radiation_mavrin_fit as impurity_radiation_mavrin_fit_lib
30+
from torax.torax_pydantic import model_config
31+
32+
33+
def _validate_source_model_config(
34+
source_model_config_class: type[base.SourceModelBase],
35+
source_name: str,
36+
):
37+
"""Validates that the source model config is valid."""
38+
if source_name in ('qei', 'j_bootstrap'):
39+
raise ValueError(
40+
'Cannot register a new source model config for the qei or j_bootstrap'
41+
' sources.'
42+
)
43+
44+
source_model_config = source_model_config_class()
45+
if not hasattr(source_model_config, 'model_function_name'):
46+
raise ValueError(
47+
'The source model config must have a model_function_name attribute.'
48+
)
49+
model_function_name: str = source_model_config.model_function_name
50+
51+
match source_name:
52+
case bremsstrahlung_heat_sink_lib.BremsstrahlungHeatSink.SOURCE_NAME:
53+
default_model_function_name = (
54+
bremsstrahlung_heat_sink_lib.DEFAULT_MODEL_FUNCTION_NAME
55+
)
56+
case (
57+
cyclotron_radiation_heat_sink_lib.CyclotronRadiationHeatSink.SOURCE_NAME
58+
):
59+
default_model_function_name = (
60+
cyclotron_radiation_heat_sink_lib.DEFAULT_MODEL_FUNCTION_NAME
61+
)
62+
case electron_cyclotron_source_lib.ElectronCyclotronSource.SOURCE_NAME:
63+
default_model_function_name = (
64+
electron_cyclotron_source_lib.DEFAULT_MODEL_FUNCTION_NAME
65+
)
66+
case gas_puff_source_lib.GasPuffSource.SOURCE_NAME:
67+
default_model_function_name = (
68+
gas_puff_source_lib.DEFAULT_MODEL_FUNCTION_NAME
69+
)
70+
case generic_particle_source_lib.GenericParticleSource.SOURCE_NAME:
71+
default_model_function_name = (
72+
generic_particle_source_lib.DEFAULT_MODEL_FUNCTION_NAME
73+
)
74+
case pellet_source_lib.PelletSource.SOURCE_NAME:
75+
default_model_function_name = (
76+
pellet_source_lib.DEFAULT_MODEL_FUNCTION_NAME
77+
)
78+
case fusion_heat_source_lib.FusionHeatSource.SOURCE_NAME:
79+
default_model_function_name = (
80+
fusion_heat_source_lib.DEFAULT_MODEL_FUNCTION_NAME
81+
)
82+
case (
83+
generic_ion_el_heat_source_lib.GenericIonElectronHeatSource.SOURCE_NAME
84+
):
85+
default_model_function_name = (
86+
generic_ion_el_heat_source_lib.DEFAULT_MODEL_FUNCTION_NAME
87+
)
88+
case impurity_radiation_heat_sink_lib.ImpurityRadiationHeatSink.SOURCE_NAME:
89+
default_model_function_name = (
90+
impurity_radiation_mavrin_fit_lib.DEFAULT_MODEL_FUNCTION_NAME
91+
)
92+
case ion_cyclotron_source_lib.IonCyclotronSource.SOURCE_NAME:
93+
default_model_function_name = (
94+
ion_cyclotron_source_lib.DEFAULT_MODEL_FUNCTION_NAME
95+
)
96+
case ohmic_heat_source_lib.OhmicHeatSource.SOURCE_NAME:
97+
default_model_function_name = (
98+
ohmic_heat_source_lib.DEFAULT_MODEL_FUNCTION_NAME
99+
)
100+
case generic_current_source_lib.GenericCurrentSource.SOURCE_NAME:
101+
default_model_function_name = (
102+
generic_current_source_lib.DEFAULT_MODEL_FUNCTION_NAME
103+
)
104+
case _:
105+
raise ValueError(f'The source name {source_name} is not supported.')
106+
107+
if model_function_name == default_model_function_name:
108+
raise ValueError(
109+
f'The model function name {model_function_name} must be different from'
110+
f' the default model function name {default_model_function_name} for'
111+
f' the source {source_name}.'
112+
)
113+
114+
115+
def register_source_model_config(
116+
source_model_config_class: type[base.SourceModelBase],
117+
source_name: str,
118+
):
119+
"""Update Pydantic schema to include a source model config.
120+
121+
See torax.torax_pydantic.tests.register_config_test.py for an example of how
122+
to use this function and expected behavior.
123+
124+
Args:
125+
source_model_config_class: The new source model config to register. This
126+
should be a subclass of SourceModelBase that implements the interface and
127+
has a unique `model_function_name`.
128+
source_name: The name of the source to register the model config against.
129+
This should be one of the fields in the Sources pydantic model. For the
130+
two "special" sources ("qei" and "j_bootstrap") registering a new
131+
implementation is not supported.
132+
"""
133+
_validate_source_model_config(source_model_config_class, source_name)
134+
# Update the Sources pydantic model to be aware of the new config.
135+
sources_pydantic_model.Sources.model_fields[
136+
f'{source_name}'
137+
].annotation |= source_model_config_class
138+
# Rebuild the pydantic schema for both the Sources and ToraxConfig models so
139+
# that uses of either will have access to the new config.
140+
sources_pydantic_model.Sources.model_rebuild(force=True)
141+
model_config.ToraxConfig.model_rebuild(force=True)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Copyright 2024 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import copy
15+
import importlib
16+
from typing import Literal
17+
18+
from absl.testing import parameterized
19+
import chex
20+
from torax import array_typing
21+
from torax import state
22+
from torax.config import runtime_params_slice
23+
from torax.geometry import geometry
24+
from torax.sources import base as source_base_pydantic_model
25+
from torax.sources import gas_puff_source as gas_puff_source_lib
26+
from torax.sources import runtime_params
27+
from torax.sources import source as source_lib
28+
from torax.sources import source_profiles
29+
from torax.torax_pydantic import model_config
30+
from torax.torax_pydantic import register_config
31+
from torax.torax_pydantic import torax_pydantic
32+
33+
34+
@chex.dataclass(frozen=True)
35+
class DynamicRuntimeParams(runtime_params.DynamicRuntimeParams):
36+
a: array_typing.ScalarFloat
37+
b: bool
38+
39+
40+
def double_gas_puff_source(
41+
unused_static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
42+
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
43+
geo: geometry.Geometry,
44+
source_name: str,
45+
unused_state: state.CoreProfiles,
46+
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
47+
) -> tuple[chex.Array, ...]:
48+
"""Calculates external source term for n from puffs."""
49+
output = gas_puff_source_lib.calc_puff_source(
50+
unused_static_runtime_params_slice,
51+
dynamic_runtime_params_slice,
52+
geo,
53+
source_name,
54+
unused_state,
55+
unused_calculated_source_profiles,
56+
)
57+
return 2 * output
58+
59+
60+
class NewGasPuffSourceModelConfig(source_base_pydantic_model.SourceModelBase):
61+
"""New source model config."""
62+
model_function_name: Literal['test_model_function'] = 'test_model_function'
63+
a: torax_pydantic.TimeVaryingScalar = torax_pydantic.ValidatedDefault(1.0)
64+
b: bool = False
65+
66+
@property
67+
def model_func(self) -> source_lib.SourceProfileFunction:
68+
return double_gas_puff_source
69+
70+
def build_source(self) -> source_lib.Source:
71+
return gas_puff_source_lib.GasPuffSource(model_func=self.model_func)
72+
73+
def build_dynamic_params(
74+
self,
75+
t: chex.Numeric,
76+
) -> DynamicRuntimeParams:
77+
return DynamicRuntimeParams(
78+
a=self.a.get_value(t),
79+
b=self.b,
80+
prescribed_values=self.prescribed_values.get_value(t),
81+
)
82+
83+
84+
class DuplicateGasPuffSourceModelConfig(
85+
source_base_pydantic_model.SourceModelBase
86+
):
87+
# Name that is already registered.
88+
model_function_name: Literal['calc_puff_source'] = 'calc_puff_source'
89+
a: torax_pydantic.TimeVaryingScalar = torax_pydantic.ValidatedDefault(1.0)
90+
b: bool = False
91+
92+
@property
93+
def model_func(self) -> source_lib.SourceProfileFunction:
94+
return double_gas_puff_source
95+
96+
def build_source(self) -> source_lib.Source:
97+
return gas_puff_source_lib.GasPuffSource(model_func=self.model_func)
98+
99+
def build_dynamic_params(
100+
self,
101+
t: chex.Numeric,
102+
) -> DynamicRuntimeParams:
103+
return DynamicRuntimeParams(
104+
a=self.a.get_value(t),
105+
b=self.b,
106+
prescribed_values=self.prescribed_values.get_value(t),
107+
)
108+
109+
110+
class RegisterConfigTest(parameterized.TestCase):
111+
112+
def test_register_source_model_config(self):
113+
config_name = 'test_iterhybrid_rampup'
114+
test_config_path = '.tests.test_data.' + config_name
115+
config_module = importlib.import_module(test_config_path, 'torax')
116+
config = copy.deepcopy(config_module.CONFIG)
117+
# Register the new source model config against the gas puff source.
118+
register_config.register_source_model_config(
119+
NewGasPuffSourceModelConfig, 'gas_puff_source'
120+
)
121+
122+
# Load the original config and check the gas puff source is expected type.
123+
config_pydantic = model_config.ToraxConfig.from_dict(config)
124+
gas_puff_source_config = config_pydantic.sources.gas_puff_source
125+
self.assertIsInstance(
126+
gas_puff_source_config, gas_puff_source_lib.GasPuffSourceConfig
127+
)
128+
gas_puff_source = gas_puff_source_config.build_source()
129+
self.assertIsInstance(gas_puff_source, gas_puff_source_lib.GasPuffSource)
130+
dynamic_params = gas_puff_source_config.build_dynamic_params(t=0.0)
131+
self.assertIsInstance(
132+
dynamic_params, gas_puff_source_lib.DynamicGasPuffRuntimeParams
133+
)
134+
135+
# Now modify the original config to use the new config.
136+
del config['sources']['gas_puff_source']
137+
config['sources']['gas_puff_source'] = {
138+
'model_function_name': 'test_model_function', # new registered name.
139+
'a': 2.0,
140+
}
141+
config_pydantic = model_config.ToraxConfig.from_dict(config)
142+
# Check we build the correct config.
143+
new_gas_puff_config = config_pydantic.sources.gas_puff_source
144+
self.assertIsInstance(new_gas_puff_config, NewGasPuffSourceModelConfig)
145+
# Check the dynamic params are built correctly.
146+
new_dynamic_params = new_gas_puff_config.build_dynamic_params(t=0.0)
147+
self.assertIsInstance(new_dynamic_params, DynamicRuntimeParams)
148+
self.assertEqual(new_dynamic_params.a, 2.0)
149+
self.assertEqual(new_dynamic_params.b, False)
150+
151+
def test_error_thrown_if_model_function_name_is_already_registered(self):
152+
with self.assertRaises(ValueError):
153+
register_config.register_source_model_config(
154+
DuplicateGasPuffSourceModelConfig, 'gas_puff_source'
155+
)
156+
157+
@parameterized.parameters('qei', 'j_bootstrap')
158+
def test_error_thrown_if_using_special_source(self, special_source):
159+
with self.assertRaisesRegex(
160+
ValueError,
161+
'Cannot register a new source model config for the qei or j_bootstrap'
162+
' sources.',
163+
):
164+
register_config.register_source_model_config(
165+
NewGasPuffSourceModelConfig, special_source
166+
)
167+
168+
def test_error_thrown_if_source_not_supported(self):
169+
with self.assertRaisesRegex(
170+
ValueError,
171+
'The source name foo_source is not supported.',
172+
):
173+
register_config.register_source_model_config(
174+
NewGasPuffSourceModelConfig, 'foo_source'
175+
)

0 commit comments

Comments
 (0)