Skip to content

Commit 117af18

Browse files
sbodensteinTorax team
authored andcommitted
Add dependency on typeguard.
PiperOrigin-RevId: 793588618
1 parent df83283 commit 117af18

File tree

3 files changed

+22
-8
lines changed

3 files changed

+22
-8
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,14 @@ dependencies = [
3333
"netcdf4>=1.7.2",
3434
"h5netcdf>=1.3.0",
3535
"scipy>=1.13.0",
36-
"jaxtyping>=0.2.28",
36+
"jaxtyping>=0.3.2",
3737
"contourpy>=1.2.1",
3838
"eqdsk>=0.4.0",
3939
"pydantic>=2.10.5",
4040
"tqdm>=4.67.0",
4141
"treelib>=1.3.2",
4242
"imas-python>=2.0.1",
43+
"typeguard>=4.4.4",
4344
]
4445

4546
dynamic = ["version"]

torax/_src/array_typing.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,19 @@
1313
# limitations under the License.
1414
# ============================================================================
1515
"""Common types for using jaxtyping in TORAX."""
16-
import chex
16+
from typing import TypeAlias
17+
import jax
1718
import jaxtyping as jt
19+
import numpy as np
20+
import typeguard
1821

19-
ScalarFloat = jt.Float[chex.Array | float, ""]
20-
ScalarBool = jt.Bool[chex.Array | bool, ""]
21-
ScalarInt = jt.Int[chex.Array | int, ""]
22-
ArrayFloat = jt.Float[chex.Array, "rhon"]
23-
ArrayBool = jt.Bool[chex.Array, "rhon"]
22+
Array: TypeAlias = jax.Array | np.ndarray
23+
24+
ScalarFloat = jt.Float[Array | float, ""]
25+
ScalarBool = jt.Bool[Array | bool, ""]
26+
ScalarInt = jt.Int[Array | int, ""]
27+
28+
ArrayFloat = jt.Float[Array, "rhon"]
29+
ArrayBool = jt.Bool[Array, "rhon"]
30+
31+
jaxtyped = jt.jaxtyped(typechecker=typeguard.typechecked)

torax/_src/math_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,12 @@ def cell_to_face(
111111
return face_values
112112

113113

114-
def tridiag(diag: jax.Array, above: jax.Array, below: jax.Array) -> jax.Array:
114+
@array_typing.jaxtyped
115+
def tridiag(
116+
diag: jt.Shaped[array_typing.Array, 'size'],
117+
above: jt.Shaped[array_typing.Array, 'size-1'],
118+
below: jt.Shaped[array_typing.Array, 'size-1'],
119+
) -> jt.Shaped[array_typing.Array, 'size size']:
115120
"""Builds a tridiagonal matrix.
116121
117122
Args:

0 commit comments

Comments
 (0)