Skip to content

Commit 0612aba

Browse files
committed
Issue 199: One way communication and broadcast action [WIP]
1 parent 859db94 commit 0612aba

File tree

4 files changed

+69
-13
lines changed

4 files changed

+69
-13
lines changed

docs/source/release_notes.rst

+3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ Development Version
2626
* Improve performance of :class:`~bsk_rl.obs.Eclipse` observations by about 95%.
2727
* Logs a warning if the initial battery charge or buffer level is incompatible with its capacity.
2828
* Optimize communication when all satellites are communicating with each other.
29+
* Allows communication to be one-way. Adds a :class:`~bsk_rl.act.Broadcast` action
30+
that can be used with :class:`~bsk_rl.comm.BroadcastCommunication` to only communicate
31+
data when the action has been called.
2932

3033

3134

src/bsk_rl/act/discrete_actions.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818

1919
class DiscreteActionBuilder(ActionBuilder):
20-
2120
def __init__(self, satellite: "Satellite") -> None:
2221
"""Processes actions for a discrete action space.
2322
@@ -302,5 +301,34 @@ def set_action_override(
302301
return self.image(action, prev_action_key)
303302

304303

304+
class Broadcast(DiscreteAction):
305+
def __init__(
306+
self,
307+
name: str = "action_broadcast",
308+
duration: Optional[float] = None,
309+
):
310+
"""Action to broadcast data to all satellites in the simulation.
311+
312+
Args:
313+
name: Action name.
314+
"""
315+
super().__init__(name=name, n_actions=1)
316+
if duration is None:
317+
duration = 1e9
318+
self.duration = duration
319+
320+
def reset_post_sim_init(self) -> None:
321+
"""Log previous action key."""
322+
super().reset_post_sim_init()
323+
self.broadcast_pending = False
324+
325+
def set_action(self, prev_action_key=None) -> str:
326+
self.broadcast_pending = True
327+
self.satellite.update_timed_terminal_event(
328+
self.simulator.sim_time + self.duration, info=f"for broadcast"
329+
)
330+
return self.name
331+
332+
305333
__doc_title__ = "Discrete Backend"
306334
__all__ = ["DiscreteActionBuilder"]

src/bsk_rl/comm/communication.py

+36-8
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
from abc import ABC, abstractmethod
55
from copy import copy
6-
from itertools import combinations
6+
from itertools import permutations
77
from typing import TYPE_CHECKING
88

99
import numpy as np
@@ -47,7 +47,7 @@ def link_satellites(self, satellites: list["Satellite"]) -> None:
4747

4848
@abstractmethod # pragma: no cover
4949
def communication_pairs(self) -> list[tuple["Satellite", "Satellite"]]:
50-
"""List pairs of satellite that should share data.
50+
"""List pairs of satellite that should share data, in the form (sender, recipient).
5151
5252
To define a new communication type, this method must be implemented.
5353
"""
@@ -68,12 +68,13 @@ def communicate(self) -> None:
6868
f"Communicating data between {len(communication_pairs)} pairs of satellites"
6969
)
7070

71-
if len(communication_pairs) == comb(len(self.satellites), 2):
71+
if len(communication_pairs) == 2 * comb(len(self.satellites), 2):
7272
self._communicate_all()
7373
else:
74-
for sat_1, sat_2 in communication_pairs:
75-
sat_1.data_store.stage_communicated_data(sat_2.data_store.data)
76-
sat_2.data_store.stage_communicated_data(sat_1.data_store.data)
74+
for sat_sender, sat_receiver in communication_pairs:
75+
sat_receiver.data_store.stage_communicated_data(
76+
sat_sender.data_store.data
77+
)
7778
for satellite in self.satellites:
7879
satellite.data_store.update_with_communicated_data()
7980

@@ -116,7 +117,7 @@ def __init__(self, *args, **kwargs) -> None:
116117

117118
def communication_pairs(self) -> list[tuple["Satellite", "Satellite"]]:
118119
"""Return all possible communication pairs."""
119-
return list(combinations(self.satellites, 2))
120+
return list(permutations(self.satellites, 2))
120121

121122

122123
class LOSCommunication(CommunicationMethod):
@@ -177,6 +178,7 @@ def communication_pairs(self) -> list[tuple["Satellite", "Satellite"]]:
177178
for sat_2, logger in logs.items():
178179
if any(logger.hasAccess):
179180
pairs.append((sat_1, sat_2))
181+
pairs.append((sat_2, sat_1))
180182
return pairs
181183

182184
def communicate(self) -> None:
@@ -208,7 +210,7 @@ def communication_pairs(self) -> list[tuple["Satellite", "Satellite"]]:
208210
pairs = []
209211
n_components, labels = connected_components(graph, directed=False)
210212
for comp in range(n_components):
211-
for i_sat_1, i_sat_2 in combinations(np.where(labels == comp)[0], 2):
213+
for i_sat_1, i_sat_2 in permutations(np.where(labels == comp)[0], 2):
212214
pairs.append((self.satellites[i_sat_1], self.satellites[i_sat_2]))
213215
return pairs
214216

@@ -222,4 +224,30 @@ class LOSMultiCommunication(MultiDegreeCommunication, LOSCommunication):
222224
pass
223225

224226

227+
class BroadcastCommunication(CommunicationMethod):
228+
def communication_pairs(self) -> list[tuple["Satellite", "Satellite"]]:
229+
"""Return pairs of satellites that used broadcast action."""
230+
broadcasters = []
231+
for satellite in self.satellites:
232+
if any(
233+
hasattr(act, "broadcast_pending") and act.broadcast_pending
234+
for act in satellite.action_builder.action_spec
235+
):
236+
broadcasters.append(satellite)
237+
238+
pairs = super().communication_pairs()
239+
pairs = [
240+
(sat_sender, sat_receiver)
241+
for sat_sender, sat_receiver in pairs
242+
if sat_sender in broadcasters
243+
]
244+
245+
for satellite in broadcasters:
246+
for act in satellite.action_builder.action_spec:
247+
if hasattr(act, "broadcast_pending"):
248+
act.broadcast_pending = False
249+
250+
return pairs
251+
252+
225253
__all__ = []

tests/unittest/comm/test_communication.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,9 @@ def test_communicate(self):
2121
comms.last_communication_time = 0.0
2222
comms.link_satellites(mock_sats)
2323
comms.communication_pairs = MagicMock(
24-
return_value=[(mock_sats[1], mock_sats[0])]
24+
return_value=[(mock_sats[0], mock_sats[1])]
2525
)
2626
comms.communicate()
27-
mock_sats[0].data_store.stage_communicated_data.assert_called_once_with(
28-
mock_sats[1].data_store.data
29-
)
3027
mock_sats[1].data_store.stage_communicated_data.assert_called_once_with(
3128
mock_sats[0].data_store.data
3229
)

0 commit comments

Comments
 (0)