diff --git a/HISTORY.rst b/HISTORY.rst index f3a0d17e..a6ee2f3a 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -10,8 +10,9 @@ History * Add support for custom RV subclasses; * Use HRU_ID instead of SubId in BasinMaker reservoirs extraction logic; * Added support for Python 3.12 and dropped support for Python3.8. -* Upgraded `raven-hydro` to v0.3.0 and `RavenHydroFramework` to v3.8. +* Added support for `raven-hydro` v0.3.0 and `RavenHydroFramework` to v3.8. * `ravenpy` now requires `xclim` >= v0.48.2, `xarray` >= v2023.11.0, and `pandas` >= 2.2.0. +* Now automatically filters HRUs based on the ``hru_type``. Internal changes ^^^^^^^^^^^^^^^^ diff --git a/docs/notebooks/03_Extracting_forcing_data.ipynb b/docs/notebooks/03_Extracting_forcing_data.ipynb index 757394e7..80844c9b 100644 --- a/docs/notebooks/03_Extracting_forcing_data.ipynb +++ b/docs/notebooks/03_Extracting_forcing_data.ipynb @@ -26,6 +26,7 @@ "\n", "import fsspec # noqa\n", "import intake\n", + "import numpy as np\n", "import s3fs # noqa\n", "import xarray as xr\n", "from clisops.core import subset\n", @@ -207,7 +208,7 @@ " ERA5_pr = ERA5_pr.mean({\"latitude\", \"longitude\"})\n", "\n", " # Ensure that the precipitation is non-negative, which can happen with some reanalysis models.\n", - " ERA5_pr[ERA5_pr < 0] = 0\n", + " ERA5_pr = np.maximum(ERA5_pr, 0)\n", "\n", " # Transform them to a dataset such that they can be written with attributes to netcdf\n", " ERA5_tmin = ERA5_tmin.to_dataset(name=\"tmin\", promote_attrs=True)\n", @@ -282,7 +283,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.10.13" }, "nbdime-conflicts": { "local_diff": [ diff --git a/environment-rtd.yml b/environment-rtd.yml index 8331da1e..5b171329 100644 --- a/environment-rtd.yml +++ b/environment-rtd.yml @@ -4,7 +4,7 @@ channels: - defaults dependencies: - python >=3.9,<3.10 # fixed to reduce solver time - - raven-hydro ==0.2.4 # FIXME: Update when raven-hydro 0.3.0 is available on conda-forge + - raven-hydro >=0.2.4,<1.0 - autodoc-pydantic - click # - clisops # mocked diff --git a/environment.yml b/environment.yml index 6f8dc0a0..cb403252 100644 --- a/environment.yml +++ b/environment.yml @@ -4,7 +4,7 @@ channels: - defaults dependencies: - python >=3.9,<3.13 - - raven-hydro ==0.2.4 # FIXME: Update when raven-hydro 0.3.0 is available on conda-forge + - raven-hydro >=0.2.4,<1.0 - libgcc # for mixing raven-hydro from PyPI with conda environments - affine - black >=24.2.0 diff --git a/pyproject.toml b/pyproject.toml index 7935fab4..a158c71c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ dependencies = [ "platformdirs", "pydantic >=2.0", "pymbolic", - "raven-hydro ==0.3.0", + "raven-hydro >=0.2.4,<1.0", "requests", "scipy", "spotpy", diff --git a/ravenpy/config/base.py b/ravenpy/config/base.py index 93a083b1..0c123cb3 100644 --- a/ravenpy/config/base.py +++ b/ravenpy/config/base.py @@ -1,3 +1,4 @@ +import typing from enum import Enum from textwrap import dedent, indent from typing import Any, Dict, Optional, Sequence, Tuple, Union diff --git a/ravenpy/config/commands.py b/ravenpy/config/commands.py index 1daefb09..f4e1b827 100644 --- a/ravenpy/config/commands.py +++ b/ravenpy/config/commands.py @@ -12,6 +12,7 @@ Tuple, Union, get_args, + get_origin, no_type_check, ) @@ -44,7 +45,7 @@ RootRecord, Sym, ) -from .utils import filter_for, nc_specs +from .utils import filter_for, get_annotations, nc_specs """ Raven commands @@ -315,6 +316,38 @@ class HRUs(ListCommand): ] ) + @field_validator("root", mode="before") + @classmethod + def ignore_unrecognized_hrus(cls, values): + """Ignore HRUs with unrecognized hru_type. + + HRUs are ignored only if all allowed HRU classes define `hru_type`, and if the values passed include it. + """ + import collections + import warnings + + a = cls.model_fields["root"].annotation + + # Annotation should be a sequence + if get_origin(a) != collections.abc.Sequence: + return values + + # Extract allowed HRU types + allowed = [hru.model_fields["hru_type"].default for hru in get_annotations(a)] + + # If some HRU classes do not define rhu_type, skip filtering + if None in allowed: + return values + + allowed.append(None) + + out = [value for value in values if getattr(value, "hru_type", None) in allowed] + if len(out) != len(values): + warnings.warn( + "HRUs with an unrecognized `hru_type` attribute were ignored." + ) + return out + class HRUGroup(FlatCommand): class _Rec(RootRecord): diff --git a/ravenpy/config/emulators/hbvec.py b/ravenpy/config/emulators/hbvec.py index 32700a7d..3ee3a44a 100644 --- a/ravenpy/config/emulators/hbvec.py +++ b/ravenpy/config/emulators/hbvec.py @@ -1,4 +1,5 @@ from typing import Dict, Literal, Sequence, Union +from warnings import warn from pydantic import Field, field_validator from pydantic.dataclasses import dataclass diff --git a/ravenpy/config/utils.py b/ravenpy/config/utils.py index bc9919e4..0ffba335 100644 --- a/ravenpy/config/utils.py +++ b/ravenpy/config/utils.py @@ -1,4 +1,5 @@ import os +import typing from typing import Optional, Sequence, Union import cf_xarray # noqa: F401 @@ -179,3 +180,13 @@ def get_average_annual_runoff( q_year = np.mean(q_year) # [mm/yr] mean over all years in record return q_year + + +def get_annotations(a): + """Return all annotations inside [] or Union[...].""" + + for arg in typing.get_args(a): + if typing.get_origin(arg) == Union: + yield from get_annotations(arg) + else: + yield arg diff --git a/tests/test_rvs.py b/tests/test_rvs.py index 3193e4ab..0354c361 100644 --- a/tests/test_rvs.py +++ b/tests/test_rvs.py @@ -139,3 +139,15 @@ class MyConfig(myRVI, cls): conf.write_rv(workdir=tmp_path) assert conf.run_name == "myRunName" assert "myRunName" in conf._rv("RVI") + + +def test_hru_filter(): + """Test that unrecognized HRU types are filtered out.""" + from ravenpy.config.emulators.gr4jcn import LakeHRU + from ravenpy.config.emulators.hbvec import HRUs, LandHRU + + with pytest.warns(UserWarning): + hrus = HRUs([LandHRU(), LandHRU(), LakeHRU()]) + + # The GR4J lake HRU is not part of the HBVEC config. + assert len(hrus.root) == 2