Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 79537bd

Browse files
convert to tuple instead of jax.Array
1 parent bd4d94b commit 79537bd

File tree

13 files changed

+26
-24
lines changed

13 files changed

+26
-24
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
strategy:
1414
matrix:
1515
# Select the Python versions to test against
16-
python-version: ['3.8', '3.9', '3.10']
16+
python-version: ['3.8', '3.9', '3.10', '3.11']
1717
steps:
1818
- name: Check out the code
1919
uses: actions/checkout@v3

fortuna/output_calib_model/state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
CalibParams,
1717
OptaxOptimizer,
1818
)
19-
from fortuna.utils.strings import convert_string_to_jnp_array
19+
from fortuna.utils.strings import convert_string_to_tuple
2020

2121

2222
class OutputCalibState(TrainState):
2323
params: CalibParams
2424
mutable: Optional[CalibMutable] = None
25-
encoded_name: jnp.ndarray = convert_string_to_jnp_array("OutputCalibState")
25+
encoded_name: tuple = convert_string_to_tuple("OutputCalibState")
2626

2727
@classmethod
2828
def init(

fortuna/prob_model/posterior/laplace/laplace_state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
)
2121
from fortuna.utils.nested_dicts import nested_pair
2222
from fortuna.utils.strings import (
23-
convert_string_to_jnp_array,
23+
convert_string_to_tuple,
2424
encode_tuple_of_lists_of_strings_to_numpy,
2525
)
2626

@@ -36,7 +36,7 @@ class LaplaceState(PosteriorState):
3636
"""
3737

3838
prior_log_var: float = 0.0
39-
encoded_name: jnp.ndarray = convert_string_to_jnp_array("LaplaceState")
39+
encoded_name: tuple = convert_string_to_tuple("LaplaceState")
4040
_encoded_which_params: Optional[Dict[str, Array]] = None
4141

4242
@classmethod

fortuna/prob_model/posterior/map/map_state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import jax.numpy as jnp
44

55
from fortuna.prob_model.posterior.state import PosteriorState
6-
from fortuna.utils.strings import convert_string_to_jnp_array
6+
from fortuna.utils.strings import convert_string_to_tuple
77

88

99
class MAPState(PosteriorState):
@@ -14,4 +14,4 @@ class MAPState(PosteriorState):
1414
MAP state name encoded as an array.
1515
"""
1616

17-
encoded_name: jnp.ndarray = convert_string_to_jnp_array("MAPState")
17+
encoded_name: tuple = convert_string_to_tuple("MAPState")

fortuna/prob_model/posterior/normalizing_flow/advi/advi_state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
NormalizingFlowState,
1313
)
1414
from fortuna.typing import Array
15-
from fortuna.utils.strings import convert_string_to_jnp_array
15+
from fortuna.utils.strings import convert_string_to_tuple
1616

1717

1818
class ADVIState(NormalizingFlowState):
@@ -23,5 +23,5 @@ class ADVIState(NormalizingFlowState):
2323
ADVI state name encoded as an array.
2424
"""
2525

26-
encoded_name: jnp.ndarray = convert_string_to_jnp_array("ADVIState")
26+
encoded_name: tuple = convert_string_to_tuple("ADVIState")
2727
_encoded_which_params: Optional[Dict[str, List[Array]]] = None

fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
OptaxOptimizer,
1818
)
1919
from fortuna.utils.strings import (
20-
convert_string_to_jnp_array,
20+
convert_string_to_tuple,
2121
encode_tuple_of_lists_of_strings_to_numpy,
2222
)
2323

@@ -30,7 +30,7 @@ class CyclicalSGLDState(PosteriorState):
3030
CyclicalSGLDState state name encoded as an array.
3131
"""
3232

33-
encoded_name: jnp.ndarray = convert_string_to_jnp_array("CyclicalSGLDState")
33+
encoded_name: tuple = convert_string_to_tuple("CyclicalSGLDState")
3434
_encoded_which_params: Optional[Dict[str, List[Array]]] = None
3535

