Skip to content

Commit

Permalink
Merge pull request #384 from ecmwf/filter-by-keys-list
Browse files Browse the repository at this point in the history
provide list for filter_by_keys
  • Loading branch information
EddyCMWF authored Jun 26, 2024
2 parents b098bc5 + 151037b commit 011a5d1
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 4 deletions.
6 changes: 3 additions & 3 deletions cfgrib/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,9 @@ def get_values_in_order(message, shape):
class OnDiskArray:
index: abc.Index[T.Any, abc.Field]
shape: T.Tuple[int, ...]
field_id_index: T.Dict[
T.Tuple[T.Any, ...], T.List[T.Union[int, T.Tuple[int, int]]]
] = attr.attrib(repr=False)
field_id_index: T.Dict[T.Tuple[T.Any, ...], T.List[T.Union[int, T.Tuple[int, int]]]] = (
attr.attrib(repr=False)
)
missing_value: float
geo_ndim: int = attr.attrib(default=1, repr=False)
dtype = np.dtype("float32")
Expand Down
5 changes: 4 additions & 1 deletion cfgrib/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,10 @@ def subindex(self, filter_by_keys={}, **query):
field_ids_index = []
for header_values, field_ids_values in self.field_ids_index:
for idx, val in raw_query:
if header_values[idx] != val:
# Ensure that the values to be tested is a list or tuple
if not isinstance(val, (list, tuple)):
val = [val]
if header_values[idx] not in val:
break
else:
field_ids_index.append((header_values, field_ids_values))
Expand Down
Binary file added tests/sample-data/era5-levels-members.nc
Binary file not shown.
20 changes: 20 additions & 0 deletions tests/test_30_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
TEST_DATA_SCALAR_TIME = os.path.join(SAMPLE_DATA_FOLDER, "era5-single-level-scalar-time.grib")
TEST_DATA_ALTERNATE_ROWS = os.path.join(SAMPLE_DATA_FOLDER, "alternate-scanning.grib")
TEST_DATA_MISSING_VALS = os.path.join(SAMPLE_DATA_FOLDER, "fields_with_missing_values.grib")
TEST_DATA_MULTI_PARAMS = os.path.join(SAMPLE_DATA_FOLDER, "multi_param_on_multi_dims.grib")


def test_enforce_unique_attributes() -> None:
Expand Down Expand Up @@ -340,11 +341,30 @@ def test_open_fieldset_ignore_keys() -> None:
assert "GRIB_subCentre" not in res.attributes

def test_open_file() -> None:
res = dataset.open_file(TEST_DATA)

assert "t" in res.variables
assert "z" in res.variables


def test_open_file_filter_by_keys() -> None:
res = dataset.open_file(TEST_DATA, filter_by_keys={"shortName": "t"})

assert "t" in res.variables
assert "z" not in res.variables

res = dataset.open_file(TEST_DATA_MULTI_PARAMS)

assert "t" in res.variables
assert "z" in res.variables
assert "u" in res.variables

res = dataset.open_file(TEST_DATA_MULTI_PARAMS, filter_by_keys={"shortName": ["t", "z"]})

assert "t" in res.variables
assert "z" in res.variables
assert "u" not in res.variables


def test_alternating_rows() -> None:
res = dataset.open_file(TEST_DATA_ALTERNATE_ROWS)
Expand Down
25 changes: 25 additions & 0 deletions tests/test_50_xarray_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
SAMPLE_DATA_FOLDER = os.path.join(os.path.dirname(__file__), "sample-data")
TEST_DATA = os.path.join(SAMPLE_DATA_FOLDER, "regular_ll_sfc.grib")
TEST_DATA_MISSING_VALS = os.path.join(SAMPLE_DATA_FOLDER, "fields_with_missing_values.grib")
TEST_DATA_MULTI_PARAMS = os.path.join(SAMPLE_DATA_FOLDER, "multi_param_on_multi_dims.grib")


def test_plugin() -> None:
Expand All @@ -29,6 +30,30 @@ def test_xr_open_dataset_file() -> None:
assert list(ds.data_vars) == ["skt"]


def test_xr_open_dataset_file_filter_by_keys() -> None:
ds = xr.open_dataset(TEST_DATA_MULTI_PARAMS, engine="cfgrib")

assert "t" in ds.data_vars
assert "z" in ds.data_vars
assert "u" in ds.data_vars

ds = xr.open_dataset(
TEST_DATA_MULTI_PARAMS, engine="cfgrib", filter_by_keys={"shortName": "t"}
)

assert "t" in ds.data_vars
assert "z" not in ds.data_vars
assert "u" not in ds.data_vars

ds = xr.open_dataset(
TEST_DATA_MULTI_PARAMS, engine="cfgrib", filter_by_keys={"shortName": ["t", "z"]}
)

assert "t" in ds.data_vars
assert "z" in ds.data_vars
assert "u" not in ds.data_vars


def test_xr_open_dataset_file_ignore_keys() -> None:
ds = xr.open_dataset(TEST_DATA, engine="cfgrib")
assert "GRIB_typeOfLevel" in ds["skt"].attrs
Expand Down

0 comments on commit 011a5d1

Please sign in to comment.