Skip to content

Commit

Permalink
Merge branch 'main' into tqdm_support
Browse files Browse the repository at this point in the history
  • Loading branch information
tgross03 authored Jul 18, 2024
2 parents 79ddd58 + ed7ad00 commit 21ce8ff
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 26 deletions.
2 changes: 2 additions & 0 deletions docs/changes/31.maintenance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- use observation class to pass sampling options to the fits writer
- include writer in tests
44 changes: 24 additions & 20 deletions pyvisgen/fits/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pyvisgen.layouts.layouts as layouts


def create_vis_hdu(data, conf, layout="vlba", source_name="sim-source-0"):
def create_vis_hdu(data, obs, layout="vlba", source_name="sim-source-0"):
u = data.u

v = data.v
Expand All @@ -23,7 +23,7 @@ def create_vis_hdu(data, conf, layout="vlba", source_name="sim-source-0"):

BASELINE = data.base_num

INTTIM = np.repeat(np.array(conf["corr_int_time"], dtype=">f4"), len(u))
INTTIM = np.repeat(np.array(obs.int_time, dtype=">f4"), len(u))

# visibility data
values = data.get_values()
Expand All @@ -36,10 +36,10 @@ def create_vis_hdu(data, conf, layout="vlba", source_name="sim-source-0"):
# in dim 4 = IFs , dim = 1, dim 4 = number of jones, 3 = real, imag, weight

# wcs
ra = conf["fov_center_ra"]
dec = conf["fov_center_dec"]
freq = (conf["ref_frequency"] * un.Hz).value
freq_d = (conf["bandwidths"][0] * un.Hz).value
ra = obs.ra.cpu().numpy().item()
dec = obs.dec.cpu().numpy().item()
freq = obs.ref_frequency.cpu().numpy().item()
freq_d = obs.bandwidths[0].cpu().numpy().item()

ws = wcs.WCS(naxis=7)
ws.wcs.crpix = [1, 1, 1, 1, 1, 1, 1]
Expand Down Expand Up @@ -87,7 +87,7 @@ def create_vis_hdu(data, conf, layout="vlba", source_name="sim-source-0"):
hdu_vis.header.comments["PTYPE6"] = "Relative Julian date ?"
hdu_vis.header.comments["PTYPE7"] = "Integration time"

date_obs = conf["scan_start"].date().strftime("%Y-%m-%d")
date_obs = obs.start.strftime("%Y-%m-%d")

date_map = Time.now().to_value(format="iso", subfmt="date")

Expand Down Expand Up @@ -165,24 +165,28 @@ def create_time_hdu(data):
return hdu_time


def create_frequency_hdu(conf):
def create_frequency_hdu(obs):
FRQSEL = np.array([1], dtype=">i4")
col1 = fits.Column(name="FRQSEL", format="1J", unit=" ", array=FRQSEL)

IF_FREQ = np.array(
[np.array(conf["frequency_offsets"])],
[np.array(obs.frequency_offsets.cpu().numpy())],
dtype=">f8",
) # start with 0, add ch_with per IF
col2 = fits.Column(
name="IF FREQ", format=str(IF_FREQ.shape[-1]) + "D", unit="Hz", array=IF_FREQ
)

CH_WIDTH = np.repeat(np.array([conf["bandwidths"]], dtype=">f4"), 1, axis=1)
CH_WIDTH = np.repeat(
np.array([obs.bandwidths.cpu().numpy()], dtype=">f4"), 1, axis=1
)
col3 = fits.Column(
name="CH WIDTH", format=str(CH_WIDTH.shape[-1]) + "E", unit="Hz", array=CH_WIDTH
)

TOTAL_BANDWIDTH = np.repeat(np.array([conf["bandwidths"]], dtype=">f4"), 1, axis=1)
TOTAL_BANDWIDTH = np.repeat(
np.array([obs.bandwidths.cpu().numpy()], dtype=">f4"), 1, axis=1
)
col4 = fits.Column(
name="TOTAL BANDWIDTH",
format=str(TOTAL_BANDWIDTH.shape[-1]) + "E",
Expand Down Expand Up @@ -220,8 +224,8 @@ def create_frequency_hdu(conf):
return hdu_freq


def create_antenna_hdu(conf):
array = layouts.get_array_layout(conf["layout"], writer=True)
def create_antenna_hdu(obs):
array = layouts.get_array_layout(obs.layout, writer=True)

ANNAME = np.chararray(len(array), itemsize=8, unicode=True)
ANNAME[:] = array["station_name"].values
Expand Down Expand Up @@ -288,8 +292,8 @@ def create_antenna_hdu(conf):
)
hdu_ant = fits.BinTableHDU.from_columns(coldefs_ant)

freq = (conf["ref_frequency"] * un.Hz).value
ref_date = Time(conf["scan_start"].isoformat(), format="isot")
freq = (obs.ref_frequency.cpu().numpy() * un.Hz).value
ref_date = obs.start

from astropy.utils import iers

