Skip to content

Commit

Permalink
Return propagated orbits with original origins and frame based on inp…
Browse files Browse the repository at this point in the history
…ut orbit_id
  • Loading branch information
akoumjian committed Aug 12, 2024
1 parent 13165ed commit 4e248a2
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 10 deletions.
7 changes: 0 additions & 7 deletions src/adam_core/dynamics/impacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,6 @@

logger = logging.getLogger(__name__)

# Test to see that at least one impact-enabled propagator is
# installed and if not print a warning
if importlib.util.find_spec("adam_core.propagator.adam_assist") is None:
logger.warning(
"No impact-enabled propagator installed. Impact calculations will not be possible."
)

RAY_INSTALLED = False
try:
import ray
Expand Down
29 changes: 28 additions & 1 deletion src/adam_core/propagator/propagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import numpy.typing as npt
import pyarrow.compute as pc
import quivr as qv

from ..constants import Constants as c
Expand Down Expand Up @@ -479,6 +480,9 @@ def propagate_orbits(
propagated : `~adam_core.orbits.orbits.Orbits`
Propagated orbits.
"""



if max_processes is None or max_processes > 1:
propagated_list: List[Orbits] = []
variants_list: List[VariantOrbits] = []
Expand Down Expand Up @@ -575,6 +579,29 @@ def propagate_orbits(
if propagated_variants is not None:
propagated = propagated_variants.collapse(propagated)

return propagated.sort_by(
# Return the results with the original origin and frame
# Preserve the original output origin for the input orbits
# by orbit id
final_results = None
unique_origins = pc.unique(orbits.coordinates.origin.code)
for origin_code in unique_origins:
origin_orbits = orbits.select("coordinates.origin.code", origin_code)
result_origin_orbits = propagated.where(
pc.field("orbit_id").isin(origin_orbits.orbit_id)
)
partial_results = result_origin_orbits.set_column(
"coordinates",
transform_coordinates(
result_origin_orbits.coordinates,
origin_out=OriginCodes[origin_code.as_py()],
frame_out=orbits.coordinates.frame,
),
)
if final_results is None:
final_results = partial_results
else:
final_results = qv.concatenate([final_results, partial_results])

return final_results.sort_by(
["orbit_id", "coordinates.time.days", "coordinates.time.nanos"]
)
48 changes: 46 additions & 2 deletions src/adam_core/propagator/tests/test_propagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import quivr as qv

from ...coordinates.cartesian import CartesianCoordinates
from ...coordinates.origin import Origin
from ...coordinates.origin import Origin, OriginCodes
from ...coordinates.transform import transform_coordinates
from ...observers.observers import Observers
from ...orbits.ephemeris import Ephemeris
from ...orbits.orbits import Orbits
Expand All @@ -21,8 +22,19 @@ def _propagate_orbits(self, orbits: Orbits, times: Timestamp) -> Orbits:
repeated_time = qv.concatenate([t] * len(orbits))
orbits.coordinates.time = repeated_time
all_times.append(orbits)
all_times = qv.concatenate(all_times)

# Artifically change origin to test that it is preserved in the final output
output = all_times.set_column(
"coordinates",
transform_coordinates(
all_times.coordinates,
origin_out=OriginCodes["SATURN_BARYCENTER"],
frame_out="equatorial",
),
)

return qv.concatenate(all_times)
return output

# MockPropagator generated ephemeris by just subtracting the state from
# the state of the observers
Expand Down Expand Up @@ -105,3 +117,35 @@ def test_propagator_multiple_workers_ray():
have = prop.generate_ephemeris(orbits_ref, observers_ref, max_processes=4)

assert len(have) == len(orbits) * len(times)


def test_propagate_different_origins():
"""
Test that we are returning propagated orbits with their original origins
"""
orbits = Orbits.from_kwargs(
orbit_id=["1", "2"],
object_id=["1", "2"],
coordinates=CartesianCoordinates.from_kwargs(
x=[1, 1],
y=[1, 1],
z=[1, 1],
vx=[1, 1],
vy=[1, 1],
vz=[1, 1],
time=Timestamp.from_mjd([60000, 60000], scale="tdb"),
frame="ecliptic",
origin=Origin.from_kwargs(code=["SOLAR_SYSTEM_BARYCENTER", "EARTH_MOON_BARYCENTER"]),
),
)

prop = MockPropagator()
propagated_orbits = prop.propagate_orbits(orbits, Timestamp.from_mjd([60001, 60002, 60003], scale="tdb"))
orbit_one_results = propagated_orbits.select("orbit_id", "1")
orbit_two_results = propagated_orbits.select("orbit_id", "2")
# Assert that the origin codes for each set of results is unique
# and that it matches the original input
assert len(orbit_one_results.coordinates.origin.code.unique()) == 1
assert orbit_one_results.coordinates.origin.code.unique()[0].as_py() == "SOLAR_SYSTEM_BARYCENTER"
assert len(orbit_two_results.coordinates.origin.code.unique()) == 1
assert orbit_two_results.coordinates.origin.code.unique()[0].as_py() == "EARTH_MOON_BARYCENTER"

0 comments on commit 4e248a2

Please sign in to comment.