Skip to content

Commit

Permalink
Use __getstate__ and __setstate__ for propagator serialization (#139)
Browse files Browse the repository at this point in the history
* Use __getstate__ and __setstate__ for propagator serialization
* Give helpful error messages
  • Loading branch information
akoumjian authored Jan 24, 2025
1 parent ee57577 commit 36d9308
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 17 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ authors = [
]
description = "Core libraries for the ADAM platform"
readme = "README.md"
requires-python = ">=3.10,<3.13"
requires-python = ">=3.11,<3.13"
classifiers = [
"Operating System :: OS Independent",
"Development Status :: 4 - Beta",
Expand Down
7 changes: 7 additions & 0 deletions src/adam_core/dynamics/tests/test_impacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@


class MockImpactPropagator(Propagator, ImpactMixin):
def __getstate__(self):
state = self.__dict__.copy()
return state

def __setstate__(self, state):
self.__dict__.update(state)

def _propagate_orbits(self, orbits: Orbits, times: Timestamp) -> Orbits:
return orbits

Expand Down
66 changes: 50 additions & 16 deletions src/adam_core/propagator/propagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,25 +69,21 @@ def propagation_worker_ray(
idx: npt.NDArray[np.int64],
orbits: OrbitType,
times: OrbitType,
propagator: Type["Propagator"],
**kwargs,
propagator: "Propagator",
) -> OrbitType:
prop = propagator(**kwargs)
orbits_chunk = orbits.take(idx)
propagated = prop._propagate_orbits(orbits_chunk, times)
propagated = propagator._propagate_orbits(orbits_chunk, times)
return propagated

@ray.remote
def ephemeris_worker_ray(
idx: npt.NDArray[np.int64],
orbits: OrbitType,
observers: ObserverType,
propagator: Type["Propagator"],
**kwargs,
propagator: "Propagator",
) -> EphemerisType:
prop = propagator(**kwargs)
orbits_chunk = orbits.take(idx)
ephemeris = prop._generate_ephemeris(orbits_chunk, observers)
ephemeris = propagator._generate_ephemeris(orbits_chunk, observers)
return ephemeris


Expand Down Expand Up @@ -369,8 +365,7 @@ def generate_ephemeris(
idx_chunk,
orbits_ref,
observers_ref,
self.__class__,
**self.__dict__,
self,
)
)

Expand All @@ -393,8 +388,7 @@ def generate_ephemeris(
variant_chunk_idx,
variants_ref,
observers_ref,
self.__class__,
**self.__dict__,
self,
)
)

Expand Down Expand Up @@ -466,6 +460,48 @@ def _propagate_orbits(self, orbits: OrbitType, times: TimestampType) -> OrbitTyp
"""
pass

def __getstate__(self):
"""
Get the state of the propagator.
Subclasses need to define what is picklable for multiprocessing.
e.g.
def __getstate__(self):
state = self.__dict__.copy()
state.pop("_stateful_attribute_that_is_not_pickleable")
return state
"""
raise NotImplementedError(
"Propagator must implement __getstate__ for multiprocessing serialization.\n"
"Example implementation: \n"
"def __getstate__(self):\n"
" state = self.__dict__.copy()\n"
" state.pop('_stateful_attribute_that_is_not_pickleable')\n"
" return state"
)

def __setstate__(self, state):
"""
Set the state of the propagator.
Subclasses need to define what is unpicklable for multiprocessing.
e.g.
def __setstate__(self, state):
self.__dict__.update(state)
self._stateful_attribute_that_is_not_pickleable = None
"""
raise NotImplementedError(
"Propagator must implement __setstate__ for multiprocessing serialization.\n"
"Example implementation: \n"
"def __setstate__(self, state):\n"
" self.__dict__.update(state)\n"
" self._stateful_attribute_that_is_not_pickleable = None"
)

def propagate_orbits(
self,
orbits: Union[OrbitType, ObjectRef],
Expand Down Expand Up @@ -551,8 +587,7 @@ def propagate_orbits(
idx_chunk,
orbits_ref,
times_ref,
self.__class__,
**self.__dict__,
self,
)
)

Expand All @@ -574,8 +609,7 @@ def propagate_orbits(
variant_chunk_idx,
variants_ref,
times_ref,
self.__class__,
**self.__dict__,
self,
)
)

Expand Down
8 changes: 8 additions & 0 deletions src/adam_core/propagator/tests/test_propagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@


class MockPropagator(Propagator, EphemerisMixin):

def __getstate__(self):
state = self.__dict__.copy()
return state

def __setstate__(self, state):
self.__dict__.update(state)

# MockPropagator propagates orbits by just setting the time of the orbits.
def _propagate_orbits(self, orbits: Orbits, times: Timestamp) -> Orbits:
all_times = []
Expand Down

0 comments on commit 36d9308

Please sign in to comment.