Skip to content

Commit

Permalink
Nt/seed propagate (#135)
Browse files Browse the repository at this point in the history
* add a seed for variant orbits in propagate_orbits

* Ensure seed is passed to all VariantOrbits.create and sort varianats before collapse
  • Loading branch information
akoumjian authored Jan 10, 2025
1 parent 7eeb381 commit 9e21f2c
Showing 1 changed file with 32 additions and 10 deletions.
42 changes: 32 additions & 10 deletions src/adam_core/propagator/propagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ def generate_ephemeris(
num_samples: int = 1000,
chunk_size: int = 100,
max_processes: Optional[int] = 1,
seed: Optional[int] = None,
) -> Ephemeris:
"""
Generate ephemerides for each orbit in orbits as observed by each observer
Expand Down Expand Up @@ -376,18 +377,20 @@ def generate_ephemeris(
# Add variants to futures (if we have any)
if covariance is True and not orbits.coordinates.covariance.is_all_nan():
variants = VariantOrbits.create(
orbits, method=covariance_method, num_samples=num_samples
orbits,
method=covariance_method,
num_samples=num_samples,
seed=seed,
)

# Add variants to object store
variants_ref = ray.put(variants)

idx = np.arange(0, len(variants))
for variant_chunk in _iterate_chunks(idx, chunk_size):
idx_chunk = ray.put(variant_chunk)
for variant_chunk_idx in _iterate_chunks(idx, chunk_size):
futures.append(
ephemeris_worker_ray.remote(
idx_chunk,
variant_chunk_idx,
variants_ref,
observers_ref,
self.__class__,
Expand Down Expand Up @@ -420,7 +423,10 @@ def generate_ephemeris(

if covariance is True and not orbits.coordinates.covariance.is_all_nan():
variants = VariantOrbits.create(
orbits, method=covariance_method, num_samples=num_samples
orbits,
method=covariance_method,
num_samples=num_samples,
seed=seed,
)
ephemeris_variants = self._generate_ephemeris(variants, observers)
else:
Expand Down Expand Up @@ -471,6 +477,7 @@ def propagate_orbits(
num_samples: int = 1000,
chunk_size: int = 100,
max_processes: Optional[int] = 1,
seed: Optional[int] = None,
) -> Orbits:
"""
Propagate each orbit in orbits to each time in times.
Expand Down Expand Up @@ -552,16 +559,19 @@ def propagate_orbits(
# Add variants to propagate to futures
if covariance is True and not orbits.coordinates.covariance.is_all_nan():
variants = VariantOrbits.create(
orbits, method=covariance_method, num_samples=num_samples
orbits,
method=covariance_method,
num_samples=num_samples,
seed=seed,
)

variants_ref = ray.put(variants)

idx = np.arange(0, len(variants))
for variant_chunk in _iterate_chunks(idx, chunk_size):
idx_chunk = ray.put(variant_chunk)
for variant_chunk_idx in _iterate_chunks(idx, chunk_size):
futures.append(
propagation_worker_ray.remote(
idx_chunk,
variant_chunk_idx,
variants_ref,
times_ref,
self.__class__,
Expand All @@ -587,6 +597,10 @@ def propagate_orbits(
propagated = qv.concatenate(propagated_list)
if len(variants_list) > 0:
propagated_variants = qv.concatenate(variants_list)
# sort by variant_id and time
propagated_variants = propagated_variants.sort_by(
["variant_id", "coordinates.time.days", "coordinates.time.nanos"]
)
else:
propagated_variants = None

Expand All @@ -595,9 +609,17 @@ def propagate_orbits(

if covariance is True and not orbits.coordinates.covariance.is_all_nan():
variants = VariantOrbits.create(
orbits, method=covariance_method, num_samples=num_samples
orbits,
method=covariance_method,
num_samples=num_samples,
seed=seed,
)

propagated_variants = self._propagate_orbits(variants, times)
# sort by variant_id and time
propagated_variants = propagated_variants.sort_by(
["variant_id", "coordinates.time.days", "coordinates.time.nanos"]
)
else:
propagated_variants = None

Expand Down

0 comments on commit 9e21f2c

Please sign in to comment.