3636
@classmethod

fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
OptaxOptimizer,
1818
)
1919
from fortuna.utils.strings import (
20-
convert_string_to_jnp_array,
20+
convert_string_to_tuple,
2121
encode_tuple_of_lists_of_strings_to_numpy,
2222
)
2323

@@ -30,7 +30,7 @@ class SGHMCState(PosteriorState):
3030
SGHMC state name encoded as an array.
3131
"""
3232

33-
encoded_name: jnp.ndarray = convert_string_to_jnp_array("SGHMCState")
33+
encoded_name: tuple = convert_string_to_tuple("SGHMCState")
3434
_encoded_which_params: Optional[Dict[str, List[Array]]] = None
3535

3636
@classmethod

fortuna/prob_model/posterior/state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
OptaxOptimizer,
1919
Params,
2020
)
21-
from fortuna.utils.strings import convert_string_to_jnp_array
21+
from fortuna.utils.strings import convert_string_to_tuple
2222

2323

2424
class PosteriorState(TrainState):
@@ -33,7 +33,7 @@ class PosteriorState(TrainState):
3333
calib_mutable: Optional[CalibMutable] = None
3434
grad_accumulated: Optional[jnp.ndarray] = None
3535
dynamic_scale: Optional[dynamic_scale.DynamicScale] = None
36-
encoded_name: jnp.ndarray = convert_string_to_jnp_array("PosteriorState")
36+
encoded_name: tuple = convert_string_to_tuple("PosteriorState")
3737

3838
@classmethod
3939
def init(

fortuna/prob_model/posterior/swag/swag_state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
Array,
1616
OptaxOptimizer,
1717
)
18-
from fortuna.utils.strings import convert_string_to_jnp_array
18+
from fortuna.utils.strings import convert_string_to_tuple
1919

2020

2121
class SWAGState(PosteriorState):
@@ -35,7 +35,7 @@ class SWAGState(PosteriorState):
3535
mean: Optional[jnp.ndarray] = None
3636
std: Optional[jnp.ndarray] = None
3737
dev: Optional[jnp.ndarray] = None
38-
encoded_name: jnp.ndarray = convert_string_to_jnp_array("SWAGState")
38+
encoded_name: tuple = convert_string_to_tuple("SWAGState")
3939
_encoded_which_params: Optional[Dict[str, List[Array]]] = None
4040

4141
@classmethod

fortuna/training/train_state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
import jax.numpy as jnp
1313

1414
from fortuna.typing import Params
15-
from fortuna.utils.strings import convert_string_to_jnp_array
15+
from fortuna.utils.strings import convert_string_to_tuple
1616

1717

1818
class TrainState(train_state.TrainState):
19-
encoded_name: jnp.ndarray = convert_string_to_jnp_array("TrainState")
19+
encoded_name: tuple = convert_string_to_tuple("TrainState")
2020
frozen_params: Optional[Params] = None
2121
dynamic_scale: Optional[dynamic_scale.DynamicScale] = None
2222

fortuna/utils/strings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from fortuna.typing import Array
1313

1414

15-
def convert_string_to_jnp_array(s: str) -> jnp.ndarray:
16-
return jnp.array([ord(c) for c in s])
15+
def convert_string_to_tuple(s: str) -> Tuple:
16+
return tuple([ord(c) for c in s])
1717

1818

1919
def convert_string_to_np_array(s: str) -> np.ndarray:

poetry.lock

Lines changed: 4 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ documentation = "https://aws-fortuna.readthedocs.io/en/latest/"
99
packages = [{include = "fortuna"}]
1010

1111
[tool.poetry.dependencies]
12-
python = ">=3.8,<3.11"
12+
python = ">=3.8,<3.12"
1313
flax = "^0.6.2"
1414
optax = "^0.1.3"
1515
matplotlib = "^3.6.2"

0 commit comments

Comments
 (0)