Skip to content

Commit 1d96a40

Browse files
committed
[#66] Rewarder composition
1 parent 922d3b8 commit 1d96a40

File tree

8 files changed

+481
-6
lines changed

8 files changed

+481
-6
lines changed

docs/source/release_notes.rst

+2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ Development - |version|
2828
* Optimize communication when all satellites are communicating with each other.
2929
* Enable Vizard visualization of the environment by setting the ``vizard_dir`` and ``vizard_settings``
3030
options in the environment.
31+
* Allow for the specification of multiple rewarders in the environment.
32+
3133

3234

3335
Version 1.0.1

src/bsk_rl/data/__init__.py

+13
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,19 @@
7474
...
7575
)
7676
77+
Multiple reward systems can be added to the environment by instead passing an iterable of
78+
reward systems to the ``data`` field of the environment constructor:
79+
80+
.. code-block:: python
81+
82+
env = ConstellationTasking(
83+
...,
84+
data=(ScanningTimeReward(), SomeOtherReward()),
85+
...
86+
)
87+
88+
On the backend, this creates a :class:`~bsk_rl.data.composition.ComposedDataStore` that
89+
handles the combination of multiple reward systems.
7790
"""
7891

7992
from bsk_rl.data.base import GlobalReward

src/bsk_rl/data/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def reward(self, new_data_dict: dict[str, Data]) -> dict[str, float]:
193193
self.data += new_data
194194

195195
nonzero_reward = {k: v for k, v in reward.items() if v != 0}
196-
logger.info(f"Data reward: {nonzero_reward}")
196+
logger.info(f"Total reward: {nonzero_reward}")
197197
return reward
198198

199199

src/bsk_rl/data/composition.py

+230
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
"""Data composition classes."""
2+
3+
import logging
4+
from typing import TYPE_CHECKING, Optional
5+
6+
from bsk_rl.data.base import Data, DataStore, GlobalReward
7+
from bsk_rl.sats import Satellite
8+
from bsk_rl.scene.scenario import Scenario
9+
10+
if TYPE_CHECKING:
11+
from bsk_rl.sats import Satellite
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
class ComposedData(Data):
17+
"""Data for composed data types."""
18+
19+
def __init__(self, *data: Data) -> None:
20+
"""Data for composed data types.
21+
22+
Args:
23+
data: Data types to compose.
24+
"""
25+
self.data = data
26+
27+
def __add__(self, other: "ComposedData") -> "ComposedData":
28+
"""Combine two units of composed data.
29+
30+
Args:
31+
other: Another unit of composed data to combine with this one.
32+
33+
Returns:
34+
Combined unit of composed data.
35+
"""
36+
if len(self.data) == 0 and len(other.data) == 0:
37+
data = []
38+
elif len(self.data) == 0:
39+
data = [type(d)() + d for d in other.data]
40+
elif len(other.data) == 0:
41+
data = [d + type(d)() for d in self.data]
42+
elif len(self.data) == len(other.data):
43+
data = [d1 + d2 for d1, d2 in zip(self.data, other.data)]
44+
else:
45+
raise ValueError(
46+
"ComposedData units must have the same number of data types."
47+
)
48+
return ComposedData(*data)
49+
50+
def __getattr__(self, name: str):
51+
"""Search for an attribute in the datas."""
52+
for data in self.data:
53+
if hasattr(data, name):
54+
return getattr(data, name)
55+
raise AttributeError(f"No Data in ComposedData has attribute '{name}'")
56+
57+
58+
class ComposedDataStore(DataStore):
59+
data_type = ComposedData
60+
61+
def pass_data(self) -> None:
62+
"""Pass data to the sub-datastores.
63+
64+
:meta private:
65+
"""
66+
for ds, data in zip(self.datastores, self.data.data):
67+
ds.data = data
68+
69+
def __init__(
70+
self,
71+
satellite: "Satellite",
72+
*datastore_types: type[DataStore],
73+
initial_data: Optional[ComposedData] = None,
74+
):
75+
"""DataStore for composed data types.
76+
77+
Args:
78+
satellite: Satellite which data is being stored for.
79+
datastore_types: DataStore types to compose.
80+
initial_data: Initial data to start the store with. Usually comes from
81+
:class:`~bsk_rl.data.GlobalReward.initial_data`.
82+
"""
83+
self.data: ComposedData
84+
super().__init__(satellite, initial_data)
85+
self.datastores = tuple([ds(satellite) for ds in datastore_types])
86+
self.pass_data()
87+
88+
def __getattr__(self, name: str):
89+
"""Search for an attribute in the datastores."""
90+
for datastore in self.datastores:
91+
if hasattr(datastore, name):
92+
return getattr(datastore, name)
93+
raise AttributeError(
94+
f"No DataStore in ComposedDataStore has attribute '{name}'"
95+
)
96+
97+
def get_log_state(self) -> list:
98+
"""Pull information used in determining current data contribution."""
99+
log_states = [ds.get_log_state() for ds in self.datastores]
100+
return log_states
101+
102+
def compare_log_states(self, prev_state: list, new_state: list) -> Data:
103+
"""Generate a unit of composed data based on previous step and current step logs."""
104+
data = [
105+
ds.compare_log_states(prev, new)
106+
for ds, prev, new in zip(self.datastores, prev_state, new_state)
107+
]
108+
return ComposedData(*data)
109+
110+
def update_from_logs(self) -> Data:
111+
"""Update the data store based on collected information."""
112+
new_data = super().update_from_logs()
113+
self.pass_data()
114+
return new_data
115+
116+
def update_with_communicated_data(self) -> None:
117+
"""Update the data store based on collected information from other satellites."""
118+
super().update_with_communicated_data()
119+
self.pass_data()
120+
121+
122+
class ComposedReward(GlobalReward):
123+
datastore_type = ComposedDataStore
124+
125+
def pass_data(self) -> Data:
126+
"""Pass data to the sub-rewarders.
127+
128+
:meta private:
129+
"""
130+
for rewarder, data in zip(self.rewarders, self.data.data):
131+
rewarder.data = data
132+
133+
def __init__(self, *rewarders: GlobalReward) -> None:
134+
"""Rewarder for composed data types.
135+
136+
This type can be automatically constructed by passing a tuple of rewarders to
137+
the environment constructor's `reward` argument.
138+
139+
Args:
140+
rewarders: Global rewarders to compose.
141+
"""
142+
super().__init__()
143+
self.rewarders = rewarders
144+
145+
def __getattr__(self, name: str):
146+
"""Search for an attribute in the rewarders."""
147+
for rewarder in self.rewarders:
148+
if hasattr(rewarder, name):
149+
return getattr(rewarder, name)
150+
raise AttributeError(
151+
f"No GlobalReward in ComposedReward has attribute '{name}'"
152+
)
153+
154+
def reset_pre_sim_init(self) -> None:
155+
"""Handle resetting for all rewarders."""
156+
super().reset_pre_sim_init()
157+
for rewarder in self.rewarders:
158+
rewarder.reset_pre_sim_init()
159+
160+
def reset_post_sim_init(self) -> None:
161+
"""Handle resetting for all rewarders."""
162+
super().reset_post_sim_init()
163+
for rewarder in self.rewarders:
164+
rewarder.reset_post_sim_init()
165+
166+
def reset_overwrite_previous(self) -> None:
167+
"""Handle resetting for all rewarders."""
168+
super().reset_overwrite_previous()
169+
for rewarder in self.rewarders:
170+
rewarder.reset_overwrite_previous()
171+
172+
def link_scenario(self, scenario: Scenario) -> None:
173+
"""Link the rewarder to the scenario."""
174+
super().link_scenario(scenario)
175+
for rewarder in self.rewarders:
176+
rewarder.link_scenario(scenario)
177+
178+
def initial_data(self, satellite: Satellite) -> ComposedData:
179+
"""Furnsish the datastore with :class:`ComposedData`."""
180+
return ComposedData(
181+
*[rewarder.initial_data(satellite) for rewarder in self.rewarders]
182+
)
183+
184+
def create_data_store(self, satellite: Satellite) -> None:
185+
"""Create a :class:`CompositeDataStore` for a satellite."""
186+
# TODO support passing kwargs
187+
satellite.data_store = ComposedDataStore(
188+
satellite,
189+
*[r.datastore_type for r in self.rewarders],
190+
initial_data=self.initial_data(satellite),
191+
)
192+
self.cum_reward[satellite.name] = 0.0
193+
for rewarder in self.rewarders:
194+
rewarder.cum_reward[satellite.name] = 0.0
195+
196+
def calculate_reward(
197+
self, new_data_dict: dict[str, ComposedData]
198+
) -> dict[str, float]:
199+
"""Calculate reward for each data type and combine them."""
200+
data_len = len(list(new_data_dict.values())[0].data)
201+
202+
for data in new_data_dict.values():
203+
assert len(data.data) == data_len
204+
205+
reward = {}
206+
if data_len != 0:
207+
for i, rewarder in enumerate(self.rewarders):
208+
reward_i = rewarder.calculate_reward(
209+
{sat_id: data.data[i] for sat_id, data in new_data_dict.items()}
210+
)
211+
212+
# Logging
213+
nonzero_reward = {k: v for k, v in reward_i.items() if v != 0}
214+
if len(nonzero_reward) > 0:
215+
logger.info(f"{type(rewarder).__name__} reward: {nonzero_reward}")
216+
217+
for sat_id, sat_reward in reward_i.items():
218+
reward[sat_id] = reward.get(sat_id, 0.0) + sat_reward
219+
rewarder.cum_reward[sat_id] += sat_reward
220+
return reward
221+
222+
def reward(self, new_data_dict: dict[str, ComposedData]) -> dict[str, float]:
223+
"""Return combined reward calculation and update data."""
224+
reward = super().reward(new_data_dict)
225+
self.pass_data()
226+
return reward
227+
228+
229+
__doc_title__ = "Data Composition"
230+
__all__ = ["ComposedReward", "ComposedDataStore", "ComposedData"]

src/bsk_rl/gym.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from bsk_rl.comm import CommunicationMethod, NoCommunication
1515
from bsk_rl.data import GlobalReward, NoReward
16+
from bsk_rl.data.composition import ComposedReward
1617
from bsk_rl.sats import Satellite
1718
from bsk_rl.scene import Scenario
1819
from bsk_rl.sim import Simulator
@@ -36,7 +37,7 @@ def __init__(
3637
self,
3738
satellites: Union[Satellite, list[Satellite]],
3839
scenario: Optional[Scenario] = None,
39-
rewarder: Optional[GlobalReward] = None,
40+
rewarder: Optional[Union[GlobalReward, list[GlobalReward]]] = None,
4041
world_type: Optional[type[WorldModel]] = None,
4142
world_args: Optional[dict[str, Any]] = None,
4243
communicator: Optional[CommunicationMethod] = None,
@@ -70,7 +71,8 @@ def __init__(
7071
scenario: Environment the satellite is acting in; contains information
7172
about targets, etc. See :ref:`bsk_rl.scene`.
7273
rewarder: Handles recording and rewarding for data collection towards
73-
objectives. See :ref:`bsk_rl.data`.
74+
objectives. Can be a single rewarder or a tuple of multiple rewarders.
75+
See :ref:`bsk_rl.data`.
7476
communicator: Manages communication between satellites. See :ref:`bsk_rl.comm`.
7577
sat_arg_randomizer: For correlated randomization of satellites arguments. Should
7678
be a function that takes a list of satellites and returns a dictionary that
@@ -139,8 +141,6 @@ def __init__(
139141

140142
if scenario is None:
141143
scenario = Scenario()
142-
if rewarder is None:
143-
rewarder = NoReward()
144144

145145
if world_type is None:
146146
world_type = self._minimum_world_model()
@@ -151,7 +151,16 @@ def __init__(
151151

152152
self.scenario = deepcopy(scenario)
153153
self.scenario.link_satellites(self.satellites)
154-
self.rewarder = deepcopy(rewarder)
154+
155+
rewarder = deepcopy(rewarder)
156+
if rewarder is None:
157+
rewarder = NoReward()
158+
if (
159+
isinstance(rewarder, Iterable)
160+
and not type(rewarder).__name__ == "MagicMock"
161+
):
162+
rewarder = ComposedReward(*rewarder)
163+
self.rewarder = rewarder
155164
self.rewarder.link_scenario(self.scenario)
156165

157166
if communicator is None:
+52
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,57 @@
1+
import gymnasium as gym
2+
3+
from bsk_rl import act, data, obs, sats, scene
4+
from bsk_rl.data.composition import ComposedReward
5+
from bsk_rl.utils.orbital import random_orbit
6+
17
# For data models not tested in other tests
28

39
# NoData sufficiently checked in many cases
410

511
# UniqueImageData sufficiently checked in test_int_communication
12+
13+
# from ..test_int_full_environments
14+
15+
16+
class FullFeaturedSatellite(sats.ImagingSatellite):
17+
observation_spec = [
18+
obs.SatProperties(dict(prop="r_BN_P", module="dynamics", norm=6e6)),
19+
obs.Time(),
20+
]
21+
action_spec = [act.Image(n_ahead_image=10)]
22+
23+
24+
def test_multi_rewarder():
25+
env = gym.make(
26+
"GeneralSatelliteTasking-v1",
27+
satellites=[
28+
FullFeaturedSatellite(
29+
"Sentinel-2A",
30+
sat_args=FullFeaturedSatellite.default_sat_args(
31+
oe=random_orbit,
32+
imageAttErrorRequirement=0.01,
33+
imageRateErrorRequirement=0.01,
34+
),
35+
),
36+
FullFeaturedSatellite(
37+
"Sentinel-2B",
38+
sat_args=FullFeaturedSatellite.default_sat_args(
39+
oe=random_orbit,
40+
imageAttErrorRequirement=0.01,
41+
imageRateErrorRequirement=0.01,
42+
),
43+
),
44+
],
45+
scenario=scene.UniformTargets(n_targets=1000),
46+
rewarder=(data.UniqueImageReward(), data.UniqueImageReward()),
47+
sim_rate=0.5,
48+
max_step_duration=1e9,
49+
time_limit=5700.0,
50+
disable_env_checker=True,
51+
)
52+
53+
assert isinstance(env.unwrapped.rewarder, ComposedReward)
54+
55+
env.reset()
56+
for _ in range(10):
57+
env.step(env.action_space.sample())

0 commit comments

Comments
 (0)