From 3e21763665dd82fa016f0cf53e75bb38c4c1aeef Mon Sep 17 00:00:00 2001 From: zmoon Date: Fri, 4 Nov 2022 11:08:00 -0600 Subject: [PATCH] raqms: Add `surf_only` option --- monetio/models/raqms.py | 19 ++++++++++++++----- tests/test_raqms.py | 13 +++++++++++++ 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/monetio/models/raqms.py b/monetio/models/raqms.py index 853e2b92..e19d2a68 100644 --- a/monetio/models/raqms.py +++ b/monetio/models/raqms.py @@ -8,7 +8,7 @@ import xarray as xr -def open_dataset(fname): +def open_dataset(fname, *, surf_only=False): """Open a single dataset from RAQMS output. Currently expects netCDF file format. Parameters @@ -27,14 +27,12 @@ def open_dataset(fname): ) ds = xr.open_dataset(names[0], drop_variables=["theta"]) - ds = _fix_grid(ds) - ds = _fix_time(ds) - ds = _fix_pres(ds) + ds = _fix(ds, surf_only=surf_only) return ds -def open_mfdataset(fname): +def open_mfdataset(fname, *, surf_only=False): """Open a multiple file dataset from RAQMS output. Parameters @@ -55,10 +53,21 @@ def open_mfdataset(fname): ) ds = xr.open_mfdataset(names, concat_dim="time", drop_variables=["theta"], combine="nested") + ds = _fix(ds, surf_only=surf_only) + + return ds + + +def _fix(ds, *, surf_only): ds = _fix_grid(ds) ds = _fix_time(ds) ds = _fix_pres(ds) + if surf_only: + ds = ds.isel(z=0).expand_dims("z") + + ds = ds.transpose("time", "z", "y", "x") + return ds diff --git a/tests/test_raqms.py b/tests/test_raqms.py index 9d67496c..7fb6b101 100644 --- a/tests/test_raqms.py +++ b/tests/test_raqms.py @@ -36,6 +36,8 @@ def _test_ds(ds): assert (ds["dp_pa"].mean(dim=("time", "y", "x")) > 1000).all() assert 100000 > ds["surfpres_pa"].mean() > 95000 + assert tuple(ds.o3vmr.dims) == ("time", "z", "y", "x") + def test_open_dataset(): ds = raqms.open_dataset(TEST_FP) @@ -55,3 +57,14 @@ def test_open_dataset_bad(): def test_open_mfdataset_bad(): with pytest.raises(ValueError, match="^File format "): raqms.open_mfdataset("asdf") + + +@pytest.mark.parametrize( + "fn", + ["open_dataset", "open_mfdataset"], +) +def test_surf_only(fn): + ds = getattr(raqms, fn)(TEST_FP, surf_only=True) + assert set(ds.dims) == {"time", "z", "y", "x"} + assert tuple(ds.o3vmr.dims) == ("time", "z", "y", "x") + assert ds.sizes["z"] == 1