Expand Down Expand Up @@ -325,7 +329,7 @@ def create_antenna_hdu(conf):
hdu_ant.header["UT1UTC"] = (iers_b.ut1_utc(ref_date).value, "UT1 - UTC (sec)")
hdu_ant.header["DATUTC"] = (0, "time system - UTC (sec)") # missing
hdu_ant.header["TIMSYS"] = ("UTC", "Time system")
hdu_ant.header["ARRNAM"] = (conf["layout"], "Array name")
hdu_ant.header["ARRNAM"] = (obs.layout, "Array name")
hdu_ant.header["XYZHAND"] = ("RIGHT", "Handedness of station coordinates")
hdu_ant.header["FRAME"] = ("????", "Coordinate frame, FOR IGNORANCE")
hdu_ant.header["NUMORB"] = (0, "Number orbital parameters in table (n orb)")
Expand Down Expand Up @@ -360,11 +364,11 @@ def create_antenna_hdu(conf):
return hdu_ant


def create_hdu_list(data, conf):
def create_hdu_list(data, obs):
warnings.filterwarnings("ignore", module="astropy.io.fits")
vis_hdu = create_vis_hdu(data, conf)
vis_hdu = create_vis_hdu(data, obs)
time_hdu = create_time_hdu(data)
freq_hdu = create_frequency_hdu(conf)
ant_hdu = create_antenna_hdu(conf)
freq_hdu = create_frequency_hdu(obs)
ant_hdu = create_antenna_hdu(obs)
hdu_list = fits.HDUList([vis_hdu, time_hdu, freq_hdu, ant_hdu])
return hdu_list
10 changes: 5 additions & 5 deletions pyvisgen/simulation/data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,21 @@ def simulate_data_set(config, slurm=False, job_id=None, n=None):
if len(SI.shape) == 2:
SI = SI.unsqueeze(0)

obs, samp_ops = create_observation(conf)
obs = create_observation(conf)
vis_data = vis_loop(obs, SI, noisy=conf["noisy"], mode=conf["mode"])
hdu_list = writer.create_hdu_list(vis_data, samp_ops)
hdu_list = writer.create_hdu_list(vis_data, obs)
hdu_list.writeto(out, overwrite=True)

else:
for i in tqdm(range(len(data))):
SIs = get_images(data, i)

for j, SI in enumerate(tqdm(SIs)):
obs, samp_ops = create_observation(conf)
obs = create_observation(conf)
vis_data = vis_loop(obs, SI, noisy=conf["noisy"], mode=conf["mode"])

out = out_path / Path("vis_" + str(j + len(SIs) * i) + ".fits")
hdu_list = writer.create_hdu_list(vis_data, samp_ops)
hdu_list = writer.create_hdu_list(vis_data, obs)
hdu_list.writeto(out, overwrite=True)


Expand Down Expand Up @@ -103,7 +103,7 @@ def create_observation(conf):
dense=dense,
sensitivity_cut=rc["sensitivity_cut"],
)
return obs, rc
return obs


def create_sampling_rc(conf):
Expand Down
1 change: 1 addition & 0 deletions pyvisgen/simulation/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def __init__(
self.sensitivity_cut = sensitivity_cut
self.device = torch.device(device)

self.layout = array_layout
self.array = layouts.get_array_layout(array_layout)
self.num_baselines = int(
len(self.array.st_num) * (len(self.array.st_num) - 1) / 2
Expand Down
56 changes: 56 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from setuptools import find_packages, setup

setup(
name="pyvisgen",
version="0.2.0",
description="Simulate radio interferometer observations \
and visibility generation with the RIME formalism.",
url="https://github.com/radionets-project/pyvisgen",
author="Kevin Schmidt, Felix Geyer, Stefan Fröse",
author_email="[email protected]",
license="MIT",
include_package_data=True,
packages=find_packages(),
install_requires=[
"numpy",
"matplotlib",
"ipython",
"scipy",
"pandas",
"toml",
"pytest",
"pytest-cov",
"jupyter",
"astroplan",
"torch",
"tqdm",
"numexpr",
"click",
"h5py",
"natsort",
"pre-commit",
],
setup_requires=["pytest-runner"],
tests_require=["pytest"],
zip_safe=False,
entry_points={
"console_scripts": [
"pyvisgen_create_dataset = pyvisgen.simulation.scripts.create_dataset:main",
],
},
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3 :: Only",
"Topic :: Scientific/Engineering :: Astronomy",
"Topic :: Scientific/Engineering :: Physics",
"Topic :: Scientific/Engineering :: Information Analysis",
],
)
9 changes: 8 additions & 1 deletion tests/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ def test_create_sampling_rc():
def test_vis_loop():
import torch

import pyvisgen.fits.writer as writer
from pyvisgen.simulation.data_set import create_observation
from pyvisgen.simulation.visibility import vis_loop
from pyvisgen.utils.data import load_bundles, open_bundles

bundles = load_bundles(conf["in_path"])
obs, samp_ops = create_observation(conf)
obs = create_observation(conf)
# num_active_telescopes = test_opts(samp_ops)
data = open_bundles(bundles[0])
SI = torch.tensor(data[0])[None]
Expand All @@ -57,3 +58,9 @@ def test_vis_loop():
# num_vis_calc = vis_data.base_num[vis_data.date == vis_data.date[0]].shape[0]
# dunno what's going on here
# assert num_vis_theory == num_vis_calc
#

out_path = Path(conf["out_path_fits"])
out = out_path / Path("vis_0.fits")
hdu_list = writer.create_hdu_list(vis_data, obs)
hdu_list.writeto(out, overwrite=True)

0 comments on commit 21ce8ff

Please sign in to comment.