diff --git a/demo/demo.ipynb b/demo/demo.ipynb index 33c59343..d995ffc3 100644 --- a/demo/demo.ipynb +++ b/demo/demo.ipynb @@ -1062,7 +1062,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1180,7 +1180,7 @@ " in_file=msv2_name,\n", " out_file=convert_out,\n", " overwrite=True,\n", - " parallel=True,\n", + " parallel_mode=\"partition\",\n", ")" ] }, @@ -9357,7 +9357,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -9746,7 +9746,7 @@ } ], "source": [ - "import numpy as np \n", + "import numpy as np\n", "np.abs(ms_xds.VISIBILITY).max().compute()" ] }, diff --git a/dev/review_fs/review_fs.ipynb b/dev/review_fs/review_fs.ipynb index a67bff01..cc96d580 100644 --- a/dev/review_fs/review_fs.ipynb +++ b/dev/review_fs/review_fs.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -139,7 +139,7 @@ "source": [ "from toolviper.utils.data import download\n", "download(file=\"ALMA_uid___A002_X1003af4_X75a3.split.avg.ms\") #ALMA Mosaic Ephmeris of the Sun.\n", - "download(file=\"VLBA_TL016B_split.ms\") \n", + "download(file=\"VLBA_TL016B_split.ms\")\n", "download(file=\"Antennae_North.cal.lsrk.split.ms\")\n", "download(file=\"SNR_G55_10s.split.ms\")\n", "# download()" @@ -154,7 +154,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -312,7 +312,7 @@ " in_file=msv2_name,\n", " out_file=convert_out,\n", " overwrite=True,\n", - " parallel=False,\n", + " parallel_mode=\"none\",\n", ")\n", "\n", "from xradio.measurement_set import open_processing_set\n", @@ -849,7 +849,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -965,7 +965,7 @@ " in_file=msv2_name,\n", " out_file=convert_out,\n", " overwrite=True,\n", - " parallel=False,\n", + " parallel_mode=\"none\",\n", " partition_scheme=[]\n", ")\n", "\n", @@ -1741,7 +1741,7 @@ " in_file=msv2_name,\n", " out_file=convert_out,\n", " overwrite=True,\n", - " parallel=False,\n", + " parallel_mode=\"none\",\n", ")\n", "\n", "from xradio.measurement_set import open_processing_set\n", @@ -2714,7 +2714,7 @@ " in_file=msv2_name,\n", " out_file=convert_out,\n", " overwrite=True,\n", - " parallel=False,\n", + " parallel_mode=\"none\",\n", " partition_scheme=[]\n", ")\n", "\n", @@ -3550,7 +3550,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -3882,7 +3882,7 @@ " in_file=msv2_name,\n", " out_file=convert_out,\n", " overwrite=True,\n", - " parallel=False,\n", + " parallel_mode=\"none\",\n", ")\n", "\n", "from xradio.measurement_set import open_processing_set\n", @@ -4476,7 +4476,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -4646,7 +4646,7 @@ " in_file=msv2_name,\n", " out_file=convert_out,\n", " overwrite=True,\n", - " parallel=False,\n", + " parallel_mode=\"none\",\n", " partition_scheme=[]\n", ")\n", "\n", @@ -5244,7 +5244,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -5644,7 +5644,7 @@ " out_file=convert_out,\n", " overwrite=True,\n", " ephemeris_interpolate=True,\n", - " parallel=False,\n", + " parallel_mode=\"none\",\n", ")\n", "\n", "from xradio.measurement_set import open_processing_set\n", @@ -6700,7 +6700,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -7190,7 +7190,7 @@ " in_file=msv2_name,\n", " out_file=convert_out,\n", " overwrite=True,\n", - " parallel=False,\n", + " parallel_mode=\"none\",\n", " partition_scheme=[]\n", ")\n", "\n", @@ -7201,7 +7201,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -7773,7 +7773,7 @@ } ], "source": [ - "ps[\"ALMA_uid___A002_X1003af4_X75a3.split.avg_09\"].VISIBILITY.field_and_source_xds " + "ps[\"ALMA_uid___A002_X1003af4_X75a3.split.avg_09\"].VISIBILITY.field_and_source_xds" ] }, { @@ -10462,7 +10462,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -11658,7 +11658,7 @@ " in_file=msv2_name,\n", " out_file=convert_out,\n", " overwrite=True,\n", - " parallel=False,\n", + " parallel_mode=\"none\",\n", " ephemeris_interpolate=True,\n", " partition_scheme=[]\n", ")\n", diff --git a/docs/source/measurement_set/guides/ALMA_SD.ipynb b/docs/source/measurement_set/guides/ALMA_SD.ipynb index 1c0d941f..2e56c8e8 100644 --- a/docs/source/measurement_set/guides/ALMA_SD.ipynb +++ b/docs/source/measurement_set/guides/ALMA_SD.ipynb @@ -928,7 +928,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -967,7 +967,7 @@ "convert_msv2_to_processing_set(\n", " in_file=ms_file,\n", " out_file=ps_store,\n", - " parallel=True,\n", + " parallel_mode=\"partition\",\n", " overwrite=True,\n", ")" ] diff --git a/docs/source/measurement_set/guides/ALMA_ephemeris.ipynb b/docs/source/measurement_set/guides/ALMA_ephemeris.ipynb index 85655f83..58dc2cf8 100644 --- a/docs/source/measurement_set/guides/ALMA_ephemeris.ipynb +++ b/docs/source/measurement_set/guides/ALMA_ephemeris.ipynb @@ -121,7 +121,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "ee52d124-2c17-450b-879f-1f86f0ae265c", "metadata": {}, "outputs": [ @@ -210,7 +210,7 @@ "convert_msv2_to_processing_set(\n", " in_file=ms_file,\n", " out_file=outfile,\n", - " parallel=False,\n", + " parallel_mode=\"none\",\n", " overwrite=True,\n", " main_chunksize=main_chunksize,\n", ")" diff --git a/docs/source/measurement_set/guides/GBT_single_dish.ipynb b/docs/source/measurement_set/guides/GBT_single_dish.ipynb index c7f0282c..4f7f1728 100644 --- a/docs/source/measurement_set/guides/GBT_single_dish.ipynb +++ b/docs/source/measurement_set/guides/GBT_single_dish.ipynb @@ -119,7 +119,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "ee52d124-2c17-450b-879f-1f86f0ae265c", "metadata": {}, "outputs": [ @@ -145,7 +145,7 @@ "convert_msv2_to_processing_set(\n", " in_file=ms_file,\n", " out_file=outfile,\n", - " parallel=False,\n", + " parallel_mode=\"none\",\n", " overwrite=True,\n", " main_chunksize=main_chunksize,\n", ")" diff --git a/docs/source/measurement_set/guides/GMRT.ipynb b/docs/source/measurement_set/guides/GMRT.ipynb index a792eae0..aed315e4 100644 --- a/docs/source/measurement_set/guides/GMRT.ipynb +++ b/docs/source/measurement_set/guides/GMRT.ipynb @@ -109,7 +109,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "id": "87431f1b-b94d-442a-88b4-936f74be530c", "metadata": {}, "outputs": [ @@ -132,7 +132,7 @@ "convert_msv2_to_processing_set(\n", " in_file=ms_file,\n", " out_file=outfile,\n", - " parallel=False,\n", + " parallel_mode=\"none\",\n", " overwrite=True,\n", " main_chunksize=main_chunksize,\n", ")" diff --git a/docs/source/measurement_set/guides/LOFAR.ipynb b/docs/source/measurement_set/guides/LOFAR.ipynb index 26dd2d49..e05aa73a 100644 --- a/docs/source/measurement_set/guides/LOFAR.ipynb +++ b/docs/source/measurement_set/guides/LOFAR.ipynb @@ -123,7 +123,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "id": "ee52d124-2c17-450b-879f-1f86f0ae265c", "metadata": {}, "outputs": [ @@ -148,7 +148,7 @@ "convert_msv2_to_processing_set(\n", " in_file=ms_file,\n", " out_file=outfile,\n", - " parallel=False,\n", + " parallel_mode=\"none\",\n", " overwrite=True,\n", " main_chunksize=main_chunksize,\n", ")" diff --git a/docs/source/measurement_set/guides/MeerKAT.ipynb b/docs/source/measurement_set/guides/MeerKAT.ipynb index dce2b7e6..2c1b1c10 100644 --- a/docs/source/measurement_set/guides/MeerKAT.ipynb +++ b/docs/source/measurement_set/guides/MeerKAT.ipynb @@ -123,7 +123,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "id": "ee52d124-2c17-450b-879f-1f86f0ae265c", "metadata": {}, "outputs": [ @@ -151,7 +151,7 @@ "convert_msv2_to_processing_set(\n", " in_file=ms_file,\n", " out_file=outfile,\n", - " parallel=False,\n", + " parallel_mode=\"none\",\n", " overwrite=True,\n", " main_chunksize=main_chunksize,\n", ")" diff --git a/docs/source/measurement_set/guides/SKA_mid.ipynb b/docs/source/measurement_set/guides/SKA_mid.ipynb index 1e174b17..f32f304f 100644 --- a/docs/source/measurement_set/guides/SKA_mid.ipynb +++ b/docs/source/measurement_set/guides/SKA_mid.ipynb @@ -123,7 +123,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "ee52d124-2c17-450b-879f-1f86f0ae265c", "metadata": {}, "outputs": [ @@ -149,7 +149,7 @@ "convert_msv2_to_processing_set(\n", " in_file=ms_file,\n", " out_file=outfile,\n", - " parallel=False,\n", + " parallel_mode=\"none\",\n", " overwrite=True,\n", " main_chunksize=main_chunksize,\n", ")" diff --git a/docs/source/measurement_set/guides/VLBA.ipynb b/docs/source/measurement_set/guides/VLBA.ipynb index 315d9d4d..6b33fc43 100644 --- a/docs/source/measurement_set/guides/VLBA.ipynb +++ b/docs/source/measurement_set/guides/VLBA.ipynb @@ -139,7 +139,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "ee52d124-2c17-450b-879f-1f86f0ae265c", "metadata": {}, "outputs": [ @@ -174,7 +174,7 @@ "convert_msv2_to_processing_set(\n", " in_file=ms_file,\n", " out_file=outfile,\n", - " parallel=False,\n", + " parallel_mode=\"none\",\n", " overwrite=True,\n", " main_chunksize=main_chunksize,\n", ")" diff --git a/docs/source/measurement_set/guides/VLBI.ipynb b/docs/source/measurement_set/guides/VLBI.ipynb index 8c2a3cc1..022d8e86 100644 --- a/docs/source/measurement_set/guides/VLBI.ipynb +++ b/docs/source/measurement_set/guides/VLBI.ipynb @@ -133,7 +133,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "ee52d124-2c17-450b-879f-1f86f0ae265c", "metadata": {}, "outputs": [ @@ -168,7 +168,7 @@ "convert_msv2_to_processing_set(\n", " in_file=ms_file,\n", " out_file=outfile,\n", - " parallel=False,\n", + " parallel_mode=\"none\",\n", " overwrite=True,\n", " main_chunksize=main_chunksize,\n", ")" diff --git a/docs/source/measurement_set/guides/ngEHT.ipynb b/docs/source/measurement_set/guides/ngEHT.ipynb index 759ac300..6c98a0b4 100644 --- a/docs/source/measurement_set/guides/ngEHT.ipynb +++ b/docs/source/measurement_set/guides/ngEHT.ipynb @@ -136,7 +136,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "id": "ee52d124-2c17-450b-879f-1f86f0ae265c", "metadata": {}, "outputs": [ @@ -165,7 +165,7 @@ "convert_msv2_to_processing_set(\n", " in_file=ms_file,\n", " out_file=outfile,\n", - " parallel=False,\n", + " parallel_mode=\"none\",\n", " overwrite=True,\n", " main_chunksize=main_chunksize,\n", ")" diff --git a/docs/source/measurement_set/tutorials/ps_vis.ipynb b/docs/source/measurement_set/tutorials/ps_vis.ipynb index 2aee0c7b..a2a79b81 100644 --- a/docs/source/measurement_set/tutorials/ps_vis.ipynb +++ b/docs/source/measurement_set/tutorials/ps_vis.ipynb @@ -356,7 +356,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "7febe111-0000-4c51-bb8c-33efe13fd934", "metadata": {}, "outputs": [ @@ -371,8 +371,8 @@ } ], "source": [ - "do_parallel = True\n", - "if do_parallel:\n", + "do_parallel = \"none\"\n", + "if do_parallel == \"none\":\n", " from toolviper import dask\n", " viper_client = toolviper.dask.local_client(cores=suggested_cores)" ] @@ -387,7 +387,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "fcd710fb-c2e0-4ff6-8e03-6c11673eec09", "metadata": {}, "outputs": [ @@ -420,7 +420,7 @@ " in_file=msv2_name,\n", " out_file=convert_out,\n", " overwrite=True,\n", - " parallel=do_parallel,\n", + " parallel_mode=do_parallel,\n", ")" ] }, diff --git a/reviews/review_antenna_xds.ipynb b/reviews/review_antenna_xds.ipynb index be666648..afdf27e2 100644 --- a/reviews/review_antenna_xds.ipynb +++ b/reviews/review_antenna_xds.ipynb @@ -860,7 +860,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -895,7 +895,7 @@ "convert_msv2_to_processing_set(\n", " in_file=in_file,\n", " out_file=out_file,\n", - " parallel=False,\n", + " parallel_mode=\"none\",\n", " overwrite=True,\n", " phase_cal_interpolate=True,\n", ")" @@ -3841,7 +3841,7 @@ "convert_msv2_to_processing_set(\n", " in_file=in_file,\n", " out_file=out_file,\n", - " parallel=True,\n", + " parallel_mode=\"partition\",\n", " overwrite=True,\n", ")" ] diff --git a/reviews/review_field_and_source_xds.ipynb b/reviews/review_field_and_source_xds.ipynb index d4d30163..fff7d9b2 100644 --- a/reviews/review_field_and_source_xds.ipynb +++ b/reviews/review_field_and_source_xds.ipynb @@ -722,7 +722,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "9a82a2e5", "metadata": {}, "outputs": [ @@ -782,12 +782,11 @@ "convert_msv2_to_processing_set(\n", " in_file=in_file,\n", " out_file=out_file,\n", - " parallel=False,\n", + " parallel_mode=\"none\",\n", " overwrite=True,\n", " ephemeris_interpolate=True,\n", " partition_scheme=partition_scheme\n", - ")\n", - "\n" + ")\n" ] }, { @@ -4926,7 +4925,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "id": "f2a63ba3", "metadata": {}, "outputs": [ @@ -4953,7 +4952,7 @@ " 'weight': 'WEIGHT'}\n", "\n", "sel_data_group_set = set(dgs['base'].values())\n", - " \n", + "\n", "data_variables_to_drop = []\n", "for dg in dgs.values():\n", " temp_set = set(dg.values()) - sel_data_group_set\n", @@ -7012,7 +7011,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "d798c392", "metadata": {}, "outputs": [ @@ -7028,13 +7027,13 @@ } ], "source": [ - "if len(partition_scheme)==0: #Largest partitions, \n", + "if len(partition_scheme)==0: #Largest partitions,\n", " msv4_name_ephemeris = \"ALMA_uid___A002_X1003af4_X75a3.split.avg_17\"\n", " msv4_name = \"ALMA_uid___A002_X1003af4_X75a3.split.avg_15\"\n", "else: #Partition also by Field_id (default behavior).\n", " msv4_name_ephemeris = \"ALMA_uid___A002_X1003af4_X75a3.split.avg_81\"\n", " msv4_name = \"ALMA_uid___A002_X1003af4_X75a3.split.avg_67\"\n", - " \n", + "\n", "msv4_name" ] }, diff --git a/reviews/review_ps.ipynb b/reviews/review_ps.ipynb index 03d3a359..e66997a6 100644 --- a/reviews/review_ps.ipynb +++ b/reviews/review_ps.ipynb @@ -632,7 +632,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -754,7 +754,7 @@ "convert_msv2_to_processing_set(\n", " in_file=in_file,\n", " out_file=out_file,\n", - " parallel=True,\n", + " parallel_mode=\"partition\",\n", " overwrite=True,\n", " ephemeris_interpolate=True,\n", " partition_scheme=partition_scheme\n", @@ -2771,7 +2771,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2945,8 +2945,7 @@ ], "source": [ "#Note that no selection is applied on the MS data so even if field_name=['Sun_10_10','Sun_10_11'] all the fields are kept.\n", - "ps.sel(field_coords='Ephemeris',field_name=['Sun_10_10','Sun_10_11']).summary() #Select all Ephemeris data and where any of the fields are 'Sun_10_10' or 'Sun_10_11'.\n", - "\n" + "ps.sel(field_coords='Ephemeris',field_name=['Sun_10_10','Sun_10_11']).summary() #Select all Ephemeris data and where any of the fields are 'Sun_10_10' or 'Sun_10_11'.\n" ] }, { diff --git a/src/xradio/measurement_set/_utils/_msv2/_tables/read.py b/src/xradio/measurement_set/_utils/_msv2/_tables/read.py index a9d49aaf..4fb34811 100644 --- a/src/xradio/measurement_set/_utils/_msv2/_tables/read.py +++ b/src/xradio/measurement_set/_utils/_msv2/_tables/read.py @@ -4,6 +4,7 @@ import re from typing import Any, Callable, Dict, List, Tuple, Union +import dask.array as da import numpy as np import pandas as pd import xarray as xr @@ -11,7 +12,7 @@ import astropy.units from casacore import tables -from .table_query import open_query, open_table_ro +from .table_query import open_query, open_table_ro, TableManager from xradio._utils.list_and_array import get_pad_value CASACORE_TO_PD_TIME_CORRECTION = 3_506_716_800.0 @@ -1207,13 +1208,14 @@ def read_col_chunk( return fulldata -def read_col_conversion( - tb_tool: tables.table, +def read_col_conversion_numpy( + table_manager: TableManager, col: str, cshape: Tuple[int], tidxs: np.ndarray, bidxs: np.ndarray, use_table_iter: bool, + time_chunksize: int, ) -> np.ndarray: """ Function to perform delayed reads from table columns when converting @@ -1221,7 +1223,7 @@ def read_col_conversion( Parameters ---------- - tb_tool : tables.table + table_manager : TableManager col : str @@ -1231,6 +1233,8 @@ def read_col_conversion( bidxs : np.ndarray + use_table_iter : bool + Returns ------- np.ndarray @@ -1241,60 +1245,197 @@ def read_col_conversion( # WARNING: Assumes the num_frequencies * num_polarizations < 2**29. If false, # https://github.com/casacore/python-casacore/issues/130 isn't mitigated. + with table_manager.get_table() as tb_tool: + + # Use casacore to get the shape of a row for this column + ################################################################################# + + # getcolshapestring() only works on columns where a row element is an + # array ie. fails for TIME + # Assumes the RuntimeError is because the column is a scalar + try: + shape_string = tb_tool.getcolshapestring(col)[0] + # Convert `shape_string` into a tuple that numpy understands + extra_dimensions = tuple( + [ + int(idx) + for idx in shape_string.replace("[", "") + .replace("]", "") + .split(", ") + ] + ) + except RuntimeError: + extra_dimensions = () + + ################################################################################# + + # Get dtype of the column. Only read first row from disk + col_dtype = np.array(tb_tool.col(col)[0]).dtype + # Use a custom/safe fill value (https://github.com/casangi/xradio/issues/219) + fill_value = get_pad_value(col_dtype) + + # Construct a numpy array to populate. `data` has shape (n_times, n_baselines, n_frequencies, n_polarizations) + data = np.full(cshape + extra_dimensions, fill_value, dtype=col_dtype) + + # Use built-in casacore table iterator to populate the data column by unique times. + if use_table_iter: + start_row = 0 + for ts in tb_tool.iter("TIME", sort=False): + num_rows = ts.nrows() + + # Create small temporary array to store the partial column + tmp_arr = np.full( + (num_rows,) + extra_dimensions, fill_value, dtype=col_dtype + ) + + # Note we don't use `getcol()` because it's less safe. See: + # https://github.com/casacore/python-casacore/issues/130#issuecomment-463202373 + ts.getcolnp(col, tmp_arr) + + # Get the slice of rows contained in `tmp_arr`. + # Used to get the relevant integer indexes from `tidxs` and `bidxs` + tmp_slice = slice(start_row, start_row + num_rows) + + # Copy `tmp_arr` into correct elements of `tmp_arr` + data[tidxs[tmp_slice], bidxs[tmp_slice]] = tmp_arr + start_row += num_rows + else: + data[tidxs, bidxs] = tb_tool.getcol(col) + + return data + + +def read_col_conversion_dask( + table_manager: TableManager, + col: str, + cshape: Tuple[int], + tidxs: np.ndarray, + bidxs: np.ndarray, + use_table_iter: bool, + time_chunksize: int, +) -> da.Array: + """ + Function to perform delayed reads from table columns when converting + (no need for didxs) + + Parameters + ---------- + tb_tool : tables.table + + col : str + + cshape : Tuple[int] + + tidxs : np.ndarray + + bidxs : np.ndarray + + Returns + ------- + da.Array + """ + # Use casacore to get the shape of a row for this column ################################################################################# - # Get the total number of rows in the base measurement set - nrows_total = tb_tool.nrows() + with table_manager.get_table() as tb_tool: + first_row = tb_tool.row(col)[0][col] - # getcolshapestring() only works on columns where a row element is an - # array ie. fails for TIME - # Assumes the RuntimeError is because the column is a scalar - try: - shape_string = tb_tool.getcolshapestring(col)[0] - # Convert `shape_string` into a tuple that numpy understands - extra_dimensions = tuple( - [ - int(idx) - for idx in shape_string.replace("[", "").replace("]", "").split(", ") - ] - ) - except RuntimeError: + if isinstance(first_row, np.ndarray): + extra_dimensions = first_row.shape + + else: extra_dimensions = () + # Use dask primitives to lazily read chunks of data from the MeasurementSet + # Takes inspiration from dask_image https://image.dask.org/en/latest/ ################################################################################# - # Get dtype of the column. Only read first row from disk - col_dtype = np.array(tb_tool.col(col)[0]).dtype - # Use a custom/safe fill value (https://github.com/casangi/xradio/issues/219) - fill_value = get_pad_value(col_dtype) + # Get dtype of the column. Wrap in numpy array in case of scalar column + col_dtype = np.array(first_row).dtype + + # Get the number of rows for a single TIME value + num_utimes = cshape[0] + rows_per_time = cshape[1] + + # Calculate the chunks of unique times that gives the target chunk sizes + tmp_chunks = da.core.normalize_chunks(time_chunksize, (num_utimes,))[0] + + sum = 0 + arr_start_end_rows = [] + for chunk in tmp_chunks: + start = (sum) * rows_per_time + end = (sum + chunk) * rows_per_time + + arr_start_end_rows.append((start, end)) + sum += chunk + + # Store the start and end rows that should be read for the chunk + arr_start_end_rows = da.from_array(arr_start_end_rows, chunks=(1, 2)) + + # Specify the output shape `load_col_chunk` + output_chunkshape = (tmp_chunks, cshape[1]) + extra_dimensions + + # Apply `load_col_chunk` to each chunk + data = arr_start_end_rows.map_blocks( + load_col_chunk, + table_manager=table_manager, + col_name=col, + col_dtype=col_dtype, + tidxs=tidxs, + bidxs=bidxs, + rows_per_time=rows_per_time, + cshape=cshape, + extra_dimensions=extra_dimensions, + drop_axis=[1], + new_axis=list(range(1, len(cshape + extra_dimensions))), + meta=np.array([], dtype=col_dtype), + chunks=output_chunkshape, + ) - # Construct a numpy array to populate. `data` has shape (n_times, n_baselines, n_frequencies, n_polarizations) - data = np.full(cshape + extra_dimensions, fill_value, dtype=col_dtype) + return data - # Use built-in casacore table iterator to populate the data column by unique times. - if use_table_iter: - start_row = 0 - for ts in tb_tool.iter("TIME", sort=False): - num_rows = ts.nrows() - # Create small temporary array to store the partial column - tmp_arr = np.full( - (num_rows,) + extra_dimensions, fill_value, dtype=col_dtype - ) +def load_col_chunk( + x, + table_manager, + col_name, + col_dtype, + tidxs, + bidxs, + rows_per_time, + cshape, + extra_dimensions, +): + start_row = x[0][0] + end_row = x[0][1] + num_rows = end_row - start_row + assert (num_rows % rows_per_time) == 0 + num_utimes = num_rows // rows_per_time + + # Create memory buffer to populate with data from disk + row_data = np.full((num_rows,) + extra_dimensions, np.nan, dtype=col_dtype) + + # Load data from the column + # Release the casacore table as soon as possible + with table_manager.get_table() as tb_tool: + tb_tool.getcolnp(col_name, row_data, startrow=start_row, nrow=num_rows) + + # Initialise reshaped numpy array + reshaped_data = np.full( + (num_utimes, cshape[1]) + extra_dimensions, np.nan, dtype=col_dtype + ) - # Note we don't use `getcol()` because it's less safe. See: - # https://github.com/casacore/python-casacore/issues/130#issuecomment-463202373 - ts.getcolnp(col, tmp_arr) + # Create slice object for readability + slc = slice(start_row, end_row) + tidxs_slc = tidxs[slc] - # Get the slice of rows contained in `tmp_arr`. - # Used to get the relevant integer indexes from `tidxs` and `bidxs` - tmp_slice = slice(start_row, start_row + num_rows) + tidxs_slc = ( + tidxs_slc - tidxs_slc[0] + ) # Indices of reshaped_data along time differ from values in tidxs. Assumes first time is earliest time + bidxs_slc = bidxs[slc] - # Copy `tmp_arr` into correct elements of `tmp_arr` - data[tidxs[tmp_slice], bidxs[tmp_slice]] = tmp_arr - start_row += num_rows - else: - data[tidxs, bidxs] = tb_tool.getcol(col) + # Populate `reshaped_data` with `row_data` + reshaped_data[tidxs_slc, bidxs_slc] = row_data - return data + return reshaped_data diff --git a/src/xradio/measurement_set/_utils/_msv2/_tables/table_query.py b/src/xradio/measurement_set/_utils/_msv2/_tables/table_query.py index f5075c3c..6d3f4aa2 100644 --- a/src/xradio/measurement_set/_utils/_msv2/_tables/table_query.py +++ b/src/xradio/measurement_set/_utils/_msv2/_tables/table_query.py @@ -22,3 +22,25 @@ def open_query(table: tables.table, query: str) -> Generator[tables.table, None, yield ttq finally: ttq.close() + + +class TableManager: + + def __init__( + self, + infile: str, + taql_where: str = "", + ): + self.infile = infile + self.taql_where = taql_where + self.taql_query = taql_where.replace("where ", "") + + def get_table(self): + # Performance note: + # table.query("(DATA_DESC_ID = 0)") is slightly faster than + # tables.taql("select * from $table (DATA_DESC_ID = 0)") + with tables.table( + self.infile, readonly=True, lockoptions={"option": "usernoread"}, ack=False + ) as mtable: + query = f"select * from $mtable {self.taql_where}" + return tables.taql(query) diff --git a/src/xradio/measurement_set/_utils/_msv2/conversion.py b/src/xradio/measurement_set/_utils/_msv2/conversion.py index 02464e11..fa7ff86d 100644 --- a/src/xradio/measurement_set/_utils/_msv2/conversion.py +++ b/src/xradio/measurement_set/_utils/_msv2/conversion.py @@ -36,11 +36,12 @@ from .._zarr.encoding import add_encoding from .subtables import subt_rename_ids -from ._tables.table_query import open_table_ro, open_query +from ._tables.table_query import open_table_ro, open_query, TableManager from ._tables.read import ( convert_casacore_time, extract_table_attributes, - read_col_conversion, + read_col_conversion_numpy, + read_col_conversion_dask, load_generic_table, ) from ._tables.read_main_table import get_baselines, get_baseline_indices, get_utimes_tol @@ -379,6 +380,10 @@ def calc_used_gb( # TODO: if the didxs are not used in read_col_conversion, remove didxs from here (and convert_and_write_partition) def calc_indx_for_row_split(tb_tool, taql_where): + # Allow TableManager object to be used + if isinstance(tb_tool, TableManager): + tb_tool = tb_tool.get_table() + baselines = get_baselines(tb_tool) col_names = tb_tool.colnames() cshapes = [ @@ -562,10 +567,42 @@ def find_min_max_times(tb_tool: tables.table, taql_where: str) -> tuple: def create_data_variables( - in_file, xds, tb_tool, time_baseline_shape, tidxs, bidxs, didxs, use_table_iter + in_file, + xds, + table_manager, + time_baseline_shape, + tidxs, + bidxs, + didxs, + use_table_iter, + parallel_mode, + main_chunksize, ): + + # Get time chunks + time_chunksize = None + if parallel_mode == "time": + try: + time_chunksize = main_chunksize["time"] + except KeyError: + # If time isn't chunked then `read_col_conversion_dask` is slower than `read_col_conversion_numpy` + logger.warning( + "'time' isn't specified in `main_chunksize`. Defaulting to `parallel_mode = 'none'`." + ) + parallel_mode = "none" + + # Set read_col_conversion from value of `parallel_mode` argument + # TODO: To make this compatible with multi-node conversion, `read_col_conversion_dask` and TableManager must be pickled. + # Casacore will make this difficult + global read_col_conversion + if parallel_mode == "time": + read_col_conversion = read_col_conversion_dask + else: + read_col_conversion = read_col_conversion_numpy + # Create Data Variables - col_names = tb_tool.colnames() + with table_manager.get_table() as tb_tool: + col_names = tb_tool.colnames() main_table_attrs = extract_table_attributes(in_file) main_column_descriptions = main_table_attrs["column_descriptions"] @@ -579,22 +616,24 @@ def create_data_variables( xds = get_weight( xds, col, - tb_tool, + table_manager, time_baseline_shape, tidxs, bidxs, use_table_iter, main_column_descriptions, + time_chunksize, ) else: xds[col_to_data_variable_names[col]] = xr.DataArray( read_col_conversion( - tb_tool, + table_manager, col, time_baseline_shape, tidxs, bidxs, use_table_iter, + time_chunksize, ), dims=col_dims[col], ) @@ -615,12 +654,13 @@ def create_data_variables( xds = get_weight( xds, "WEIGHT", - tb_tool, + table_manager, time_baseline_shape, tidxs, bidxs, use_table_iter, main_column_descriptions, + time_chunksize, ) @@ -653,22 +693,24 @@ def add_missing_data_var_attrs(xds): def get_weight( xds, col, - tb_tool, + table_manager, time_baseline_shape, tidxs, bidxs, use_table_iter, main_column_descriptions, + time_chunksize, ): xds[col_to_data_variable_names[col]] = xr.DataArray( np.tile( read_col_conversion( - tb_tool, + table_manager, col, time_baseline_shape, tidxs, bidxs, use_table_iter, + time_chunksize, )[:, :, None, :], (1, 1, xds.sizes["frequency"], 1), ), @@ -933,6 +975,7 @@ def convert_and_write_partition( sys_cal_interpolate: bool = False, compressor: numcodecs.abc.Codec = numcodecs.Zstd(level=2), storage_backend="zarr", + parallel_mode: str = "none", overwrite: bool = False, ): """_summary_ @@ -969,6 +1012,8 @@ def convert_and_write_partition( _description_, by default numcodecs.Zstd(level=2) storage_backend : str, optional _description_, by default "zarr" + parallel_mode : _type_, optional + _description_ overwrite : bool, optional _description_, by default False @@ -979,365 +1024,358 @@ def convert_and_write_partition( """ taql_where = create_taql_query_where(partition_info) + table_manager = TableManager(in_file, taql_where) ddi = partition_info["DATA_DESC_ID"][0] intents = str(partition_info["OBS_MODE"][0]) start = time.time() - with open_table_ro(in_file) as mtable: - taql_main = f"select * from $mtable {taql_where}" - with open_query(mtable, taql_main) as tb_tool: - - if tb_tool.nrows() == 0: - tb_tool.close() - mtable.close() - return xr.Dataset(), {}, {} + with table_manager.get_table() as tb_tool: + if tb_tool.nrows() == 0: + tb_tool.close() + return xr.Dataset(), {}, {} + + logger.debug("Starting a real convert_and_write_partition") + ( + tidxs, + bidxs, + didxs, + baseline_ant1_id, + baseline_ant2_id, + utime, + ) = calc_indx_for_row_split(tb_tool, taql_where) + time_baseline_shape = (len(utime), len(baseline_ant1_id)) + logger.debug("Calc indx for row split " + str(time.time() - start)) + + observation_id = check_if_consistent( + tb_tool.getcol("OBSERVATION_ID"), "OBSERVATION_ID" + ) - logger.debug("Starting a real convert_and_write_partition") - ( - tidxs, - bidxs, - didxs, - baseline_ant1_id, - baseline_ant2_id, - utime, - ) = calc_indx_for_row_split(tb_tool, taql_where) - time_baseline_shape = (len(utime), len(baseline_ant1_id)) - logger.debug("Calc indx for row split " + str(time.time() - start)) - - observation_id = check_if_consistent( - tb_tool.getcol("OBSERVATION_ID"), "OBSERVATION_ID" + def get_observation_info(in_file, observation_id, intents): + generic_observation_xds = load_generic_table( + in_file, + "OBSERVATION", + taql_where=f" where (ROWID() IN [{str(observation_id)}])", ) - def get_observation_info(in_file, observation_id, intents): - generic_observation_xds = load_generic_table( - in_file, - "OBSERVATION", - taql_where=f" where (ROWID() IN [{str(observation_id)}])", - ) - - if intents == "None": - intents = "obs_" + str(observation_id) - - return generic_observation_xds["TELESCOPE_NAME"].values[0], intents + if intents == "None": + intents = "obs_" + str(observation_id) + + return generic_observation_xds["TELESCOPE_NAME"].values[0], intents + + telescope_name, intents = get_observation_info(in_file, observation_id, intents) + + start = time.time() + xds = xr.Dataset( + attrs={ + "schema_version": MSV4_SCHEMA_VERSION, + "creator": { + "software_name": "xradio", + "version": importlib.metadata.version("xradio"), + }, + "creation_date": datetime.datetime.now( + datetime.timezone.utc + ).isoformat(), + "type": "visibility", + } + ) - telescope_name, intents = get_observation_info( - in_file, observation_id, intents - ) + # interval = check_if_consistent(tb_tool.getcol("INTERVAL"), "INTERVAL") + interval = tb_tool.getcol("INTERVAL") - start = time.time() - xds = xr.Dataset( - attrs={ - "schema_version": MSV4_SCHEMA_VERSION, - "creator": { - "software_name": "xradio", - "version": importlib.metadata.version("xradio"), - }, - "creation_date": datetime.datetime.now( - datetime.timezone.utc - ).isoformat(), - "type": "visibility", - } + interval_unique = unique_1d(interval) + if len(interval_unique) > 1: + logger.debug( + "Integration time (interval) not consitent in partition, using median." ) + interval = np.median(interval) + else: + interval = interval_unique[0] + + scan_id = np.full(time_baseline_shape, -42, dtype=int) + scan_id[tidxs, bidxs] = tb_tool.getcol("SCAN_NUMBER") + scan_id = np.max(scan_id, axis=1) + + xds = create_coordinates( + xds, + in_file, + ddi, + utime, + interval, + baseline_ant1_id, + baseline_ant2_id, + scan_id, + ) + logger.debug("Time create coordinates " + str(time.time() - start)) + + start = time.time() + main_chunksize = parse_chunksize(main_chunksize, "main", xds) + create_data_variables( + in_file, + xds, + table_manager, + time_baseline_shape, + tidxs, + bidxs, + didxs, + use_table_iter, + parallel_mode, + main_chunksize, + ) - # interval = check_if_consistent(tb_tool.getcol("INTERVAL"), "INTERVAL") - interval = tb_tool.getcol("INTERVAL") + # Add data_groups + xds, is_single_dish = add_data_groups(xds) + xds = add_missing_data_var_attrs(xds) - interval_unique = unique_1d(interval) - if len(interval_unique) > 1: - logger.debug( - "Integration time (interval) not consitent in partition, using median." + if ( + "WEIGHT" not in xds.data_vars + ): # Some single dish datasets don't have WEIGHT. + if is_single_dish: + xds["WEIGHT"] = xr.DataArray( + np.ones(xds.SPECTRUM.shape, dtype=np.float64), + dims=xds.SPECTRUM.dims, ) - interval = np.median(interval) else: - interval = interval_unique[0] + xds["WEIGHT"] = xr.DataArray( + np.ones(xds.VISIBILITY.shape, dtype=np.float64), + dims=xds.VISIBILITY.dims, + ) - scan_id = np.full(time_baseline_shape, -42, dtype=int) - scan_id[tidxs, bidxs] = tb_tool.getcol("SCAN_NUMBER") - scan_id = np.max(scan_id, axis=1) + logger.debug("Time create data variables " + str(time.time() - start)) - xds = create_coordinates( - xds, - in_file, - ddi, - utime, - interval, - baseline_ant1_id, - baseline_ant2_id, - scan_id, - ) - logger.debug("Time create coordinates " + str(time.time() - start)) + # To constrain the time range to load (in pointing, ephemerides, phase_cal data_vars) + time_min_max = find_min_max_times(tb_tool, taql_where) - start = time.time() - create_data_variables( - in_file, - xds, - tb_tool, - time_baseline_shape, - tidxs, - bidxs, - didxs, - use_table_iter, + # Create ant_xds + start = time.time() + feed_id = unique_1d( + np.concatenate( + [ + unique_1d(tb_tool.getcol("FEED1")), + unique_1d(tb_tool.getcol("FEED2")), + ] ) + ) + antenna_id = unique_1d( + np.concatenate( + [xds["baseline_antenna1_id"].data, xds["baseline_antenna2_id"].data] + ) + ) - # Add data_groups - xds, is_single_dish = add_data_groups(xds) - xds = add_missing_data_var_attrs(xds) - - if ( - "WEIGHT" not in xds.data_vars - ): # Some single dish datasets don't have WEIGHT. - if is_single_dish: - xds["WEIGHT"] = xr.DataArray( - np.ones(xds.SPECTRUM.shape, dtype=np.float64), - dims=xds.SPECTRUM.dims, - ) - else: - xds["WEIGHT"] = xr.DataArray( - np.ones(xds.VISIBILITY.shape, dtype=np.float64), - dims=xds.VISIBILITY.dims, - ) - - logger.debug("Time create data variables " + str(time.time() - start)) - - # To constrain the time range to load (in pointing, ephemerides, phase_cal data_vars) - time_min_max = find_min_max_times(tb_tool, taql_where) + ant_xds = create_antenna_xds( + in_file, + xds.frequency.attrs["spectral_window_id"], + antenna_id, + feed_id, + telescope_name, + xds.polarization, + ) + logger.debug("Time antenna xds " + str(time.time() - start)) - # Create ant_xds - start = time.time() - feed_id = unique_1d( - np.concatenate( - [ - unique_1d(tb_tool.getcol("FEED1")), - unique_1d(tb_tool.getcol("FEED2")), - ] - ) - ) - antenna_id = unique_1d( - np.concatenate( - [xds["baseline_antenna1_id"].data, xds["baseline_antenna2_id"].data] - ) - ) + start = time.time() + gain_curve_xds = create_gain_curve_xds( + in_file, xds.frequency.attrs["spectral_window_id"], ant_xds + ) + logger.debug("Time gain_curve xds " + str(time.time() - start)) - ant_xds = create_antenna_xds( - in_file, - xds.frequency.attrs["spectral_window_id"], - antenna_id, - feed_id, - telescope_name, - xds.polarization, - ) - logger.debug("Time antenna xds " + str(time.time() - start)) + start = time.time() + if phase_cal_interpolate: + phase_cal_interp_time = xds.time.values + else: + phase_cal_interp_time = None + phase_calibration_xds = create_phase_calibration_xds( + in_file, + xds.frequency.attrs["spectral_window_id"], + ant_xds, + time_min_max, + phase_cal_interp_time, + ) + logger.debug("Time phase_calibration xds " + str(time.time() - start)) - start = time.time() - gain_curve_xds = create_gain_curve_xds( - in_file, xds.frequency.attrs["spectral_window_id"], ant_xds - ) - logger.debug("Time gain_curve xds " + str(time.time() - start)) + # Create system_calibration_xds + start = time.time() + if sys_cal_interpolate: + sys_cal_interp_time = xds.time.values + else: + sys_cal_interp_time = None + system_calibration_xds = create_system_calibration_xds( + in_file, + xds.frequency, + ant_xds, + sys_cal_interp_time, + ) + logger.debug("Time system_calibation " + str(time.time() - start)) + # Change antenna_ids to antenna_names + with_antenna_partitioning = "ANTENNA1" in partition_info + xds = antenna_ids_to_names( + xds, ant_xds, is_single_dish, with_antenna_partitioning + ) + # but before, keep the name-id arrays, we need them for the pointing and weather xds + ant_xds_name_ids = ant_xds["antenna_name"].set_xindex("antenna_id") + ant_xds_station_name_ids = ant_xds["station"].set_xindex("antenna_id") + # No longer needed after converting to name. + ant_xds = ant_xds.drop_vars("antenna_id") + + # Create weather_xds + start = time.time() + weather_xds = create_weather_xds(in_file, ant_xds_station_name_ids) + logger.debug("Time weather " + str(time.time() - start)) + + # Create pointing_xds + pointing_xds = xr.Dataset() + if with_pointing: start = time.time() - if phase_cal_interpolate: - phase_cal_interp_time = xds.time.values + if pointing_interpolate: + pointing_interp_time = xds.time else: - phase_cal_interp_time = None - phase_calibration_xds = create_phase_calibration_xds( - in_file, - xds.frequency.attrs["spectral_window_id"], - ant_xds, - time_min_max, - phase_cal_interp_time, + pointing_interp_time = None + pointing_xds = create_pointing_xds( + in_file, ant_xds_name_ids, time_min_max, pointing_interp_time ) - logger.debug("Time phase_calibration xds " + str(time.time() - start)) - - # Create system_calibration_xds - start = time.time() - if sys_cal_interpolate: - sys_cal_interp_time = xds.time.values - else: - sys_cal_interp_time = None - system_calibration_xds = create_system_calibration_xds( - in_file, - xds.frequency, - ant_xds, - sys_cal_interp_time, + pointing_chunksize = parse_chunksize( + pointing_chunksize, "pointing", pointing_xds ) - logger.debug("Time system_calibation " + str(time.time() - start)) - - # Change antenna_ids to antenna_names - with_antenna_partitioning = "ANTENNA1" in partition_info - xds = antenna_ids_to_names( - xds, ant_xds, is_single_dish, with_antenna_partitioning + add_encoding(pointing_xds, compressor=compressor, chunks=pointing_chunksize) + logger.debug( + "Time pointing (with add compressor and chunking) " + + str(time.time() - start) ) - # but before, keep the name-id arrays, we need them for the pointing and weather xds - ant_xds_name_ids = ant_xds["antenna_name"].set_xindex("antenna_id") - ant_xds_station_name_ids = ant_xds["station"].set_xindex("antenna_id") - # No longer needed after converting to name. - ant_xds = ant_xds.drop_vars("antenna_id") - # Create weather_xds - start = time.time() - weather_xds = create_weather_xds(in_file, ant_xds_station_name_ids) - logger.debug("Time weather " + str(time.time() - start)) + start = time.time() - # Create pointing_xds - pointing_xds = xr.Dataset() - if with_pointing: - start = time.time() - if pointing_interpolate: - pointing_interp_time = xds.time - else: - pointing_interp_time = None - pointing_xds = create_pointing_xds( - in_file, ant_xds_name_ids, time_min_max, pointing_interp_time - ) - pointing_chunksize = parse_chunksize( - pointing_chunksize, "pointing", pointing_xds - ) - add_encoding( - pointing_xds, compressor=compressor, chunks=pointing_chunksize - ) - logger.debug( - "Time pointing (with add compressor and chunking) " - + str(time.time() - start) - ) + # Time and frequency should always be increasing + if len(xds.frequency) > 1 and xds.frequency[1] - xds.frequency[0] < 0: + xds = xds.sel(frequency=slice(None, None, -1)) - start = time.time() - - # Time and frequency should always be increasing - if len(xds.frequency) > 1 and xds.frequency[1] - xds.frequency[0] < 0: - xds = xds.sel(frequency=slice(None, None, -1)) + if len(xds.time) > 1 and xds.time[1] - xds.time[0] < 0: + xds = xds.sel(time=slice(None, None, -1)) - if len(xds.time) > 1 and xds.time[1] - xds.time[0] < 0: - xds = xds.sel(time=slice(None, None, -1)) - - # Create field_and_source_xds (combines field, source and ephemeris data into one super dataset) - start = time.time() - if ephemeris_interpolate: - ephemeris_interp_time = xds.time.values - else: - ephemeris_interp_time = None - - # if "FIELD_ID" not in partition_scheme: - # field_id = np.full(time_baseline_shape, -42, dtype=int) - # field_id[tidxs, bidxs] = tb_tool.getcol("FIELD_ID") - # field_id = np.max(field_id, axis=1) - # field_times = utime - # else: - # field_id = check_if_consistent(tb_tool.getcol("FIELD_ID"), "FIELD_ID") - # field_times = None - - field_id = np.full( - time_baseline_shape, -42, dtype=int - ) # -42 used for missing baselines - field_id[tidxs, bidxs] = tb_tool.getcol("FIELD_ID") - field_id = np.max(field_id, axis=1) - field_times = xds.time.values - - # col_unique = unique_1d(col) - # assert len(col_unique) == 1, col_name + " is not consistent." - # return col_unique[0] - - field_and_source_xds, source_id, _num_lines, field_names = ( - create_field_and_source_xds( - in_file, - field_id, - xds.frequency.attrs["spectral_window_id"], - field_times, - is_single_dish, - time_min_max, - ephemeris_interpolate, - ) + # Create field_and_source_xds (combines field, source and ephemeris data into one super dataset) + start = time.time() + if ephemeris_interpolate: + ephemeris_interp_time = xds.time.values + else: + ephemeris_interp_time = None + + # if "FIELD_ID" not in partition_scheme: + # field_id = np.full(time_baseline_shape, -42, dtype=int) + # field_id[tidxs, bidxs] = tb_tool.getcol("FIELD_ID") + # field_id = np.max(field_id, axis=1) + # field_times = utime + # else: + # field_id = check_if_consistent(tb_tool.getcol("FIELD_ID"), "FIELD_ID") + # field_times = None + + field_id = np.full( + time_baseline_shape, -42, dtype=int + ) # -42 used for missing baselines + field_id[tidxs, bidxs] = tb_tool.getcol("FIELD_ID") + field_id = np.max(field_id, axis=1) + field_times = xds.time.values + + # col_unique = unique_1d(col) + # assert len(col_unique) == 1, col_name + " is not consistent." + # return col_unique[0] + + field_and_source_xds, source_id, _num_lines, field_names = ( + create_field_and_source_xds( + in_file, + field_id, + xds.frequency.attrs["spectral_window_id"], + field_times, + is_single_dish, + time_min_max, + ephemeris_interpolate, ) + ) - logger.debug("Time field_and_source_xds " + str(time.time() - start)) - - xds = fix_uvw_frame(xds, field_and_source_xds, is_single_dish) - xds = xds.assign_coords({"field_name": ("time", field_names)}) - - partition_info_misc_fields = { - "scan_name": xds.coords["scan_name"].data, - "intents": intents, - "taql_where": taql_where, - } - if with_antenna_partitioning: - partition_info_misc_fields["antenna_name"] = xds.coords[ - "antenna_name" - ].data[0] - info_dicts = create_info_dicts( - in_file, xds, field_and_source_xds, partition_info_misc_fields, tb_tool - ) - xds.attrs.update(info_dicts) + logger.debug("Time field_and_source_xds " + str(time.time() - start)) + + xds = fix_uvw_frame(xds, field_and_source_xds, is_single_dish) + xds = xds.assign_coords({"field_name": ("time", field_names)}) + + partition_info_misc_fields = { + "scan_name": xds.coords["scan_name"].data, + "intents": intents, + "taql_where": taql_where, + } + if with_antenna_partitioning: + partition_info_misc_fields["antenna_name"] = xds.coords[ + "antenna_name" + ].data[0] + info_dicts = create_info_dicts( + in_file, xds, field_and_source_xds, partition_info_misc_fields, tb_tool + ) + xds.attrs.update(info_dicts) - # xds ready, prepare to write - start = time.time() - main_chunksize = parse_chunksize(main_chunksize, "main", xds) - add_encoding(xds, compressor=compressor, chunks=main_chunksize) - logger.debug("Time add compressor and chunk " + str(time.time() - start)) + # xds ready, prepare to write + start = time.time() + add_encoding(xds, compressor=compressor, chunks=main_chunksize) + logger.debug("Time add compressor and chunk " + str(time.time() - start)) - file_name = os.path.join( - out_file, - pathlib.Path(in_file).name.replace(".ms", "") + "_" + str(ms_v4_id), - ) + file_name = os.path.join( + out_file, + pathlib.Path(in_file).name.replace(".ms", "") + "_" + str(ms_v4_id), + ) - if overwrite: - mode = "w" - else: - mode = "w-" + if overwrite: + mode = "w" + else: + mode = "w-" - if is_single_dish: - xds.attrs["type"] = "spectrum" - xds = xds.drop_vars(["UVW"]) - del xds["uvw_label"] + if is_single_dish: + xds.attrs["type"] = "spectrum" + xds = xds.drop_vars("UVW") + xds = xds.drop_dims("uvw_label") + else: + if xds.attrs["processor_info"]["type"] == "RADIOMETER": + xds.attrs["type"] = "radiometer" else: - if xds.attrs["processor_info"]["type"] == "RADIOMETER": - xds.attrs["type"] = "radiometer" - else: - xds.attrs["type"] = "visibility" - - import sys - - start = time.time() - if storage_backend == "zarr": - xds.to_zarr(store=os.path.join(file_name, "correlated_xds"), mode=mode) - ant_xds.to_zarr(store=os.path.join(file_name, "antenna_xds"), mode=mode) - for group_name in xds.attrs["data_groups"]: - field_and_source_xds.to_zarr( - store=os.path.join( - file_name, f"field_and_source_xds_{group_name}" - ), - mode=mode, - ) + xds.attrs["type"] = "visibility" + + import sys + + start = time.time() + if storage_backend == "zarr": + xds.to_zarr(store=os.path.join(file_name, "correlated_xds"), mode=mode) + ant_xds.to_zarr(store=os.path.join(file_name, "antenna_xds"), mode=mode) + for group_name in xds.attrs["data_groups"]: + field_and_source_xds.to_zarr( + store=os.path.join(file_name, f"field_and_source_xds_{group_name}"), + mode=mode, + ) - if with_pointing and len(pointing_xds.data_vars) > 0: - pointing_xds.to_zarr( - store=os.path.join(file_name, "pointing_xds"), mode=mode - ) + if with_pointing and len(pointing_xds.data_vars) > 0: + pointing_xds.to_zarr( + store=os.path.join(file_name, "pointing_xds"), mode=mode + ) - if system_calibration_xds: - system_calibration_xds.to_zarr( - store=os.path.join(file_name, "system_calibration_xds"), - mode=mode, - ) + if system_calibration_xds: + system_calibration_xds.to_zarr( + store=os.path.join(file_name, "system_calibration_xds"), + mode=mode, + ) - if gain_curve_xds: - gain_curve_xds.to_zarr( - store=os.path.join(file_name, "gain_curve_xds"), mode=mode - ) + if gain_curve_xds: + gain_curve_xds.to_zarr( + store=os.path.join(file_name, "gain_curve_xds"), mode=mode + ) - if phase_calibration_xds: - phase_calibration_xds.to_zarr( - store=os.path.join(file_name, "phase_calibration_xds"), - mode=mode, - ) + if phase_calibration_xds: + phase_calibration_xds.to_zarr( + store=os.path.join(file_name, "phase_calibration_xds"), + mode=mode, + ) - if weather_xds: - weather_xds.to_zarr( - store=os.path.join(file_name, "weather_xds"), mode=mode - ) + if weather_xds: + weather_xds.to_zarr( + store=os.path.join(file_name, "weather_xds"), mode=mode + ) - elif storage_backend == "netcdf": - # xds.to_netcdf(path=file_name+"/MAIN", mode=mode) #Does not work - raise - logger.debug("Write data " + str(time.time() - start)) + elif storage_backend == "netcdf": + # xds.to_netcdf(path=file_name+"/MAIN", mode=mode) #Does not work + raise + logger.debug("Write data " + str(time.time() - start)) # logger.info("Saved ms_v4 " + file_name + " in " + str(time.time() - start_with) + "s") diff --git a/src/xradio/measurement_set/convert_msv2_to_processing_set.py b/src/xradio/measurement_set/convert_msv2_to_processing_set.py index 96d2485b..f60d22e7 100644 --- a/src/xradio/measurement_set/convert_msv2_to_processing_set.py +++ b/src/xradio/measurement_set/convert_msv2_to_processing_set.py @@ -62,7 +62,7 @@ def convert_msv2_to_processing_set( use_table_iter: bool = False, compressor: numcodecs.abc.Codec = numcodecs.Zstd(level=2), storage_backend: str = "zarr", - parallel: bool = False, + parallel_mode: str = "none", overwrite: bool = False, ): """Convert a Measurement Set v2 into a Processing Set of Measurement Set v4. @@ -99,14 +99,30 @@ def convert_msv2_to_processing_set( The Blosc compressor to use when saving the converted data to disk using Zarr, by default numcodecs.Zstd(level=2). storage_backend : {"zarr", "netcdf"}, optional The on-disk format to use. "netcdf" is not yet implemented. - parallel : bool, optional - Makes use of Dask to execute conversion in parallel, by default False. + parallel_mode : {"none", "partition", "time"}, optional + Choose whether to use Dask to execute conversion in parallel, by default "none" and conversion occurs serially. + The option "partition", parallelises the conversion over partitions specified by `partition_scheme`. The option "time" can only be used for phased array interferometers where there are no partitions + in the MS v2; instead the MS v2 is parallelised along the time dimension and can be controlled by `main_chunksize`. overwrite : bool, optional Whether to overwrite an existing processing set, by default False. """ + # Check `parallel_mode` is valid + try: + assert parallel_mode in ["none", "partition", "time"] + except AssertionError: + logger.warning( + f"`parallel_mode` {parallel_mode} not recognosed. Defauling to 'none'." + ) + parallel_mode = "none" + partitions = create_partitions(in_file, partition_scheme=partition_scheme) logger.info("Number of partitions: " + str(len(partitions))) + if parallel_mode == "time": + assert ( + len(partitions) == 1 + ), "MS v2 contains more than one partition. `parallel_mode = 'time'` not valid." + delayed_list = [] for ms_v4_id, partition_info in enumerate(partitions): @@ -132,7 +148,7 @@ def convert_msv2_to_processing_set( # prepend '0' to ms_v4_id as needed ms_v4_id = f"{ms_v4_id:0>{len(str(len(partitions) - 1))}}" - if parallel: + if parallel_mode == "partition": delayed_list.append( dask.delayed(convert_and_write_partition)( in_file, @@ -149,6 +165,7 @@ def convert_msv2_to_processing_set( phase_cal_interpolate=phase_cal_interpolate, sys_cal_interpolate=sys_cal_interpolate, compressor=compressor, + parallel_mode=parallel_mode, overwrite=overwrite, ) ) @@ -168,8 +185,9 @@ def convert_msv2_to_processing_set( phase_cal_interpolate=phase_cal_interpolate, sys_cal_interpolate=sys_cal_interpolate, compressor=compressor, + parallel_mode=parallel_mode, overwrite=overwrite, ) - if parallel: + if parallel_mode == "partition": dask.compute(delayed_list) diff --git a/tests/stakeholder/test_measure_set_stakeholder.py b/tests/stakeholder/test_measure_set_stakeholder.py index 60042dc7..50efc46b 100644 --- a/tests/stakeholder/test_measure_set_stakeholder.py +++ b/tests/stakeholder/test_measure_set_stakeholder.py @@ -69,7 +69,7 @@ def download_and_convert_msv2_to_processing_set(msv2_name, folder, partition_sch # sys_cal_interpolate=True, use_table_iter=False, overwrite=True, - parallel=False, + parallel_mode="none", ) return ps_name @@ -497,7 +497,7 @@ def test_gmrt(tmp_path): """ ALMA_uid___A002_X1003af4_X75a3.split.avg.ms: An ephemeris mosaic observation of the sun. -ALMA archive file downloaded: https://almascience.nrao.edu/dataPortal/2022.A.00001.S_uid___A002_X1003af4_X75a3.asdm.sdm.tar +ALMA archive file downloaded: https://almascience.nrao.edu/dataPortal/2022.A.00001.S_uid___A002_X1003af4_X75a3.asdm.sdm.tar - Project: 2022.A.00001.S - Member ous id (MOUS): uid://A001/X3571/X130 @@ -513,7 +513,7 @@ def test_gmrt(tmp_path): for subtable in ['FLAG_CMD', 'POINTING', 'CALDEVICE', 'ASDM_CALATMOSPHERE']: tb.open('ALMA_uid___A002_X1003af4_X75a3.split.avg.ms::'+subtable,nomodify=False) - tb.removerows(np.arange(tb.nrows())) + tb.removerows(np.arange(tb.nrows())) tb.flush() tb.done() ```