diff --git a/cfgrib/dataset.py b/cfgrib/dataset.py index f4054d1c..900c2840 100644 --- a/cfgrib/dataset.py +++ b/cfgrib/dataset.py @@ -332,9 +332,9 @@ def get_values_in_order(message, shape): class OnDiskArray: index: abc.Index[T.Any, abc.Field] shape: T.Tuple[int, ...] - field_id_index: T.Dict[ - T.Tuple[T.Any, ...], T.List[T.Union[int, T.Tuple[int, int]]] - ] = attr.attrib(repr=False) + field_id_index: T.Dict[T.Tuple[T.Any, ...], T.List[T.Union[int, T.Tuple[int, int]]]] = ( + attr.attrib(repr=False) + ) missing_value: float geo_ndim: int = attr.attrib(default=1, repr=False) dtype = np.dtype("float32") diff --git a/cfgrib/messages.py b/cfgrib/messages.py index f7d725fb..d2d2000a 100644 --- a/cfgrib/messages.py +++ b/cfgrib/messages.py @@ -468,7 +468,10 @@ def subindex(self, filter_by_keys={}, **query): field_ids_index = [] for header_values, field_ids_values in self.field_ids_index: for idx, val in raw_query: - if header_values[idx] != val: + # Ensure that the values to be tested is a list or tuple + if not isinstance(val, (list, tuple)): + val = [val] + if header_values[idx] not in val: break else: field_ids_index.append((header_values, field_ids_values)) diff --git a/tests/sample-data/era5-levels-members.nc b/tests/sample-data/era5-levels-members.nc new file mode 100644 index 00000000..9e52bd77 Binary files /dev/null and b/tests/sample-data/era5-levels-members.nc differ diff --git a/tests/test_30_dataset.py b/tests/test_30_dataset.py index a61cde53..9914e482 100644 --- a/tests/test_30_dataset.py +++ b/tests/test_30_dataset.py @@ -13,6 +13,7 @@ TEST_DATA_SCALAR_TIME = os.path.join(SAMPLE_DATA_FOLDER, "era5-single-level-scalar-time.grib") TEST_DATA_ALTERNATE_ROWS = os.path.join(SAMPLE_DATA_FOLDER, "alternate-scanning.grib") TEST_DATA_MISSING_VALS = os.path.join(SAMPLE_DATA_FOLDER, "fields_with_missing_values.grib") +TEST_DATA_MULTI_PARAMS = os.path.join(SAMPLE_DATA_FOLDER, "multi_param_on_multi_dims.grib") def test_enforce_unique_attributes() -> None: @@ -340,11 +341,30 @@ def test_open_fieldset_ignore_keys() -> None: assert "GRIB_subCentre" not in res.attributes def test_open_file() -> None: + res = dataset.open_file(TEST_DATA) + + assert "t" in res.variables + assert "z" in res.variables + + +def test_open_file_filter_by_keys() -> None: res = dataset.open_file(TEST_DATA, filter_by_keys={"shortName": "t"}) assert "t" in res.variables assert "z" not in res.variables + res = dataset.open_file(TEST_DATA_MULTI_PARAMS) + + assert "t" in res.variables + assert "z" in res.variables + assert "u" in res.variables + + res = dataset.open_file(TEST_DATA_MULTI_PARAMS, filter_by_keys={"shortName": ["t", "z"]}) + + assert "t" in res.variables + assert "z" in res.variables + assert "u" not in res.variables + def test_alternating_rows() -> None: res = dataset.open_file(TEST_DATA_ALTERNATE_ROWS) diff --git a/tests/test_50_xarray_plugin.py b/tests/test_50_xarray_plugin.py index 388e68cb..16d63dd8 100644 --- a/tests/test_50_xarray_plugin.py +++ b/tests/test_50_xarray_plugin.py @@ -10,6 +10,7 @@ SAMPLE_DATA_FOLDER = os.path.join(os.path.dirname(__file__), "sample-data") TEST_DATA = os.path.join(SAMPLE_DATA_FOLDER, "regular_ll_sfc.grib") TEST_DATA_MISSING_VALS = os.path.join(SAMPLE_DATA_FOLDER, "fields_with_missing_values.grib") +TEST_DATA_MULTI_PARAMS = os.path.join(SAMPLE_DATA_FOLDER, "multi_param_on_multi_dims.grib") def test_plugin() -> None: @@ -29,6 +30,30 @@ def test_xr_open_dataset_file() -> None: assert list(ds.data_vars) == ["skt"] +def test_xr_open_dataset_file_filter_by_keys() -> None: + ds = xr.open_dataset(TEST_DATA_MULTI_PARAMS, engine="cfgrib") + + assert "t" in ds.data_vars + assert "z" in ds.data_vars + assert "u" in ds.data_vars + + ds = xr.open_dataset( + TEST_DATA_MULTI_PARAMS, engine="cfgrib", filter_by_keys={"shortName": "t"} + ) + + assert "t" in ds.data_vars + assert "z" not in ds.data_vars + assert "u" not in ds.data_vars + + ds = xr.open_dataset( + TEST_DATA_MULTI_PARAMS, engine="cfgrib", filter_by_keys={"shortName": ["t", "z"]} + ) + + assert "t" in ds.data_vars + assert "z" in ds.data_vars + assert "u" not in ds.data_vars + + def test_xr_open_dataset_file_ignore_keys() -> None: ds = xr.open_dataset(TEST_DATA, engine="cfgrib") assert "GRIB_typeOfLevel" in ds["skt"].attrs