Skip to content

Commit

Permalink
convert to tuple instead of jax.Array
Browse files Browse the repository at this point in the history
  • Loading branch information
gianlucadetommaso committed Nov 20, 2023
1 parent bd4d94b commit 79537bd
Show file tree
Hide file tree
Showing 13 changed files with 26 additions and 24 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
matrix:
# Select the Python versions to test against
python-version: ['3.8', '3.9', '3.10']
python-version: ['3.8', '3.9', '3.10', '3.11']
steps:
- name: Check out the code
uses: actions/checkout@v3
Expand Down
4 changes: 2 additions & 2 deletions fortuna/output_calib_model/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
CalibParams,
OptaxOptimizer,
)
from fortuna.utils.strings import convert_string_to_jnp_array
from fortuna.utils.strings import convert_string_to_tuple


class OutputCalibState(TrainState):
params: CalibParams
mutable: Optional[CalibMutable] = None
encoded_name: jnp.ndarray = convert_string_to_jnp_array("OutputCalibState")
encoded_name: tuple = convert_string_to_tuple("OutputCalibState")

@classmethod
def init(
Expand Down
4 changes: 2 additions & 2 deletions fortuna/prob_model/posterior/laplace/laplace_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from fortuna.utils.nested_dicts import nested_pair
from fortuna.utils.strings import (
convert_string_to_jnp_array,
convert_string_to_tuple,
encode_tuple_of_lists_of_strings_to_numpy,
)

Expand All @@ -36,7 +36,7 @@ class LaplaceState(PosteriorState):
"""

prior_log_var: float = 0.0
encoded_name: jnp.ndarray = convert_string_to_jnp_array("LaplaceState")
encoded_name: tuple = convert_string_to_tuple("LaplaceState")
_encoded_which_params: Optional[Dict[str, Array]] = None

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions fortuna/prob_model/posterior/map/map_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import jax.numpy as jnp

from fortuna.prob_model.posterior.state import PosteriorState
from fortuna.utils.strings import convert_string_to_jnp_array
from fortuna.utils.strings import convert_string_to_tuple


class MAPState(PosteriorState):
Expand All @@ -14,4 +14,4 @@ class MAPState(PosteriorState):
MAP state name encoded as an array.
"""

encoded_name: jnp.ndarray = convert_string_to_jnp_array("MAPState")
encoded_name: tuple = convert_string_to_tuple("MAPState")
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
NormalizingFlowState,
)
from fortuna.typing import Array
from fortuna.utils.strings import convert_string_to_jnp_array
from fortuna.utils.strings import convert_string_to_tuple


class ADVIState(NormalizingFlowState):
Expand All @@ -23,5 +23,5 @@ class ADVIState(NormalizingFlowState):
ADVI state name encoded as an array.
"""

encoded_name: jnp.ndarray = convert_string_to_jnp_array("ADVIState")
encoded_name: tuple = convert_string_to_tuple("ADVIState")
_encoded_which_params: Optional[Dict[str, List[Array]]] = None
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
OptaxOptimizer,
)
from fortuna.utils.strings import (
convert_string_to_jnp_array,
convert_string_to_tuple,
encode_tuple_of_lists_of_strings_to_numpy,
)

Expand All @@ -30,7 +30,7 @@ class CyclicalSGLDState(PosteriorState):
CyclicalSGLDState state name encoded as an array.
"""

encoded_name: jnp.ndarray = convert_string_to_jnp_array("CyclicalSGLDState")
encoded_name: tuple = convert_string_to_tuple("CyclicalSGLDState")
_encoded_which_params: Optional[Dict[str, List[Array]]] = None

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
OptaxOptimizer,
)
from fortuna.utils.strings import (
convert_string_to_jnp_array,
convert_string_to_tuple,
encode_tuple_of_lists_of_strings_to_numpy,
)

Expand All @@ -30,7 +30,7 @@ class SGHMCState(PosteriorState):
SGHMC state name encoded as an array.
"""

encoded_name: jnp.ndarray = convert_string_to_jnp_array("SGHMCState")
encoded_name: tuple = convert_string_to_tuple("SGHMCState")
_encoded_which_params: Optional[Dict[str, List[Array]]] = None

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions fortuna/prob_model/posterior/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
OptaxOptimizer,
Params,
)
from fortuna.utils.strings import convert_string_to_jnp_array
from fortuna.utils.strings import convert_string_to_tuple


class PosteriorState(TrainState):
Expand All @@ -33,7 +33,7 @@ class PosteriorState(TrainState):
calib_mutable: Optional[CalibMutable] = None
grad_accumulated: Optional[jnp.ndarray] = None
dynamic_scale: Optional[dynamic_scale.DynamicScale] = None
encoded_name: jnp.ndarray = convert_string_to_jnp_array("PosteriorState")
encoded_name: tuple = convert_string_to_tuple("PosteriorState")

@classmethod
def init(
Expand Down
4 changes: 2 additions & 2 deletions fortuna/prob_model/posterior/swag/swag_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Array,
OptaxOptimizer,
)
from fortuna.utils.strings import convert_string_to_jnp_array
from fortuna.utils.strings import convert_string_to_tuple


class SWAGState(PosteriorState):
Expand All @@ -35,7 +35,7 @@ class SWAGState(PosteriorState):
mean: Optional[jnp.ndarray] = None
std: Optional[jnp.ndarray] = None
dev: Optional[jnp.ndarray] = None
encoded_name: jnp.ndarray = convert_string_to_jnp_array("SWAGState")
encoded_name: tuple = convert_string_to_tuple("SWAGState")
_encoded_which_params: Optional[Dict[str, List[Array]]] = None

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions fortuna/training/train_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
import jax.numpy as jnp

from fortuna.typing import Params
from fortuna.utils.strings import convert_string_to_jnp_array
from fortuna.utils.strings import convert_string_to_tuple


class TrainState(train_state.TrainState):
encoded_name: jnp.ndarray = convert_string_to_jnp_array("TrainState")
encoded_name: tuple = convert_string_to_tuple("TrainState")
frozen_params: Optional[Params] = None
dynamic_scale: Optional[dynamic_scale.DynamicScale] = None

Expand Down
4 changes: 2 additions & 2 deletions fortuna/utils/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from fortuna.typing import Array


def convert_string_to_jnp_array(s: str) -> jnp.ndarray:
return jnp.array([ord(c) for c in s])
def convert_string_to_tuple(s: str) -> Tuple:
return tuple([ord(c) for c in s])


def convert_string_to_np_array(s: str) -> np.ndarray:
Expand Down
6 changes: 4 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ documentation = "https://aws-fortuna.readthedocs.io/en/latest/"
packages = [{include = "fortuna"}]

[tool.poetry.dependencies]
python = ">=3.8,<3.11"
python = ">=3.8,<3.12"
flax = "^0.6.2"
optax = "^0.1.3"
matplotlib = "^3.6.2"
Expand Down

0 comments on commit 79537bd

Please sign in to comment.