Skip to content

Commit be82ba6

Browse files
jcitrinTorax team
authored and
Torax team
committed
Connect MHD and sawtooth configs and models to TORAX.
1. sawtooth_model is now available in step_fn 2. mhd.sawtooth dynamic_runtime_params now available in dynamic_runtime_params_provider Nothing yet is done with the models. Next PRs will focus on model implementation. PiperOrigin-RevId: 739937629
1 parent 943faa0 commit be82ba6

12 files changed

+281
-5
lines changed

torax/config/build_runtime_params.py

+12
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from torax.config import runtime_params_slice
2828
from torax.geometry import geometry
2929
from torax.geometry import geometry_provider as geometry_provider_lib
30+
from torax.mhd import pydantic_model as mhd_pydantic_model
31+
from torax.mhd import runtime_params as mhd_runtime_params
3032
from torax.pedestal_model import pydantic_model as pedestal_pydantic_model
3133
from torax.sources import pydantic_model as sources_pydantic_model
3234
from torax.stepper import pydantic_model as stepper_pydantic_model
@@ -128,6 +130,7 @@ def __init__(
128130
transport: transport_model_pydantic_model.Transport | None = None,
129131
sources: sources_pydantic_model.Sources | None = None,
130132
stepper: stepper_pydantic_model.Stepper | None = None,
133+
mhd: mhd_pydantic_model.MHD | None = None,
131134
torax_mesh: torax_pydantic.Grid1D | None = None,
132135
):
133136
"""Constructs a build_simulation_params.DynamicRuntimeParamsSliceProvider.
@@ -142,6 +145,8 @@ def __init__(
142145
defaults to an empty dict (i.e. no sources).
143146
stepper: The stepper configuration to use. If None, defaults to the
144147
default stepper configuration.
148+
mhd: The mhd configuration to use. If None, defaults to an empty MHD
149+
object.
145150
torax_mesh: The torax mesh to use. If the slice provider doesn't need to
146151
construct any rho interpolated values, this can be None, else an error
147152
will be raised within the constructor of the interpolated variable.
@@ -153,13 +158,15 @@ def __init__(
153158
sources = sources or sources_pydantic_model.Sources()
154159
stepper = stepper or stepper_pydantic_model.Stepper()
155160
pedestal = pedestal or pedestal_pydantic_model.Pedestal()
161+
mhd = mhd or mhd_pydantic_model.MHD()
156162
torax_pydantic.set_grid(sources, torax_mesh, mode='relaxed')
157163
self._torax_mesh = torax_mesh
158164
self._sources = sources
159165
self._runtime_params = runtime_params
160166
self._transport_model = transport
161167
self._stepper = stepper
162168
self._pedestal = pedestal
169+
self._mhd = mhd
163170

164171
@property
165172
def sources(self) -> sources_pydantic_model.Sources:
@@ -196,6 +203,11 @@ def __call__(
196203
profile_conditions=dynamic_general_runtime_params.profile_conditions,
197204
numerics=dynamic_general_runtime_params.numerics,
198205
pedestal=self._pedestal.build_dynamic_params(t),
206+
mhd=mhd_runtime_params.DynamicMHDParams(**{
207+
mhd_model_name: mhd_model_config.build_dynamic_params(t)
208+
for mhd_model_name, mhd_model_config in self._mhd
209+
if mhd_model_config is not None
210+
}),
199211
)
200212

201213

torax/config/runtime_params_slice.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from torax.config import profile_conditions
4444
from torax.geometry import geometry
4545
from torax.geometry import standard_geometry
46+
from torax.mhd import runtime_params as mhd_runtime_params
4647
from torax.pedestal_model import runtime_params as pedestal_model_params
4748
from torax.sources import runtime_params as sources_params
4849
from torax.stepper import runtime_params as stepper_params
@@ -76,10 +77,14 @@ class DynamicRuntimeParamsSlice:
7677
This class contains "slices" of various RuntimeParams attributes defined
7778
throughout TORAX:
7879
79-
- from the "general" runtime params
80+
- from the profile_conditions runtime params
81+
- from the numerics runtime params
82+
- from the plasma_composition runtime params
8083
- from the transport model's runtime params
8184
- from the stepper's runtime params
8285
- from each of the sources' runtime params
86+
- from the pedestal model's runtime params
87+
- from each of the mhd models' runtime params
8388
8489
This class packages all these together for convenience, as it simplifies many
8590
of the internal APIs within TORAX.
@@ -92,6 +97,7 @@ class DynamicRuntimeParamsSlice:
9297
numerics: numerics.DynamicNumerics
9398
sources: Mapping[str, sources_params.DynamicRuntimeParams]
9499
pedestal: pedestal_model_params.DynamicRuntimeParams
100+
mhd: mhd_runtime_params.DynamicMHDParams
95101

96102

97103
@chex.dataclass(frozen=True)

torax/mhd/base.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright 2024 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Base classes for MHD models."""
16+
17+
import chex
18+
from torax.mhd.sawtooth import sawtooth_model
19+
20+
21+
@chex.dataclass
22+
class MHDModels:
23+
"""Container for instantiated MHD model objects."""
24+
25+
sawtooth: sawtooth_model.SawtoothModel | None = None

