Skip to content

Commit d1e0b91

Browse files
committed
Issue #58: Expose state vector element names
1 parent fa6e6f9 commit d1e0b91

File tree

3 files changed

+34
-10
lines changed

3 files changed

+34
-10
lines changed

src/bsk_rl/envs/general_satellite_tasking/scenario/sat_observations.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,14 @@ def obs_dict(self):
5353
@property
5454
def obs_ndarray(self):
5555
"""Numpy vector observation format."""
56-
return vectorize_nested_dict(self.obs_dict)
56+
_, obs = vectorize_nested_dict(self.obs_dict)
57+
return obs
58+
59+
@property
60+
def obs_array_keys(self):
61+
"""Utility to get the keys of the obs_ndarray."""
62+
keys, _ = vectorize_nested_dict(self.obs_dict)
63+
return keys
5764

5865
@property
5966
def obs_list(self):

src/bsk_rl/envs/general_satellite_tasking/utils/functional.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -77,18 +77,25 @@ def collect_default_args(object: object) -> dict[str, Any]:
7777
return defaults
7878

7979

80-
def vectorize_nested_dict(dictionary: dict) -> np.ndarray:
81-
"""Flattens a dictionary of dicts, arrays, and scalars into a single vector."""
80+
def vectorize_nested_dict(dictionary: dict) -> tuple[list[str], np.ndarray]:
81+
"""Flattens a dictionary of dictionaries, arrays, and scalars into a vector."""
82+
keys = list(dictionary.keys())
8283
values = list(dictionary.values())
8384
for i, value in enumerate(values):
8485
if isinstance(value, np.ndarray):
8586
values[i] = value.flatten()
87+
keys[i] = [keys[i] + f"[{j}]" for j in range(len(value.flatten()))]
88+
elif isinstance(value, list):
89+
keys[i] = [keys[i] + f"[{j}]" for j in range(len(value))]
8690
elif isinstance(value, (float, int)):
8791
values[i] = [value]
92+
keys[i] = [keys[i]]
8893
elif isinstance(value, dict):
89-
values[i] = vectorize_nested_dict(value)
94+
prepend = keys[i]
95+
keys[i], values[i] = vectorize_nested_dict(value)
96+
keys[i] = [prepend + "." + key for key in keys[i]]
9097

91-
return np.concatenate(values)
98+
return list(np.concatenate(keys)), np.concatenate(values)
9299

93100

94101
def aliveness_checker(func: Callable[..., bool]) -> Callable[..., bool]:

tests/unittest/envs/general_satellite_tasking/utils/test_functional.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,24 @@ class C13(self.C1, self.C3):
7676

7777

7878
@pytest.mark.parametrize(
79-
"input,output",
79+
"input,outkeys,outvec",
8080
[
81-
({"a": np.array([1]), "b": 2, "c": [3]}, np.array([1, 2, 3])),
82-
({"a": {"b": 1, "c": 2}, "d": 3}, np.array([1, 2, 3])),
81+
(
82+
{"alpha": np.array([1]), "b": 2, "c": [3]},
83+
["alpha[0]", "b", "c[0]"],
84+
np.array([1, 2, 3]),
85+
),
86+
(
87+
{"a": {"b": 1, "charlie": 2}, "d": 3},
88+
["a.b", "a.charlie", "d"],
89+
np.array([1, 2, 3]),
90+
),
8391
],
8492
)
85-
def test_vectorize_nested_dict(input, output):
86-
assert np.equal(output, functional.vectorize_nested_dict(input)).all()
93+
def test_vectorize_nested_dict(input, outkeys, outvec):
94+
keys, vec = functional.vectorize_nested_dict(input)
95+
assert np.equal(outvec, vec).all()
96+
assert outkeys == keys
8797

8898

8999
class TestAlivenessChecker:

0 commit comments

Comments
 (0)