|
32 | 32 |
|
33 | 33 |
|
34 | 34 | class GeneralSatelliteTasking(Env, Generic[SatObs, SatAct]): |
35 | | - |
36 | 35 | def __init__( |
37 | 36 | self, |
38 | 37 | satellites: Union[Satellite, list[Satellite]], |
@@ -314,7 +313,6 @@ def _get_obs(self) -> MultiSatObs: |
314 | 313 | tuple: Joint observation |
315 | 314 | """ |
316 | 315 | if self.generate_obs_retasking_only: |
317 | | - |
318 | 316 | return tuple( |
319 | 317 | ( |
320 | 318 | satellite.get_obs() |
@@ -698,9 +696,14 @@ def step( |
698 | 696 | terminated = self._get_terminated() |
699 | 697 | truncated = self._get_truncated() |
700 | 698 | info = self._get_info() |
701 | | - logger.info(f"Step reward: {reward}") |
702 | | - logger.info(f"Episode terminated: {terminated}") |
703 | | - logger.info(f"Episode truncated: {truncated}") |
| 699 | + nonzero_reward = {k: v for k, v in reward.items() if v != 0} |
| 700 | + logger.info(f"Step reward: {nonzero_reward}") |
| 701 | + if any(terminated.values()): |
| 702 | + terminated_true = [k for k, v in terminated.items() if v] |
| 703 | + logger.info(f"Episode terminated: {terminated_true}") |
| 704 | + if any(truncated.values()): |
| 705 | + truncated_true = [k for k, v in truncated.items() if v] |
| 706 | + logger.info(f"Episode truncated: {truncated_true}") |
704 | 707 | logger.debug(f"Step info: {info}") |
705 | 708 | logger.debug(f"Step observation: {observation}") |
706 | 709 | return observation, reward, terminated, truncated, info |
|
0 commit comments