torax/mhd/pydantic_model.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2024 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Pydantic config for MHD models."""
16+
17+
import pydantic
18+
from torax.mhd import base
19+
from torax.mhd.sawtooth import pydantic_model as sawtooth_pydantic_model
20+
from torax.torax_pydantic import torax_pydantic
21+
22+
23+
class MHD(torax_pydantic.BaseModelFrozen):
24+
"""Config for MHD models.
25+
26+
Attributes:
27+
sawtooth: Config for sawtooth models.
28+
"""
29+
30+
sawtooth: sawtooth_pydantic_model.SawtoothConfig | None = pydantic.Field(
31+
default=None
32+
)
33+
34+
def build_mhd_models(self) -> base.MHDModels:
35+
"""Builds and returns a container with instantiated MHD model objects."""
36+
37+
return base.MHDModels(
38+
sawtooth=self.sawtooth.build_model()
39+
if self.sawtooth is not None
40+
else None,
41+
)

torax/mhd/runtime_params.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright 2024 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Container for MHD model dynamic runtime params."""
16+
17+
import chex
18+
from torax.mhd.sawtooth import runtime_params as sawtooth_runtime_params
19+
20+
21+
@chex.dataclass(frozen=True)
22+
class DynamicMHDParams:
23+
"""Container for dynamic parameters of all configured MHD models."""
24+
25+
sawtooth: sawtooth_runtime_params.DynamicRuntimeParams | None = None

torax/mhd/sawtooth/pydantic_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class SawtoothConfig(torax_pydantic.BaseModelFrozen):
4545
)
4646
crash_step_duration: torax_pydantic.Second = 1e-3
4747

