Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use __getstate__ and __setstate__ for propagator serialization #139

Merged
merged 3 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ authors = [
]
description = "Core libraries for the ADAM platform"
readme = "README.md"
requires-python = ">=3.10,<3.13"
requires-python = ">=3.11,<3.13"
classifiers = [
"Operating System :: OS Independent",
"Development Status :: 4 - Beta",
Expand Down
7 changes: 7 additions & 0 deletions src/adam_core/dynamics/tests/test_impacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@


class MockImpactPropagator(Propagator, ImpactMixin):
def __getstate__(self):
state = self.__dict__.copy()
return state

def __setstate__(self, state):
self.__dict__.update(state)

def _propagate_orbits(self, orbits: Orbits, times: Timestamp) -> Orbits:
return orbits

Expand Down
66 changes: 50 additions & 16 deletions src/adam_core/propagator/propagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,25 +69,21 @@ def propagation_worker_ray(
idx: npt.NDArray[np.int64],
orbits: OrbitType,
times: OrbitType,
propagator: Type["Propagator"],
**kwargs,
propagator: "Propagator",
) -> OrbitType:
prop = propagator(**kwargs)
orbits_chunk = orbits.take(idx)
propagated = prop._propagate_orbits(orbits_chunk, times)
propagated = propagator._propagate_orbits(orbits_chunk, times)
return propagated

@ray.remote
def ephemeris_worker_ray(
idx: npt.NDArray[np.int64],
orbits: OrbitType,
observers: ObserverType,
propagator: Type["Propagator"],
**kwargs,
propagator: "Propagator",
) -> EphemerisType:
prop = propagator(**kwargs)
orbits_chunk = orbits.take(idx)
ephemeris = prop._generate_ephemeris(orbits_chunk, observers)
ephemeris = propagator._generate_ephemeris(orbits_chunk, observers)
return ephemeris


Expand Down Expand Up @@ -369,8 +365,7 @@ def generate_ephemeris(
idx_chunk,
orbits_ref,
observers_ref,
self.__class__,
**self.__dict__,
self,
)
)

Expand All @@ -393,8 +388,7 @@ def generate_ephemeris(
variant_chunk_idx,
variants_ref,
observers_ref,
self.__class__,
**self.__dict__,
self,
)
)

Expand Down Expand Up @@ -466,6 +460,48 @@ def _propagate_orbits(self, orbits: OrbitType, times: TimestampType) -> OrbitTyp
"""
pass

def __getstate__(self):
"""
Get the state of the propagator.

Subclasses need to define what is picklable for multiprocessing.

e.g.

def __getstate__(self):
state = self.__dict__.copy()
state.pop("_stateful_attribute_that_is_not_pickleable")
return state
"""
raise NotImplementedError(
"Propagator must implement __getstate__ for multiprocessing serialization.\n"
"Example implementation: \n"
"def __getstate__(self):\n"
" state = self.__dict__.copy()\n"
" state.pop('_stateful_attribute_that_is_not_pickleable')\n"
" return state"
)

def __setstate__(self, state):
"""
Set the state of the propagator.

Subclasses need to define what is unpicklable for multiprocessing.

e.g.

def __setstate__(self, state):
self.__dict__.update(state)
self._stateful_attribute_that_is_not_pickleable = None
"""
raise NotImplementedError(
"Propagator must implement __setstate__ for multiprocessing serialization.\n"
"Example implementation: \n"
"def __setstate__(self, state):\n"
" self.__dict__.update(state)\n"
" self._stateful_attribute_that_is_not_pickleable = None"
)

def propagate_orbits(
self,
orbits: Union[OrbitType, ObjectRef],
Expand Down Expand Up @@ -551,8 +587,7 @@ def propagate_orbits(
idx_chunk,
orbits_ref,
times_ref,
self.__class__,
**self.__dict__,
self,
)
)

Expand All @@ -574,8 +609,7 @@ def propagate_orbits(
variant_chunk_idx,
variants_ref,
times_ref,
self.__class__,
**self.__dict__,
self,
)
)

Expand Down
8 changes: 8 additions & 0 deletions src/adam_core/propagator/tests/test_propagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@


class MockPropagator(Propagator, EphemerisMixin):

def __getstate__(self):
state = self.__dict__.copy()
return state

def __setstate__(self, state):
self.__dict__.update(state)

# MockPropagator propagates orbits by just setting the time of the orbits.
def _propagate_orbits(self, orbits: Orbits, times: Timestamp) -> Orbits:
all_times = []
Expand Down
Loading