Skip to content

Commit 859db94

Browse files
committed
Issue #207: Support Gymnasium 1.0.0
1 parent 45b941e commit 859db94

File tree

5 files changed

+20
-14
lines changed

5 files changed

+20
-14
lines changed

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ requires-python = ">=3.10.0"
1515
license = { text = "MIT" }
1616
dependencies = [
1717
"Deprecated",
18-
"gymnasium<1.0.0",
18+
"gymnasium",
1919
"numpy",
2020
"pandas",
2121
"pettingzoo",
@@ -29,7 +29,7 @@ dependencies = [
2929

3030
[project.optional-dependencies]
3131
docs = ["ipykernel", "ipywidgets", "nbdime", "nbsphinx", "sphinx-rtd-theme"]
32-
rllib = ["dm_tree", "pyarrow", "ray[rllib]", "scikit-image", "torch", "typer"]
32+
rllib = ["dm_tree", "pyarrow", "ray[rllib]==2.35.0", "scikit-image", "torch", "typer"]
3333

3434
[project.scripts]
3535
finish_install = "bsk_rl.finish_install:pck_install"

src/bsk_rl/obs/observations.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,17 @@ def nested_obs_to_space(obs_dict):
3737
)
3838
elif isinstance(obs_dict, list):
3939
return spaces.Box(
40-
low=-1e16, high=1e16, shape=(len(obs_dict),), dtype=np.float64
40+
low=-1e16, high=1e16, shape=(len(obs_dict),), dtype=np.float32
4141
)
4242
elif isinstance(obs_dict, (float, int)):
43-
return spaces.Box(low=-1e16, high=1e16, shape=(1,), dtype=np.float64)
43+
return spaces.Box(low=-1e16, high=1e16, shape=(1,), dtype=np.float32)
4444
elif isinstance(obs_dict, np.ndarray):
45-
return spaces.Box(low=-1e16, high=1e16, shape=obs_dict.shape, dtype=np.float64)
45+
return spaces.Box(low=-1e16, high=1e16, shape=obs_dict.shape, dtype=np.float32)
4646
else:
4747
raise TypeError(f"Cannot convert {obs_dict} to gym space.")
4848

4949

5050
class ObservationBuilder:
51-
5251
def __init__(self, satellite: "Satellite", obs_type: type = np.ndarray) -> None:
5352
"""Satellite subclass for composing observations.
5453
@@ -312,7 +311,6 @@ def _r_LB_H(sat, opp):
312311

313312

314313
class OpportunityProperties(Observation):
315-
316314
_fn_map = {
317315
"priority": lambda sat, opp: opp["object"].priority,
318316
"r_LP_P": lambda sat, opp: opp["r_LP_P"],

tests/integration/test_int_gym_env.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def test_action_space(self):
3535
assert self.env.action_space == spaces.Discrete(1)
3636

3737
def test_observation_space(self):
38-
assert self.env.observation_space == spaces.Box(-1e16, 1e16, (1,))
38+
assert self.env.observation_space == spaces.Box(
39+
-1e16, 1e16, (1,), dtype=np.float32
40+
)
3941

4042
def test_step(self):
4143
observation, reward, terminated, truncated, info = self.env.step(0)
@@ -124,7 +126,10 @@ def test_action_space(self):
124126

125127
def test_observation_space(self):
126128
assert self.env.observation_space == spaces.Tuple(
127-
(spaces.Box(-1e16, 1e16, (1,)), spaces.Box(-1e16, 1e16, (1,)))
129+
(
130+
spaces.Box(-1e16, 1e16, (1,), dtype=np.float32),
131+
spaces.Box(-1e16, 1e16, (1,), dtype=np.float32),
132+
)
128133
)
129134

130135
def test_step(self):

tests/unittest/obs/test_observations.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -69,23 +69,23 @@ def test_obs_cache(self):
6969
[
7070
(
7171
np.array([1]),
72-
spaces.Box(low=-1e16, high=1e16, shape=(1,), dtype=np.float64),
72+
spaces.Box(low=-1e16, high=1e16, shape=(1,), dtype=np.float32),
7373
),
7474
(
7575
np.array([1, 2]),
76-
spaces.Box(low=-1e16, high=1e16, shape=(2,), dtype=np.float64),
76+
spaces.Box(low=-1e16, high=1e16, shape=(2,), dtype=np.float32),
7777
),
7878
(
7979
{"a": 1, "b": {"c": 1}},
8080
spaces.Dict(
8181
{
8282
"a": spaces.Box(
83-
low=-1e16, high=1e16, shape=(1,), dtype=np.float64
83+
low=-1e16, high=1e16, shape=(1,), dtype=np.float32
8484
),
8585
"b": spaces.Dict(
8686
{
8787
"c": spaces.Box(
88-
low=-1e16, high=1e16, shape=(1,), dtype=np.float64
88+
low=-1e16, high=1e16, shape=(1,), dtype=np.float32
8989
)
9090
}
9191
),

tests/unittest/test_gym_env.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from unittest.mock import MagicMock, patch
22

3+
import numpy as np
34
import pytest
45
from gymnasium import spaces
56

@@ -132,7 +133,9 @@ def test_get_obs_retasking_only(self):
132133
satellites=[
133134
MagicMock(
134135
get_obs=MagicMock(return_value=[i + 1]),
135-
observation_space=spaces.Box(-1e9, 1e9, shape=(1,)),
136+
observation_space=spaces.Box(
137+
-1e9, 1e9, shape=(1,), dtype=np.float32
138+
),
136139
requires_retasking=(i == 1),
137140
)
138141
for i in range(3)

0 commit comments

Comments
 (0)