File tree Expand file tree Collapse file tree 3 files changed +22
-8
lines changed
Expand file tree Collapse file tree 3 files changed +22
-8
lines changed Original file line number Diff line number Diff line change @@ -33,13 +33,14 @@ dependencies = [
3333 " netcdf4>=1.7.2" ,
3434 " h5netcdf>=1.3.0" ,
3535 " scipy>=1.13.0" ,
36- " jaxtyping>=0.2.28 " ,
36+ " jaxtyping>=0.3.2 " ,
3737 " contourpy>=1.2.1" ,
3838 " eqdsk>=0.4.0" ,
3939 " pydantic>=2.10.5" ,
4040 " tqdm>=4.67.0" ,
4141 " treelib>=1.3.2" ,
4242 " imas-python>=2.0.1" ,
43+ " typeguard>=4.4.4" ,
4344]
4445
4546dynamic = [" version" ]
Original file line number Diff line number Diff line change 1313# limitations under the License.
1414# ============================================================================
1515"""Common types for using jaxtyping in TORAX."""
16- import chex
16+ from typing import TypeAlias
17+ import jax
1718import jaxtyping as jt
19+ import numpy as np
20+ import typeguard
1821
19- ScalarFloat = jt .Float [chex .Array | float , "" ]
20- ScalarBool = jt .Bool [chex .Array | bool , "" ]
21- ScalarInt = jt .Int [chex .Array | int , "" ]
22- ArrayFloat = jt .Float [chex .Array , "rhon" ]
23- ArrayBool = jt .Bool [chex .Array , "rhon" ]
22+ Array : TypeAlias = jax .Array | np .ndarray
23+
24+ ScalarFloat = jt .Float [Array | float , "" ]
25+ ScalarBool = jt .Bool [Array | bool , "" ]
26+ ScalarInt = jt .Int [Array | int , "" ]
27+
28+ ArrayFloat = jt .Float [Array , "rhon" ]
29+ ArrayBool = jt .Bool [Array , "rhon" ]
30+
31+ jaxtyped = jt .jaxtyped (typechecker = typeguard .typechecked )
Original file line number Diff line number Diff line change @@ -111,7 +111,12 @@ def cell_to_face(
111111 return face_values
112112
113113
114- def tridiag (diag : jax .Array , above : jax .Array , below : jax .Array ) -> jax .Array :
114+ @array_typing .jaxtyped
115+ def tridiag (
116+ diag : jt .Shaped [array_typing .Array , 'size' ],
117+ above : jt .Shaped [array_typing .Array , 'size-1' ],
118+ below : jt .Shaped [array_typing .Array , 'size-1' ],
119+ ) -> jt .Shaped [array_typing .Array , 'size size' ]:
115120 """Builds a tridiagonal matrix.
116121
117122 Args:
You can’t perform that action at this time.
0 commit comments