Skip to content

Commit 07b2c63

Browse files
committed
progress
1 parent d99f09e commit 07b2c63

File tree

2 files changed

+94
-18
lines changed

2 files changed

+94
-18
lines changed

src/adam_core/time/tests/test_time.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytest
88
import quivr as qv
99

10-
from ..time import Timestamp
10+
from ..time import NANOS_PER_DAY, Timestamp
1111

1212

1313
class Wrapper(qv.Table):
@@ -699,14 +699,19 @@ def test_Timestamp_rescale(scale1, scale2):
699699
Test rescaling by using round trip calculations
700700
"""
701701
times = Timestamp.from_kwargs(
702-
days=[-57023, 0, 51544, 103088, 164178, 68000, 68000, 68010, 68020], # Spans from ~1702 to ~2308
703-
nanos=[100_000_000, 200_000_000, 300_000_000, 400_000_000, 500_000_000, 1, 2, 3, 4],
702+
days=[-57032, -36525, -2, -1, 0, 51544, 103088, 164178, 68000, 68000, 68010, 68020], # Spans from ~1702 to ~2308
703+
nanos=[50_000, 0, 123, 100_000_000, 200_000_000, 300_000_000, 400_000_000, 500_000_000, 1, 2, 3, 4],
704704
scale=scale1,
705705
)
706706
rescaled = times.rescale(scale2)
707707
round_tripped = rescaled.rescale(scale1)
708708
assert rescaled.scale == scale2
709709
assert round_tripped.scale == scale1
710-
print(times.difference(round_tripped))
711-
assert pc.all(times.equals(round_tripped, precision="us")).as_py()
710+
print(scale1, scale2)
711+
mjds = list(zip(times.mjd().to_pylist(), rescaled.mjd().to_pylist(), round_tripped.mjd().to_pylist()))
712+
days_diff, nanos_diff = times.difference(round_tripped)
713+
diff = list(zip(days_diff.to_pylist(), nanos_diff.to_pylist()))
714+
for i in range(len(mjds)):
715+
print(mjds[i], diff[i])
716+
assert pc.all(times.equals(round_tripped, precision="ns")).as_py()
712717

src/adam_core/time/time.py

Lines changed: 84 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
MJD_EPOCH_IN_TAI = hifitime.Epoch("1858-11-17T00:00:00 TAI")
2222
MJD_EPOCH_IN_UTC = hifitime.Epoch("1858-11-17T00:00:00 UTC")
2323

24+
DAYS_PER_CENTURY = 36_525
25+
NANOS_PER_DAY = 86_400_000_000_000
26+
NANOS_PER_CENTURY = DAYS_PER_CENTURY * NANOS_PER_DAY
27+
2428
class Timestamp(qv.Table):
2529
# Scale, the rate at which time passes:
2630
scale = qv.StringAttribute(default="tai")
@@ -503,30 +507,43 @@ def rescale(self, new_scale: str) -> Timestamp:
503507

504508
# Get only the values needed based on the input scale
505509
if self.scale == "tdb":
506-
day_values = self.jd().to_numpy()
507-
init_epoch_fn = lambda val: hifitime.Epoch.init_from_jde_tdb(float(val))
510+
def init_epoch_fn(days, nanos):
511+
total_nanos = int(days) * NANOS_PER_DAY + int(nanos)
512+
duration_since_mjd_epoch = hifitime.Duration.from_total_nanoseconds(total_nanos)
513+
epoch = MJD_EPOCH_IN_TDB + duration_since_mjd_epoch
514+
return epoch
508515
else: # tai or utc
509-
day_values = self.mjd().to_numpy()
510516
if self.scale == "tai":
511-
init_epoch_fn = lambda val: hifitime.Epoch.init_from_mjd_tai(float(val))
517+
def init_epoch_fn(days, nanos):
518+
total_nanos = int(days) * NANOS_PER_DAY + int(nanos)
519+
duration_since_mjd_epoch = hifitime.Duration.from_total_nanoseconds(total_nanos)
520+
epoch = MJD_EPOCH_IN_TAI + duration_since_mjd_epoch
521+
return epoch
512522
else: # utc
513-
init_epoch_fn = lambda val: hifitime.Epoch.init_from_mjd_utc(float(val))
523+
def init_epoch_fn(days, nanos):
524+
total_nanos = int(days) * NANOS_PER_DAY + int(nanos)
525+
duration_since_mjd_epoch = hifitime.Duration.from_total_nanoseconds(total_nanos)
526+
epoch = MJD_EPOCH_IN_UTC + duration_since_mjd_epoch
527+
return epoch
514528

515529
# Determine the result calculation based on the new scale
516530
if new_scale == "tai":
517-
calculate_result = lambda epoch: epoch.to_mjd_tai_days()
531+
calculate_result = _calculate_tai_result
518532
elif new_scale == "utc":
519-
calculate_result = lambda epoch: epoch.to_mjd_utc_days()
533+
calculate_result = _calculate_utc_result
520534
elif new_scale == "tdb":
521-
calculate_result = lambda epoch: epoch.timedelta(MJD_EPOCH_IN_TDB).to_unit(hifitime.Unit.Day)
535+
calculate_result = _calculate_tdb_result
522536
else:
523537
raise ValueError(f"Unknown scale: {new_scale}")
524538

525539
# Process each timestamp individually
526-
result_values = [calculate_result(init_epoch_fn(val)) for val in day_values]
527-
540+
inputs = zip(self.days.to_numpy(), self.nanos.to_numpy())
541+
init_epochs = [init_epoch_fn(*val) for val in inputs]
542+
result_values = [calculate_result(init_epoch, self.scale) for init_epoch in init_epochs]
543+
result_days = [val[0] for val in result_values]
544+
result_nanos = [val[1] for val in result_values]
528545
# Convert result list back to a Timestamp
529-
return self.from_mjd(pa.array(result_values), scale=new_scale)
546+
return self.from_kwargs(days=result_days, nanos=result_nanos, scale=new_scale)
530547

531548
def link(
532549
self, other: Timestamp, precision: str = "ns"
@@ -554,8 +571,6 @@ def link(
554571
rounded = self.rounded(precision)
555572
other_rounded = other.rounded(precision)
556573

557-
import pdb; pdb.set_trace()
558-
559574
left_keys = {"days": rounded.days, "nanos": rounded.nanos}
560575
right_keys = {"days": other_rounded.days, "nanos": other_rounded.nanos}
561576
return qv.MultiKeyLinkage(self, other, left_keys, right_keys)
@@ -582,3 +597,59 @@ def _duration_arrays_within_tolerance(
582597
pc.greater_equal(pc.abs(delta_nanos), 86400 * 1e9 - max_nanos_deviation),
583598
)
584599
return pc.or_(cond1, cond2)
600+
601+
def _calculate_tai_result(epoch, from_scale: str):
602+
epoch_new_time_scale = epoch.to_time_scale(hifitime.TimeScale.TAI)
603+
tai_duration_since_mjd_epoch = epoch_new_time_scale.timedelta(MJD_EPOCH_IN_TAI)
604+
605+
total_nanos = tai_duration_since_mjd_epoch.total_nanoseconds()
606+
607+
# Integer division rounds towards negative infinity for negative numbers
608+
days = total_nanos // NANOS_PER_DAY
609+
610+
if days < -36525:
611+
total_nanos += NANOS_PER_CENTURY
612+
days = total_nanos // NANOS_PER_DAY
613+
614+
# Modulo will give us a positive remainder
615+
nanos = total_nanos % NANOS_PER_DAY
616+
617+
return int(days), int(nanos)
618+
619+
def _calculate_utc_result(epoch, from_scale: str):
620+
epoch_new_time_scale = epoch.to_time_scale(hifitime.TimeScale.UTC)
621+
utc_duration_since_mjd_epoch = epoch_new_time_scale.timedelta(MJD_EPOCH_IN_UTC)
622+
623+
total_nanos = utc_duration_since_mjd_epoch.total_nanoseconds()
624+
625+
# Integer division rounds towards negative infinity for negative numbers
626+
days = total_nanos // NANOS_PER_DAY
627+
628+
# Special case for the Julian century boundary
629+
if days < -36525:
630+
total_nanos += NANOS_PER_CENTURY
631+
days = total_nanos // NANOS_PER_DAY
632+
633+
# Modulo will give us a positive remainder
634+
nanos = total_nanos % NANOS_PER_DAY
635+
636+
return int(days), int(nanos)
637+
638+
def _calculate_tdb_result(epoch, from_scale: str):
639+
epoch_new_time_scale = epoch.to_time_scale(hifitime.TimeScale.TDB)
640+
tdb_duration_since_mjd_epoch = epoch_new_time_scale.timedelta(MJD_EPOCH_IN_TDB)
641+
642+
total_nanos = tdb_duration_since_mjd_epoch.total_nanoseconds()
643+
644+
# Integer division rounds towards negative infinity for negative numbers
645+
days = total_nanos // NANOS_PER_DAY
646+
647+
# Special case for the Julian century boundary
648+
if days < -36525:
649+
total_nanos += NANOS_PER_CENTURY
650+
days = total_nanos // NANOS_PER_DAY
651+
652+
# Modulo will give us a positive remainder
653+
nanos = total_nanos % NANOS_PER_DAY
654+
655+
return int(days), int(nanos)

0 commit comments

Comments
 (0)