diff --git a/cfgrib/messages.py b/cfgrib/messages.py index f7d725fb..d2d2000a 100644 --- a/cfgrib/messages.py +++ b/cfgrib/messages.py @@ -468,7 +468,10 @@ def subindex(self, filter_by_keys={}, **query): field_ids_index = [] for header_values, field_ids_values in self.field_ids_index: for idx, val in raw_query: - if header_values[idx] != val: + # Ensure that the values to be tested is a list or tuple + if not isinstance(val, (list, tuple)): + val = [val] + if header_values[idx] not in val: break else: field_ids_index.append((header_values, field_ids_values)) diff --git a/tests/test_30_dataset.py b/tests/test_30_dataset.py index 5523d3ee..9c3acbcb 100644 --- a/tests/test_30_dataset.py +++ b/tests/test_30_dataset.py @@ -13,6 +13,7 @@ TEST_DATA_SCALAR_TIME = os.path.join(SAMPLE_DATA_FOLDER, "era5-single-level-scalar-time.grib") TEST_DATA_ALTERNATE_ROWS = os.path.join(SAMPLE_DATA_FOLDER, "alternate-scanning.grib") TEST_DATA_MISSING_VALS = os.path.join(SAMPLE_DATA_FOLDER, "fields_with_missing_values.grib") +TEST_DATA_MULTI_PARAMS = os.path.join(SAMPLE_DATA_FOLDER, "multi_param_on_multi_dims.grib") def test_enforce_unique_attributes() -> None: @@ -304,10 +305,28 @@ def test_open_fieldset_computed_keys() -> None: def test_open_file() -> None: + res = dataset.open_file(TEST_DATA) + + assert "t" in res.variables + assert "z" in res.variables + +def test_open_file_filter_by_keys_list() -> None: res = dataset.open_file(TEST_DATA, filter_by_keys={"shortName": "t"}) assert "t" in res.variables assert "z" not in res.variables + + res = dataset.open_file(TEST_DATA_MULTI_PARAMS) + + assert "t" in res.variables + assert "z" in res.variables + assert "u" in res.variables + + res = dataset.open_file(TEST_DATA_MULTI_PARAMS, filter_by_keys={"shortName": ["t", "z"]}) + + assert "t" in res.variables + assert "z" in res.variables + assert "u" not in res.variables def test_alternating_rows() -> None: diff --git a/tests/test_50_xarray_plugin.py b/tests/test_50_xarray_plugin.py index d638af62..afd2560c 100644 --- a/tests/test_50_xarray_plugin.py +++ b/tests/test_50_xarray_plugin.py @@ -10,6 +10,7 @@ SAMPLE_DATA_FOLDER = os.path.join(os.path.dirname(__file__), "sample-data") TEST_DATA = os.path.join(SAMPLE_DATA_FOLDER, "regular_ll_sfc.grib") TEST_DATA_MISSING_VALS = os.path.join(SAMPLE_DATA_FOLDER, "fields_with_missing_values.grib") +TEST_DATA_MULTI_PARAMS = os.path.join(SAMPLE_DATA_FOLDER, "multi_param_on_multi_dims.grib") def test_plugin() -> None: @@ -29,6 +30,27 @@ def test_xr_open_dataset_file() -> None: assert list(ds.data_vars) == ["skt"] +def test_xr_open_dataset_file_filter_by_keys() -> None: + ds = xr.open_dataset(TEST_DATA_MULTI_PARAMS, engine="cfgrib") + + assert "t" in ds.data_vars + assert "z" in ds.data_vars + assert "u" in ds.data_vars + + ds = xr.open_dataset(TEST_DATA_MULTI_PARAMS, engine="cfgrib", filter_by_keys={"shortName": "t"}) + + assert "t" in ds.data_vars + assert "z" not in ds.data_vars + assert "u" not in ds.data_vars + + ds = xr.open_dataset(TEST_DATA_MULTI_PARAMS, engine="cfgrib", filter_by_keys={"shortName": ["t", "z"]}) + + assert "t" in ds.data_vars + assert "z" in ds.data_vars + assert "u" not in ds.data_vars + + + def test_xr_open_dataset_dict() -> None: fieldset = { -10: {