Skip to content

Commit 8c3d118

Browse files
Nush395Torax team
authored andcommitted
Modify JIT of transport model to be defined on hash of TransportModel so that subsequent simulations hit cache.
Also: - remove jitted_transport_model from step function as this will be cache missed in subsequent simulations. - remove inner JIT from qlknn_transport_model as this is now covered by the JIT on __call__. Otherwise this misses a lookup as the function is defined on an object by id. PiperOrigin-RevId: 739945207
1 parent 322ea52 commit 8c3d118

File tree

4 files changed

+6
-12
lines changed

4 files changed

+6
-12
lines changed

torax/orchestration/step_function.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,6 @@ def __init__(
7070
self._time_step_calculator = time_step_calculator
7171
self._transport_model = transport_model
7272
self._pedestal_model = pedestal_model
73-
self._jitted_transport_model = jax_utils.jit(
74-
transport_model.__call__,
75-
)
7673

7774
@property
7875
def pedestal_model(self) -> pedestal_model_lib.PedestalModel:
@@ -239,7 +236,7 @@ def init_time_step_calculator(
239236
pedestal_model_output = self._pedestal_model(
240237
dynamic_runtime_params_slice_t, geo_t, input_state.core_profiles
241238
)
242-
transport_coeffs = self._jitted_transport_model(
239+
transport_coeffs = self._transport_model(
243240
dynamic_runtime_params_slice_t,
244241
geo_t,
245242
input_state.core_profiles,

torax/transport_model/qlknn_transport_model.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -212,12 +212,6 @@ def _call_implementation(
212212
)
213213
return self._combined(runtime_config_inputs, geo, core_profiles)
214214

215-
# Wrap in JIT here in order to cache the tracing/compilation of this function.
216-
# We mark self as static because it is a singleton. Other args are pytrees.
217-
# There's no global coordination of calls to transport model so it is called
218-
# 2-4X with the same args. Caching prevents construction of multiple copies of
219-
# identical expressions saving ~30% in compile time.
220-
@functools.partial(jax.jit, static_argnames=['self'])
221215
def _combined(
222216
self,
223217
runtime_config_inputs: QLKNNRuntimeConfigInputs,

torax/transport_model/tests/qlknn_transport_model_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,15 @@ def test_qlknn_transport_model_cache_works(self):
7676
qlknn(
7777
dynamic_runtime_params_slice, geo, core_profiles, pedestal_model_outputs
7878
)
79-
cache_size = qlknn._combined._cache_size() # pylint: disable=protected-access
79+
cache_size = qlknn.__call__._cache_size() # pytype: disable=attribute-error
8080
self.assertGreaterEqual(cache_size, 1)
8181

8282
# Executing again should lead to the same cache entry being used.
8383
qlknn(
8484
dynamic_runtime_params_slice, geo, core_profiles, pedestal_model_outputs
8585
)
8686
self.assertEqual(
87-
qlknn._combined._cache_size(), # pylint: disable=protected-access
87+
qlknn.__call__._cache_size(), # pytype: disable=attribute-error
8888
cache_size,
8989
)
9090

torax/transport_model/transport_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020

2121
import abc
2222
import dataclasses
23+
import functools
2324

2425
import jax
2526
from jax import numpy as jnp
2627
from torax import constants
28+
from torax import jax_utils
2729
from torax import state
2830
from torax.config import runtime_params_slice
2931
from torax.geometry import geometry
@@ -53,6 +55,7 @@ def __setattr__(self, attr, value):
5355
raise AttributeError("TransportModels are immutable.")
5456
return super().__setattr__(attr, value)
5557

58+
@functools.partial(jax_utils.jit, static_argnums=(0,))
5659
def __call__(
5760
self,
5861
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,

0 commit comments

Comments
 (0)