diff --git a/damnit/ctxsupport/ctxrunner.py b/damnit/ctxsupport/ctxrunner.py index 0253751b..e4b991e8 100644 --- a/damnit/ctxsupport/ctxrunner.py +++ b/damnit/ctxsupport/ctxrunner.py @@ -28,6 +28,7 @@ import xarray as xr import yaml +from damnit_h5write import SummaryToWrite, ToWrite, writer_threads from damnit_ctx import RunData, Variable, Cell, isinstance_no_import log = logging.getLogger(__name__) @@ -263,10 +264,12 @@ def filter(self, run_data=RunData.ALL, cluster=True, name_matches=(), variables= return ContextFile(new_vars, self.code) - def execute(self, run_data, run_number, proposal, input_vars) -> 'Results': + def execute(self, run_data, run_number, proposal, input_vars, writers=()) -> 'Results': res = {'start_time': Cell(np.asarray(get_start_time(run_data)))} mymdc = None + self.queue_write('start_time', res['start_time'], writers) + for name in self.ordered_vars(): t0 = time.perf_counter() var = self.vars[name] @@ -340,8 +343,79 @@ def execute(self, run_data, run_number, proposal, input_vars) -> 'Results': t1 = time.perf_counter() log.info("Computed %s in %.03f s", name, t1 - t0) res[name] = data + + self.queue_write(name, data, writers) + return Results(res, self) + def queue_write(self, name, cell, writers): + summary_val, opts = self._prepare_hdf5(self._summarise(cell)) + summary_entry = SummaryToWrite(name, summary_val, cell.summary_attrs(), opts) + + ds_attrs, group_attrs = {}, {} + OBJ_TYPE_HINT = '_damnit_objtype' + obj = cell.data + if isinstance(obj, (xr.DataArray, xr.Dataset)): + if isinstance(obj, xr.DataArray): + # HDF5 doesn't allow slashes in names :( + if obj.name is not None and "/" in obj.name: + obj.name = obj.name.replace("/", "_") + obj = _set_encoding(obj) + group_attrs[OBJ_TYPE_HINT] = DataType.DataArray.value + else: # Dataset + vars_names = {} + for var_name, dataarray in obj.items(): + if var_name is not None and "/" in var_name: + vars_names[var_name] = var_name.replace("/", "_") + _set_encoding(dataarray) + obj = obj.rename_vars(vars_names) + group_attrs[OBJ_TYPE_HINT] = DataType.Dataset.value + + data_entry = ToWrite(name, obj, group_attrs) + else: + if isinstance_no_import(obj, 'matplotlib.figure', 'Figure'): + value = figure2array(obj) + group_attrs[OBJ_TYPE_HINT] = DataType.Image.value + elif isinstance_no_import(obj, 'plotly.graph_objs', 'Figure'): + # we want to compresss plotly figures in HDF5 files + # so we need to convert the data to array of uint8 + value = np.frombuffer(obj.to_json().encode('utf-8'), dtype=np.uint8) + group_attrs[OBJ_TYPE_HINT] = DataType.PlotlyFigure.value + elif isinstance(obj, str): + value = obj + else: + value = np.asarray(obj) + + arr, compression_opts = self._prepare_hdf5(value) + data_entry = ToWrite(name, arr, group_attrs, compression_opts) + + for writer in writers: + writer.queue.put(summary_entry) + if not writer.reduced_only: + writer.queue.put(data_entry) + + @staticmethod + def _summarise(cell): + if (summary_val := cell.get_summary()) is not None: + return summary_val + + # If a summary wasn't specified, try some default fallbacks + return default_summary(cell.data) + + @staticmethod + def _prepare_hdf5(obj): + if isinstance(obj, str): + return np.array(obj, dtype=h5py.string_dtype()), {} + elif isinstance(obj, PNGData): # Thumbnail + return np.frombuffer(obj.data, dtype=np.uint8), {} + # Anything else should already be an array + elif obj.ndim > 0 and ( + np.issubdtype(obj.dtype, np.number) or + np.issubdtype(obj.dtype, np.bool_)): + return obj, COMPRESSION_OPTS + else: + return obj, {} + def get_start_time(xd_run): ts = xd_run.select_trains(np.s_[:1]).train_timestamps()[0] @@ -467,6 +541,32 @@ def _set_encoding(data_array: xr.DataArray) -> xr.DataArray: return data_array +def default_summary(data): + if isinstance(data, str): + return data + elif isinstance(data, xr.Dataset): + size = data.nbytes / 1e6 + return f"Dataset ({size:.2f}MB)" + elif isinstance_no_import(data, 'matplotlib.figure', 'Figure'): + # For the sake of space and memory we downsample images to a + # resolution of THUMBNAIL_SIZE pixels on the larger dimension. + image_shape = data.get_size_inches() * data.dpi + zoom_ratio = min(1, THUMBNAIL_SIZE / max(image_shape)) + return figure2png(data, dpi=(data.dpi * zoom_ratio)) + elif isinstance_no_import(data, 'plotly.graph_objs', 'Figure'): + return plotly2png(data) + + elif isinstance(data, (np.ndarray, xr.DataArray)): + if data.ndim == 0: + return data + elif data.ndim == 2: + return generate_thumbnail(np.nan_to_num(data)) + else: + return f"{data.dtype}: {data.shape}" + + return None + + class Results: def __init__(self, cells, ctx): self.cells = cells @@ -491,30 +591,7 @@ def summarise(self, name): return summary_val # If a summary wasn't specified, try some default fallbacks - data = cell.data - if isinstance(data, str): - return data - elif isinstance(data, xr.Dataset): - size = data.nbytes / 1e6 - return f"Dataset ({size:.2f}MB)" - elif isinstance_no_import(data, 'matplotlib.figure', 'Figure'): - # For the sake of space and memory we downsample images to a - # resolution of THUMBNAIL_SIZE pixels on the larger dimension. - image_shape = data.get_size_inches() * data.dpi - zoom_ratio = min(1, THUMBNAIL_SIZE / max(image_shape)) - return figure2png(data, dpi=(data.dpi * zoom_ratio)) - elif isinstance_no_import(data, 'plotly.graph_objs', 'Figure'): - return plotly2png(data) - - elif isinstance(data, (np.ndarray, xr.DataArray)): - if data.ndim == 0: - return data - elif data.ndim == 2: - return generate_thumbnail(np.nan_to_num(data)) - else: - return f"{data.dtype}: {data.shape}" - - return None + return default_summary(cell.data) def save_hdf5(self, hdf5_path, reduced_only=False): xarray_dsets = [] @@ -695,12 +772,11 @@ def main(argv=None): actual_run_data = RunData.ALL if run_data == RunData.PROC else run_data run_dc = extra_data.open_run(args.proposal, args.run, data=actual_run_data.value) - res = ctx.execute(run_dc, args.run, args.proposal, input_vars={}) + with writer_threads(args.save, args.save_reduced) as writers: + res = ctx.execute( + run_dc, args.run, args.proposal, input_vars={}, writers=writers + ) - for path in args.save: - res.save_hdf5(path) - for path in args.save_reduced: - res.save_hdf5(path, reduced_only=True) elif args.subcmd == "ctx": error_info = None diff --git a/damnit/ctxsupport/damnit_h5write.py b/damnit/ctxsupport/damnit_h5write.py new file mode 100644 index 00000000..703118e9 --- /dev/null +++ b/damnit/ctxsupport/damnit_h5write.py @@ -0,0 +1,180 @@ +import os +import fcntl +import logging +import time +from contextlib import contextmanager +from dataclasses import dataclass, field +from queue import Queue, Empty +from threading import Thread + +import h5py +import h5netcdf +import xarray as xr +from xarray.backends import H5NetCDFStore +from xarray.backends.api import dump_to_store + +log = logging.getLogger(__name__) + +@dataclass +class ToWrite: + name: str + data: object + attrs: dict + compression_opts: dict = field(default_factory=dict) + +@dataclass +class SummaryToWrite(ToWrite): + pass + + +class WriterThread(Thread): + def __init__(self, file_path, reduced_only=False): + super().__init__(daemon=True) + self.file_path = file_path + self.reduced_only = reduced_only + + self.lock_fd = os.open(file_path, os.O_RDWR | os.O_CLOEXEC | os.O_CREAT) + if os.stat(file_path).st_uid == os.getuid(): + os.chmod(file_path, 0o666) + self.have_lock = False + self.queue = Queue() + self.abort = False + self.n_reduced = 0 + self.n_main = 0 + + def stop(self, abort=False): + if abort: + self.abort = True + self.queue.put(None) + + def get_lock(self): + while True: + if self.abort: + raise SystemExit(0) # exit the thread with no traceback + try: + fcntl.lockf(self.lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB) + self.have_lock = True + return + except (PermissionError, BlockingIOError): + time.sleep(0.5) + + @contextmanager + def locked_h5_access(self): + self.get_lock() + try: + with h5py.File(self.file_path, 'r+') as h5f: + with h5netcdf.File(h5f.id, 'r+') as h5ncf: + yield h5f, h5ncf + finally: + self.have_lock = False + # Closing the file above has already released the lock; this is how + # POSIX process-associated locks work (see lockf & fcntl man pages). + # We'll do this as well to ensure the lock is released, just in case + # anything does not behave as expected. + fcntl.lockf(self.lock_fd, fcntl.LOCK_UN) + + def run(self): + try: + while True: + if (item := self.queue.get()) is None: + return + + with self.locked_h5_access() as (h5f, ncf): + while True: + self._write_one(item, h5f, ncf) + + # Try to do more writes without reopening file + try: + if (item := self.queue.get(timeout=0.2)) is None: + return + except Empty: + break # Nothing waiting; release the lock + + if self.abort: + return + finally: + os.close(self.lock_fd) + self.lock_fd = -1 + + log.info("Written %d data & %d summary variables to %s", + self.n_main, self.n_reduced, self.file_path) + + def _write_one(self, item: ToWrite, h5f: h5py.File, ncf: h5netcdf.File): + if isinstance(item, SummaryToWrite): + path = f'.reduced/{item.name}' + if path in h5f: + del h5f[path] + ds = h5f.create_dataset( + path, data=item.data, **item.compression_opts + ) + ds.attrs.update(item.attrs) + self.n_reduced += 1 + else: + if item.name in h5f: + del h5f[item.name] + + if isinstance(item.data, (xr.Dataset, xr.DataArray)): + write_xarray_object(item.data, item.name, ncf) + else: + path = f"{item.name}/data" + h5f.create_dataset( + path, data=item.data, **item.compression_opts + ) + # Add group-level attributes + h5f[item.name].attrs.update(item.attrs) + self.n_main += 1 + + +def write_xarray_object(obj, group, ncf: h5netcdf.File): + """Write an xarray DataArray/Dataset into an h5netcdf File""" + if isinstance(obj, xr.DataArray): + obj = dataarray_to_dataset_for_netcdf(obj) + store = H5NetCDFStore(ncf, group=group, mode='a', autoclose=False) + dump_to_store(obj, store) + # Don't close the store object - that would also close the file + +def dataarray_to_dataset_for_netcdf(self: xr.DataArray): + # From xarray (DataArray.to_netcdf() method), under Apache License 2.0 + # Copyright 2014-2023, xarray Developers + from xarray.backends.api import DATAARRAY_NAME, DATAARRAY_VARIABLE + + if self.name is None: + # If no name is set then use a generic xarray name + dataset = self.to_dataset(name=DATAARRAY_VARIABLE) + elif self.name in self.coords or self.name in self.dims: + # The name is the same as one of the coords names, which netCDF + # doesn't support, so rename it but keep track of the old name + dataset = self.to_dataset(name=DATAARRAY_VARIABLE) + dataset.attrs[DATAARRAY_NAME] = self.name + else: + # No problems with the name - so we're fine! + dataset = self.to_dataset() + + return dataset + + +@contextmanager +def writer_threads(paths, reduced_paths): + threads = [ + WriterThread(path) for path in paths + ] + [ + WriterThread(path, reduced_only=True) for path in reduced_paths + ] + error = False + for thread in threads: + thread.start() + try: + yield threads + except: + error = True + raise + finally: + for thread in threads: + thread.stop(abort=error) + for thread in threads: + # If there was no error, give threads a generous amount of time + # to do any further writes. + thread.join(timeout=(5 if error else 120)) + if thread.is_alive(): + log.warning("HDF5 writer thread for %s did not stop properly", + thread.file_path) diff --git a/tests/test_backend.py b/tests/test_backend.py index abb05b1c..3f8b813d 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -681,7 +681,9 @@ def test_custom_environment(mock_db, venv, monkeypatch, qtbot): db_dir, db = mock_db monkeypatch.chdir(db_dir) - ctxrunner_deps = ["extra_data", "matplotlib", "plotly", "pyyaml", "requests"] + ctxrunner_deps = [ + "extra_data", "matplotlib", "plotly", "pyyaml", "requests", "h5netcdf" + ] # Install dependencies for ctxrunner and a light-weight package (sfollow) # that isn't in our current environment.