48-
def build_sawtooth_model(self) -> sawtooth_model.SawtoothModel:
48+
def build_model(self) -> sawtooth_model.SawtoothModel:
4949
return sawtooth_model.SawtoothModel(
5050
trigger_model=self.trigger_model_config.build_trigger_model(),
5151
redistribution_model=self.redistribution_model_config.build_redistribution_model(),

torax/mhd/sawtooth/sawtooth_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __call__(
5151

5252

5353
class SawtoothModel:
54-
"""Container for sawtooth trigger and redistribution models."""
54+
"""Sawtooth trigger and redistribution, and carries out sawtooth step."""
5555

5656
def __init__(
5757
self,

torax/mhd/sawtooth/simple_redistribution.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __call__(
3939
return core_profiles
4040

4141

42-
@chex.dataclass
42+
@chex.dataclass(frozen=True)
4343
class DynamicRuntimeParams(runtime_params.RedistributionDynamicRuntimeParams):
4444
# TODO(b/317360481): implement redistribution model. For now, does nothing.
4545
"""Dynamic runtime params for simple redistribution model."""
+147
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright 2024 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from absl.testing import absltest
16+
from absl.testing import parameterized
17+
from torax.config import build_runtime_params
18+
from torax.config import runtime_params as general_runtime_params
19+
from torax.geometry import pydantic_model as geometry_pydantic_model
20+
from torax.mhd import pydantic_model as mhd_pydantic_model
21+
from torax.mhd import runtime_params as mhd_runtime_params
22+
from torax.mhd.sawtooth import pydantic_model as sawtooth_pydantic_model
23+
from torax.mhd.sawtooth import runtime_params as sawtooth_runtime_params
24+
from torax.mhd.sawtooth import sawtooth_model
25+
from torax.stepper import pydantic_model as stepper_pydantic_model
26+
from torax.tests.test_lib import default_sources
27+
from torax.torax_pydantic import model_config
28+
29+
30+
class MHDPydanticModelTest(parameterized.TestCase):
31+
"""Tests for the MHD Pydantic model and dynamic params construction."""
32+
33+
def setUp(self):
34+
super().setUp()
35+
self.geo = geometry_pydantic_model.CircularConfig().build_geometry()
36+
self.runtime_params = general_runtime_params.GeneralRuntimeParams()
37+
self.sources = default_sources.get_default_sources()
38+
self.stepper = stepper_pydantic_model.Stepper()
39+
40+
def test_no_mhd_config(self):
41+
"""Tests the case where the 'mhd' key is entirely absent."""
42+
config_dict = {
43+
'runtime_params': {},
44+
'geometry': {'geometry_type': 'circular'},
45+
'pedestal': {},
46+
'sources': {},
47+
'stepper': {},
48+
'time_step_calculator': {},
49+
'transport': {},
50+
}
51+
torax_config = model_config.ToraxConfig.from_dict(config_dict)
52+
53+
self.assertIs(torax_config.mhd, None)
54+
provider = build_runtime_params.DynamicRuntimeParamsSliceProvider(
55+
runtime_params=self.runtime_params,
56+
sources=self.sources,
57+
stepper=self.stepper,
58+
torax_mesh=self.geo.torax_mesh,
59+
mhd=torax_config.mhd,
60+
)
61+
dynamic_slice = provider(t=0.0)
62+
self.assertIsInstance(
63+
dynamic_slice.mhd, mhd_runtime_params.DynamicMHDParams
64+
)
65+
self.assertIs(dynamic_slice.mhd.sawtooth, None)
66+
67+
def test_empty_mhd_config(self):
68+
"""Tests the case where 'mhd' key exists but is an empty dict."""
69+
config_dict = {
70+
'runtime_params': {},
71+
'geometry': {'geometry_type': 'circular'},
72+
'pedestal': {},
73+
'sources': {},
74+
'stepper': {},
75+
'time_step_calculator': {},
76+
'transport': {},
77+
'mhd': {},
78+
}
79+
torax_config = model_config.ToraxConfig.from_dict(config_dict)
80+
81+
self.assertIsInstance(torax_config.mhd, mhd_pydantic_model.MHD)
82+
assert isinstance(torax_config.mhd, mhd_pydantic_model.MHD)
83+
mhd_models = torax_config.mhd.build_mhd_models()
84+
self.assertIs(mhd_models.sawtooth, None)
85+
provider = build_runtime_params.DynamicRuntimeParamsSliceProvider(
86+
runtime_params=self.runtime_params,
87+
sources=self.sources,
88+
stepper=self.stepper,
89+
torax_mesh=self.geo.torax_mesh,
90+
mhd=torax_config.mhd,
91+
)
92+
dynamic_slice = provider(t=0.0)
93+
self.assertIsInstance(
94+
dynamic_slice.mhd, mhd_runtime_params.DynamicMHDParams
95+
)
96+
self.assertIs(dynamic_slice.mhd.sawtooth, None)
97+
98+
def test_mhd_config_with_sawtooth(self):
99+
"""Tests the case with a valid sawtooth configuration."""
100+
config_dict = {
101+
'runtime_params': {},
102+
'geometry': {'geometry_type': 'circular'},
103+
'pedestal': {},
104+
'sources': {},
105+
'stepper': {},
106+
'time_step_calculator': {},
107+
'transport': {},
108+
'mhd': {
109+
'sawtooth': {
110+
'trigger_model_config': {'trigger_model_type': 'simple'},
111+
'redistribution_model_config': {
112+
'redistribution_model_type': 'simple'
113+
},
114+
'minimum_radius': 0.06,
115+
}
116+
},
117+
}
118+
torax_config = model_config.ToraxConfig.from_dict(config_dict)
119+
120+
self.assertIsInstance(torax_config.mhd, mhd_pydantic_model.MHD)
121+
assert torax_config.mhd is not None
122+
self.assertIsInstance(
123+
torax_config.mhd.sawtooth, sawtooth_pydantic_model.SawtoothConfig
124+
)
125+
126+
mhd_models = torax_config.mhd.build_mhd_models()
127+
self.assertIn('sawtooth', mhd_models)
128+
self.assertIsInstance(mhd_models['sawtooth'], sawtooth_model.SawtoothModel)
129+
130+
provider = build_runtime_params.DynamicRuntimeParamsSliceProvider(
131+
runtime_params=self.runtime_params,
132+
sources=self.sources,
133+
stepper=self.stepper,
134+
torax_mesh=self.geo.torax_mesh,
135+
mhd=torax_config.mhd,
136+
)
137+
dynamic_slice = provider(t=0.0)
138+
self.assertIn('sawtooth', dynamic_slice.mhd)
139+
sawtooth_dynamic_params = dynamic_slice.mhd.sawtooth
140+
self.assertIsInstance(
141+
sawtooth_dynamic_params, sawtooth_runtime_params.DynamicRuntimeParams
142+
)
143+
self.assertEqual(sawtooth_dynamic_params.minimum_radius, 0.06)
144+
145+
146+
if __name__ == '__main__':
147+
absltest.main()

torax/orchestration/run_simulation.py

+8
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,18 @@ def run_simulation(
5555
pedestal_model=pedestal_model,
5656
)
5757

58+
mhd_models = (
59+
torax_config.mhd.build_mhd_models()
60+
if torax_config.mhd is not None
61+
else None
62+
)
63+
5864
step_fn = step_function.SimulationStepFn(
5965
stepper=stepper,
6066
time_step_calculator=torax_config.time_step_calculator.time_step_calculator,
6167
transport_model=transport_model,
6268
pedestal_model=pedestal_model,
69+
mhd_models=mhd_models,
6370
)
6471

6572
static_runtime_params_slice = (
@@ -78,6 +85,7 @@ def run_simulation(
7885
transport=torax_config.transport,
7986
sources=torax_config.sources,
8087
stepper=torax_config.stepper,
88+
mhd=torax_config.mhd,
8189
torax_mesh=torax_config.geometry.build_provider.torax_mesh,
8290
)
8391
)

0 commit comments

Comments
 (0)