diff --git a/CHANGES.rst b/CHANGES.rst index b3f162f95..f550884d1 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -28,7 +28,8 @@ Breaking changes * The previously deprecated function ``xclim.ensembles.change_significance`` has been removed. (:pull:`1737`). * Indicators ``snw_season_length`` and ``snd_season_length`` have been modified, see above. * The `hargeaves85`/`hg85` method for the ``potential_evapotranspiration`` indicator and indice has been modified for precision and consistency with recent academic literature. (:issue:`1710`, :pull:`1723`). - +* The `__getitem__` method of ``xclim.core.indicator.Parameter`` instances has been removed. Accessing members of ``Parameters`` now uniquely uses dot notation. (:pull:`1721`). +* The obsolete function wrapper for generating Indicators ``xclim.core.utils.wrapped_partial`` has been removed. (:pull:`1721`). Bug fixes ^^^^^^^^^ @@ -39,6 +40,7 @@ Bug fixes * Fixed "agreement fraction" in ``robustness_fractions`` to distinguish between negative change and no change. Added "negative" and "changed negative" fractions (:issue:`1690`, :pull:`1711`). * ``make_criteria`` now skips columns with NaNs across all realizations. (:pull:`1713`). * Fixed bug QuantileDeltaMapping adjustment not working for seasonal grouping (:issue:`1704`, :pull:`1716`). +* The codebase has been adjusted to address several (~400) `mypy`-related errors attributable to inaccurate function call signatures and variable name shadowing. (:issue:`1719`, :pull:`1721`). Internal changes ^^^^^^^^^^^^^^^^ @@ -48,6 +50,7 @@ Internal changes * Added the `tox-gh` dependency to the development installation recipe. This will soon be required for running the `tox` test ensemble on GitHub Workflows. (:pull:`1709`). * Added the `vulture` static code analysis tool for finding dead code to the development dependency list and linters (makefile, tox and pre-commit hooks). (:pull:`1717`). * Added error message when using `xclim.indices.stats.dist_method` with `nnlf` and included note in docstring. (:issue:`1683`, :pull:`1714`). +* PEP8 rule `N802` is now enabled in the `ruff` formatter. Function names should follow `Snake case `_, with rare exceptions. (:pull:`1721`). v0.48.2 (2024-02-26) -------------------- diff --git a/docs/notebooks/extendxclim.ipynb b/docs/notebooks/extendxclim.ipynb index e046a92a2..7e190eb36 100644 --- a/docs/notebooks/extendxclim.ipynb +++ b/docs/notebooks/extendxclim.ipynb @@ -540,7 +540,6 @@ "outputs": [], "source": [ "from xclim.core.indicator import build_indicator_module, registry\n", - "from xclim.core.utils import wrapped_partial\n", "\n", "mapping = dict(\n", " egg_cooking_season=registry[\"MAXIMUM_CONSECUTIVE_WARM_DAYS\"](\n", diff --git a/environment.yml b/environment.yml index 3e01a4cbc..6fd159c02 100644 --- a/environment.yml +++ b/environment.yml @@ -53,6 +53,7 @@ dependencies: - nc-time-axis - netCDF4 >=1.4 - notebook + - pandas-stubs - platformdirs - pooch - pre-commit diff --git a/pyproject.toml b/pyproject.toml index 1a5e6e2d9..9bb8856cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ dev = [ "nbqa", "nbval", "netCDF4 >=1.4", + "pandas-stubs>=2.2", "platformdirs >=3.2", "pre-commit >=2.9", "pybtex", @@ -153,7 +154,7 @@ values = [ [tool.codespell] skip = 'xclim/data/*.json,docs/_build,docs/notebooks/xclim_training/*.ipynb,docs/references.bib,__pycache__,*.nc,*.png,*.gz,*.whl' -ignore-words-list = "absolue,astroid,bloc,bui,callendar,degreee,environnement,hanel,inferrable,lond,nam,nd,ressources,vas" +ignore-words-list = "absolue,astroid,bloc,bui,callendar,degreee,environnement,hanel,inferrable,lond,nam,nd,ressources,sie,vas" [tool.coverage.run] relative_files = true @@ -207,24 +208,20 @@ python_version = 3.9 show_error_codes = true warn_return_any = true warn_unused_configs = true +plugins = ["numpy.typing.mypy_plugin"] [[tool.mypy.overrides]] module = [ "boltons.*", - "bottleneck.*", "cftime.*", - "clisops.core.subset.*", - "dask.*", - "lmoments3.*", - "matplotlib.*", + "jsonpickle.*", "numba.*", - "numpy.*", - "pandas.*", - "pint.*", + "pytest_socket.*", "SBCK.*", "scipy.*", - "sklearn.cluster.*", - "xarray.*", + "sklearn.*", + "statsmodels.*", + "yamale.*", "yaml.*" ] ignore_missing_imports = true @@ -273,6 +270,7 @@ select = [ "D", "E", "F", + "N802", "W" ] @@ -299,7 +297,7 @@ max-complexity = 20 [tool.ruff.lint.per-file-ignores] "docs/*.py" = ["D100", "D101", "D102", "D103"] -"tests/*.py" = ["D100", "D101", "D102", "D103"] +"tests/*.py" = ["D100", "D101", "D102", "D103", "N802"] "xclim/**/__init__.py" = ["F401", "F403"] "xclim/core/indicator.py" = ["D214", "D405", "D406", "D407", "D411"] "xclim/core/locales.py" = ["E501", "W505"] diff --git a/tests/test_indicators.py b/tests/test_indicators.py index bca8bc69c..19bf80df9 100644 --- a/tests/test_indicators.py +++ b/tests/test_indicators.py @@ -225,7 +225,7 @@ def test_opt_vars(tasmin_series, tasmax_series): tx = tasmax_series(np.zeros(365)) multiOptVar(tasmin=tn, tasmax=tx) - assert multiOptVar.parameters["tasmin"]["kind"] == InputKind.OPTIONAL_VARIABLE + assert multiOptVar.parameters["tasmin"].kind == InputKind.OPTIONAL_VARIABLE def test_registering(): @@ -265,8 +265,10 @@ def test_module(): """Translations are keyed according to the module where the indicators are defined.""" assert atmos.tg_mean.__module__.split(".")[2] == "atmos" # Virtual module also are stored under xclim.indicators - assert xclim.indicators.cf.fg.__module__ == "xclim.indicators.cf" - assert xclim.indicators.icclim.GD4.__module__ == "xclim.indicators.icclim" + assert xclim.indicators.cf.fg.__module__ == "xclim.indicators.cf" # noqa: F821 + assert ( + xclim.indicators.icclim.GD4.__module__ == "xclim.indicators.icclim" + ) # noqa: F821 def test_temp_unit_conversion(tas_series): @@ -377,7 +379,7 @@ def test_multiindicator(tas_series): compute=uniindtemp_compute, ) with pytest.raises(ValueError, match="Indicator minmaxtemp4 was wrongly defined"): - tmin, tmax = ind(tas, freq="YS") + _tmin, _tmax = ind(tas, freq="YS") def test_missing(tas_series): @@ -480,7 +482,7 @@ def test_all_parameters_understood(official_indicators): for identifier, ind in official_indicators.items(): indinst = ind.get_instance() for name, param in indinst.parameters.items(): - if param["kind"] == InputKind.OTHER_PARAMETER: + if param.kind == InputKind.OTHER_PARAMETER: problems.add((identifier, name)) # this one we are ok with. if problems - { @@ -587,15 +589,15 @@ def test_parsed_doc(): assert "tas" in xclim.atmos.liquid_precip_accumulation.parameters params = xclim.atmos.drought_code.parameters - assert params["tas"]["description"] == "Noon temperature." - assert params["tas"]["units"] == "[temperature]" - assert params["tas"]["kind"] is InputKind.VARIABLE - assert params["tas"]["default"] == "tas" - assert params["snd"]["default"] is None - assert params["snd"]["kind"] is InputKind.OPTIONAL_VARIABLE - assert params["snd"]["units"] == "[length]" - assert params["season_method"]["kind"] is InputKind.STRING - assert params["season_method"]["choices"] == {"GFWED", None, "WF93", "LA08"} + assert params["tas"].description == "Noon temperature." + assert params["tas"].units == "[temperature]" + assert params["tas"].kind is InputKind.VARIABLE + assert params["tas"].default == "tas" + assert params["snd"].default is None + assert params["snd"].kind is InputKind.OPTIONAL_VARIABLE + assert params["snd"].units == "[length]" + assert params["season_method"].kind is InputKind.STRING + assert params["season_method"].choices == {"GFWED", None, "WF93", "LA08"} def test_default_formatter(): @@ -655,14 +657,14 @@ def test_input_dataset(open_dataset): ds = open_dataset("ERA5/daily_surface_cancities_1990-1993.nc") # Use defaults - out = xclim.atmos.daily_temperature_range(freq="YS", ds=ds) + _ = xclim.atmos.daily_temperature_range(freq="YS", ds=ds) # Use non-defaults (inverted on purpose) with xclim.set_options(cf_compliance="log"): - out = xclim.atmos.daily_temperature_range("tasmax", "tasmin", freq="YS", ds=ds) + _ = xclim.atmos.daily_temperature_range("tasmax", "tasmin", freq="YS", ds=ds) # Use a mix - out = xclim.atmos.daily_temperature_range(tasmax=ds.tasmax, freq="YS", ds=ds) + _ = xclim.atmos.daily_temperature_range(tasmax=ds.tasmax, freq="YS", ds=ds) # Inexistent variable: dsx = ds.drop_vars("tasmin") diff --git a/tests/test_utils.py b/tests/test_utils.py index f679cd1d9..c5e614185 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,8 +2,6 @@ # Test for utils from __future__ import annotations -from inspect import signature - import numpy as np import xarray as xr @@ -12,7 +10,6 @@ ensure_chunk_size, nan_calc_percentiles, walk_map, - wrapped_partial, ) from xclim.testing.helpers import test_timeseries as _test_timeseries @@ -24,29 +21,6 @@ def test_walk_map(): assert o["b"]["c"] == 0 -def test_wrapped_partial(): - def func(a, b=1, c=1): - """Docstring""" - return (a, b, c) - - newf = wrapped_partial(func, b=2) - assert list(signature(newf).parameters.keys()) == ["a", "c"] - assert newf(1) == (1, 2, 1) - - newf = wrapped_partial(func, suggested=dict(c=2), b=2) - assert list(signature(newf).parameters.keys()) == ["a", "c"] - assert newf(1) == (1, 2, 2) - assert newf.__doc__ == func.__doc__ - - def func(a, b=1, c=1, **kws): # pylint: disable=function-redefined - """Docstring""" - return a, b, c - - newf = wrapped_partial(func, suggested=dict(c=2), a=2, b=2) - assert list(signature(newf).parameters.keys()) == ["c", "kws"] - assert newf() == (2, 2, 2) - - def test_ensure_chunk_size(): da = xr.DataArray(np.zeros((20, 21, 20)), dims=("x", "y", "z")) diff --git a/xclim/__init__.py b/xclim/__init__.py index cd97c0107..3644e5266 100644 --- a/xclim/__init__.py +++ b/xclim/__init__.py @@ -2,10 +2,7 @@ from __future__ import annotations -try: - from importlib.resources import files as _files -except ImportError: - from importlib_resources import files as _files +import importlib.resources as _resources from xclim import indices from xclim.core import units # noqa @@ -19,15 +16,13 @@ __version__ = "0.48.3-dev.15" -_module_data = _files("xclim.data") +with _resources.as_file(_resources.files("xclim.data")) as _module_data: + # Load official locales + for filename in _module_data.glob("??.json"): + # Only select .json and not ..json + load_locale(filename, filename.stem) -# Load official locales -for filename in _module_data.glob("??.json"): - # Only select .json and not ..json - load_locale(filename, filename.stem) - - -# Virtual modules creation: -build_indicator_module_from_yaml(_module_data / "icclim", mode="raise") -build_indicator_module_from_yaml(_module_data / "anuclim", mode="raise") -build_indicator_module_from_yaml(_module_data / "cf", mode="raise") + # Virtual modules creation: + build_indicator_module_from_yaml(_module_data / "icclim", mode="raise") + build_indicator_module_from_yaml(_module_data / "anuclim", mode="raise") + build_indicator_module_from_yaml(_module_data / "cf", mode="raise") diff --git a/xclim/analog.py b/xclim/analog.py index a60c92086..bdc60f7c8 100644 --- a/xclim/analog.py +++ b/xclim/analog.py @@ -60,12 +60,12 @@ def spatial_analogs( # Create the target DataArray # drop any (sub-)index along "dist_dim" that could conflict with target, and rename it. # The drop is the simplest solution that is compatible with both xarray <=2022.3.0 and >2022.3.1 - candidates = candidates.to_array("_indices", "candidates").rename( + candidate_array = candidates.to_array("_indices", "candidates").rename( {dist_dim: "_dist_dim"} ) - if isinstance(candidates.indexes["_dist_dim"], pd.MultiIndex): - candidates = candidates.drop_vars( - ["_dist_dim"] + candidates.indexes["_dist_dim"].names, + if isinstance(candidate_array.indexes["_dist_dim"], pd.MultiIndex): + candidate_array = candidate_array.drop_vars( + ["_dist_dim"] + candidate_array.indexes["_dist_dim"].names, # in xarray <= 2022.3.0 the sub-indexes are not listed as separate coords, # instead, they are dropped when the multiindex is dropped. errors="ignore", @@ -78,8 +78,8 @@ def spatial_analogs( f"Method `{method}` is not implemented. Available methods are: {','.join(metrics.keys())}." ) from e - if candidates.chunks is not None: - candidates = candidates.chunk({"_indices": -1}) + if candidate_array.chunks is not None: + candidate_array = candidate_array.chunk({"_indices": -1}) if target_array.chunks is not None: target_array = target_array.chunk({"_indices": -1}) @@ -87,7 +87,7 @@ def spatial_analogs( diss = xr.apply_ufunc( metric_func, target_array, - candidates, + candidate_array, input_core_dims=[(dist_dim, "_indices"), ("_dist_dim", "_indices")], output_core_dims=[()], vectorize=True, diff --git a/xclim/cli.py b/xclim/cli.py index e6d8fec6f..9cddfe75c 100644 --- a/xclim/cli.py +++ b/xclim/cli.py @@ -19,12 +19,14 @@ from xclim.testing.helpers import TESTDATA_BRANCH, populate_testing_data from xclim.testing.utils import _default_cache_dir, publish_release_notes, show_versions +distributed = False try: - from dask.distributed import Client, progress # pylint: disable=ungrouped-imports + from dask.distributed import Client, progress + + distributed = True except ImportError: # Distributed is not a dependency of xclim - Client = None - progress = None + pass def _get_indicator(indicator_name): @@ -432,7 +434,7 @@ def cli(ctx, **kwargs): kwargs["input"] = kwargs["input"][0] if kwargs["dask_nthreads"] is not None: - if Client is None: + if not distributed: raise click.BadOptionUsage( "dask_nthreads", "Dask's distributed scheduler is not installed, only the " diff --git a/xclim/core/calendar.py b/xclim/core/calendar.py index 0d82955b0..3927b1075 100644 --- a/xclim/core/calendar.py +++ b/xclim/core/calendar.py @@ -9,7 +9,7 @@ import datetime as pydt from collections.abc import Sequence -from typing import Any +from typing import Any, TypeVar import cftime import numpy as np @@ -80,6 +80,9 @@ uniform_calendars = ("noleap", "all_leap", "365_day", "366_day", "360_day") +DataType = TypeVar("DataType", xr.DataArray, xr.Dataset) + + def days_in_year(year: int, calendar: str = "default") -> int: """Return the number of days in the input year according to the input calendar.""" return ( @@ -333,7 +336,7 @@ def convert_calendar( missing: Any | None = None, doy: bool | str = False, dim: str = "time", -) -> xr.DataArray | xr.Dataset: +) -> DataType: """Convert a DataArray/Dataset to another calendar using the specified method. By default, only converts the individual timestamps, does not modify any data except in dropping invalid/surplus dates or inserting missing dates. @@ -350,30 +353,32 @@ def convert_calendar( Parameters ---------- source : xr.DataArray or xr.Dataset - Input array/dataset with a time coordinate of a valid dtype (datetime64 or a cftime.datetime). + Input array/dataset with a time coordinate of a valid dtype (datetime64 or a cftime.datetime). target : xr.DataArray or str - Either a calendar name or the 1D time coordinate to convert to. - If an array is provided, the output will be reindexed using it and in that case, days in `target` - that are missing in the converted `source` are filled by `missing` (which defaults to NaN). + Either a calendar name or the 1D time coordinate to convert to. + If an array is provided, the output will be reindexed using it and in that case, days in `target` + that are missing in the converted `source` are filled by `missing` (which defaults to NaN). align_on : {None, 'date', 'year', 'random'} - Must be specified when either source or target is a `360_day` calendar, ignored otherwise. See Notes. + Must be specified when either source or target is a `360_day` calendar, ignored otherwise. See Notes. missing : Any, optional - A value to use for filling in dates in the target that were missing in the source. - If `target` is a string, default (None) is not to fill values. If it is an array, default is to fill with NaN. + A value to use for filling in dates in the target that were missing in the source. + If `target` is a string, default (None) is not to fill values. If it is an array, default is to fill with NaN. doy: bool or {'year', 'date'} - If not False, variables flagged as "dayofyear" (with a `is_dayofyear==1` attribute) are converted to the new calendar too. - Can be a string, which will be passed as the `align_on` argument of :py:func:`convert_doy`. If True, `year` is passed. + If not False, variables flagged as "dayofyear" (with a `is_dayofyear==1` attribute) are converted to the new calendar too. + Can be a string, which will be passed as the `align_on` argument of :py:func:`convert_doy`. + If True, `year` is passed. dim : str - Name of the time coordinate. + Name of the time coordinate. Returns ------- xr.DataArray or xr.Dataset - Copy of source with the time coordinate converted to the target calendar. - If `target` is given as an array, the output is reindexed to it, with fill value `missing`. - If `target` was a string and `missing` was None (default), invalid dates in the new calendar are dropped, but missing dates are not inserted. - If `target` was a string and `missing` was given, then start, end and frequency of the new time axis are inferred and - the output is reindexed to that a new array. + Copy of source with the time coordinate converted to the target calendar. + If `target` is given as an array, the output is reindexed to it, with fill value `missing`. + If `target` was a string and `missing` was None (default), invalid dates in the new calendar are dropped, + but missing dates are not inserted. + If `target` was a string and `missing` was given, then start, end and frequency of the new time axis are + inferred and the output is reindexed to that a new array. Notes ----- @@ -563,7 +568,7 @@ def interp_calendar( return out -def ensure_cftime_array(time: Sequence) -> np.ndarray: +def ensure_cftime_array(time: Sequence) -> np.ndarray | Sequence[cftime.datetime]: """Convert an input 1D array to a numpy array of cftime objects. Python's datetime are converted to cftime.DatetimeGregorian ("standard" calendar). @@ -1495,7 +1500,7 @@ def select_time( doy_bounds: tuple[int, int] | None = None, date_bounds: tuple[str, str] | None = None, include_bounds: bool | tuple[bool, bool] = True, -) -> xr.DataArray | xr.Dataset: +) -> DataType: """Select entries according to a time period. This conveniently improves xarray's :py:meth:`xarray.DataArray.where` and @@ -1557,16 +1562,16 @@ def select_time( if N == 0: return da - def get_doys(start, end, inclusive): - if start <= end: - doys = np.arange(start, end + 1) + def _get_doys(_start, _end, _inclusive): + if _start <= _end: + _doys = np.arange(_start, _end + 1) else: - doys = np.concatenate((np.arange(start, 367), np.arange(0, end + 1))) - if not inclusive[0]: - doys = doys[1:] - if not inclusive[1]: - doys = doys[:-1] - return doys + _doys = np.concatenate((np.arange(_start, 367), np.arange(0, _end + 1))) + if not _inclusive[0]: + _doys = _doys[1:] + if not _inclusive[1]: + _doys = _doys[:-1] + return _doys if isinstance(include_bounds, bool): include_bounds = (include_bounds, include_bounds) @@ -1582,7 +1587,7 @@ def get_doys(start, end, inclusive): mask = da.time.dt.month.isin(month) elif doy_bounds is not None: - mask = da.time.dt.dayofyear.isin(get_doys(*doy_bounds, include_bounds)) + mask = da.time.dt.dayofyear.isin(_get_doys(*doy_bounds, include_bounds)) elif date_bounds is not None: # This one is a bit trickier. @@ -1598,15 +1603,20 @@ def get_doys(start, end, inclusive): calendar = "all_leap" # Get doy of date, this is now safe because the calendar is uniform. - doys = get_doys( - to_cftime_datetime("2000-" + start, calendar).dayofyr, - to_cftime_datetime("2000-" + end, calendar).dayofyr, + doys = _get_doys( + to_cftime_datetime(f"2000-{start}", calendar).dayofyr, + to_cftime_datetime(f"2000-{end}", calendar).dayofyr, include_bounds, ) mask = time.time.dt.dayofyear.isin(doys) # Needed if we converted calendar, this puts back the correct coord mask["time"] = da.time + else: + raise ValueError( + "Must provide either `season`, `month`, `doy_bounds` or `date_bounds`." + ) + return da.where(mask, drop=drop) diff --git a/xclim/core/dataflags.py b/xclim/core/dataflags.py index 0529f5f39..666a50f74 100644 --- a/xclim/core/dataflags.py +++ b/xclim/core/dataflags.py @@ -43,7 +43,7 @@ class DataQualityException(Exception): Message prepended to the error messages. """ - flag_array: xarray.Dataset = None + flag_array: xarray.Dataset | None = None def __init__( self, @@ -639,9 +639,9 @@ def get_variable_name(function, kwargs): def _missing_vars(function, dataset: xarray.Dataset, var_provided: str): """Handle missing variables in passed datasets.""" sig = signature(function) - sig = sig.parameters + sig_params = sig.parameters extra_vars = {} - for arg, val in sig.items(): + for arg, val in sig_params.items(): if arg in ["da", var_provided]: continue kind = infer_kind_from_parameter(val) @@ -746,7 +746,7 @@ def ecad_compliant( xarray.DataArray or xarray.Dataset or None """ flags = xarray.Dataset() - history = [] + history: list[str] = [] for var in ds.data_vars: df = data_flags(ds[var], ds, dims=dims) for flag_name, flag_data in df.data_vars.items(): @@ -773,7 +773,7 @@ def ecad_compliant( if raise_flags: if np.any([flags[dv] for dv in flags.data_vars]): raise DataQualityException(flags) - return + return None ecad_flag = xarray.DataArray( # TODO: Test for this change concerning data of type None in dataflag variables diff --git a/xclim/core/formatting.py b/xclim/core/formatting.py index dbbde8f44..ff443dcbb 100644 --- a/xclim/core/formatting.py +++ b/xclim/core/formatting.py @@ -356,7 +356,7 @@ def merge_attributes( def update_history( hist_str: str, - *inputs_list: Sequence[xr.DataArray | xr.Dataset], + *inputs_list: xr.DataArray | xr.Dataset, new_name: str | None = None, **inputs_kws: xr.DataArray | xr.Dataset, ): @@ -368,12 +368,12 @@ def update_history( ---------- hist_str : str The string describing what has been done on the data. - new_name : Optional[str] - The name of the newly created variable or dataset to prefix hist_msg. - \*inputs_list : Sequence[Union[xr.DataArray, xr.Dataset]] + \*inputs_list : xr.DataArray or xr.Dataset The datasets or variables that were used to produce the new object. Inputs given that way will be prefixed by their "name" attribute if available. - \*\*inputs_kws : Union[xr.DataArray, xr.Dataset] + new_name : str, optional + The name of the newly created variable or dataset to prefix hist_msg. + \*\*inputs_kws : xr.DataArray or xr.Dataset Mapping from names to the datasets or variables that were used to produce the new object. Inputs given that way will be prefixes by the passed name. diff --git a/xclim/core/indicator.py b/xclim/core/indicator.py index 6eacdb45c..54d25642f 100644 --- a/xclim/core/indicator.py +++ b/xclim/core/indicator.py @@ -190,8 +190,6 @@ class Parameter: False >>> p.description 'A simple number' - >>> p["description"] # Same as above, for convenience. - 'A simple number' """ _empty = _empty @@ -218,12 +216,12 @@ def is_parameter_dict(cls, other: dict) -> bool: cls.__dataclass_fields__.keys() # pylint: disable=no-member ) - def __getitem__(self, key) -> str: - """Return an item in retro-compatible fashion.""" - try: - return getattr(self, key) - except AttributeError as err: - raise KeyError(key) from err + # def __getitem__(self, key) -> str: + # """Return an item in retro-compatible fashion.""" + # try: + # return str(getattr(self, key)) + # except AttributeError as err: + # raise KeyError(key) from err def __contains__(self, key) -> bool: """Imitate previous behaviour where "units" and "choices" were missing, instead of being "_empty".""" @@ -455,7 +453,7 @@ def __new__(cls, **kwds): # noqa: C901 # parameters has already been update above. kwds["compute"] = declare_units( **{ - inv_var_map[k]: m["units"] + inv_var_map[k]: m.units for k, m in parameters.items() if "units" in m and k in inv_var_map } @@ -1613,23 +1611,24 @@ def build_indicator_module( Parameters ---------- name : str - New module name. If it already exists, the module is extended with the passed objects, - overwriting those with same names. + New module name. If it already exists, the module is extended with the passed objects, + overwriting those with same names. objs : dict[str, Indicator] - Mapping of the indicators to put in the new module. Keyed by the name they will take in that module. + Mapping of the indicators to put in the new module. Keyed by the name they will take in that module. doc : str - Docstring of the new module. Defaults to a simple header. Invalid if the module already exists. + Docstring of the new module. Defaults to a simple header. Invalid if the module already exists. reload : bool - If reload is True and the module already exists, it is first removed before being rebuilt. - If False (default), indicators are added or updated, but not removed. + If reload is True and the module already exists, it is first removed before being rebuilt. + If False (default), indicators are added or updated, but not removed. Returns ------- ModuleType - A indicator module built from a mapping of Indicators. + A indicator module built from a mapping of Indicators. """ from xclim import indicators + out: ModuleType if hasattr(indicators, name): if doc is not None: warnings.warn( @@ -1675,35 +1674,35 @@ def build_indicator_module_from_yaml( # noqa: C901 Parameters ---------- filename : PathLike - Path to a YAML file or to the stem of all module files. See Notes for behaviour when passing a basename only. + Path to a YAML file or to the stem of all module files. See Notes for behaviour when passing a basename only. name : str, optional - The name of the new or existing module, defaults to the basename of the file. - (e.g: `atmos.yml` -> `atmos`) + The name of the new or existing module, defaults to the basename of the file. + (e.g: `atmos.yml` -> `atmos`) indices : Mapping of callables or module or path, optional - A mapping or module of indice functions or a python file declaring such a file. - When creating the indicator, the name in the `index_function` field is first sought - here, then the indicator class will search in xclim.indices.generic and finally in xclim.indices. - translations : Mapping of dicts or path, optional - Translated metadata for the new indicators. Keys of the mapping must be 2-char language tags. - Values can be translations dictionaries as defined in :ref:`internationalization:Internationalization`. - They can also be a path to a json file defining the translations. + A mapping or module of indice functions or a python file declaring such a file. + When creating the indicator, the name in the `index_function` field is first sought + here, then the indicator class will search in xclim.indices.generic and finally in xclim.indices. + translations : Mapping of dicts or path, optional + Translated metadata for the new indicators. Keys of the mapping must be 2-char language tags. + Values can be translations dictionaries as defined in :ref:`internationalization:Internationalization`. + They can also be a path to a json file defining the translations. mode : {'raise', 'warn', 'ignore'} - How to deal with broken indice definitions. + How to deal with broken indice definitions. encoding : str - The encoding used to open the `.yaml` and `.json` files. - It defaults to UTF-8, overriding python's mechanism which is machine dependent. + The encoding used to open the `.yaml` and `.json` files. + It defaults to UTF-8, overriding python's mechanism which is machine dependent. reload : bool - If reload is True and the module already exists, it is first removed before being rebuilt. - If False (default), indicators are added or updated, but not removed. + If reload is True and the module already exists, it is first removed before being rebuilt. + If False (default), indicators are added or updated, but not removed. validate : bool or path - If True (default), the yaml module is validated against xclim's schema. - Can also be the path to a yml schema against which to validate. - Or False, in which case validation is simply skipped. + If True (default), the yaml module is validated against xclim's schema. + Can also be the path to a yml schema against which to validate. + Or False, in which case validation is simply skipped. Returns ------- ModuleType - A submodule of `pym:mod:`xclim.indicators`. + A submodule of `pym:mod:`xclim.indicators`. Notes ----- @@ -1769,17 +1768,17 @@ def build_indicator_module_from_yaml( # noqa: C901 if isinstance(indices, (str, Path)): indices = load_module(indices, name=module_name) + _translations: dict[str, dict] = {} if not filepath.suffix and translations is None: # No suffix mean we try to automatically detect the json files. - translations = {} for locfile in filepath.parent.glob(f"{filepath.stem}.*.json"): locale = locfile.suffixes[0][1:] - translations[locale] = read_locale_file( + _translations[locale] = read_locale_file( locfile, module=module_name, encoding=encoding ) elif translations is not None: # A mapping was passed, we read paths is any. - translations = { + _translations = { lng: ( read_locale_file(trans, module=module_name, encoding=encoding) if isinstance(trans, (str, Path)) @@ -1861,8 +1860,8 @@ def _merge_attrs(dbase, dextra, attr, sep): mod = build_indicator_module(module_name, objs=mapping, doc=doc, reload=reload) # If there are translations, load them - if translations: - for locale, loc_dict in translations.items(): + if _translations: + for locale, loc_dict in _translations.items(): load_locale(loc_dict, locale) return mod diff --git a/xclim/core/locales.py b/xclim/core/locales.py index 21fd564ce..b0be2eef5 100644 --- a/xclim/core/locales.py +++ b/xclim/core/locales.py @@ -48,7 +48,7 @@ import json import warnings -from collections.abc import Mapping, Sequence +from collections.abc import Sequence from copy import deepcopy from pathlib import Path @@ -229,7 +229,7 @@ def __init__(self, locale): def read_locale_file( filename, module: str | None = None, encoding: str = "UTF8" -) -> dict: +) -> dict[str, dict]: """Read a locale file (.json) and return its dictionary. Parameters @@ -243,6 +243,7 @@ def read_locale_file( The encoding to use when reading the file. Defaults to UTF-8, overriding python's default mechanism which is machine dependent. """ + locdict: dict[str, dict] with open(filename, encoding=encoding) as f: locdict = json.load(f) @@ -254,12 +255,12 @@ def read_locale_file( return locdict -def load_locale(locdata: str | Path | Mapping[str, dict], locale: str): +def load_locale(locdata: str | Path | dict[str, dict], locale: str): """Load translations from a json file into xclim. Parameters ---------- - locdata : str or dictionary + locdata : str or Path or dictionary Either a loaded locale dictionary or a path to a json file. locale : str The locale name (IETF tag). diff --git a/xclim/core/units.py b/xclim/core/units.py index edd22e8a9..7f432eab7 100644 --- a/xclim/core/units.py +++ b/xclim/core/units.py @@ -11,14 +11,9 @@ import logging import warnings from copy import deepcopy - -try: - from importlib.resources import files -except ImportError: - from importlib_resources import files - +from importlib.resources import files from inspect import _empty, signature # noqa -from typing import Any, Callable +from typing import Any, Callable, Literal, cast import cf_xarray.units import numpy as np @@ -103,6 +98,8 @@ _CONVERSIONS = {} +# FIXME: This needs to be properly annotated for mypy compliance. +# See: https://mypy.readthedocs.io/en/stable/generics.html#declaring-decorators def _register_conversion(conversion, direction): """Register a conversion function to be automatically picked up in `convert_units_to`. @@ -114,7 +111,7 @@ def _register_conversion(conversion, direction): "Automatic conversion functions must have a corresponding section in xclim/data/variables.yml" ) - def _func_register(func): + def _func_register(func: Callable) -> Callable: _CONVERSIONS[(conversion, direction)] = func return func @@ -143,7 +140,8 @@ def units2pint(value: xr.DataArray | str | units.Quantity) -> pint.Unit: elif isinstance(value, xr.DataArray): unit = value.attrs["units"] elif isinstance(value, units.Quantity): - return value.units + # This is a pint.PlainUnit, which is not the same as a pint.Unit + return cast(pint.Unit, value.units) else: raise NotImplementedError(f"Value of type `{type(value)}` not supported.") @@ -224,8 +222,8 @@ def pint_multiply( f = f.to(out_units) else: f = f.to_reduced_units() - out = da * f.magnitude - out.attrs["units"] = pint2cfunits(f.units) + out: xr.DataArray = da * f.magnitude + out = out.assign_attrs(units=pint2cfunits(f.units)) return out @@ -252,11 +250,12 @@ def str2pint(val: str) -> pint.Quantity: return units.Quantity(1, units2pint(val)) +# FIXME: The typing here is difficult to determine, as Generics cannot be used to track the type of the output. def convert_units_to( # noqa: C901 source: Quantified, target: Quantified | units.Unit, - context: str | None = None, -) -> Quantified: + context: Literal["infer", "hydro", "none"] | None = None, +) -> xr.DataArray | float: """Convert a mathematical expression into a value with the same units as a DataArray. If the dimensionalities of source and target units differ, automatic CF conversions @@ -268,15 +267,15 @@ def convert_units_to( # noqa: C901 The value to be converted, e.g. '4C' or '1 mm/d'. target : str or xr.DataArray or units.Quantity or units.Unit Target array of values to which units must conform. - context : str, optional + context : {"infer", "hydro", "none"}, optional The unit definition context. Default: None. If "infer", it will be inferred with :py:func:`xclim.core.units.infer_context` using the standard name from the `source` or, if none is found, from the `target`. - This means that the 'hydro' context could be activated if any one of the standard names allows it. + This means that the "hydro" context could be activated if any one of the standard names allows it. Returns ------- - str or xr.DataArray or units.Quantity + xr.DataArray or float The source value converted to target's units. The outputted type is always similar to `source` initial type. Attributes are preserved unless an automatic CF conversion is performed, @@ -314,13 +313,15 @@ def convert_units_to( # noqa: C901 else: context = "none" + m: float if isinstance(source, str): q = str2pint(source) # Return magnitude of converted quantity. This is going to fail if units are not compatible. - return q.to(target_unit, context).m - + m = q.to(target_unit, context).m + return m if isinstance(source, units.Quantity): - return source.to(target_unit, context).m + m = source.to(target_unit, context).m + return m if isinstance(source, xr.DataArray): source_unit = units2pint(source) @@ -356,14 +357,15 @@ def convert_units_to( # noqa: C901 else: source_unit = units2pint(source) + out: xr.DataArray if source_unit == target_unit: # The units are the same, but the symbol may not be. - source.attrs["units"] = target_cf_unit - return source + out = source.assign_attrs(units=target_cf_unit) + return out with units.context(context or "none"): out = source.copy(data=units.convert(source.data, source_unit, target_unit)) - out.attrs["units"] = target_cf_unit + out = out.assign_attrs(units=target_cf_unit) return out # TODO remove backwards compatibility of int/float thresholds after v1.0 release @@ -373,7 +375,9 @@ def convert_units_to( # noqa: C901 raise NotImplementedError(f"Source of type `{type(source)}` is not supported.") -def cf_conversion(standard_name: str, conversion: str, direction: str) -> str | None: +def cf_conversion( + standard_name: str, conversion: str, direction: Literal["to", "from"] +) -> str | None: """Get the standard name of the specific conversion for the given standard name. Parameters @@ -397,7 +401,8 @@ def cf_conversion(standard_name: str, conversion: str, direction: str) -> str | i = ["to", "from"].index(direction) for names in CF_CONVERSIONS[conversion]["valid_names"]: if names[i] == standard_name: - return names[int(not i)] + cf_name: str = names[int(not i)] + return cf_name return None @@ -567,7 +572,7 @@ def _rate_and_amount_converter( """Internal converter for :py:func:`xclim.core.units.rate2amount` and :py:func:`xclim.core.units.amount2rate`.""" m = 1 u = None # Default to assume a non-uniform axis - label = "lower" + label: Literal["lower", "upper"] = "lower" # Default to "lower" label for diff time = da[dim] try: @@ -603,6 +608,7 @@ def _rate_and_amount_converter( else: m, u = multi, FREQ_UNITS[base] + out: xr.DataArray # Freq is month, season or year, which are not constant units, or simply freq is not inferrable. if u is None: # Get sampling period lengths in nanoseconds @@ -641,10 +647,10 @@ def _rate_and_amount_converter( old_name, "amount2rate", "to" if to == "rate" else "from" ) ): - out.attrs["standard_name"] = new_name + out = out.assign_attrs(standard_name=new_name) if out_units: - out = convert_units_to(out, out_units) + out = cast(xr.DataArray, convert_units_to(out, out_units)) return out @@ -815,7 +821,7 @@ def amount2lwethickness( if old_name and (new_name := cf_conversion(old_name, "amount2lwethickness", "to")): out.attrs["standard_name"] = new_name if out_units: - out = convert_units_to(out, out_units) + out = cast(xr.DataArray, convert_units_to(out, out_units)) return out @@ -853,13 +859,13 @@ def lwethickness2amount( ): out.attrs["standard_name"] = new_name if out_units: - out = convert_units_to(out, out_units) + out = cast(xr.DataArray, convert_units_to(out, out_units)) return out def _flux_and_rate_converter( da: xr.DataArray, - density: Quantified | str, + density: Quantified, to: str = "rate", out_units: str | None = None, ) -> xr.DataArray: @@ -874,9 +880,11 @@ def _flux_and_rate_converter( raise ValueError("Argument `to` must be one of 'rate' or 'flux'.") in_u = units2pint(da) - density_u = ( - str2pint(density).units if isinstance(density, str) else units2pint(density) - ) + if isinstance(density, str): + density_u = str2pint(density).units + else: + density_u = units2pint(density) + if out_units: out_u = str2pint(out_units).units @@ -891,9 +899,9 @@ def _flux_and_rate_converter( else: out_u = in_u * density_u**density_exp - density = convert_units_to(density, (out_u / in_u) ** density_exp) - out = (da * density**density_exp).assign_attrs(da.attrs) - out.attrs["units"] = pint2cfunits(out_u) + density_conv = convert_units_to(density, (out_u / in_u) ** density_exp) + out: xr.DataArray = (da * density_conv**density_exp).assign_attrs(da.attrs) + out = out.assign_attrs(units=pint2cfunits(out_u)) if "standard_name" in out.attrs.keys(): out.attrs.pop("standard_name") return out @@ -1003,7 +1011,9 @@ def flux2rate( @datacheck -def check_units(val: str | xr.DataArray | None, dim: str | xr.DataArray | None) -> None: +def check_units( + val: str | xr.DataArray | None, dim: str | xr.DataArray | None = None +) -> None: """Check that units are compatible with dimensions, otherwise raise a `ValidationError`. Parameters @@ -1016,10 +1026,17 @@ def check_units(val: str | xr.DataArray | None, dim: str | xr.DataArray | None) if dim is None or val is None: return + if isinstance(dim, xr.DataArray): + _dim = str(dim.dims[0]) + else: + _dim = dim + # In case val is a DataArray, we try to get a standard_name - context = infer_context( - standard_name=getattr(val, "standard_name", None), dimension=dim - ) + if hasattr(val, "attrs"): + standard_name = val.attrs.get("standard_name", None) + else: + standard_name = None + context = infer_context(standard_name=standard_name, dimension=_dim) # Issue originally introduced in https://github.com/hgrecco/pint/issues/1486 # Should be resolved in pint v0.24. See: https://github.com/hgrecco/pint/issues/1913 @@ -1037,13 +1054,21 @@ def check_units(val: str | xr.DataArray | None, dim: str | xr.DataArray | None) raise TypeError("Please set units explicitly using a string.") try: - dim_units = str2pint(dim) if isinstance(dim, str) else units2pint(dim) + dim_units: pint.Unit | pint.Quantity + if isinstance(dim, str): + dim_units = str2pint(dim) + else: + dim_units = units2pint(dim) expected = dim_units.dimensionality except pint.UndefinedUnitError: # Raised when it is not understood, we assume it was a dimensionality expected = units.get_dimensionality(dim.replace("dimensionless", "")) - val_units = str2pint(val) if isinstance(val, str) else units2pint(val) + val_units: pint.Unit | pint.Quantity + if isinstance(val, str): + val_units = str2pint(val) + else: + val_units = units2pint(val) val_dim = val_units.dimensionality if val_dim == expected: @@ -1080,6 +1105,8 @@ def _check_output_has_units(out: xr.DataArray | tuple[xr.DataArray]) -> None: outd.attrs["units"] = ensure_cf_units(outd.attrs["units"]) +# FIXME: This needs to be properly annotated for mypy compliance. +# See: https://mypy.readthedocs.io/en/stable/generics.html#declaring-decorators def declare_relative_units(**units_by_name) -> Callable: r"""Function decorator checking the units of arguments. @@ -1193,6 +1220,8 @@ def wrapper(*args, **kwargs): return dec +# FIXME: This needs to be properly annotated for mypy compliance. +# See: https://mypy.readthedocs.io/en/stable/generics.html#declaring-decorators def declare_units(**units_by_name) -> Callable: r"""Create a decorator to check units of function arguments. @@ -1294,7 +1323,9 @@ def ensure_delta(unit: str) -> str: return delta_unit -def infer_context(standard_name: str | None = None, dimension: str | None = None): +def infer_context( + standard_name: str | None = None, dimension: str | None = None +) -> str: """Return units context based on either the variable's standard name or the pint dimension. Valid standard names for the hydro context are those including the terms "rainfall", diff --git a/xclim/core/utils.py b/xclim/core/utils.py index 75bd62918..13e180b7d 100644 --- a/xclim/core/utils.py +++ b/xclim/core/utils.py @@ -13,15 +13,9 @@ import os import warnings from collections import defaultdict +from collections.abc import Sequence from enum import IntEnum -from functools import partial - -try: - from importlib.resources import files -except ImportError: - from importlib_resources import files - -from collections.abc import Mapping, Sequence +from importlib.resources import as_file, files from inspect import Parameter, _empty # noqa from io import StringIO from pathlib import Path @@ -29,7 +23,6 @@ import numpy as np import xarray as xr -from boltons.funcutils import update_wrapper from dask import array as dsk from pint import Quantity from yaml import safe_dump, safe_load @@ -45,9 +38,10 @@ #: Type annotation for thresholds and other not-exactly-a-variable quantities Quantified = TypeVar("Quantified", xr.DataArray, str, Quantity) -with (files("xclim.data") / "variables.yml").open() as f: - VARIABLES = safe_load(f)["variables"] - """Official variables definitions. +with as_file(files("xclim.data")) as data_dir: + with (data_dir / "variables.yml").open() as f: + VARIABLES = safe_load(f)["variables"] + """Official variables definitions. A mapping from variable name to a dict with the following keys: @@ -72,53 +66,6 @@ } -def wrapped_partial(func: Callable, suggested: dict | None = None, **fixed) -> Callable: - r"""Wrap a function, updating its signature but keeping its docstring. - - Parameters - ---------- - func : Callable - The function to be wrapped - suggested : dict, optional - Keyword arguments that should have new default values but still appear in the signature. - \*\*fixed - Keyword arguments that should be fixed by the wrapped and removed from the signature. - - Returns - ------- - Callable - - Examples - -------- - >>> from inspect import signature - >>> def func(a, b=1, c=1): - ... print(a, b, c) - ... - >>> newf = wrapped_partial(func, b=2) - >>> signature(newf) - - >>> newf(1) - 1 2 1 - >>> newf = wrapped_partial(func, suggested=dict(c=2), b=2) - >>> signature(newf) - - >>> newf(1) - 1 2 2 - """ - suggested = suggested or {} - partial_func = partial(func, **suggested, **fixed) - - fully_wrapped = update_wrapper( - partial_func, func, injected=list(fixed.keys()), hide_wrapped=True # noqa - ) - - # Store all injected params, - injected = getattr(func, "_injected", {}).copy() - injected.update(fixed) - fully_wrapped._injected = injected - return fully_wrapped - - def deprecated(from_version: str | None, suggested: str | None = None) -> Callable: """Mark an index as deprecated and optionally suggest a replacement. @@ -226,7 +173,7 @@ class MissingVariableError(ValueError): """Error raised when a dataset is passed to an indicator but one of the needed variable is missing.""" -def ensure_chunk_size(da: xr.DataArray, **minchunks: Mapping[str, int]) -> xr.DataArray: +def ensure_chunk_size(da: xr.DataArray, **minchunks: int) -> xr.DataArray: r"""Ensure that the input DataArray has chunks of at least the given size. If only one chunk is too small, it is merged with an adjacent chunk. @@ -305,18 +252,25 @@ def uses_dask(*das: xr.DataArray | xr.Dataset) -> bool: def calc_perc( arr: np.ndarray, - percentiles: Sequence[float] = None, + percentiles: Sequence[float] | None = None, alpha: float = 1.0, beta: float = 1.0, copy: bool = True, ) -> np.ndarray: """Compute percentiles using nan_calc_percentiles and move the percentiles' axis to the end.""" if percentiles is None: - percentiles = [50.0] + _percentiles = [50.0] + else: + _percentiles = percentiles return np.moveaxis( nan_calc_percentiles( - arr=arr, percentiles=percentiles, axis=-1, alpha=alpha, beta=beta, copy=copy + arr=arr, + percentiles=_percentiles, + axis=-1, + alpha=alpha, + beta=beta, + copy=copy, ), source=0, destination=-1, @@ -333,13 +287,15 @@ def nan_calc_percentiles( ) -> np.ndarray: """Convert the percentiles to quantiles and compute them using _nan_quantile.""" if percentiles is None: - percentiles = [50.0] + _percentiles = [50.0] + else: + _percentiles = percentiles if copy: # bootstrapping already works on a data's copy # doing it again is extremely costly, especially with dask. arr = arr.copy() - quantiles = np.array([per / 100.0 for per in percentiles]) + quantiles = np.array([per / 100.0 for per in _percentiles]) return _nan_quantile(arr, quantiles, axis, alpha, beta) @@ -553,15 +509,15 @@ def raise_warn_or_log( stacklevel : int Stacklevel when warning. Relative to the call of this function (1 is added). """ - msg = msg or getattr(err, "msg", f"Failed with {err!r}.") + message = msg or getattr(err, "msg", f"Failed with {err!r}.") if mode == "ignore": pass elif mode == "log": - logger.info(msg) + logger.info(message) elif mode == "warn": - warnings.warn(msg, stacklevel=stacklevel + 1) + warnings.warn(message, stacklevel=stacklevel + 1) else: # mode == "raise" - raise err from err_type(msg) + raise err from err_type(message) class InputKind(IntEnum): diff --git a/xclim/ensembles/_base.py b/xclim/ensembles/_base.py index 2b9a5987f..237368678 100644 --- a/xclim/ensembles/_base.py +++ b/xclim/ensembles/_base.py @@ -425,9 +425,10 @@ def _ens_align_datasets( if isinstance(datasets, str): datasets = glob(datasets) - ds_all = [] + ds_all: list[xr.Dataset] = [] calendars = [] for i, n in enumerate(datasets): + ds: xr.Dataset if multifile: ds = xr.open_mfdataset(n, combine="by_coords", **xr_kwargs) else: diff --git a/xclim/indices/_agro.py b/xclim/indices/_agro.py index cdfff68e7..8719739a4 100644 --- a/xclim/indices/_agro.py +++ b/xclim/indices/_agro.py @@ -2,6 +2,7 @@ from __future__ import annotations import warnings +from typing import cast import numpy as np import xarray @@ -123,7 +124,7 @@ def corn_heat_units( mask_tasmin = tasmin > thresh_tasmin mask_tasmax = tasmax > thresh_tasmax - chu = ( + chu: xarray.DataArray = ( xarray.where(mask_tasmin, 1.8 * (tasmin - thresh_tasmin), 0) + xarray.where( mask_tasmax, @@ -132,7 +133,7 @@ def corn_heat_units( ) ) / 2 - chu.attrs["units"] = "" + chu = chu.assign_attrs(units="") return chu @@ -281,7 +282,7 @@ def huglin_index( else: raise NotImplementedError(f"'{method}' method is not implemented.") - hi = (((tas + tasmax) / 2) - thresh).clip(min=0) * k + hi: xarray.DataArray = (((tas + tasmax) / 2) - thresh).clip(min=0) * k hi = ( select_time( hi, date_bounds=(start_date, end_date), include_bounds=(True, False) @@ -290,7 +291,7 @@ def huglin_index( .sum() * k_aggregated ) - hi.attrs["units"] = "" + hi = hi.assign_attrs(units="") return hi @@ -446,9 +447,9 @@ def biologically_effective_degree_days( else: raise NotImplementedError() - bedd = ((((tasmin + tasmax) / 2) - thresh_tasmin).clip(min=0) * k + tr_adj).clip( - max=max_daily_degree_days - ) + bedd: xarray.DataArray = ( + (((tasmin + tasmax) / 2) - thresh_tasmin).clip(min=0) * k + tr_adj + ).clip(max=max_daily_degree_days) bedd = ( select_time( @@ -459,7 +460,7 @@ def biologically_effective_degree_days( * k_aggregated ) - bedd.attrs["units"] = "K days" + bedd = bedd.assign_attrs(units="K days") return bedd @@ -537,8 +538,8 @@ def cool_night_index( tasmin = tasmin.where(months == month, drop=True) - cni = tasmin.resample(time=freq).mean(keep_attrs=True) - cni.attrs["units"] = "degC" + cni: xarray.DataArray = tasmin.resample(time=freq).mean(keep_attrs=True) + cni = cni.assign_attrs(units="degC") return cni @@ -715,6 +716,8 @@ def dryness_index( * (pr_masked / 5).clip(max=evspsblpot.time.dt.daysinmonth) ) + di_north: xarray.DataArray | None = None + di_south: xarray.DataArray | None = None # Dryness index if has_north: di_north = wo + (pr_masked - t_v - e_s).resample(time="YS-JAN").sum() @@ -724,14 +727,17 @@ def dryness_index( di_south = di_south.shift(time=1).isel(time=slice(1, None)) di_south["time"] = di_south.indexes["time"].shift(-6, "MS") + di: xarray.DataArray if has_north and has_south: - di = di_north.where(lat >= 0, di_south) # noqa + di = di_north.where(lat >= 0, di_south) elif has_north: di = di_north # noqa elif has_south: di = di_south # noqa + else: + raise ValueError("No hemisphere data found.") - di.attrs["units"] = "mm" # noqa + di = di.assign_attrs(units="mm") return di @@ -792,8 +798,8 @@ def latitude_temperature_index( lat_mask = (abs(lat) >= 0) & (abs(lat) <= lat_factor) lat_coeff = xarray.where(lat_mask, lat_factor - abs(lat), 0) - lti = mtwm * lat_coeff - lti.attrs["units"] = "" + lti: xarray.DataArray = mtwm * lat_coeff + lti = lti.assign_attrs(units="") return lti @@ -895,9 +901,8 @@ def water_budget( if xarray.infer_freq(pet.time) == "MS": pr = pr.resample(time="MS").mean(dim="time", keep_attrs=True) - out = pr - pet - - out.attrs["units"] = pr.attrs["units"] + out: xarray.DataArray = pr - pet + out = out.assign_attrs(units=pr.attrs["units"]) return out @@ -923,7 +928,7 @@ def rain_season( date_min_end: DayOfYearStr = "09-01", date_max_end: DayOfYearStr = "12-31", freq="YS-JAN", -): +) -> tuple[xarray.DataArray, xarray.DataArray, xarray.DataArray]: """Find the length of the rain season and the day of year of its start and its end. The rain season begins when two conditions are met: 1) There must be a number of wet days with precipitations above @@ -1013,22 +1018,26 @@ def _get_first_run(run_positions, start_date, end_date): ) # Find the start of the rain season - def _get_first_run_start(pram): - last_doy = pram.indexes["time"][-1].strftime("%m-%d") - pram = select_time(pram, date_bounds=(date_min_start, last_doy)) + def _get_first_run_start(_pram): + last_doy = _pram.indexes["time"][-1].strftime("%m-%d") + _pram = select_time(_pram, date_bounds=(date_min_start, last_doy)) # First condition: Start with enough precipitation - da_start = pram.rolling({"time": window_wet_start}).sum() >= thresh_wet_start + da_start = _pram.rolling({"time": window_wet_start}).sum() >= thresh_wet_start # Second condition: No dry period after if method_dry_start == "per_day": - da_stop = pram <= thresh_dry_start + da_stop = _pram <= thresh_dry_start window_dry = window_dry_start elif method_dry_start == "total": - da_stop = pram.rolling({"time": window_dry_start}).sum() <= thresh_dry_start + da_stop = ( + _pram.rolling({"time": window_dry_start}).sum() <= thresh_dry_start + ) # equivalent to rolling forward in time instead, i.e. end date will be at beginning of dry run da_stop = da_stop.shift({"time": -(window_dry_start - 1)}, fill_value=False) window_dry = 1 + else: + raise ValueError(f"Unknown method_dry_start: {method_dry_start}.") # First and second condition combined in a run length events = rl.extract_events(da_start, 1, da_stop, window_dry) @@ -1038,58 +1047,62 @@ def _get_first_run_start(pram): # Find the end of the rain season # FIXME: This function mixes local and parent-level variables. It should be refactored. - def _get_first_run_end(pram): + def _get_first_run_end(_pram): if method_dry_end == "per_day": - da_stop = pram <= thresh_dry_end + da_stop = _pram <= thresh_dry_end run_positions = rl.rle(da_stop) >= window_dry_end elif method_dry_end == "total": run_positions = ( - pram.rolling({"time": window_dry_end}).sum() <= thresh_dry_end + _pram.rolling({"time": window_dry_end}).sum() <= thresh_dry_end ) + else: + raise ValueError(f"Unknown method_dry_end: {method_dry_end}.") return _get_first_run(run_positions, date_min_end, date_max_end) # Get start, end and length of rain season. Written as a function so it can be resampled # FIXME: This function mixes local and parent-level variables. It should be refactored. - def _get_rain_season(pram): - start = _get_first_run_start(pram) + def _get_rain_season(_pram): + start = _get_first_run_start(_pram) # masking value before start of the season (end of season should be after) # Get valid integer indexer of the day after the first run starts. # `start != NaN` only possible if a condition on next few time steps is respected. # Thus, `start+1` exists if `start != NaN` start_ind = (start + 1).fillna(-1).astype(int) - mask = pram * np.NaN + mask = _pram * np.NaN # Put "True" on the day of run start mask[{"time": start_ind}] = 1 # Mask back points without runs, propagate the True mask = mask.where(start.notnull()).ffill("time") mask = mask.notnull() - end = _get_first_run_end(pram.where(mask)) + end = _get_first_run_end(_pram.where(mask)) - length = xarray.where(end.notnull(), end - start, pram["time"].size - start) + length = xarray.where(end.notnull(), end - start, _pram["time"].size - start) # converting to doy - crd = pram.time.dt.dayofyear + crd = _pram.time.dt.dayofyear start = rl.lazy_indexing(crd, start) end = rl.lazy_indexing(crd, end) - out = xarray.Dataset( + _out = xarray.Dataset( { "rain_season_start": start, "rain_season_end": end, "rain_season_length": length, } ) - return out + return _out # Compute rain season, attribute units - out = pram.resample(time=freq).map(_get_rain_season) - out["rain_season_start"].attrs["units"] = "" - out["rain_season_end"].attrs["units"] = "" - out["rain_season_length"].attrs["units"] = "days" - out["rain_season_start"].attrs["is_dayofyear"] = np.int32(1) - out["rain_season_end"].attrs["is_dayofyear"] = np.int32(1) - return out["rain_season_start"], out["rain_season_end"], out["rain_season_length"] + out = cast(xarray.Dataset, pram.resample(time=freq).map(_get_rain_season)) + rain_season_start = out.rain_season_start.assign_attrs( + units="", is_dayofyear=np.int32(1) + ) + rain_season_end = out.rain_season_end.assign_attrs( + units="", is_dayofyear=np.int32(1) + ) + rain_season_length = out.rain_season_length.assign_attrs(units="days") + return rain_season_start, rain_season_end, rain_season_length @declare_units( @@ -1236,11 +1249,11 @@ def standardized_precipitation_index( if paramsd != template.sizes: params = params.broadcast_like(template) - spi = standardized_index(pr, params) - spi.attrs = params.attrs - spi.attrs["freq"] = (freq or xarray.infer_freq(spi.time)) or "undefined" - spi.attrs["window"] = window - spi.attrs["units"] = "" + spi: xarray.DataArray = standardized_index(pr, params) + spi = spi.assign_attrs(params.attrs) + spi = spi.assign_attrs(freq=(freq or xarray.infer_freq(spi.time)) or "undefined") + spi = spi.assign_attrs(window=window) + spi = spi.assign_attrs(units="") return spi @@ -1349,7 +1362,7 @@ def standardized_precipitation_evapotranspiration_index( if wb_cal is not None: wb_cal = wb_cal + offset - spei = standardized_precipitation_index( + spei: xarray.DataArray = standardized_precipitation_index( wb, wb_cal, freq, window, dist, method, cal_start, cal_end, params, **indexer ) @@ -1395,9 +1408,10 @@ def qian_weighted_mean_average( units = tas.attrs["units"] weights = xarray.DataArray([0.0625, 0.25, 0.375, 0.25, 0.0625], dims=["window"]) - weighted_mean = tas.rolling({dim: 5}, center=True).construct("window").dot(weights) - - weighted_mean.attrs["units"] = units + weighted_mean: xarray.DataArray = ( + tas.rolling({dim: 5}, center=True).construct("window").dot(weights) + ) + weighted_mean = weighted_mean.assign_attrs(units=units) return weighted_mean @@ -1499,9 +1513,11 @@ def effective_growing_degree_days( ) deg_days = (tas - thresh).clip(min=0) - egdd = aggregate_between_dates(deg_days, start=start, end=end, freq=freq) - - return to_agg_units(egdd, tas, op="integral") + egdd: xarray.DataArray = aggregate_between_dates( + deg_days, start=start, end=end, freq=freq + ) + egdd = to_agg_units(egdd, tas, op="integral") + return egdd @declare_units(tasmin="[temperature]") @@ -1548,9 +1564,9 @@ def hardiness_zones( ) tn_min_rolling = tn_min(tasmin, freq=freq).rolling(time=window).mean() - zones = get_zones( + zones: xarray.DataArray = get_zones( tn_min_rolling, zone_min=zone_min, zone_max=zone_max, zone_step=zone_step ) - zones.attrs["units"] = "" + zones = zones.assign_attrs(units="") return zones diff --git a/xclim/indices/_anuclim.py b/xclim/indices/_anuclim.py index 940c86650..76256c984 100644 --- a/xclim/indices/_anuclim.py +++ b/xclim/indices/_anuclim.py @@ -1,7 +1,7 @@ # noqa: D100 from __future__ import annotations -from typing import Callable +from typing import Callable, cast import numpy as np import xarray @@ -95,8 +95,8 @@ def isothermality( """ dtr = daily_temperature_range(tasmin=tasmin, tasmax=tasmax, freq=freq) etr = extreme_temperature_range(tasmin=tasmin, tasmax=tasmax, freq=freq) - iso = dtr / etr * 100 - iso.attrs["units"] = "%" + iso: xarray.DataArray = dtr / etr * 100 + iso = iso.assign_attrs(units="%") return iso @@ -461,8 +461,9 @@ def prcptot( Total {freq} precipitation. """ thresh = convert_units_to(thresh, pr, context="hydro") - pram = rate2amount(pr.where(pr >= thresh, 0)) - return pram.resample(time=freq).sum().assign_attrs(units=pram.units) + pram: xarray.DataArray = rate2amount(pr.where(pr >= thresh, 0)) + pram = pram.resample(time=freq).sum().assign_attrs(units=pram.units) + return pram @declare_units(pr="[precipitation]") @@ -506,9 +507,9 @@ def prcptot_wetdry_period( ) op = _np_ops[op] - return getattr(pram.resample(time=freq), op)(dim="time").assign_attrs( - units=pram.units - ) + pwp: xarray.DataArray = getattr(pram.resample(time=freq), op)(dim="time") + pwp = pwp.assign_attrs(units=pram.units) + return pwp def _anuclim_coeff_var(arr: xarray.DataArray, freq: str = "YS") -> xarray.DataArray: @@ -542,63 +543,82 @@ def _from_other_arg( ds = xarray.Dataset(data_vars={"criteria": criteria, "output": output}) dim = "time" - def get_other_op(dataset): + def _get_other_op(dataset: xarray.Dataset) -> xarray.DataArray: all_nans = dataset.criteria.isnull().all(dim=dim) index = op(dataset.criteria.where(~all_nans, 0), dim=dim) - return lazy_indexing(dataset.output, index=index, dim=dim).where(~all_nans) + other_op = lazy_indexing(dataset.output, index=index, dim=dim).where(~all_nans) + return other_op - return ds.resample(time=freq).map(get_other_op) + resampled = ds.resample(time=freq) + # Manually casting here since the mapping returns a DataArray and not a Dataset + out = cast(xarray.DataArray, resampled.map(_get_other_op)) + return out def _to_quarter( pr: xarray.DataArray | None = None, tas: xarray.DataArray | None = None, ) -> xarray.DataArray: - """Convert daily, weekly or monthly time series to quarterly time series according to ANUCLIM specifications.""" - if tas is not None and pr is not None: + """Convert daily, weekly or monthly time series to quarterly time series according to ANUCLIM specifications. + + Parameters + ---------- + pr : xarray.DataArray, optional + Total precipitation flux [mm d-1], [mm week-1], [mm month-1] or similar. + tas : xarray.DataArray, optional + Mean temperature at daily, weekly, or monthly frequency. + + Returns + ------- + xarray.DataArray + Quarterly time series. + """ + if pr is not None and tas is not None: raise ValueError("Supply only one variable, 'tas' (exclusive) or 'pr'.") + elif tas is not None: + ts_var = tas + elif pr is not None: + ts_var = pr + else: + raise ValueError("Supply one variable, `tas` or `pr`.") - freq = xarray.infer_freq((tas if tas is not None else pr).time) + freq = xarray.infer_freq(ts_var.time) if freq is None: raise ValueError("Can't infer sampling frequency of the input data.") + freq_upper = freq.upper() - if freq.upper().startswith("D"): + if freq_upper.startswith("D"): if tas is not None: - tas = tg_mean(tas, freq="7D") - - if pr is not None: + ts_var = tg_mean(ts_var, freq="7D") + else: # Accumulate on a week # Ensure units are back to a "rate" for rate2amount below - pr = convert_units_to( - precip_accumulation(pr, freq="7D"), "mm", context="hydro" + ts_var = precip_accumulation(ts_var, freq="7D") + ts_var = convert_units_to(ts_var, "mm", context="hydro").assign_attrs( + units="mm/week" ) - pr.attrs["units"] = "mm/week" - - freq = "W" - - if freq.upper().startswith("W"): + freq_upper = "W" + if freq_upper.startswith("W"): window = 13 - - elif freq.upper().startswith("M"): + elif freq_upper.startswith("M"): window = 3 - else: raise NotImplementedError( f'Unknown input time frequency "{freq}": must be one of "D", "W" or "M".' ) + ts_var = ensure_chunk_size(ts_var, time=np.ceil(window / 2)) if tas is not None: - tas = ensure_chunk_size(tas, time=np.ceil(window / 2)) - out = tas.rolling(time=window, center=False).mean(skipna=False) - out.attrs = tas.attrs + out = ts_var.rolling(time=window, center=False).mean(skipna=False) + out_units = ts_var.units elif pr is not None: - pr = ensure_chunk_size(pr, time=np.ceil(window / 2)) - pram = rate2amount(pr) + pram = rate2amount(ts_var) out = pram.rolling(time=window, center=False).sum() - out.attrs = pr.attrs - out.attrs["units"] = pram.units + out_units = pram.units else: raise ValueError("No variables supplied.") + out = out.assign_attrs(ts_var.attrs) + out = out.assign_attrs(units=out_units) out = ensure_chunk_size(out, time=-1) return out diff --git a/xclim/indices/_conversion.py b/xclim/indices/_conversion.py index 50cc2ad32..0b2efd9b4 100644 --- a/xclim/indices/_conversion.py +++ b/xclim/indices/_conversion.py @@ -1,6 +1,8 @@ # noqa: D100 from __future__ import annotations +from typing import cast + import numpy as np import xarray as xr from numba import float32, float64, vectorize # noqa @@ -54,10 +56,6 @@ ] -def _deaccumulate(ds: xr.DataArray) -> xr.DataArray: - """Deaccumulate units.""" - - @declare_units(tas="[temperature]", tdps="[temperature]", hurs="[]") def humidex( tas: xr.DataArray, @@ -73,9 +71,9 @@ def humidex( ---------- tas : xarray.DataArray Air temperature. - tdps : xarray.DataArray, + tdps : xarray.DataArray, optional Dewpoint temperature, used to compute the vapour pressure. - hurs : xarray.DataArray + hurs : xarray.DataArray, optional Relative humidity, used as an alternative way to compute the vapour pressure if the dewpoint temperature is not available. @@ -138,9 +136,12 @@ def humidex( tasC = convert_units_to(tas, "celsius") e = hurs / 100 * 6.112 * 10 ** (7.5 * tasC / (tasC + 237.7)) + else: + raise ValueError("Either `tdps` or `hurs` must be provided.") + # Temperature delta due to humidity in delta_degC - h = 5 / 9 * (e - 10) - h.attrs["units"] = "delta_degree_Celsius" + h: xr.DataArray = 5 / 9 * (e - 10) + h = h.assign_attrs(units="delta_degree_Celsius") # Get delta_units for output du = (1 * units2pint(tas) - 0 * units2pint(tas)).units @@ -148,7 +149,7 @@ def humidex( # Add the delta to the input temperature out = h + tas - out.attrs["units"] = tas.units + out = out.assign_attrs(units=tas.units) return out @@ -276,8 +277,8 @@ def uas_vas_2_sfcwind( wind_thresh = convert_units_to(calm_wind_thresh, "m/s") # Wind speed is the hypotenuse of "uas" and "vas" - wind = np.hypot(uas, vas) - wind.attrs["units"] = "m s-1" + wind = cast(xr.DataArray, np.hypot(uas, vas)) + wind = wind.assign_attrs(units="m s-1") # Calculate the angle wind_from_dir_math = np.degrees(np.arctan2(vas, uas)) @@ -398,6 +399,8 @@ def saturation_vapor_pressure( thresh = convert_units_to("0 K", "K") tas = convert_units_to(tas, "K") ref_is_water = tas > thresh + + e_sat: xr.DataArray if method in ["sonntag90", "SO90"]: e_sat = xr.where( ref_is_water, @@ -480,7 +483,7 @@ def saturation_vapor_pressure( f"Method {method} is not in ['sonntag90', 'tetens30', 'goffgratch46', 'wmo08', 'its90']" ) - e_sat.attrs["units"] = "Pa" + e_sat = e_sat.assign_attrs(units="Pa") return e_sat @@ -581,6 +584,7 @@ def relative_humidity( ---------- :cite:cts:`bohren_atmospheric_1998,lawrence_relationship_2005` """ + hurs: xr.DataArray if method in ("bohren98", "BA90"): if tdps is None: raise ValueError("To use method 'bohren98' (BA98), dewpoint must be given.") @@ -597,7 +601,7 @@ def relative_humidity( tas=tas, ice_thresh=ice_thresh, method=method ) hurs = 100 * e_sat_dt / e_sat_t # type: ignore - else: + elif huss is not None and ps is not None: ps = convert_units_to(ps, "Pa") huss = convert_units_to(huss, "") tas = convert_units_to(tas, "K") @@ -607,12 +611,14 @@ def relative_humidity( w = huss / (1 - huss) w_sat = 0.62198 * e_sat / (ps - e_sat) # type: ignore hurs = 100 * w / w_sat + else: + raise ValueError("`huss` and `ps` must be provided if `tdps` is not given.") if invalid_values == "clip": hurs = hurs.clip(0, 100) elif invalid_values == "mask": hurs = hurs.where((hurs <= 100) & (hurs >= 0)) - hurs.attrs["units"] = "%" + hurs = hurs.assign_attrs(units="%") return hurs @@ -703,7 +709,7 @@ def specific_humidity( w_sat = 0.62198 * e_sat / (ps - e_sat) # type: ignore w = w_sat * hurs - q = w / (1 + w) + q: xr.DataArray = w / (1 + w) if invalid_values is not None: q_sat = w_sat / (1 + w_sat) @@ -711,7 +717,7 @@ def specific_humidity( q = q.clip(0, q_sat) elif invalid_values == "mask": q = q.where((q <= q_sat) & (q >= 0)) - q.attrs["units"] = "" + q = q.assign_attrs(units="") return q @@ -771,8 +777,8 @@ def specific_humidity_from_dewpoint( e = saturation_vapor_pressure(tas=tdps, method=method) # vapour pressure [Pa] ps = convert_units_to(ps, "Pa") # total air pressure - q = ε * e / (ps - e * (1 - ε)) - q.attrs["units"] = "" + q: xr.DataArray = ε * e / (ps - e * (1 - ε)) + q = q.assign_attrs(units="") return q @@ -820,6 +826,7 @@ def snowfall_approximation( ---------- :cite:cts:`verseghy_class_2009,melton_atmosphericvarscalcf90_2019` """ + prsn: xr.DataArray if method == "binary": thresh = convert_units_to(thresh, tas) prsn = pr.where(tas <= thresh, 0) @@ -868,7 +875,7 @@ def snowfall_approximation( else: raise ValueError(f"Method {method} not one of 'binary', 'brown' or 'auer'.") - prsn.attrs["units"] = pr.attrs["units"] + prsn = prsn.assign_attrs(units=pr.attrs["units"]) return prsn @@ -909,8 +916,10 @@ def rain_approximation( -------- snowfall_approximation """ - prra = pr - snowfall_approximation(pr, tas, thresh=thresh, method=method) - prra.attrs["units"] = pr.attrs["units"] + prra: xr.DataArray = pr - snowfall_approximation( + pr, tas, thresh=thresh, method=method + ) + prra = prra.assign_attrs(units=pr.attrs["units"]) return prra @@ -949,9 +958,11 @@ def snd_to_snw( :cite:cts:`sturm_swe_2010` """ density = snr if (snr is not None) else const - snw = rate2flux(snd, density=density, out_units=out_units).rename("snw") + snw: xr.DataArray = rate2flux(snd, density=density, out_units=out_units).rename( + "snw" + ) # TODO: Leave this operation to rate2flux? Maybe also the variable renaming above? - snw.attrs["standard_name"] = "surface_snow_amount" + snw = snw.assign_attrs(standard_name="surface_snow_amount") return snw @@ -990,8 +1001,10 @@ def snw_to_snd( :cite:cts:`sturm_swe_2010` """ density = snr if (snr is not None) else const - snd = flux2rate(snw, density=density, out_units=out_units).rename("snd") - snd.attrs["standard_name"] = "surface_snow_thickness" + snd: xr.DataArray = flux2rate(snw, density=density, out_units=out_units).rename( + "snd" + ) + snd = snd.assign_attrs(standard_name="surface_snow_thickness") return snd @@ -1033,7 +1046,9 @@ def prsn_to_prsnd( :cite:cts:`frei_snowfall_2018, cbcl_climate_2020` """ density = snr if snr else const - prsnd = flux2rate(prsn, density=density, out_units=out_units).rename("prsnd") + prsnd: xr.DataArray = flux2rate(prsn, density=density, out_units=out_units).rename( + "prsnd" + ) return prsnd @@ -1073,8 +1088,10 @@ def prsnd_to_prsn( :cite:cts:`frei_snowfall_2018, cbcl_climate_2020` """ density = snr if snr else const - prsn = rate2flux(prsnd, density=density, out_units=out_units).rename("prsn") - prsn.attrs["standard_name"] = "snowfall_flux" + prsn: xr.DataArray = rate2flux(prsnd, density=density, out_units=out_units).rename( + "prsn" + ) + prsn = prsn.assign_attrs(standard_name="snowfall_flux") return prsn @@ -1097,10 +1114,8 @@ def longwave_upwelling_radiation_from_net_downwelling( Surface upwelling thermal radiation (rlus). """ rls = convert_units_to(rls, rlds) - - rlus = rlds - rls - - rlus.attrs["units"] = rlds.units + rlus: xr.DataArray = rlds - rls + rlus = rlus.assign_attrs(units=rlds.units) return rlus @@ -1123,10 +1138,8 @@ def shortwave_upwelling_radiation_from_net_downwelling( Surface upwelling solar radiation (rsus). """ rss = convert_units_to(rss, rsds) - - rsus = rsds - rss - - rsus.attrs["units"] = rsds.units + rsus: xr.DataArray = rsds - rss + rsus = rsus.assign_attrs(units=rsds.units) return rsus @@ -1183,7 +1196,6 @@ def wind_chill_index( W = T + \frac{-1.59 + 0.1345 * T}{5} * V - Both equations are invalid for temperature over 0°C in the canadian method. The american Wind Chill Temperature index (WCT), as defined by USA's National Weather Service, is computed when @@ -1203,7 +1215,7 @@ def wind_chill_index( sfcWind = convert_units_to(sfcWind, "km/h") V = sfcWind**0.16 - W = 13.12 + 0.6215 * tas - 11.37 * V + 0.3965 * tas * V + W: xr.DataArray = 13.12 + 0.6215 * tas - 11.37 * V + 0.3965 * tas * V if method.upper() == "CAN": W = xr.where(sfcWind < 5, tas + sfcWind * (-1.59 + 0.1345 * tas) / 5, W) @@ -1214,7 +1226,7 @@ def wind_chill_index( mask = {"CAN": tas <= 0, "US": (sfcWind > 4.828032) & (tas <= 10)} W = W.where(mask[method.upper()]) - W.attrs["units"] = "degC" + W = W.assign_attrs(units="degC") return W @@ -1235,12 +1247,12 @@ def clausius_clapeyron_scaled_precipitation( Difference in temperature between a baseline climatology and another climatology. pr_baseline : xarray.DataArray Baseline precipitation to adjust with Clausius-Clapeyron. - cc_scale_factor : float (default = 1.07) - Clausius Clapeyron scale factor. + cc_scale_factor : float + Clausius Clapeyron scale factor. (default = 1.07). Returns ------- - DataArray + xarray.DataArray Baseline precipitation scaled to other climatology using Clausius-Clapeyron relationship. Notes @@ -1263,13 +1275,12 @@ def clausius_clapeyron_scaled_precipitation( delta_tas = convert_units_to(delta_tas, "delta_degreeC") # Calculate scaled precipitation. - pr_out = pr_baseline * (cc_scale_factor**delta_tas) - pr_out.attrs["units"] = pr_baseline.attrs["units"] - + pr_out: xr.DataArray = pr_baseline * (cc_scale_factor**delta_tas) + pr_out = pr_out.assign_attrs(units=pr_baseline.attrs["units"]) return pr_out -def _get_D_from_M(time): +def _get_D_from_M(time): # noqa: N802 start = time[0].dt.strftime("%Y-%m-01").item() yrmn = time[-1].dt.strftime("%Y-%m").item() end = f"{yrmn}-{time[-1].dt.daysinmonth.item()}" @@ -1395,33 +1406,36 @@ def potential_evapotranspiration( """ # noqa: E501 # ^ Ignoring "line too long" as it comes from un-splittable constructs if lat is None: - lat = _gather_lat(tasmin if tas is None else tas) + _lat = _gather_lat(tasmin if tas is None else tas) + else: + _lat = lat + pet: xr.DataArray if method in ["baierrobertson65", "BR65"]: - tasmin = convert_units_to(tasmin, "degF") - tasmax = convert_units_to(tasmax, "degF") + _tasmin = convert_units_to(tasmin, "degF") + _tasmax = convert_units_to(tasmax, "degF") re = extraterrestrial_solar_radiation( - tasmin.time, lat, chunks=tasmin.chunksizes + _tasmin.time, _lat, chunks=_tasmin.chunksizes ) re = convert_units_to(re, "cal cm-2 day-1") # Baier et Robertson(1965) formula - out = 0.094 * ( - -87.03 + 0.928 * tasmax + 0.933 * (tasmax - tasmin) + 0.0486 * re + pet = 0.094 * ( + -87.03 + 0.928 * _tasmax + 0.933 * (_tasmax - _tasmin) + 0.0486 * re ) - out = out.clip(0) + pet = pet.clip(0) elif method in ["hargreaves85", "HG85"]: - tasmin = convert_units_to(tasmin, "degC") - tasmax = convert_units_to(tasmax, "degC") + _tasmin = convert_units_to(tasmin, "degC") + _tasmax = convert_units_to(tasmax, "degC") if tas is None: - tas = (tasmin + tasmax) / 2 + _tas = (_tasmin + _tasmax) / 2 else: - tas = convert_units_to(tas, "degC") + _tas = convert_units_to(tas, "degC") ra = extraterrestrial_solar_radiation( - tasmin.time, lat, chunks=tasmin.chunksizes + _tasmin.time, _lat, chunks=_tasmin.chunksizes ) ra = convert_units_to(ra, "MJ m-2 d-1") @@ -1429,52 +1443,53 @@ def potential_evapotranspiration( ra = ra * 0.408 # Hargreaves and Samani (1985) formula - out = 0.0023 * ra * (tas + 17.8) * (tasmax - tasmin) ** 0.5 - out = out.clip(0) + pet = 0.0023 * ra * (_tas + 17.8) * (_tasmax - _tasmin) ** 0.5 + pet = pet.clip(0) elif method in ["droogersallen02", "DA02"]: - tasmin = convert_units_to(tasmin, "degC") - tasmax = convert_units_to(tasmax, "degC") - pr = convert_units_to(pr, "mm/month", context="hydro") + _tasmin = convert_units_to(tasmin, "degC") + _tasmax = convert_units_to(tasmax, "degC") + _pr = convert_units_to(pr, "mm/month", context="hydro") if tas is None: - tas = (tasmin + tasmax) / 2 + _tas = (_tasmin + _tasmax) / 2 else: - tas = convert_units_to(tas, "degC") + _tas = convert_units_to(tas, "degC") - tasmin = tasmin.resample(time="MS").mean() - tasmax = tasmax.resample(time="MS").mean() - tas = tas.resample(time="MS").mean() - pr = pr.resample(time="MS").mean() + _tasmin = _tasmin.resample(time="MS").mean() + _tasmax = _tasmax.resample(time="MS").mean() + _tas = _tas.resample(time="MS").mean() + _pr = _pr.resample(time="MS").mean() # Monthly accumulated radiation - time_d = _get_D_from_M(tasmin.time) - ra = extraterrestrial_solar_radiation(time_d, lat) + time_d = _get_D_from_M(_tasmin.time) + ra = extraterrestrial_solar_radiation(time_d, _lat) ra = convert_units_to(ra, "MJ m-2 d-1") ra = ra.resample(time="MS").sum() # Is used to convert the radiation to evaporation equivalents in mm (kg/MJ) ra = ra * 0.408 - tr = tasmax - tasmin + tr = _tasmax - _tasmin tr = tr.where(tr > 0, 0) # Droogers and Allen (2002) formula - ab = tr - 0.0123 * pr - out = 0.0013 * ra * (tas + 17.0) * ab**0.76 - out = xr.where(np.isnan(ab**0.76), 0, out) - out = out.clip(0) # mm/month + ab = tr - 0.0123 * _pr + pet = 0.0013 * ra * (_tas + 17.0) * ab**0.76 + pet = xr.where(np.isnan(ab**0.76), 0, pet) + pet = pet.clip(0) # mm/month elif method in ["mcguinnessbordne05", "MB05"]: if tas is None: - tasmin = convert_units_to(tasmin, "degC") - tasmax = convert_units_to(tasmax, "degC") - tas = (tasmin + tasmax) / 2 - tas.attrs["units"] = "degC" + _tasmin = convert_units_to(tasmin, "degC") + _tasmax = convert_units_to(tasmax, "degC") + _tas: xr.DataArray = (_tasmin + _tasmax) / 2 + _tas = _tas.assign_attrs(units="degC") + else: + _tas = convert_units_to(tas, "degC") - tas = convert_units_to(tas, "degC") - tasK = convert_units_to(tas, "K") + tasK = convert_units_to(_tas, "K") ext_rad = extraterrestrial_solar_radiation( - tas.time, lat, solar_constant="1367 W m-2", chunks=tas.chunksizes + _tas.time, _lat, solar_constant="1367 W m-2", chunks=_tas.chunksizes ) latentH = 4185.5 * (751.78 - 0.5655 * tasK) radDIVlat = ext_rad / latentH @@ -1484,30 +1499,30 @@ def potential_evapotranspiration( a = peta b = petb - out = radDIVlat * a * tas + radDIVlat * b + pet = radDIVlat * a * _tas + radDIVlat * b elif method in ["thornthwaite48", "TW48"]: if tas is None: - tasmin = convert_units_to(tasmin, "degC") - tasmax = convert_units_to(tasmax, "degC") - tas = (tasmin + tasmax) / 2 + _tasmin = convert_units_to(tasmin, "degC") + _tasmax = convert_units_to(tasmax, "degC") + _tas = (_tasmin + _tasmax) / 2 else: - tas = convert_units_to(tas, "degC") - tas = tas.clip(0) - tas = tas.resample(time="MS").mean(dim="time") + _tas = convert_units_to(tas, "degC") + _tas = _tas.clip(0) + _tas = _tas.resample(time="MS").mean(dim="time") # Thornthwaite measures half-days - time_d = _get_D_from_M(tas.time) - dl = day_lengths(time_d, lat) / 12 + time_d = _get_D_from_M(_tas.time) + dl = day_lengths(time_d, _lat) / 12 dl_m = dl.resample(time="MS").mean(dim="time") # annual heat index - id_m = (tas / 5) ** 1.514 + id_m = (_tas / 5) ** 1.514 id_y = id_m.resample(time="YS").sum(dim="time") tas_idy_a = [] - for base_time, indexes in tas.resample(time="YS").groups.items(): - tas_y = tas.isel(time=indexes) + for base_time, indexes in _tas.resample(time="YS").groups.items(): + tas_y = _tas.isel(time=indexes) id_v = id_y.sel(time=base_time) a = 6.75e-7 * id_v**3 - 7.71e-5 * id_v**2 + 0.01791 * id_v + 0.49239 @@ -1517,23 +1532,25 @@ def potential_evapotranspiration( tas_idy_a = xr.concat(tas_idy_a, dim="time") # Thornthwaite(1948) formula - out = 1.6 * dl_m * tas_idy_a # cm/month - out = 10 * out # mm/month + pet = 1.6 * dl_m * tas_idy_a # cm/month + pet = 10 * pet # mm/month elif method in ["allen98", "FAO_PM98"]: - tasmax = convert_units_to(tasmax, "degC") - tasmin = convert_units_to(tasmin, "degC") - - # wind speed at two meters - wa2 = wind_speed_height_conversion(sfcWind, h_source="10 m", h_target="2 m") - wa2 = convert_units_to(wa2, "m s-1") + _tasmax = convert_units_to(tasmax, "degC") + _tasmin = convert_units_to(tasmin, "degC") + if sfcWind is None: + raise ValueError("Wind speed is required for Allen98 method.") + else: + # wind speed at two meters + wa2 = wind_speed_height_conversion(sfcWind, h_source="10 m", h_target="2 m") + wa2 = convert_units_to(wa2, "m s-1") with xr.set_options(keep_attrs=True): # mean temperature [degC] - tas_m = (tasmax + tasmin) / 2 + tas_m = (_tasmax + _tasmin) / 2 # mean saturation vapour pressure [kPa] es = (1 / 2) * ( - saturation_vapor_pressure(tasmax) + saturation_vapor_pressure(tasmin) + saturation_vapor_pressure(_tasmax) + saturation_vapor_pressure(_tasmin) ) es = convert_units_to(es, "kPa") # mean actual vapour pressure [kPa] @@ -1552,7 +1569,7 @@ def potential_evapotranspiration( # height = 0.12m, surface resistance = 70 s m-1, albedo = 0.23 # Surface resistance implies a ``moderately dry soil surface resulting from # about a weekly irrigation frequency'' - out = ( + pet = ( 0.408 * delta * (Rn - G) + gamma * (900 / (tas_m + 273)) * wa2 * (es - ea) ) / (delta + gamma * (1 + 0.34 * wa2)) @@ -1560,9 +1577,10 @@ def potential_evapotranspiration( else: raise NotImplementedError(f"'{method}' method is not implemented.") - out.attrs["units"] = "mm" - rate = amount2rate(out, out_units="mm/d") - return convert_units_to(rate, "kg m-2 s-1", context="hydro") + pet = pet.assign_attrs(units="mm") + rate = amount2rate(pet, out_units="mm/d") + out: xr.DataArray = convert_units_to(rate, "kg m-2 s-1", context="hydro") + return out @vectorize( @@ -1891,7 +1909,7 @@ def universal_thermal_climate_index( delta = mrt - tas pa = convert_units_to(e_sat, "kPa") * convert_units_to(hurs, "1") - utci = xr.apply_ufunc( + utci: xr.DataArray = xr.apply_ufunc( _utci, tas, sfcWind, @@ -2050,18 +2068,22 @@ def mean_radiant_temperature( fp = 0.308 * np.cos(gamma * 0.988 - (gamma**2 / 50000)) i_star = xr.where(csza > 0.001, rsds_direct / csza, 0) - mrt = np.power( - ( - (1 / 5.67e-8) # Stefan-Boltzmann constant - * ( - 0.5 * rlds - + 0.5 * rlus - + (0.7 / 0.97) * (0.5 * rsds_diffuse + 0.5 * rsus + fp * i_star) - ) + mrt = cast( + xr.DataArray, + np.power( + ( + (1 / 5.67e-8) # Stefan-Boltzmann constant + * ( + 0.5 * rlds + + 0.5 * rlus + + (0.7 / 0.97) * (0.5 * rsds_diffuse + 0.5 * rsus + fp * i_star) + ) + ), + 0.25, ), - 0.25, ) - return mrt.assign_attrs({"units": "K"}) + mrt = mrt.assign_attrs({"units": "K"}) + return mrt @declare_units(wind_speed="[speed]", h="[length]", h_r="[length]") @@ -2108,8 +2130,8 @@ def wind_profile( if method == "power_law": alpha = kwds.pop("alpha", 1 / 7) - out = wind_speed * (h / h_r) ** alpha - out.attrs["units"] = wind_speed.attrs["units"] + out: xr.DataArray = wind_speed * (h / h_r) ** alpha + out = out.assign_attrs(units=wind_speed.attrs["units"]) return out else: raise NotImplementedError(f"Method {method} not implemented.") @@ -2205,8 +2227,8 @@ def wind_power_potential( v = wind_speed * f - out = xr.apply_ufunc(_wind_power_factor, v, cut_in, rated, cut_out) - out.attrs["units"] = "" + out: xr.DataArray = xr.apply_ufunc(_wind_power_factor, v, cut_in, rated, cut_out) + out = out.assign_attrs(units="") return out diff --git a/xclim/indices/_multivariate.py b/xclim/indices/_multivariate.py index bb1769292..6785f2d3c 100644 --- a/xclim/indices/_multivariate.py +++ b/xclim/indices/_multivariate.py @@ -1,7 +1,7 @@ # noqa: D100 from __future__ import annotations -from typing import Callable +from typing import Callable, cast import numpy as np import xarray @@ -87,10 +87,10 @@ def cold_spell_duration_index( window : int Minimum number of days with temperature below threshold to qualify as a cold spell. freq : str - Resampling frequency. + Resampling frequency. resample_before_rl : bool - Determines if the resampling should take place before or after the run - length encoding (or a similar algorithm) is applied to runs. + Determines if the resampling should take place before or after the run + length encoding (or a similar algorithm) is applied to runs. bootstrap : bool Flag to run bootstrapping of percentiles. Used by percentile_bootstrap decorator. Bootstrapping is only useful when the percentiles are computed on a part of the studied sample. @@ -170,15 +170,15 @@ def cold_and_dry_days( Parameters ---------- tas : xarray.DataArray - Mean daily temperature values + Mean daily temperature values pr : xarray.DataArray - Daily precipitation. + Daily precipitation. tas_per : xarray.DataArray - First quartile of daily mean temperature computed by month. + First quartile of daily mean temperature computed by month. pr_per : xarray.DataArray - First quartile of daily total precipitation computed by month. + First quartile of daily total precipitation computed by month. freq : str - Resampling frequency. + Resampling frequency. Warnings -------- @@ -188,7 +188,7 @@ def cold_and_dry_days( Returns ------- xarray.DataArray - The total number of days when cold and dry conditions coincide. + The total number of days when cold and dry conditions coincide. Notes ----- @@ -210,8 +210,10 @@ def cold_and_dry_days( thresh = resample_doy(pr_per, pr) pr25 = pr < thresh - cold_and_dry = np.logical_and(tg25, pr25).resample(time=freq).sum(dim="time") - return to_agg_units(cold_and_dry, tas, "count") + cold_and_dry = cast(xarray.DataArray, np.logical_and(tg25, pr25)) + resampled = cold_and_dry.resample(time=freq).sum(dim="time") + out = to_agg_units(resampled, tas, "count") + return out @declare_units( @@ -234,15 +236,15 @@ def warm_and_dry_days( Parameters ---------- tas : xarray.DataArray - Mean daily temperature values + Mean daily temperature values pr : xarray.DataArray - Daily precipitation. + Daily precipitation. tas_per : xarray.DataArray - Third quartile of daily mean temperature computed by month. + Third quartile of daily mean temperature computed by month. pr_per : xarray.DataArray - First quartile of daily total precipitation computed by month. + First quartile of daily total precipitation computed by month. freq : str - Resampling frequency. + Resampling frequency. Warnings -------- @@ -252,7 +254,7 @@ def warm_and_dry_days( Returns ------- xarray.DataArray, - The total number of days when warm and dry conditions coincide. + The total number of days when warm and dry conditions coincide. Notes ----- @@ -274,8 +276,10 @@ def warm_and_dry_days( thresh = resample_doy(pr_per, pr) pr25 = pr < thresh - warm_and_dry = np.logical_and(tg75, pr25).resample(time=freq).sum(dim="time") - return to_agg_units(warm_and_dry, tas, "count") + warm_and_dry = cast(xarray.DataArray, np.logical_and(tg75, pr25)) + resampled = warm_and_dry.resample(time=freq).sum(dim="time") + out = to_agg_units(resampled, tas, "count") + return out @declare_units( @@ -298,15 +302,15 @@ def warm_and_wet_days( Parameters ---------- tas : xarray.DataArray - Mean daily temperature values + Mean daily temperature values pr : xarray.DataArray - Daily precipitation. + Daily precipitation. tas_per : xarray.DataArray - Third quartile of daily mean temperature computed by month. + Third quartile of daily mean temperature computed by month. pr_per : xarray.DataArray - Third quartile of daily total precipitation computed by month. + Third quartile of daily total precipitation computed by month. freq : str - Resampling frequency. + Resampling frequency. Warnings -------- @@ -316,7 +320,7 @@ def warm_and_wet_days( Returns ------- xarray.DataArray - The total number of days when warm and wet conditions coincide. + The total number of days when warm and wet conditions coincide. Notes ----- @@ -337,8 +341,10 @@ def warm_and_wet_days( thresh = resample_doy(pr_per, pr) pr75 = pr > thresh - warm_and_wet = np.logical_and(tg75, pr75).resample(time=freq).sum(dim="time") - return to_agg_units(warm_and_wet, tas, "count") + warm_and_wet = cast(xarray.DataArray, np.logical_and(tg75, pr75)) + resampled = warm_and_wet.resample(time=freq).sum(dim="time") + out = to_agg_units(resampled, tas, "count") + return out @declare_units( @@ -366,20 +372,20 @@ def cold_and_wet_days( Parameters ---------- tas : xarray.DataArray - Mean daily temperature values + Mean daily temperature values pr : xarray.DataArray - Daily precipitation. + Daily precipitation. tas_per : xarray.DataArray - First quartile of daily mean temperature computed by month. + First quartile of daily mean temperature computed by month. pr_per : xarray.DataArray - Third quartile of daily total precipitation computed by month. + Third quartile of daily total precipitation computed by month. freq : str - Resampling frequency. + Resampling frequency. Returns ------- xarray.DataArray - The total number of days when cold and wet conditions coincide. + The total number of days when cold and wet conditions coincide. Notes ----- @@ -400,8 +406,10 @@ def cold_and_wet_days( thresh = resample_doy(pr_per, pr) pr75 = pr > thresh - cold_and_wet = np.logical_and(tg25, pr75).resample(time=freq).sum(dim="time") - return to_agg_units(cold_and_wet, tas, "count") + cold_and_wet = cast(xarray.DataArray, np.logical_and(tg25, pr75)) + resampled = cold_and_wet.resample(time=freq).sum(dim="time") + out = to_agg_units(resampled, tas, "count") + return out @declare_units( @@ -430,31 +438,31 @@ def multiday_temperature_swing( Parameters ---------- tasmin : xarray.DataArray - Minimum daily temperature. + Minimum daily temperature. tasmax : xarray.DataArray - Maximum daily temperature. + Maximum daily temperature. thresh_tasmin : Quantified - The temperature threshold needed to trigger a freeze event. + The temperature threshold needed to trigger a freeze event. thresh_tasmax : Quantified - The temperature threshold needed to trigger a thaw event. + The temperature threshold needed to trigger a thaw event. window : int - The minimal length of spells to be included in the statistics. + The minimal length of spells to be included in the statistics. op : {'mean', 'sum', 'max', 'min', 'std', 'count'} - The statistical operation to use when reducing the list of spell lengths. + The statistical operation to use when reducing the list of spell lengths. op_tasmin : {"<", "<=", "lt", "le"} - Comparison operation for tasmin. Default: "<=". + Comparison operation for tasmin. Default: "<=". op_tasmax : {">", ">=", "gt", "ge"} - Comparison operation for tasmax. Default: ">". + Comparison operation for tasmax. Default: ">". freq : str - Resampling frequency. + Resampling frequency. resample_before_rl : bool - Determines if the resampling should take place before or after the run - length encoding (or a similar algorithm) is applied to runs. + Determines if the resampling should take place before or after the run + length encoding (or a similar algorithm) is applied to runs. Returns ------- xarray.DataArray, [time] - {freq} {op} length of diurnal temperature cycles exceeding thresholds. + {freq} {op} length of diurnal temperature cycles exceeding thresholds. Notes ----- @@ -510,22 +518,22 @@ def daily_temperature_range( Parameters ---------- tasmin : xarray.DataArray - Minimum daily temperature. + Minimum daily temperature. tasmax : xarray.DataArray - Maximum daily temperature. + Maximum daily temperature. freq : str - Resampling frequency. + Resampling frequency. op : {'min', 'max', 'mean', 'std'} or func - Reduce operation. Can either be a DataArray method or a function that can be applied to a DataArray. + Reduce operation. Can either be a DataArray method or a function that can be applied to a DataArray. Returns ------- xarray.DataArray, [same units as tasmin] - The average variation in daily temperature range for the given time period. + The average variation in daily temperature range for the given time period. Notes ----- - For a default calculation using `op='mean'` : + For a default calculation using `op='mean'`: Let :math:`TX_{ij}` and :math:`TN_{ij}` be the daily maximum and minimum temperature at day :math:`i` of period :math:`j`. Then the mean diurnal temperature range in period :math:`j` is: @@ -553,16 +561,16 @@ def daily_temperature_range_variability( Parameters ---------- tasmin : xarray.DataArray - Minimum daily temperature. + Minimum daily temperature. tasmax : xarray.DataArray - Maximum daily temperature. + Maximum daily temperature. freq : str - Resampling frequency. + Resampling frequency. Returns ------- xarray.DataArray, [same units as tasmin] - The average day-to-day variation in daily temperature range for the given time period. + The average day-to-day variation in daily temperature range for the given time period. Notes ----- @@ -592,16 +600,16 @@ def extreme_temperature_range( Parameters ---------- tasmin : xarray.DataArray - Minimum daily temperature. + Minimum daily temperature. tasmax : xarray.DataArray - Maximum daily temperature. + Maximum daily temperature. freq : str - Resampling frequency. + Resampling frequency. Returns ------- xarray.DataArray, [same units as tasmin] - Extreme intra-period temperature range for the given time period. + Extreme intra-period temperature range for the given time period. Notes ----- @@ -618,7 +626,7 @@ def extreme_temperature_range( out = tx_max - tn_min u = str2pint(tasmax.units) - out.attrs["units"] = pint2cfunits(u - u) + out = out.assign_attrs(units=pint2cfunits(u - u)) return out @@ -646,27 +654,27 @@ def heat_wave_frequency( Parameters ---------- tasmin : xarray.DataArray - Minimum daily temperature. + Minimum daily temperature. tasmax : xarray.DataArray - Maximum daily temperature. + Maximum daily temperature. thresh_tasmin : Quantified - The minimum temperature threshold needed to trigger a heatwave event. + The minimum temperature threshold needed to trigger a heatwave event. thresh_tasmax : Quantified - The maximum temperature threshold needed to trigger a heatwave event. - window: int - Minimum number of days with temperatures above thresholds to qualify as a heatwave. + The maximum temperature threshold needed to trigger a heatwave event. + window : int + Minimum number of days with temperatures above thresholds to qualify as a heatwave. freq : str - Resampling frequency. - op: {">", ">=", "gt", "ge"} - Comparison operation. Default: ">". + Resampling frequency. + op : {">", ">=", "gt", "ge"} + Comparison operation. Default: ">". resample_before_rl : bool - Determines if the resampling should take place before or after the run - length encoding (or a similar algorithm) is applied to runs. + Determines if the resampling should take place before or after the run + length encoding (or a similar algorithm) is applied to runs. Returns ------- xarray.DataArray, [dimensionless] - Number of heatwave at the requested frequency. + Number of heatwave at the requested frequency. Notes ----- @@ -696,7 +704,7 @@ def heat_wave_frequency( window=window, freq=freq, ) - out.attrs["units"] = "" + out = out.assign_attrs(units="") return out @@ -726,27 +734,27 @@ def heat_wave_max_length( Parameters ---------- tasmin : xarray.DataArray - Minimum daily temperature. + Minimum daily temperature. tasmax : xarray.DataArray - Maximum daily temperature. + Maximum daily temperature. thresh_tasmin : Quantified - The minimum temperature threshold needed to trigger a heatwave event. + The minimum temperature threshold needed to trigger a heatwave event. thresh_tasmax : Quantified - The maximum temperature threshold needed to trigger a heatwave event. + The maximum temperature threshold needed to trigger a heatwave event. window : int - Minimum number of days with temperatures above thresholds to qualify as a heatwave. + Minimum number of days with temperatures above thresholds to qualify as a heatwave. freq : str - Resampling frequency. + Resampling frequency. op : {">", ">=", "gt", "ge"} - Comparison operation. Default: ">". + Comparison operation. Default: ">". resample_before_rl : bool - Determines if the resampling should take place before or after the run - length encoding (or a similar algorithm) is applied to runs. + Determines if the resampling should take place before or after the run + length encoding (or a similar algorithm) is applied to runs. Returns ------- xarray.DataArray, [time] - Maximum length of heatwave at the requested frequency. + Maximum length of heatwave at the requested frequency. Notes ----- @@ -805,27 +813,27 @@ def heat_wave_total_length( Parameters ---------- tasmin : xarray.DataArray - Minimum daily temperature. + Minimum daily temperature. tasmax : xarray.DataArray - Maximum daily temperature. + Maximum daily temperature. thresh_tasmin : str - The minimum temperature threshold needed to trigger a heatwave event. + The minimum temperature threshold needed to trigger a heatwave event. thresh_tasmax : str - The maximum temperature threshold needed to trigger a heatwave event. + The maximum temperature threshold needed to trigger a heatwave event. window : int - Minimum number of days with temperatures above thresholds to qualify as a heatwave. + Minimum number of days with temperatures above thresholds to qualify as a heatwave. freq : str - Resampling frequency. + Resampling frequency. op: {">", ">=", "gt", "ge"} - Comparison operation. Default: ">". + Comparison operation. Default: ">". resample_before_rl : bool - Determines if the resampling should take place before or after the run - length encoding (or a similar algorithm) is applied to runs. + Determines if the resampling should take place before or after the run + length encoding (or a similar algorithm) is applied to runs. Returns ------- xarray.DataArray, [time] - Total length of heatwave at the requested frequency. + Total length of heatwave at the requested frequency. Notes ----- @@ -870,20 +878,20 @@ def liquid_precip_ratio( Parameters ---------- pr : xarray.DataArray - Mean daily precipitation flux. + Mean daily precipitation flux. prsn : xarray.DataArray, optional - Mean daily solid precipitation flux. + Mean daily solid precipitation flux. tas : xarray.DataArray, optional - Mean daily temperature. + Mean daily temperature. thresh : Quantified - Threshold temperature under which precipitation is assumed to be solid. + Threshold temperature under which precipitation is assumed to be solid. freq : str - Resampling frequency. + Resampling frequency. Returns ------- xarray.DataArray, [dimensionless] - Ratio of rainfall to total precipitation. + Ratio of rainfall to total precipitation. Notes ----- @@ -908,7 +916,7 @@ def liquid_precip_ratio( tot = pr.resample(time=freq).sum(dim="time") rain = tot - prsn.resample(time=freq).sum(dim="time") ratio = rain / tot - ratio.attrs["units"] = "" + ratio = ratio.assign_attrs(units="") return ratio @@ -971,7 +979,8 @@ def precip_accumulation( elif phase == "solid": pr = snowfall_approximation(pr, tas=tas, thresh=thresh, method="binary") pram = rate2amount(pr) - return pram.resample(time=freq).sum(dim="time").assign_attrs(units=pram.units) + pram = pram.resample(time=freq).sum(dim="time").assign_attrs(units=pram.units) + return pram @declare_units(pr="[precipitation]", tas="[temperature]", thresh="[temperature]") @@ -992,20 +1001,20 @@ def precip_average( Parameters ---------- pr : xarray.DataArray - Mean daily precipitation flux. + Mean daily precipitation flux. tas : xarray.DataArray, optional - Mean, maximum or minimum daily temperature. + Mean, maximum or minimum daily temperature. phase : {None, 'liquid', 'solid'} - Which phase to consider, "liquid" or "solid", if None (default), both are considered. + Which phase to consider, "liquid" or "solid", if None (default), both are considered. thresh : Quantified - Threshold of `tas` over which the precipication is assumed to be liquid rain. + Threshold of `tas` over which the precipication is assumed to be liquid rain. freq : str - Resampling frequency. + Resampling frequency. Returns ------- xarray.DataArray, [length] - The averaged daily precipitation at the given time frequency for the given phase. + The averaged daily precipitation at the given time frequency for the given phase. Notes ----- @@ -1033,7 +1042,8 @@ def precip_average( elif phase == "solid": pr = snowfall_approximation(pr, tas=tas, thresh=thresh, method="binary") pram = rate2amount(pr) - return pram.resample(time=freq).mean(dim="time").assign_attrs(units=pram.units) + pram = pram.resample(time=freq).mean(dim="time").assign_attrs(units=pram.units) + return pram # FIXME: Resample after run length? @@ -1052,18 +1062,18 @@ def rain_on_frozen_ground_days( Parameters ---------- pr : xarray.DataArray - Mean daily precipitation flux. + Mean daily precipitation flux. tas : xarray.DataArray - Mean daily temperature. + Mean daily temperature. thresh : Quantified - Precipitation threshold to consider a day as a rain event. + Precipitation threshold to consider a day as a rain event. freq : str - Resampling frequency. + Resampling frequency. Returns ------- xarray.DataArray, [time] - The number of rain on frozen ground events per period. + The number of rain on frozen ground events per period. Notes ----- @@ -1118,20 +1128,20 @@ def high_precip_low_temp( Parameters ---------- pr : xarray.DataArray - Mean daily precipitation flux. + Mean daily precipitation flux. tas : xarray.DataArray - Daily mean, minimum or maximum temperature. + Daily mean, minimum or maximum temperature. pr_thresh : Quantified - Precipitation threshold to exceed. + Precipitation threshold to exceed. tas_thresh : Quantified - Temperature threshold not to exceed. + Temperature threshold not to exceed. freq : str - Resampling frequency. + Resampling frequency. Returns ------- xarray.DataArray, [time] - Count of days with high precipitation and low temperatures. + Count of days with high precipitation and low temperatures. Example ------- @@ -1168,23 +1178,23 @@ def days_over_precip_thresh( Parameters ---------- pr : xarray.DataArray - Mean daily precipitation flux. + Mean daily precipitation flux. pr_per : xarray.DataArray - Percentile of wet day precipitation flux. Either computed daily (one value per day - of year) or computed over a period (one value per spatial point). + Percentile of wet day precipitation flux. Either computed daily (one value per day + of year) or computed over a period (one value per spatial point). thresh : Quantified - Precipitation value over which a day is considered wet. + Precipitation value over which a day is considered wet. freq : str - Resampling frequency. + Resampling frequency. bootstrap : bool - Flag to run bootstrapping of percentiles. Used by percentile_bootstrap decorator. - Bootstrapping is only useful when the percentiles are computed on a part of the studied sample. - This period, common to percentiles and the sample must be bootstrapped to avoid inhomogeneities with - the rest of the time series. - Keep bootstrap to False when there is no common period, it would give wrong results - plus, bootstrapping is computationally expensive. - op: {">", ">=", "gt", "ge"} - Comparison operation. Default: ">". + Flag to run bootstrapping of percentiles. Used by percentile_bootstrap decorator. + Bootstrapping is only useful when the percentiles are computed on a part of the studied sample. + This period, common to percentiles and the sample must be bootstrapped to avoid inhomogeneities with + the rest of the time series. + Keep bootstrap to False when there is no common period, it would give wrong results + plus, bootstrapping is computationally expensive. + op : {">", ">=", "gt", "ge"} + Comparison operation. Default: ">". Returns ------- @@ -1229,28 +1239,28 @@ def fraction_over_precip_thresh( Parameters ---------- pr : xarray.DataArray - Mean daily precipitation flux. + Mean daily precipitation flux. pr_per : xarray.DataArray - Percentile of wet day precipitation flux. Either computed daily (one value per day - of year) or computed over a period (one value per spatial point). + Percentile of wet day precipitation flux. Either computed daily (one value per day + of year) or computed over a period (one value per spatial point). thresh : Quantified - Precipitation value over which a day is considered wet. + Precipitation value over which a day is considered wet. freq : str - Resampling frequency. + Resampling frequency. bootstrap : bool - Flag to run bootstrapping of percentiles. Used by percentile_bootstrap decorator. - Bootstrapping is only useful when the percentiles are computed on a part of the studied sample. - This period, common to percentiles and the sample must be bootstrapped to avoid inhomogeneities with - the rest of the time series. - Keep bootstrap to False when there is no common period, it would give wrong results - plus, bootstrapping is computationally expensive. - op: {">", ">=", "gt", "ge"} - Comparison operation. Default: ">". + Flag to run bootstrapping of percentiles. Used by percentile_bootstrap decorator. + Bootstrapping is only useful when the percentiles are computed on a part of the studied sample. + This period, common to percentiles and the sample must be bootstrapped to avoid inhomogeneities with + the rest of the time series. + Keep bootstrap to False when there is no common period, it would give wrong results + plus, bootstrapping is computationally expensive. + op : {">", ">=", "gt", "ge"} + Comparison operation. Default: ">". Returns ------- xarray.DataArray, [dimensionless] - Fraction of precipitation over threshold during wet days. + Fraction of precipitation over threshold during wet days. """ pr_per = convert_units_to(pr_per, pr, context="hydro") @@ -1295,25 +1305,25 @@ def tg90p( Parameters ---------- tas : xarray.DataArray - Mean daily temperature. + Mean daily temperature. tas_per : xarray.DataArray - 90th percentile of daily mean temperature. + 90th percentile of daily mean temperature. freq : str - Resampling frequency. + Resampling frequency. bootstrap : bool - Flag to run bootstrapping of percentiles. Used by percentile_bootstrap decorator. - Bootstrapping is only useful when the percentiles are computed on a part of the studied sample. - This period, common to percentiles and the sample must be bootstrapped to avoid inhomogeneities with - the rest of the time series. - Keep bootstrap to False when there is no common period, it would give wrong results - plus, bootstrapping is computationally expensive. - op: {">", ">=", "gt", "ge"} - Comparison operation. Default: ">". + Flag to run bootstrapping of percentiles. Used by percentile_bootstrap decorator. + Bootstrapping is only useful when the percentiles are computed on a part of the studied sample. + This period, common to percentiles and the sample must be bootstrapped to avoid inhomogeneities with + the rest of the time series. + Keep bootstrap to False when there is no common period, it would give wrong results + plus, bootstrapping is computationally expensive. + op : {">", ">=", "gt", "ge"} + Comparison operation. Default: ">". Returns ------- xarray.DataArray, [time] - Count of days with daily mean temperature below the 10th percentile [days]. + Count of days with daily mean temperature below the 10th percentile [days]. Notes ----- @@ -1353,25 +1363,25 @@ def tg10p( Parameters ---------- tas : xarray.DataArray - Mean daily temperature. + Mean daily temperature. tas_per : xarray.DataArray - 10th percentile of daily mean temperature. + 10th percentile of daily mean temperature. freq : str - Resampling frequency. + Resampling frequency. bootstrap : bool - Flag to run bootstrapping of percentiles. Used by percentile_bootstrap decorator. - Bootstrapping is only useful when the percentiles are computed on a part of the studied sample. - This period, common to percentiles and the sample must be bootstrapped to avoid inhomogeneities with - the rest of the time series. - Keep bootstrap to False when there is no common period, it would give wrong results - plus, bootstrapping is computationally expensive. - op: {"<", "<=", "lt", "le"} - Comparison operation. Default: "<". + Flag to run bootstrapping of percentiles. Used by percentile_bootstrap decorator. + Bootstrapping is only useful when the percentiles are computed on a part of the studied sample. + This period, common to percentiles and the sample must be bootstrapped to avoid inhomogeneities with + the rest of the time series. + Keep bootstrap to False when there is no common period, it would give wrong results + plus, bootstrapping is computationally expensive. + op : {"<", "<=", "lt", "le"} + Comparison operation. Default: "<". Returns ------- xarray.DataArray, [time] - Count of days with daily mean temperature below the 10th percentile [days]. + Count of days with daily mean temperature below the 10th percentile [days]. Notes ----- @@ -1411,25 +1421,25 @@ def tn90p( Parameters ---------- tasmin : xarray.DataArray - Minimum daily temperature. + Minimum daily temperature. tasmin_per : xarray.DataArray - 90th percentile of daily minimum temperature. + 90th percentile of daily minimum temperature. freq : str - Resampling frequency. + Resampling frequency. bootstrap : bool - Flag to run bootstrapping of percentiles. Used by percentile_bootstrap decorator. - Bootstrapping is only useful when the percentiles are computed on a part of the studied sample. - This period, common to percentiles and the sample must be bootstrapped to avoid inhomogeneities with - the rest of the time series. - Keep bootstrap to False when there is no common period, it would give wrong results - plus, bootstrapping is computationally expensive. - op: {">", ">=", "gt", "ge"} - Comparison operation. Default: ">". + Flag to run bootstrapping of percentiles. Used by percentile_bootstrap decorator. + Bootstrapping is only useful when the percentiles are computed on a part of the studied sample. + This period, common to percentiles and the sample must be bootstrapped to avoid inhomogeneities with + the rest of the time series. + Keep bootstrap to False when there is no common period, it would give wrong results + plus, bootstrapping is computationally expensive. + op : {">", ">=", "gt", "ge"} + Comparison operation. Default: ">". Returns ------- xarray.DataArray, [time] - Count of days with daily minimum temperature below the 10th percentile [days]. + Count of days with daily minimum temperature below the 10th percentile [days]. Notes ----- @@ -1469,20 +1479,20 @@ def tn10p( Parameters ---------- tasmin : xarray.DataArray - Mean daily temperature. + Mean daily temperature. tasmin_per : xarray.DataArray - 10th percentile of daily minimum temperature. + 10th percentile of daily minimum temperature. freq : str - Resampling frequency. + Resampling frequency. bootstrap : bool - Flag to run bootstrapping of percentiles. Used by percentile_bootstrap decorator. - Bootstrapping is only useful when the percentiles are computed on a part of the studied sample. - This period, common to percentiles and the sample must be bootstrapped to avoid inhomogeneities with - the rest of the time series. - Keep bootstrap to False when there is no common period, it would give wrong results - plus, bootstrapping is computationally expensive. - op: {"<", "<=", "lt", "le"} - Comparison operation. Default: "<". + Flag to run bootstrapping of percentiles. Used by percentile_bootstrap decorator. + Bootstrapping is only useful when the percentiles are computed on a part of the studied sample. + This period, common to percentiles and the sample must be bootstrapped to avoid inhomogeneities with + the rest of the time series. + Keep bootstrap to False when there is no common period, it would give wrong results + plus, bootstrapping is computationally expensive. + op : {"<", "<=", "lt", "le"} + Comparison operation. Default: "<". Returns ------- @@ -1527,25 +1537,25 @@ def tx90p( Parameters ---------- tasmax : xarray.DataArray - Maximum daily temperature. + Maximum daily temperature. tasmax_per : xarray.DataArray - 90th percentile of daily maximum temperature. + 90th percentile of daily maximum temperature. freq : str - Resampling frequency. + Resampling frequency. bootstrap : bool - Flag to run bootstrapping of percentiles. Used by percentile_bootstrap decorator. - Bootstrapping is only useful when the percentiles are computed on a part of the studied sample. - This period, common to percentiles and the sample must be bootstrapped to avoid inhomogeneities with - the rest of the time series. - Keep bootstrap to False when there is no common period, it would give wrong results - plus, bootstrapping is computationally expensive. - op: {">", ">=", "gt", "ge"} - Comparison operation. Default: ">". + Flag to run bootstrapping of percentiles. Used by percentile_bootstrap decorator. + Bootstrapping is only useful when the percentiles are computed on a part of the studied sample. + This period, common to percentiles and the sample must be bootstrapped to avoid inhomogeneities with + the rest of the time series. + Keep bootstrap to False when there is no common period, it would give wrong results + plus, bootstrapping is computationally expensive. + op : {">", ">=", "gt", "ge"} + Comparison operation. Default: ">". Returns ------- xarray.DataArray, [time] - Count of days with daily maximum temperature below the 10th percentile [days]. + Count of days with daily maximum temperature below the 10th percentile [days]. Notes ----- @@ -1585,25 +1595,25 @@ def tx10p( Parameters ---------- tasmax : xarray.DataArray - Maximum daily temperature. + Maximum daily temperature. tasmax_per : xarray.DataArray - 10th percentile of daily maximum temperature. + 10th percentile of daily maximum temperature. freq : str - Resampling frequency. + Resampling frequency. bootstrap : bool - Flag to run bootstrapping of percentiles. Used by percentile_bootstrap decorator. - Bootstrapping is only useful when the percentiles are computed on a part of the studied sample. - This period, common to percentiles and the sample must be bootstrapped to avoid inhomogeneities with - the rest of the time series. - Keep bootstrap to False when there is no common period, it would give wrong results - plus, bootstrapping is computationally expensive. - op: {"<", "<=", "lt", "le"} - Comparison operation. Default: "<". + Flag to run bootstrapping of percentiles. Used by percentile_bootstrap decorator. + Bootstrapping is only useful when the percentiles are computed on a part of the studied sample. + This period, common to percentiles and the sample must be bootstrapped to avoid inhomogeneities with + the rest of the time series. + Keep bootstrap to False when there is no common period, it would give wrong results + plus, bootstrapping is computationally expensive. + op : {"<", "<=", "lt", "le"} + Comparison operation. Default: "<". Returns ------- xarray.DataArray, [time] - Count of days with daily maximum temperature below the 10th percentile [days]. + Count of days with daily maximum temperature below the 10th percentile [days]. Notes ----- @@ -1648,17 +1658,17 @@ def tx_tn_days_above( Parameters ---------- tasmin : xarray.DataArray - Minimum daily temperature. + Minimum daily temperature. tasmax : xarray.DataArray - Maximum daily temperature. + Maximum daily temperature. thresh_tasmin : Quantified - Threshold temperature for tasmin on which to base evaluation. + Threshold temperature for tasmin on which to base evaluation. thresh_tasmax : Quantified - Threshold temperature for tasmax on which to base evaluation. + Threshold temperature for tasmax on which to base evaluation. freq : str - Resampling frequency. - op: {">", ">=", "gt", "ge"} - Comparison operation. Default: ">". + Resampling frequency. + op : {">", ">=", "gt", "ge"} + Comparison operation. Default: ">". Returns ------- @@ -1715,30 +1725,30 @@ def warm_spell_duration_index( Parameters ---------- tasmax : xarray.DataArray - Maximum daily temperature. + Maximum daily temperature. tasmax_per : xarray.DataArray - percentile(s) of daily maximum temperature. + percentile(s) of daily maximum temperature. window : int - Minimum number of days with temperature above threshold to qualify as a warm spell. + Minimum number of days with temperature above threshold to qualify as a warm spell. freq : str - Resampling frequency. + Resampling frequency. resample_before_rl : bool - Determines if the resampling should take place before or after the run - length encoding (or a similar algorithm) is applied to runs. + Determines if the resampling should take place before or after the run + length encoding (or a similar algorithm) is applied to runs. bootstrap : bool - Flag to run bootstrapping of percentiles. Used by percentile_bootstrap decorator. - Bootstrapping is only useful when the percentiles are computed on a part of the studied sample. - This period, common to percentiles and the sample must be bootstrapped to avoid inhomogeneities with - the rest of the time series. - Keep bootstrap to False when there is no common period, it would give wrong results - plus, bootstrapping is computationally expensive. - op: {">", ">=", "gt", "ge"} - Comparison operation. Default: ">". + Flag to run bootstrapping of percentiles. Used by percentile_bootstrap decorator. + Bootstrapping is only useful when the percentiles are computed on a part of the studied sample. + This period, common to percentiles and the sample must be bootstrapped to avoid inhomogeneities with + the rest of the time series. + Keep bootstrap to False when there is no common period, it would give wrong results + plus, bootstrapping is computationally expensive. + op : {">", ">=", "gt", "ge"} + Comparison operation. Default: ">". Returns ------- xarray.DataArray, [time] - Warm spell duration index. + Warm spell duration index. References ---------- @@ -1789,22 +1799,23 @@ def winter_rain_ratio( Parameters ---------- pr : xarray.DataArray - Mean daily precipitation flux. + Mean daily precipitation flux. prsn : xarray.DataArray, optional - Mean daily solid precipitation flux. + Mean daily solid precipitation flux. tas : xarray.DataArray, optional - Mean daily temperature. + Mean daily temperature. freq : str - Resampling frequency. + Resampling frequency. Returns ------- xarray.DataArray - Ratio of rainfall to total precipitation during winter months (DJF). + Ratio of rainfall to total precipitation during winter months (DJF). """ ratio = liquid_precip_ratio(pr, prsn, tas, freq=freq) winter = ratio.indexes["time"].month == 12 - return ratio.sel(time=winter) + ratio = ratio.sel(time=winter) + return ratio @declare_units( @@ -1825,22 +1836,22 @@ def blowing_snow( Parameters ---------- snd : xarray.DataArray - Surface snow depth. + Surface snow depth. sfcWind : xr.DataArray - Wind velocity + Wind velocity snd_thresh : Quantified - Threshold on net snowfall accumulation over the last `window` days. + Threshold on net snowfall accumulation over the last `window` days. sfcWind_thresh : Quantified - Wind speed threshold. + Wind speed threshold. window : int - Period over which snow is accumulated before comparing against threshold. + Period over which snow is accumulated before comparing against threshold. freq : str - Resampling frequency. + Resampling frequency. Returns ------- xarray.DataArray - Number of days when snowfall and wind speeds are above respective thresholds. + Number of days when snowfall and wind speeds are above respective thresholds. """ snd_thresh = convert_units_to(snd_thresh, snd) sfcWind_thresh = convert_units_to(sfcWind_thresh, sfcWind) # noqa @@ -1852,5 +1863,5 @@ def blowing_snow( cond = (snow >= snd_thresh) * (sfcWind >= sfcWind_thresh) * 1 out = cond.resample(time=freq).sum(dim="time") - out.attrs["units"] = to_agg_units(out, snd, "count") + out = out.assign_attrs(units=to_agg_units(out, snd, "count")) return out diff --git a/xclim/indices/_simple.py b/xclim/indices/_simple.py index e59db0f22..c97d714a4 100644 --- a/xclim/indices/_simple.py +++ b/xclim/indices/_simple.py @@ -535,7 +535,9 @@ def snow_depth( @declare_units(sfcWind="[speed]") -def sfcWind_max(sfcWind: xarray.DataArray, freq: str = "YS") -> xarray.DataArray: +def sfcWind_max( # noqa: N802 + sfcWind: xarray.DataArray, freq: str = "YS" +) -> xarray.DataArray: r"""Highest daily mean wind speed. The maximum of daily mean wind speed. @@ -574,7 +576,9 @@ def sfcWind_max(sfcWind: xarray.DataArray, freq: str = "YS") -> xarray.DataArray @declare_units(sfcWind="[speed]") -def sfcWind_mean(sfcWind: xarray.DataArray, freq: str = "YS") -> xarray.DataArray: +def sfcWind_mean( # noqa: N802 + sfcWind: xarray.DataArray, freq: str = "YS" +) -> xarray.DataArray: r"""Mean of daily mean wind speed. Resample the original daily mean wind speed series by taking the mean over each period. @@ -615,7 +619,9 @@ def sfcWind_mean(sfcWind: xarray.DataArray, freq: str = "YS") -> xarray.DataArra @declare_units(sfcWind="[speed]") -def sfcWind_min(sfcWind: xarray.DataArray, freq: str = "YS") -> xarray.DataArray: +def sfcWind_min( # noqa: N802 + sfcWind: xarray.DataArray, freq: str = "YS" +) -> xarray.DataArray: r"""Lowest daily mean wind speed. The minimum of daily mean wind speed. @@ -654,7 +660,9 @@ def sfcWind_min(sfcWind: xarray.DataArray, freq: str = "YS") -> xarray.DataArray @declare_units(sfcWindmax="[speed]") -def sfcWindmax_max(sfcWindmax: xarray.DataArray, freq: str = "YS") -> xarray.DataArray: +def sfcWindmax_max( # noqa: N802 + sfcWindmax: xarray.DataArray, freq: str = "YS" +) -> xarray.DataArray: r"""Highest maximum wind speed. The maximum of daily maximum wind speed. @@ -696,7 +704,9 @@ def sfcWindmax_max(sfcWindmax: xarray.DataArray, freq: str = "YS") -> xarray.Dat @declare_units(sfcWindmax="[speed]") -def sfcWindmax_mean(sfcWindmax: xarray.DataArray, freq: str = "YS") -> xarray.DataArray: +def sfcWindmax_mean( # noqa: N802 + sfcWindmax: xarray.DataArray, freq: str = "YS" +) -> xarray.DataArray: r"""Mean of daily maximum wind speed. Resample the original daily maximum wind speed series by taking the mean over each period. @@ -738,7 +748,9 @@ def sfcWindmax_mean(sfcWindmax: xarray.DataArray, freq: str = "YS") -> xarray.Da @declare_units(sfcWindmax="[speed]") -def sfcWindmax_min(sfcWindmax: xarray.DataArray, freq: str = "YS") -> xarray.DataArray: +def sfcWindmax_min( # noqa: N802 + sfcWindmax: xarray.DataArray, freq: str = "YS" +) -> xarray.DataArray: r"""Lowest daily maximum wind speed. The minimum of daily maximum wind speed. diff --git a/xclim/indices/_threshold.py b/xclim/indices/_threshold.py index f08b27761..150c55b04 100644 --- a/xclim/indices/_threshold.py +++ b/xclim/indices/_threshold.py @@ -15,6 +15,7 @@ rate2amount, str2pint, to_agg_units, + units2pint, ) from xclim.core.utils import DayOfYearStr, Quantified @@ -292,8 +293,9 @@ def cold_spell_max_length( rl.longest_run, freq=freq, ) - out = max_l.where(max_l >= window, 0) - return to_agg_units(out, tas, "count") + max_window = max_l.where(max_l >= window, 0) + out = to_agg_units(max_window, tas, "count") + return out @declare_units(tas="[temperature]", thresh="[temperature]") @@ -382,13 +384,16 @@ def snd_season_end( thresh = convert_units_to(thresh, snd) cond = snd >= thresh - out = ( + resampled = ( cond.resample(time=freq) .map(rl.season, window=window, dim="time", coord="dayofyear") .end ) - out.attrs.update(units="", is_dayofyear=np.int32(1), calendar=get_calendar(snd)) - return out.where(~valid) + resampled = resampled.assign_attrs( + units="", is_dayofyear=np.int32(1), calendar=get_calendar(snd) + ) + snd_se = resampled.where(~valid) + return snd_se @declare_units(snw="[mass]/[area]", thresh="[mass]/[area]") @@ -428,13 +433,16 @@ def snw_season_end( thresh = convert_units_to(thresh, snw) cond = snw >= thresh - out = ( + resampled = ( cond.resample(time=freq) .map(rl.season, window=window, dim="time", coord="dayofyear") .end ) - out.attrs.update(units="", is_dayofyear=np.int32(1), calendar=get_calendar(snw)) - return out.where(~valid) + resampled.attrs.update( + units="", is_dayofyear=np.int32(1), calendar=get_calendar(snw) + ) + snw_se = resampled.where(~valid) + return snw_se @declare_units(snd="[length]", thresh="[length]") @@ -474,7 +482,7 @@ def snd_season_start( thresh = convert_units_to(thresh, snd) cond = snd >= thresh - out = ( + resampled = ( cond.resample(time=freq) .map( rl.season, @@ -484,8 +492,11 @@ def snd_season_start( ) .start ) - out.attrs.update(units="", is_dayofyear=np.int32(1), calendar=get_calendar(snd)) - return out.where(~valid) + resampled.attrs.update( + units="", is_dayofyear=np.int32(1), calendar=get_calendar(snd) + ) + snd_ss = resampled.where(~valid) + return snd_ss @declare_units(snw="[mass]/[area]", thresh="[mass]/[area]") @@ -526,7 +537,7 @@ def snw_season_start( thresh = convert_units_to(thresh, snw) cond = snw >= thresh - out = ( + resampled = ( cond.resample(time=freq) .map( rl.season, @@ -536,8 +547,11 @@ def snw_season_start( ) .start ) - out.attrs.update(units="", is_dayofyear=np.int32(1), calendar=get_calendar(snw)) - return out.where(~valid) + resampled.attrs.update( + units="", is_dayofyear=np.int32(1), calendar=get_calendar(snw) + ) + snw_ss = resampled.where(~valid) + return snw_ss @declare_units(snd="[length]", thresh="[length]") @@ -577,12 +591,13 @@ def snd_season_length( thresh = convert_units_to(thresh, snd) cond = snd >= thresh - out = ( + snd_sl = ( cond.resample(time=freq) .map(rl.season, window=window, dim="time", coord="dayofyear") .length ) - return to_agg_units(out.where(~valid), snd, "count") + snd_sl = to_agg_units(snd_sl.where(~valid), snd, "count") + return snd_sl @declare_units(snw="[mass]/[area]", thresh="[mass]/[area]") @@ -622,12 +637,13 @@ def snw_season_length( thresh = convert_units_to(thresh, snw) cond = snw >= thresh - out = ( + snw_sl = ( cond.resample(time=freq) .map(rl.season, window=window, dim="time", coord="dayofyear") .length ) - return to_agg_units(out.where(~valid), snw, "count") + snw_sl = to_agg_units(snw_sl.where(~valid), snw, "count") + return snw_sl @declare_units(snd="[length]", thresh="[length]") @@ -666,10 +682,9 @@ def snd_storm_days( acc = snd.diff(dim="time") # Winter storm condition - out = threshold_count(acc, ">=", thresh, freq) - - out.attrs["units"] = to_agg_units(out, snd, "count") - return out + snd_sd = threshold_count(acc, ">=", thresh, freq) + snd_sd = snd_sd.assign_attrs(units=to_agg_units(snd_sd, snd, "count")) + return snd_sd @declare_units(snw="[mass]/[area]", thresh="[mass]/[area]") @@ -708,10 +723,9 @@ def snw_storm_days( acc = snw.diff(dim="time") # Winter storm condition - out = threshold_count(acc, ">=", thresh, freq) - - out.attrs["units"] = to_agg_units(out, snw, "count") - return out + snw_sd = threshold_count(acc, ">=", thresh, freq) + snw_sd = snw_sd.assign_attrs(units=to_agg_units(snw_sd, snw, "count")) + return snw_sd @declare_units(pr="[precipitation]", thresh="[precipitation]") @@ -778,15 +792,17 @@ def daily_pr_intensity( # get number of wetdays over period wd = wetdays(pr, thresh=thresh, freq=freq) - out = s / wd + dpr_int = s / wd # Issue originally introduced in https://github.com/hgrecco/pint/issues/1486 # Should be resolved in pint v0.24. See: https://github.com/hgrecco/pint/issues/1913 with warnings.catch_warnings(): warnings.simplefilter("ignore", category=DeprecationWarning) - out.attrs["units"] = f"{str2pint(pram.units) / str2pint(wd.units):~}" + dpr_int = dpr_int.assign_attrs( + units=f"{str2pint(pram.units) / str2pint(wd.units):~}" + ) - return out + return dpr_int @declare_units(pr="[precipitation]", thresh="[precipitation]") @@ -826,9 +842,9 @@ def dry_days( \sum PR_{ij} < Threshold [mm/day] """ thresh = convert_units_to(thresh, pr, context="hydro") - out = threshold_count(pr, op, thresh, freq, constrain=("<", "<=")) - out = to_agg_units(out, pr, "count") - return out + count = threshold_count(pr, op, thresh, freq, constrain=("<", "<=")) + dd = to_agg_units(count, pr, "count") + return dd @declare_units(pr="[precipitation]", thresh="[precipitation]") @@ -849,10 +865,10 @@ def maximum_consecutive_wet_days( thresh : Quantified Threshold precipitation on which to base evaluation. freq : str - Resampling frequency. + Resampling frequency. resample_before_rl : bool - Determines if the resampling should take place before or after the run - length encoding (or a similar algorithm) is applied to runs. + Determines if the resampling should take place before or after the run + length encoding (or a similar algorithm) is applied to runs. Returns ------- @@ -875,14 +891,14 @@ def maximum_consecutive_wet_days( thresh = convert_units_to(thresh, pr, "hydro") cond = pr > thresh - out = rl.resample_and_rl( + mcwd = rl.resample_and_rl( cond, resample_before_rl, rl.longest_run, freq=freq, ) - out = to_agg_units(out, pr, "count") - return out + mcwd = to_agg_units(mcwd, pr, "count") + return mcwd @declare_units(tas="[temperature]", thresh="[temperature]") @@ -918,7 +934,8 @@ def cooling_degree_days( where :math:`[P]` is 1 if :math:`P` is true, and 0 if false. """ - return cumulative_difference(tas, threshold=thresh, op=">", freq=freq) + cd = cumulative_difference(tas, threshold=thresh, op=">", freq=freq) + return cd @declare_units(tas="[temperature]", thresh="[temperature]") @@ -952,7 +969,8 @@ def growing_degree_days( GD4_j = \sum_{i=1}^I (TG_{ij}-{4} | TG_{ij} > {4}℃) """ - return cumulative_difference(tas, threshold=thresh, op=">", freq=freq) + cd = cumulative_difference(tas, threshold=thresh, op=">", freq=freq) + return cd @declare_units(tas="[temperature]", thresh="[temperature]") @@ -1558,7 +1576,7 @@ def first_day_temperature_below( """ # noqa - return first_day_threshold_reached( + fdtb = first_day_threshold_reached( tas, threshold=thresh, op=op, @@ -1567,6 +1585,7 @@ def first_day_temperature_below( freq=freq, constrain=("<", "<="), ) + return fdtb @declare_units(tas="[temperature]", thresh="[temperature]") @@ -1620,7 +1639,7 @@ def first_day_temperature_above( where :math:`w` is the number of days the temperature threshold should be exceeded, and :math:`[P]` is 1 if :math:`P` is true, and 0 if false. """ - return first_day_threshold_reached( + fdtr = first_day_threshold_reached( tas, threshold=thresh, op=op, @@ -1629,6 +1648,7 @@ def first_day_temperature_above( freq=freq, constrain=(">", ">="), ) + return fdtr @declare_units(prsn="[precipitation]", thresh="[precipitation]") @@ -1832,13 +1852,16 @@ def snowfall_frequency( """ # High threshold here just needs to be a big value. It is converted to same units as # so that a warning message won't be triggered just because of this value - thresh_units = pint2cfunits(str2pint(thresh)) - high = f"{convert_units_to('1E6 kg m-2 s-1', thresh_units, context='hydro')} {thresh_units}" + thresh_units = pint2cfunits(units2pint(thresh)) + high_thresh = convert_units_to("1E6 kg m-2 s-1", thresh_units, context="hydro") + high = f"{high_thresh} {thresh_units}" + snow_days = days_with_snow(prsn, low=thresh, high=high, freq=freq) total_days = prsn.resample(time=freq).count(dim="time") snow_freq = snow_days / total_days * 100 snow_freq = snow_freq.assign_attrs(**snow_days.attrs) - snow_freq.attrs["units"] = "%" + # overwrite snow_days units + snow_freq = snow_freq.assign_attrs(units="%") return snow_freq @@ -1887,9 +1910,9 @@ def snowfall_intensity( cond = lwe_prsn >= thresh mean = lwe_prsn.where(cond).resample(time=freq).mean(dim="time") - out = mean.fillna(0) - - return out.assign_attrs(units=lwe_prsn.units) + snow_int = mean.fillna(0) + snow_int = snow_int.assign_attrs(units=lwe_prsn.units) + return snow_int @declare_units(tasmax="[temperature]", thresh="[temperature]") @@ -1975,7 +1998,8 @@ def heating_degree_days( HD17_j = \sum_{i=1}^{I} (17℃ - TG_{ij}) | TG_{ij} < 17℃) """ - return cumulative_difference(tas, threshold=thresh, op="<", freq=freq) + hdd = cumulative_difference(tas, threshold=thresh, op="<", freq=freq) + return hdd @declare_units(tasmax="[temperature]", thresh="[temperature]") @@ -2679,7 +2703,7 @@ def maximum_consecutive_frost_days( where :math:`[P]` is 1 if :math:`P` is true, and 0 if false. Note that this formula does not handle sequences at the start and end of the series, but the numerical algorithm does. """ - return cold_spell_max_length( + csml: xarray.DataArray = cold_spell_max_length( tasmin, thresh=thresh, window=1, @@ -2687,6 +2711,7 @@ def maximum_consecutive_frost_days( op="<", resample_before_rl=resample_before_rl, ) + return csml @declare_units(pr="[precipitation]", thresh="[precipitation]") @@ -2734,13 +2759,14 @@ def maximum_consecutive_dry_days( """ t = convert_units_to(thresh, pr, context="hydro") group = pr < t - out = rl.resample_and_rl( + resampled = rl.resample_and_rl( group, resample_before_rl, rl.longest_run, freq=freq, ) - return to_agg_units(out, pr, "count") + mcdd = to_agg_units(resampled, pr, "count") + return mcdd @declare_units(tasmin="[temperature]", thresh="[temperature]") @@ -2790,7 +2816,7 @@ def maximum_consecutive_frost_free_days( where :math:`[P]` is 1 if :math:`P` is true, and 0 if false. Note that this formula does not handle sequences at the start and end of the series, but the numerical algorithm does. """ - return frost_free_spell_max_length( + mcffd = frost_free_spell_max_length( tasmin, thresh=thresh, window=1, @@ -2798,6 +2824,7 @@ def maximum_consecutive_frost_free_days( op=">=", resample_before_rl=resample_before_rl, ) + return mcffd @declare_units(tasmax="[temperature]", thresh="[temperature]") @@ -2843,7 +2870,7 @@ def maximum_consecutive_tx_days( where :math:`[P]` is 1 if :math:`P` is true, and 0 if false. Note that this formula does not handle sequences at the start and end of the series, but the numerical algorithm does. """ - return hot_spell_max_length( + mctxd = hot_spell_max_length( tasmax, thresh=thresh, window=1, @@ -2851,6 +2878,7 @@ def maximum_consecutive_tx_days( op=">", resample_before_rl=resample_before_rl, ) + return mctxd @declare_units(siconc="[]", areacello="[area]", thresh="[]") @@ -2886,9 +2914,9 @@ def sea_ice_area( """ t = convert_units_to(thresh, siconc) factor = convert_units_to("100 pct", siconc) - out = xarray.dot(siconc.where(siconc >= t, 0), areacello) / factor - out.attrs["units"] = areacello.units - return out + sia = xarray.dot(siconc.where(siconc >= t, 0), areacello) / factor + sia = sia.assign_attrs(units=areacello.units) + return sia @declare_units(siconc="[]", areacello="[area]", thresh="[]") @@ -2923,9 +2951,9 @@ def sea_ice_extent( "What is the difference between sea ice area and extent?" - :cite:cts:`nsidc_frequently_2008` """ t = convert_units_to(thresh, siconc) - out = xarray.dot(siconc >= t, areacello) - out.attrs["units"] = areacello.units - return out + sie = xarray.dot(siconc >= t, areacello) + sie = sie.assign_attrs(units=areacello.units) + return sie @declare_units(sfcWind="[speed]", thresh="[speed]") @@ -3102,9 +3130,11 @@ def _exceedance_date(grp): ) return xarray.where((cumsum <= sum_thresh).all("time"), never_reached_val, out) - out = c.clip(0).resample(time=freq).map(_exceedance_date) - out.attrs.update(units="", is_dayofyear=np.int32(1), calendar=get_calendar(tas)) - return out + dded = c.clip(0).resample(time=freq).map(_exceedance_date) + dded = dded.assign_attrs( + units="", is_dayofyear=np.int32(1), calendar=get_calendar(tas) + ) + return dded @declare_units(pr="[precipitation]", thresh="[length]") diff --git a/xclim/indices/fire/_cffwis.py b/xclim/indices/fire/_cffwis.py index 8012e7ebc..7d1e99db5 100644 --- a/xclim/indices/fire/_cffwis.py +++ b/xclim/indices/fire/_cffwis.py @@ -158,9 +158,9 @@ "overwintering_drought_code", ] -default_params = dict( - temp_start_thresh=(12, "degC"), - temp_end_thresh=(5, "degC"), +default_params: dict[str, int | float | tuple[float, str]] = dict( + temp_start_thresh=(12.0, "degC"), + temp_end_thresh=(5.0, "degC"), snow_thresh=(0.01, "m"), temp_condition_days=3, snow_condition_days=3, @@ -388,9 +388,13 @@ def _duff_moisture_code( @vectorize(nopython=True) -def _drought_code( - t: np.ndarray, p: np.ndarray, mth: np.ndarray, lat: float, dc0: float -) -> np.ndarray: # pragma: no cover +def _drought_code( # pragma: no cover + t: np.ndarray, + p: np.ndarray, + mth: np.ndarray, + lat: float, + dc0: float, +) -> np.ndarray: """Compute the drought code over one time step. Parameters @@ -411,10 +415,10 @@ def _drought_code( array_like Drought code at the current timestep """ - fl = _day_length_factor(lat, mth) + fl = _day_length_factor(lat, mth) # type: ignore if t < -2.8: - t = -2.8 + t = -2.8 # type: ignore pe = (0.36 * (t + 2.8) + fl) / 2 # *Eq.22*# pe = max(pe, 0.0) @@ -431,7 +435,7 @@ def _drought_code( dc = pe else: # f p <= 2.8: dc = dc0 + pe - return dc + return dc # type: ignore def initial_spread_index(ws: np.ndarray, ffmc: np.ndarray) -> np.ndarray: @@ -451,7 +455,7 @@ def initial_spread_index(ws: np.ndarray, ffmc: np.ndarray) -> np.ndarray: """ mo = 147.2 * (101.0 - ffmc) / (59.5 + ffmc) # *Eq.1*# ff = 19.1152 * np.exp(mo * -0.1386) * (1.0 + (mo**5.31) / 49300000.0) # *Eq.25*# - isi = ff * np.exp(0.05039 * ws) # *Eq.26*# + isi: np.ndarray = ff * np.exp(0.05039 * ws) # *Eq.26*# return isi @@ -503,7 +507,7 @@ def fire_weather_index(isi, bui): return fwi -def daily_severity_rating(fwi: np.ndarray) -> np.ndarry: +def daily_severity_rating(fwi: np.ndarray) -> np.ndarray: """Daily severity rating. Parameters @@ -548,6 +552,7 @@ def _overwintering_drought_code(DCf, wpr, a, b, minDC): # pragma: no cover # SECTION 2 : Iterators +# FIXME: default_params should be supplied within the logic of the function. def _fire_season( tas: np.ndarray, snd: np.ndarray | None = None, @@ -1056,15 +1061,16 @@ def fire_weather_ufunc( # noqa: C901 ) # Arg order : tas, pr, hurs, sfcWind, snd, mth, lat, season_mask, dc0, dmc0, ffmc0, winter_pr # 0 1 2 3 4 5 6 7 8 9 10 11 - args = [None] * 12 - input_core_dims = [[]] * 12 + args: list[xr.DataArray | None] = [None] * 12 + input_core_dims: list[list[str | None]] = [[]] * 12 # Verification of all arguments for i, (arg, name, usedby, has_time_dim) in enumerate(needed_args): if any([ind in indexes + [season_method] for ind in usedby]): if arg is None: raise TypeError( - f"Missing input argument {name} for index combination {indexes} with fire season method '{season_method}'" + f"Missing input argument {name} for index combination {indexes} " + f"with fire season method '{season_method}'." ) args[i] = arg input_core_dims[i] = ["time"] if has_time_dim else [] @@ -1078,17 +1084,14 @@ def fire_weather_ufunc( # noqa: C901 raise ValueError("'dry_start' must be one of None, 'CFS' or 'GFWED'.") # Always pass the previous codes. - if dc0 is None: - dc0 = xr.full_like(tas.isel(time=0), np.nan) - if dmc0 is None: - dmc0 = xr.full_like(tas.isel(time=0), np.nan) - if ffmc0 is None: - ffmc0 = xr.full_like(tas.isel(time=0), np.nan) - args[8:11] = [dc0, dmc0, ffmc0] + _dc0 = xr.full_like(tas.isel(time=0), np.nan) if dc0 is None else dc0 + _dmc0 = xr.full_like(tas.isel(time=0), np.nan) if dmc0 is None else dmc0 + _ffmc0 = xr.full_like(tas.isel(time=0), np.nan) if ffmc0 is None else ffmc0 + args[8:11] = [_dc0, _dmc0, _ffmc0] # Output config from the current indexes list outputs = indexes - output_dtypes = [tas.dtype] * len(indexes) + output_dtypes: list[np.dtype] = [tas.dtype] * len(indexes) output_core_dims = len(indexes) * [("time",)] if season_mask is not None: diff --git a/xclim/indices/fire/_ffdi.py b/xclim/indices/fire/_ffdi.py index b986fe2e3..847652f31 100644 --- a/xclim/indices/fire/_ffdi.py +++ b/xclim/indices/fire/_ffdi.py @@ -239,7 +239,7 @@ def keetch_byram_drought_index( :cite:cts:`ffdi-keetch_1968,ffdi-finkele_2006,ffdi-holgate_2017,ffdi-dolling_2005` """ - def _keetch_byram_drought_index_pass(pr, tasmax, pr_annual, kbdi0): + def _keetch_byram_drought_index_pass(_pr, _tasmax, _pr_annual, _kbdi0): """Pass inputs on to guvectorized function `_keetch_byram_drought_index`. This function is actually only required as `xr.apply_ufunc` will not receive @@ -249,7 +249,7 @@ def _keetch_byram_drought_index_pass(pr, tasmax, pr_annual, kbdi0): -------- DO NOT CALL DIRECTLY, use `keetch_byram_drought_index` instead. """ - return _keetch_byram_drought_index(pr, tasmax, pr_annual, kbdi0) + return _keetch_byram_drought_index(_pr, _tasmax, _pr_annual, _kbdi0) pr = convert_units_to(pr, "mm/day", context="hydro") tasmax = convert_units_to(tasmax, "C") @@ -259,7 +259,7 @@ def _keetch_byram_drought_index_pass(pr, tasmax, pr_annual, kbdi0): else: kbdi0 = xr.full_like(pr.isel(time=0), 0) - kbdi = xr.apply_ufunc( + kbdi: xr.DataArray = xr.apply_ufunc( _keetch_byram_drought_index_pass, pr, tasmax, @@ -270,7 +270,7 @@ def _keetch_byram_drought_index_pass(pr, tasmax, pr_annual, kbdi0): dask="parallelized", output_dtypes=[pr.dtype], ) - kbdi.attrs["units"] = "mm/day" + kbdi = kbdi.assign_attrs(units="mm/day") return kbdi @@ -317,7 +317,7 @@ def griffiths_drought_factor( :cite:cts:`ffdi-griffiths_1999,ffdi-finkele_2006,ffdi-holgate_2017` """ - def _griffiths_drought_factor_pass(pr, smd, lim): + def _griffiths_drought_factor_pass(_pr, _smd, _lim): """Pass inputs on to guvectorized function `_griffiths_drought_factor`. This function is actually only required as xr.apply_ufunc will not receive @@ -327,7 +327,7 @@ def _griffiths_drought_factor_pass(pr, smd, lim): -------- DO NOT CALL DIRECTLY, use `griffiths_drought_factor` instead. """ - return _griffiths_drought_factor(pr, smd, lim) + return _griffiths_drought_factor(_pr, _smd, _lim) pr = convert_units_to(pr, "mm/day", context="hydro") smd = convert_units_to(smd, "mm/day") @@ -339,17 +339,17 @@ def _griffiths_drought_factor_pass(pr, smd, lim): else: raise ValueError(f"{limiting_func} is not a valid input for `limiting_func`") - df = xr.apply_ufunc( + df: xr.DataArray = xr.apply_ufunc( _griffiths_drought_factor_pass, pr, smd, - kwargs=dict(lim=lim), + kwargs=dict(_lim=lim), input_core_dims=[["time"], ["time"]], output_core_dims=[["time"]], dask="parallelized", output_dtypes=[pr.dtype], ) - df.attrs["units"] = "" + df = df.assign_attrs(units="") # First non-zero entry is at the 19th time point since df is calculated # from a 20-day rolling window. Make prior points NaNs. diff --git a/xclim/indices/generic.py b/xclim/indices/generic.py index 07b109bc2..76d987024 100644 --- a/xclim/indices/generic.py +++ b/xclim/indices/generic.py @@ -13,6 +13,7 @@ import cftime import numpy as np +import xarray import xarray as xr from xarray.coding.cftime_offsets import _MONTH_ABBREVIATIONS # noqa @@ -943,7 +944,7 @@ def first_day_threshold_reached( cond = compare(data, op, threshold, constrain=constrain) - out = cond.resample(time=freq).map( + out: xarray.DataArray = cond.resample(time=freq).map( rl.first_run_after_date, window=window, date=after_date, @@ -1045,8 +1046,8 @@ def get_zones( else: bins = convert_units_to(bins, da) - def _get_zone(da): - return np.digitize(da, bins) - 1 + def _get_zone(_da): + return np.digitize(_da, bins) - 1 zones = xr.apply_ufunc(_get_zone, da, dask="parallelized") @@ -1060,7 +1061,9 @@ def _get_zone(da): return zones -def detrend(ds, dim="time", deg=1): +def detrend( + ds: xr.DataArray | xr.Dataset, dim="time", deg=1 +) -> xr.DataArray | xr.Dataset: """Detrend data along a given dimension computing a polynomial trend of a given order. Parameters @@ -1074,7 +1077,7 @@ def detrend(ds, dim="time", deg=1): Returns ------- - detrended : xr.Dataset or xr.DataArray + xr.Dataset or xr.DataArray Same as `ds`, but with its trend removed (subtracted). """ if isinstance(ds, xr.Dataset): diff --git a/xclim/indices/helpers.py b/xclim/indices/helpers.py index 64d5716f5..454819052 100644 --- a/xclim/indices/helpers.py +++ b/xclim/indices/helpers.py @@ -7,7 +7,9 @@ from __future__ import annotations +from collections.abc import Mapping from inspect import stack +from typing import Any import cf_xarray # noqa: F401, pylint: disable=unused-import import cftime @@ -363,8 +365,8 @@ def extraterrestrial_solar_radiation( times: xr.DataArray, lat: xr.DataArray, solar_constant: Quantified = "1361 W m-2", - method="spencer", - chunks: dict[str, int] | None = None, + method: str = "spencer", + chunks: Mapping[Any, tuple] | None = None, ) -> xr.DataArray: """Extraterrestrial solar radiation. @@ -383,7 +385,7 @@ def extraterrestrial_solar_radiation( method : {'spencer', 'simple'} Which method to use when computing the solar declination and the eccentricity correction factor. See :py:func:`solar_declination` and :py:func:`eccentricity_correction_factor`. - chunks : dictionary + chunks : dict When `times` and `lat` originate from coordinates of a large chunked dataset, passing the dataset's chunks here will ensure the computation is chunked as well. diff --git a/xclim/indices/run_length.py b/xclim/indices/run_length.py index 4c24d19da..b3015230a 100644 --- a/xclim/indices/run_length.py +++ b/xclim/indices/run_length.py @@ -83,7 +83,7 @@ def resample_and_rl( freq: str, dim: str = "time", **kwargs, -) -> xr.DataArray | xr.Dataset: +) -> xr.DataArray: """Wrap run length algorithms to control if resampling occurs before or after the algorithms. Parameters diff --git a/xclim/indices/stats.py b/xclim/indices/stats.py index 9b72e34e6..598932f11 100644 --- a/xclim/indices/stats.py +++ b/xclim/indices/stats.py @@ -514,7 +514,7 @@ def _fit_start(x, dist: str, **fitkwargs: Any) -> tuple[tuple, dict]: return (), {} -def _dist_method_1D( +def _dist_method_1D( # noqa: N802 *args, dist: str | scipy.stats.rv_continuous, function: str, **kwargs: Any ) -> xr.DataArray: r"""Statistical function for given argument on given distribution initialized with params. diff --git a/xclim/sdba/_adjustment.py b/xclim/sdba/_adjustment.py index 37c37be44..31b9cd874 100644 --- a/xclim/sdba/_adjustment.py +++ b/xclim/sdba/_adjustment.py @@ -21,7 +21,7 @@ from .processing import escore -def _adapt_freq_hist(ds, adapt_freq_thresh): +def _adapt_freq_hist(ds: xr.Dataset, adapt_freq_thresh: str): """Adapt frequency of null values of `hist` in order to match `ref`.""" with units.context(infer_context(ds.ref.attrs.get("standard_name"))): thresh = convert_units_to(adapt_freq_thresh, ds.ref) @@ -36,7 +36,14 @@ def _adapt_freq_hist(ds, adapt_freq_thresh): hist_q=[Grouper.PROP, "quantiles"], scaling=[Grouper.PROP], ) -def dqm_train(ds, *, dim, kind, quantiles, adapt_freq_thresh) -> xr.Dataset: +def dqm_train( + ds: xr.Dataset, + *, + dim: str, + kind: str, + quantiles: np.ndarray, + adapt_freq_thresh: str | None = None, +) -> xr.Dataset: """Train step on one group. Notes @@ -45,9 +52,24 @@ def dqm_train(ds, *, dim, kind, quantiles, adapt_freq_thresh) -> xr.Dataset: ref : training target hist : training data + Parameters + ---------- + ds : xr.Dataset + The dataset containing the training data. + dim : str + The dimension along which to compute the quantiles. + kind : str + The kind of correction to compute. See :py:func:`xclim.sdba.utils.get_correction`. + quantiles : array-like + The quantiles to compute. adapt_freq_thresh : str | None Threshold for frequency adaptation. See :py:class:`xclim.sdba.processing.adapt_freq` for details. Default is None, meaning that frequency adaptation is not performed. + + Returns + ------- + xr.Dataset + The dataset containing the adjustment factors, the quantiles over the training data, and the scaling factor. """ hist = _adapt_freq_hist(ds, adapt_freq_thresh) if adapt_freq_thresh else ds.hist @@ -69,16 +91,40 @@ def dqm_train(ds, *, dim, kind, quantiles, adapt_freq_thresh) -> xr.Dataset: af=[Grouper.PROP, "quantiles"], hist_q=[Grouper.PROP, "quantiles"], ) -def eqm_train(ds, *, dim, kind, quantiles, adapt_freq_thresh) -> xr.Dataset: +def eqm_train( + ds: xr.Dataset, + *, + dim: str, + kind: str, + quantiles: np.ndarray, + adapt_freq_thresh: str | None = None, +) -> xr.Dataset: """EQM: Train step on one group. + Notes + ----- Dataset variables: ref : training target hist : training data + Parameters + ---------- + ds : xr.Dataset + The dataset containing the training data. + dim : str + The dimension along which to compute the quantiles. + kind : str + The kind of correction to compute. See :py:func:`xclim.sdba.utils.get_correction`. + quantiles : array-like + The quantiles to compute. adapt_freq_thresh : str | None Threshold for frequency adaptation. See :py:class:`xclim.sdba.processing.adapt_freq` for details. Default is None, meaning that frequency adaptation is not performed. + + Returns + ------- + xr.Dataset + The dataset containing the adjustment factors and the quantiles over the training data. """ hist = _adapt_freq_hist(ds, adapt_freq_thresh) if adapt_freq_thresh else ds.hist ref_q = nbu.quantile(ds.ref, quantiles, dim) @@ -90,13 +136,35 @@ def eqm_train(ds, *, dim, kind, quantiles, adapt_freq_thresh) -> xr.Dataset: @map_blocks(reduces=[Grouper.PROP, "quantiles"], scen=[]) -def qm_adjust(ds, *, group, interp, extrapolation, kind) -> xr.Dataset: +def qm_adjust( + ds: xr.Dataset, *, group: Grouper, interp: str, extrapolation: str, kind: str +) -> xr.Dataset: """QM (DQM and EQM): Adjust step on one block. + Notes + ----- Dataset variables: af : Adjustment factors hist_q : Quantiles over the training data sim : Data to adjust. + + Parameters + ---------- + ds : xr.Dataset + The dataset containing the data to adjust. + group : Grouper + The grouper object. + interp : str + The interpolation method to use. + extrapolation : str + The extrapolation method to use. + kind : str + The kind of correction to compute. See :py:func:`xclim.sdba.utils.get_correction`. + + Returns + ------- + xr.Dataset + The adjusted data. """ af = u.interp_on_quantiles( ds.sim, @@ -107,19 +175,50 @@ def qm_adjust(ds, *, group, interp, extrapolation, kind) -> xr.Dataset: extrapolation=extrapolation, ) - scen = u.apply_correction(ds.sim, af, kind).rename("scen") - return scen.to_dataset() + scen: xr.DataArray = u.apply_correction(ds.sim, af, kind).rename("scen") + out = scen.to_dataset() + return out @map_blocks(reduces=[Grouper.PROP, "quantiles"], scen=[], trend=[]) -def dqm_adjust(ds, *, group, interp, kind, extrapolation, detrend): +def dqm_adjust( + ds: xr.Dataset, + *, + group: Grouper, + interp: str, + kind: str, + extrapolation: str, + detrend: int | PolyDetrend, +) -> xr.Dataset: """DQM adjustment on one block. + Notes + ----- Dataset variables: scaling : Scaling factor between ref and hist af : Adjustment factors hist_q : Quantiles over the training data sim : Data to adjust + + Parameters + ---------- + ds : xr.Dataset + The dataset containing the data to adjust. + group : Grouper + The grouper object. + interp : str + The interpolation method to use. + kind : str + The kind of correction to compute. See :py:func:`xclim.sdba.utils.get_correction`. + extrapolation : str + The extrapolation method to use. + detrend : int | PolyDetrend + The degree of the polynomial detrending to apply. If 0, no detrending is applied. + + Returns + ------- + xr.Dataset + The adjusted data and the trend. """ scaled_sim = u.apply_correction( ds.sim, @@ -133,10 +232,12 @@ def dqm_adjust(ds, *, group, interp, kind, extrapolation, detrend): ) if isinstance(detrend, int): - detrend = PolyDetrend(degree=detrend, kind=kind, group=group) + detrending = PolyDetrend(degree=detrend, kind=kind, group=group) + else: + detrending = detrend - detrend = detrend.fit(scaled_sim) - ds["sim"] = detrend.detrend(scaled_sim) + detrending = detrending.fit(scaled_sim) + ds["sim"] = detrending.detrend(scaled_sim) scen = qm_adjust.func( ds, group=group, @@ -144,16 +245,18 @@ def dqm_adjust(ds, *, group, interp, kind, extrapolation, detrend): extrapolation=extrapolation, kind=kind, ).scen - scen = detrend.retrend(scen) + scen = detrending.retrend(scen) - out = xr.Dataset({"scen": scen, "trend": detrend.ds.trend}) + out = xr.Dataset({"scen": scen, "trend": detrending.ds.trend}) return out @map_blocks(reduces=[Grouper.PROP, "quantiles"], scen=[], sim_q=[]) -def qdm_adjust(ds, *, group, interp, extrapolation, kind) -> xr.Dataset: +def qdm_adjust(ds: xr.Dataset, *, group, interp, extrapolation, kind) -> xr.Dataset: """QDM: Adjust process on one block. + Notes + ----- Dataset variables: af : Adjustment factors hist_q : Quantiles over the training data @@ -177,9 +280,11 @@ def qdm_adjust(ds, *, group, interp, extrapolation, kind) -> xr.Dataset: af=[Grouper.PROP], hist_thresh=[Grouper.PROP], ) -def loci_train(ds, *, group, thresh) -> xr.Dataset: +def loci_train(ds: xr.Dataset, *, group, thresh) -> xr.Dataset: """LOCI: Train on one block. + Notes + ----- Dataset variables: ref : training target hist : training data @@ -200,9 +305,11 @@ def loci_train(ds, *, group, thresh) -> xr.Dataset: @map_blocks(reduces=[Grouper.PROP], scen=[]) -def loci_adjust(ds, *, group, thresh, interp) -> xr.Dataset: +def loci_adjust(ds: xr.Dataset, *, group, thresh, interp) -> xr.Dataset: """LOCI: Adjust on one block. + Notes + ----- Dataset variables: hist_thresh : Hist's equivalent thresh from ref sim : Data to adjust @@ -210,35 +317,44 @@ def loci_adjust(ds, *, group, thresh, interp) -> xr.Dataset: sth = u.broadcast(ds.hist_thresh, ds.sim, group=group, interp=interp) factor = u.broadcast(ds.af, ds.sim, group=group, interp=interp) with xr.set_options(keep_attrs=True): - scen = (factor * (ds.sim - sth) + thresh).clip(min=0) - return scen.rename("scen").to_dataset() + scen: xr.DataArray = ( + (factor * (ds.sim - sth) + thresh).clip(min=0).rename("scen") + ) + out = scen.to_dataset() + return out @map_groups(af=[Grouper.PROP]) -def scaling_train(ds, *, dim, kind) -> xr.Dataset: +def scaling_train(ds: xr.Dataset, *, dim, kind) -> xr.Dataset: """Scaling: Train on one group. + Notes + ----- Dataset variables: ref : training target hist : training data """ mhist = ds.hist.mean(dim) mref = ds.ref.mean(dim) - af = u.get_correction(mhist, mref, kind) - return af.rename("af").to_dataset() + af: xr.DataArray = u.get_correction(mhist, mref, kind).rename("af") + out = af.to_dataset() + return out @map_blocks(reduces=[Grouper.PROP], scen=[]) -def scaling_adjust(ds, *, group, interp, kind) -> xr.Dataset: +def scaling_adjust(ds: xr.Dataset, *, group, interp, kind) -> xr.Dataset: """Scaling: Adjust on one block. + Notes + ----- Dataset variables: - af: Adjustment factors. + af : Adjustment factors. sim : Data to adjust. """ af = u.broadcast(ds.af, ds.sim, group=group, interp=interp) - scen = u.apply_correction(ds.sim, af, kind) - return scen.rename("scen").to_dataset() + scen: xr.DataArray = u.apply_correction(ds.sim, af, kind).rename("scen") + out = scen.to_dataset() + return out def npdf_transform(ds: xr.Dataset, **kwargs) -> xr.Dataset: @@ -373,8 +489,37 @@ def _extremes_train_1d(ref, hist, ref_params, *, q_thresh, cluster_thresh, dist, @map_blocks( reduces=["time"], px_hist=["quantiles"], af=["quantiles"], thresh=[Grouper.PROP] ) -def extremes_train(ds, *, group, q_thresh, cluster_thresh, dist, quantiles): - """Train extremes for a given variable series.""" +def extremes_train( + ds: xr.Dataset, + *, + group: Grouper, + q_thresh: float, + cluster_thresh: float, + dist, + quantiles: np.ndarray, +) -> xr.Dataset: + """Train extremes for a given variable series. + + Parameters + ---------- + ds : xr.Dataset + Dataset containing the reference and historical data. + group : Grouper + The grouper object. + q_thresh : float + The quantile threshold to use. + cluster_thresh : float + The threshold for clustering. + dist : Any + The distribution to fit. + quantiles : array-like + The quantiles to compute. + + Returns + ------- + xr.Dataset + The dataset containing the quantiles, the adjustment factors, and the threshold. + """ px_hist, af, thresh = xr.apply_ufunc( _extremes_train_1d, ds.ref, @@ -408,9 +553,42 @@ def _fit_cluster_and_cdf(data, thresh, dist, cluster_thresh): @map_blocks(reduces=["quantiles", Grouper.PROP], scen=[]) def extremes_adjust( - ds, *, group, frac, power, dist, interp, extrapolation, cluster_thresh -): - """Adjust extremes to reflect many distribution factors.""" + ds: xr.Dataset, + *, + group: Grouper, + frac: float, + power: float, + dist, + interp: str, + extrapolation: str, + cluster_thresh: float, +) -> xr.Dataset: + """Adjust extremes to reflect many distribution factors. + + Parameters + ---------- + ds : xr.Dataset + Dataset containing the reference and historical data. + group : Grouper + The grouper object. + frac : float + The fraction of the transition function. + power : float + The power of the transition function. + dist : Any + The distribution to fit. + interp : str + The interpolation method to use. + extrapolation : str + The extrapolation method to use. + cluster_thresh : float + The threshold for clustering. + + Returns + ------- + xr.Dataset + The dataset containing the adjusted data. + """ # Find probabilities of extremes of fut according to its own cluster-fitted dist. px_fut = xr.apply_ufunc( _fit_cluster_and_cdf, @@ -434,5 +612,6 @@ def extremes_adjust( ) ** power transition = transition.clip(0, 1) - out = (transition * scen) + ((1 - transition) * ds.scen) - return out.rename("scen").squeeze("group", drop=True).to_dataset() + adjusted: xr.DataArray = (transition * scen) + ((1 - transition) * ds.scen) + out = adjusted.rename("scen").squeeze("group", drop=True).to_dataset() + return out diff --git a/xclim/sdba/_processing.py b/xclim/sdba/_processing.py index c6cf7f486..cd2566b59 100644 --- a/xclim/sdba/_processing.py +++ b/xclim/sdba/_processing.py @@ -143,7 +143,7 @@ def _normalize( @map_groups(reordered=[Grouper.DIM], main_only=False) -def _reordering(ds, *, dim): +def _reordering(ds: xr.Dataset, *, dim: str) -> xr.Dataset: """Group-wise reordering. Parameters @@ -154,6 +154,11 @@ def _reordering(ds, *, dim): - ref : The timeseries whose rank to use. dim : str The dimension along which to reorder. + + Returns + ------- + xr.Dataset + The reordered timeseries. """ def _reordering_1d(data, ordr): diff --git a/xclim/sdba/adjustment.py b/xclim/sdba/adjustment.py index 8436d278b..113f8acb0 100644 --- a/xclim/sdba/adjustment.py +++ b/xclim/sdba/adjustment.py @@ -161,8 +161,10 @@ class TrainAdjust(BaseAdjustment): _repr_hide_params = ["hist_calendar", "train_units"] @classmethod - def train(cls, ref: DataArray, hist: DataArray, **kwargs): - """Train the adjustment object. Refer to the class documentation for the algorithm details. + def train(cls, ref: DataArray, hist: DataArray, **kwargs) -> TrainAdjust: + r"""Train the adjustment object. + + Refer to the class documentation for the algorithm details. Parameters ---------- @@ -170,6 +172,8 @@ def train(cls, ref: DataArray, hist: DataArray, **kwargs): Training target, usually a reference time series drawn from observations. hist : DataArray Training data, usually a model output whose biases are to be adjusted. + \*\*kwargs + Algorithm-specific keyword arguments, see class doc. """ kwargs = parse_group(cls._train, kwargs) skip_checks = kwargs.pop("skip_input_checks", False) @@ -195,7 +199,9 @@ def train(cls, ref: DataArray, hist: DataArray, **kwargs): return obj def adjust(self, sim: DataArray, *args, **kwargs): - """Return bias-adjusted data. Refer to the class documentation for the algorithm details. + r"""Return bias-adjusted data. + + Refer to the class documentation for the algorithm details. Parameters ---------- @@ -203,7 +209,7 @@ def adjust(self, sim: DataArray, *args, **kwargs): Time series to be bias-adjusted, usually a model output. args : xr.DataArray Other DataArrays needed for the adjustment (usually none). - kwargs + \*\*kwargs Algorithm-specific keyword arguments, see class doc. """ skip_checks = kwargs.pop("skip_input_checks", False) @@ -246,10 +252,10 @@ def set_dataset(self, ds: xr.Dataset): @classmethod def _train(cls, ref: DataArray, hist: DataArray, *kwargs): - raise NotImplementedError + raise NotImplementedError() def _adjust(self, sim, **kwargs): - raise NotImplementedError + raise NotImplementedError() class Adjust(BaseAdjustment): @@ -266,7 +272,7 @@ def adjust( hist: xr.DataArray, sim: xr.DataArray, **kwargs, - ): + ) -> xr.Dataset: r"""Return bias-adjusted data. Refer to the class documentation for the algorithm details. Parameters @@ -279,6 +285,11 @@ def adjust( Time series to be bias-adjusted, usually a model output. \*\*kwargs Algorithm-specific keyword arguments, see class doc. + + Returns + ------- + xr.Dataset + The bias-adjusted Dataset. """ kwargs = parse_group(cls._adjust, kwargs) skip_checks = kwargs.pop("skip_input_checks", False) @@ -289,7 +300,7 @@ def adjust( (ref, hist, sim), _ = cls._harmonize_units(ref, hist, sim) - out = cls._adjust(ref, hist, sim, **kwargs) + out: xr.Dataset | xr.DataArray = cls._adjust(ref, hist, sim, **kwargs) if isinstance(out, xr.DataArray): out = out.rename("scen").to_dataset() @@ -359,7 +370,7 @@ def _train( kind: str = ADDITIVE, group: str | Grouper = "time", adapt_freq_thresh: str | None = None, - ): + ) -> tuple[xr.Dataset, dict[str, Any]]: if np.isscalar(nquantiles): quantiles = equally_spaced_nodes(nquantiles).astype(ref.dtype) else: @@ -1127,7 +1138,7 @@ def _adjust( pts_dim: str = "multivar", adj_kws: dict[str, Any] | None = None, rot_matrices: xr.DataArray | None = None, - ): + ) -> xr.Dataset: if base_kws is None: base_kws = {} if "kind" in base_kws: @@ -1291,7 +1302,7 @@ def _parse(s): ] ) - def _generate_SBCK_classes(): + def _generate_SBCK_classes(): # noqa: N802 classes = [] for clsname in dir(SBCK): cls = getattr(SBCK, clsname) diff --git a/xclim/sdba/base.py b/xclim/sdba/base.py index 6590601ba..af7652f70 100644 --- a/xclim/sdba/base.py +++ b/xclim/sdba/base.py @@ -457,21 +457,21 @@ def parse_group(func: Callable, kwargs=None, allow_only=None) -> Callable: else: default_group = None - def _update_kwargs(kwargs, allowed=None): - if default_group or "group" in kwargs: - kwargs.setdefault("group", default_group) - if not isinstance(kwargs["group"], Grouper): - kwargs = Grouper.from_kwargs(**kwargs) + def _update_kwargs(_kwargs, allowed=None): + if default_group or "group" in _kwargs: + _kwargs.setdefault("group", default_group) + if not isinstance(_kwargs["group"], Grouper): + _kwargs = Grouper.from_kwargs(**_kwargs) if ( allowed is not None - and "group" in kwargs - and kwargs["group"].prop not in allowed + and "group" in _kwargs + and _kwargs["group"].prop not in allowed ): raise ValueError( - f"Grouping on {kwargs['group'].prop_name} is not allowed for this " + f"Grouping on {_kwargs['group'].prop_name} is not allowed for this " f"function. Should be one of {allowed}." ) - return kwargs + return _kwargs if kwargs is not None: # Not used as a decorator return _update_kwargs(kwargs, allowed=allow_only) diff --git a/xclim/sdba/measures.py b/xclim/sdba/measures.py index 2152bae20..e6957f434 100644 --- a/xclim/sdba/measures.py +++ b/xclim/sdba/measures.py @@ -285,17 +285,19 @@ def _rmse( Root mean square error """ - def _rmse(sim, ref): - return np.sqrt(np.mean((sim - ref) ** 2, axis=-1)) + def _rmse_internal(_sim: xr.DataArray, _ref: xr.DataArray) -> xr.DataArray: + _f: xr.DataArray = np.sqrt(np.mean((_sim - _ref) ** 2, axis=-1)) + return _f out = xr.apply_ufunc( - _rmse, + _rmse_internal, sim, ref, input_core_dims=[["time"], ["time"]], dask="parallelized", ) - return out.assign_attrs(units=ensure_delta(ref.units)) + out = out.assign_attrs(units=ensure_delta(ref.units)) + return out rmse = StatisticalPropertyMeasure( @@ -330,17 +332,19 @@ def _mae( Mean absolute error """ - def _mae(sim, ref): - return np.mean(np.abs(sim - ref), axis=-1) + def _mae_internal(_sim: xr.DataArray, _ref: xr.DataArray) -> xr.DataArray: + _f: xr.DataArray = np.mean(np.abs(_sim - _ref), axis=-1) + return _f out = xr.apply_ufunc( - _mae, + _mae_internal, sim, ref, input_core_dims=[["time"], ["time"]], dask="parallelized", ) - return out.assign_attrs(units=ensure_delta(ref.units)) + out = out.assign_attrs(units=ensure_delta(ref.units)) + return out mae = StatisticalPropertyMeasure( diff --git a/xclim/sdba/nbutils.py b/xclim/sdba/nbutils.py index 9fd245d20..aa87fd0a3 100644 --- a/xclim/sdba/nbutils.py +++ b/xclim/sdba/nbutils.py @@ -5,6 +5,8 @@ """ from __future__ import annotations +from collections.abc import Hashable, Sequence + import numpy as np from numba import boolean, float32, float64, guvectorize, njit from xarray import DataArray @@ -24,10 +26,26 @@ def _vecquantiles(arr, rnk, res): res[0] = np.nanquantile(arr, rnk) -def vecquantiles(da: DataArray, rnk: DataArray, dim: str | DataArray.dims) -> DataArray: +def vecquantiles( + da: DataArray, rnk: DataArray, dim: str | Sequence[Hashable] +) -> DataArray: """For when the quantile (rnk) is different for each point. da and rnk must share all dimensions but dim. + + Parameters + ---------- + da : xarray.DataArray + The data to compute the quantiles on. + rnk : xarray.DataArray + The quantiles to compute. + dim : str or sequence of str + The dimension along which to compute the quantiles. + + Returns + ------- + xarray.DataArray + The quantiles computed along the `dim` dimension. """ tem = utils.get_temp_dimname(da.dims, "temporal") dims = [dim] if isinstance(dim, str) else dim @@ -55,8 +73,23 @@ def _quantile(arr, q): return out -def quantile(da: DataArray, q, dim: str | DataArray.dims) -> DataArray: - """Compute the quantiles from a fixed list `q`.""" +def quantile(da: DataArray, q: np.ndarray, dim: str | Sequence[Hashable]) -> DataArray: + """Compute the quantiles from a fixed list `q`. + + Parameters + ---------- + da : xarray.DataArray + The data to compute the quantiles on. + q : array-like + The quantiles to compute. + dim : str or sequence of str + The dimension along which to compute the quantiles. + + Returns + ------- + xarray.DataArray + The quantiles computed along the `dim` dimension. + """ # We have two cases : # - When all dims are processed : we stack them and use _quantile1d # - When the quantiles are vectorized over some dims, these are also stacked and then _quantile2D is used. @@ -78,14 +111,14 @@ def quantile(da: DataArray, q, dim: str | DataArray.dims) -> DataArray: if len(da.dims) > 1: # There are some extra dims extra = utils.get_temp_dimname(da.dims, "extra") - da = da.stack({extra: set(da.dims) - {tem}}) + da = da.stack({extra: list(set(da.dims) - {tem})}) da = da.transpose(..., tem) res = DataArray( _quantile(da.values, qc), dims=(extra, "quantiles"), coords={extra: da[extra], "quantiles": q}, attrs=da.attrs, - ).unstack(extra) + ).unstack([extra]) else: # All dims are processed diff --git a/xclim/sdba/processing.py b/xclim/sdba/processing.py index b4e904829..359d95eee 100644 --- a/xclim/sdba/processing.py +++ b/xclim/sdba/processing.py @@ -5,6 +5,7 @@ """ from __future__ import annotations +import types from collections.abc import Sequence import dask.array as dsk @@ -138,7 +139,8 @@ def jitter_under_thresh(x: xr.DataArray, thresh: str) -> xr.DataArray: ----- If thresh is high, this will change the mean value of x. """ - return jitter(x, lower=thresh, upper=None, minimum=None, maximum=None) + j: xr.DataArray = jitter(x, lower=thresh, upper=None, minimum=None, maximum=None) + return j def jitter_over_thresh(x: xr.DataArray, thresh: str, upper_bnd: str) -> xr.DataArray: @@ -166,7 +168,10 @@ def jitter_over_thresh(x: xr.DataArray, thresh: str, upper_bnd: str) -> xr.DataA If thresh is low, this will change the mean value of x. """ - return jitter(x, lower=None, upper=thresh, minimum=None, maximum=upper_bnd) + j: xr.DataArray = jitter( + x, lower=None, upper=thresh, minimum=None, maximum=upper_bnd + ) + return j @update_xclim_history @@ -208,31 +213,41 @@ def jitter( The two noise distributions are independent. """ with units.context(infer_context(x.attrs.get("standard_name"))): - out = x + out: xr.DataArray = x notnull = x.notnull() if lower is not None: - lower = convert_units_to(lower, x) - minimum = convert_units_to(minimum, x) if minimum is not None else 0 - minimum = minimum + np.finfo(x.dtype).eps + jitter_lower = np.array(convert_units_to(lower, x)).astype(float) + jitter_min = np.array( + convert_units_to(minimum, x) if minimum is not None else 0 + ).astype(float) + jitter_min = jitter_min + np.finfo(x.dtype).eps if uses_dask(x): jitter_dist = dsk.random.uniform( - low=minimum, high=lower, size=x.shape, chunks=x.chunks + low=jitter_min, high=jitter_lower, size=x.shape, chunks=x.chunks ) else: - jitter_dist = np.random.uniform(low=minimum, high=lower, size=x.shape) - out = out.where(~((x < lower) & notnull), jitter_dist.astype(x.dtype)) + jitter_dist = np.random.uniform( + low=jitter_min, high=jitter_lower, size=x.shape + ) + out = out.where( + ~((x < jitter_lower) & notnull), jitter_dist.astype(x.dtype) + ) if upper is not None: if maximum is None: raise ValueError("If 'upper' is given, so must 'maximum'.") - upper = convert_units_to(upper, x) - maximum = convert_units_to(maximum, x) + jitter_upper = np.array(convert_units_to(upper, x)).astype(float) + jitter_max = np.array(convert_units_to(maximum, x)).astype(float) if uses_dask(x): jitter_dist = dsk.random.uniform( - low=upper, high=maximum, size=x.shape, chunks=x.chunks + low=jitter_upper, high=jitter_max, size=x.shape, chunks=x.chunks ) else: - jitter_dist = np.random.uniform(low=upper, high=maximum, size=x.shape) - out = out.where(~((x >= upper) & notnull), jitter_dist.astype(x.dtype)) + jitter_dist = np.random.uniform( + low=jitter_upper, high=jitter_max, size=x.shape + ) + out = out.where( + ~((x >= jitter_upper) & notnull), jitter_dist.astype(x.dtype) + ) copy_all_attrs(out, x) # copy attrs and same units return out @@ -291,6 +306,8 @@ def uniform_noise_like( Noise is uniformly distributed between low and high. Alternative method to `jitter_under_thresh` for avoiding zeroes. """ + mod: types.ModuleType + kw: dict if uses_dask(da): mod = dsk kw = {"chunks": da.chunks} @@ -367,7 +384,7 @@ def reordering(ref: xr.DataArray, sim: xr.DataArray, group: str = "time") -> xr. """ ds = xr.Dataset({"sim": sim, "ref": ref}) - out = _reordering(ds, group=group).reordered + out: xr.Dataset = _reordering(ds, group=group).reordered copy_all_attrs(out, sim) return out @@ -450,7 +467,7 @@ def escore( # Otherwise, apply_ufunc tries to align both obs_dim together. new_dim = get_temp_dimname(tgt.dims, obs_dim) sim = sim.rename({obs_dim: new_dim}) - out = xr.apply_ufunc( + out: xr.DataArray = xr.apply_ufunc( _escore, tgt, sim, @@ -460,10 +477,12 @@ def escore( ) out.name = "escores" - out.attrs.update( - long_name="Energy dissimilarity metric", - description=f"Escores computed from {N or 'all'} points.", - references="Székely, G. J. and Rizzo, M. L. (2004) Testing for Equal Distributions in High Dimension, InterStat, November (5)", + out = out.assign_attrs( + dict( + long_name="Energy dissimilarity metric", + description=f"Escores computed from {N or 'all'} points.", + references="Székely, G. J. and Rizzo, M. L. (2004) Testing for Equal Distributions in High Dimension, InterStat, November (5)", + ) ) return out @@ -559,27 +578,33 @@ def to_additive_space( """ with units.context(infer_context(data.attrs.get("standard_name"))): - lower_bound = convert_units_to(lower_bound, data) + lower_bound_array = np.array(convert_units_to(lower_bound, data)).astype(float) if upper_bound is not None: - upper_bound = convert_units_to(upper_bound, data) + upper_bound_array = np.array(convert_units_to(upper_bound, data)).astype( + float + ) + + from typing import cast with xr.set_options(keep_attrs=True), np.errstate(divide="ignore"): if trans == "log": - out = np.log(data - lower_bound) + out = cast(xr.DataArray, np.log(data - lower_bound_array)) elif trans == "logit": - data_prime = (data - lower_bound) / (upper_bound - lower_bound) - out = np.log(data_prime / (1 - data_prime)) + data_prime = (data - lower_bound_array) / ( + upper_bound_array - lower_bound_array + ) + out = cast(xr.DataArray, np.log(data_prime / (1 - data_prime))) else: raise NotImplementedError("`trans` must be one of 'log' or 'logit'.") # Attributes to remember all this. - out.attrs["sdba_transform"] = trans - out.attrs["sdba_transform_lower"] = lower_bound + out = out.assign_attrs(sdba_transform=trans) + out = out.assign_attrs(sdba_transform_lower=lower_bound_array) if upper_bound is not None: - out.attrs["sdba_transform_upper"] = upper_bound + out = out.assign_attrs(sdba_transform_upper=upper_bound_array) if "units" in out.attrs: - out.attrs["sdba_transform_units"] = out.attrs.pop("units") - out.attrs["units"] = "" + out = out.assign_attrs(sdba_transform_units=out.attrs.pop("units")) + out = out.assign_attrs(units="") return out @@ -656,9 +681,13 @@ def from_additive_space( try: trans = data.attrs["sdba_transform"] units = data.attrs["sdba_transform_units"] - lower_bound = data.attrs["sdba_transform_lower"] + lower_bound_array = np.array(data.attrs["sdba_transform_lower"]).astype( + float + ) if trans == "logit": - upper_bound = data.attrs["sdba_transform_upper"] + upper_bound_array = np.array(data.attrs["sdba_transform_upper"]).astype( + float + ) except KeyError as err: raise ValueError( f"Attribute {err!s} must be present on the input data " @@ -670,9 +699,12 @@ def from_additive_space( and units is not None and (upper_bound is not None or trans == "log") ): - lower_bound = convert_units_to(lower_bound, units) + # FIXME: convert_units_to is causing issues since it can't handle all variations of Quantified here + lower_bound_array = np.array(convert_units_to(lower_bound, units)).astype(float) if trans == "logit": - upper_bound = convert_units_to(upper_bound, units) + upper_bound_array = np.array(convert_units_to(upper_bound, units)).astype( + float + ) else: raise ValueError( "Parameters missing. Either all parameters are given as attributes of data, " @@ -681,10 +713,12 @@ def from_additive_space( with xr.set_options(keep_attrs=True): if trans == "log": - out = np.exp(data) + lower_bound + out = np.exp(data) + lower_bound_array elif trans == "logit": out_prime = 1 / (1 + np.exp(-data)) - out = out_prime * (upper_bound - lower_bound) + lower_bound + out = ( + out_prime * (upper_bound_array - lower_bound_array) + lower_bound_array + ) else: raise NotImplementedError("`trans` must be one of 'log' or 'logit'.") @@ -693,7 +727,7 @@ def from_additive_space( out.attrs.pop("sdba_transform_lower", None) out.attrs.pop("sdba_transform_upper", None) out.attrs.pop("sdba_transform_units", None) - out.attrs["units"] = units + out = out.assign_attrs(units=units) return out @@ -725,7 +759,7 @@ def stack_variables(ds: xr.Dataset, rechunk: bool = True, dim: str = "multivar") """ # Store original arrays' attributes - attrs = {} + attrs: dict = {} # sort to have coherent order with different datasets data_vars = sorted(ds.data_vars.items(), key=lambda e: e[0]) nvar = len(data_vars) @@ -748,7 +782,7 @@ def stack_variables(ds: xr.Dataset, rechunk: bool = True, dim: str = "multivar") return da.rename("multivariate") -def unstack_variables(da: xr.DataArray, dim: str | None = None): +def unstack_variables(da: xr.DataArray, dim: str | None = None) -> xr.Dataset: """Unstack a DataArray created by `stack_variables` to a dataset. Parameters @@ -765,8 +799,9 @@ def unstack_variables(da: xr.DataArray, dim: str | None = None): Dataset holding each variable in an individual DataArray. """ if dim is None: - for dim, crd in da.coords.items(): - if crd.attrs.get("is_variables"): + for _dim, _crd in da.coords.items(): + if _crd.attrs.get("is_variables"): + dim = str(_dim) break else: raise ValueError("No variable coordinate found, were attributes removed?") diff --git a/xclim/sdba/properties.py b/xclim/sdba/properties.py index 08ee15848..14d8fbf55 100644 --- a/xclim/sdba/properties.py +++ b/xclim/sdba/properties.py @@ -764,7 +764,12 @@ def _relative_frequency( """ # mask of the ocean with NaNs mask = ~(da.isel({group.dim: 0}).isnull()).drop_vars(group.dim) - ops = {">": np.greater, "<": np.less, ">=": np.greater_equal, "<=": np.less_equal} + ops: dict[str, np.ufunc] = { + ">": np.greater, + "<": np.less, + ">=": np.greater_equal, + "<=": np.less_equal, + } t = convert_units_to(thresh, da, context="infer") length = da.sizes[group.dim] cond = ops[op](da, t) @@ -1069,7 +1074,7 @@ def _decorrelation_length( thresh: float = 0.50, dims: Sequence[str] | None = None, bins: int = 100, - group: str = "time", + group: xr.Coordinate | str | None = "time", # FIXME: this needs to be clarified ): """Decorrelation length. @@ -1087,13 +1092,13 @@ def _decorrelation_length( Threshold correlation defining decorrelation. The decorrelation length is defined as the center of the distance bin that has a correlation closest to this threshold. - dims: sequence of strings + dims : sequence of strings Name of the spatial dimensions. Once these are stacked, the longitude and latitude coordinates must be 1D. bins : int Same as argument `bins` from :py:meth:`scipy.stats.binned_statistic`. If given as a scalar, the equal-width bin limits from 0 to radius are generated here (instead of letting scipy do it) to improve performance. - group : str + group : xarray.Coordinate or str, optional Useless for now. Returns @@ -1105,7 +1110,7 @@ def _decorrelation_length( ----- Calculating this property requires a lot of memory. It will not work with large datasets. """ - if dims is None: + if dims is None and group is not None: dims = [d for d in da.dims if d != group.dim] corr = _pairwise_spearman(da, dims) @@ -1121,15 +1126,19 @@ def _decorrelation_length( ) if np.isscalar(bins): - bins = np.linspace(0, radius, bins + 1) + bin_array = np.linspace(0, radius, bins + 1) + elif isinstance(bins, np.ndarray): + bin_array = bins + else: + raise ValueError("bins must be a scalar or a numpy array.") if uses_dask(corr): dists = dists.chunk() trans_dists = trans_dists.chunk() - w = np.diff(bins) + w = np.diff(bin_array) centers = xr.DataArray( - bins[:-1] + w / 2, + bin_array[:-1] + w / 2, dims=("distance_bins",), attrs={ "units": "km", @@ -1140,15 +1149,16 @@ def _decorrelation_length( # only keep points inside the radius ds = ds.where(ds.distance < radius) - ds = ds.where(ds.distance2 < radius) - def _bin_corr(corr, distance): + def _bin_corr(_corr, _distance): """Bin and mean.""" - mask_nan = ~np.isnan(corr) - return stats.binned_statistic( - distance[mask_nan], corr[mask_nan], statistic="mean", bins=bins - ).statistic + mask_nan = ~np.isnan(_corr) + binned_corr = stats.binned_statistic( + _distance[mask_nan], _corr[mask_nan], statistic="mean", bins=bin_array + ) + stat = binned_corr.statistic + return stat # (_spatial, _spatial2) -> (_spatial, distance_bins) binned = ( @@ -1163,7 +1173,7 @@ def _bin_corr(corr, distance): output_dtypes=[float], dask_gufunc_kwargs={ "allow_rechunk": True, - "output_sizes": {"distance_bins": len(bins)}, + "output_sizes": {"distance_bins": len(bin_array)}, }, ) .rename("corr") diff --git a/xclim/sdba/utils.py b/xclim/sdba/utils.py index 09a3129af..2815f5f42 100644 --- a/xclim/sdba/utils.py +++ b/xclim/sdba/utils.py @@ -155,12 +155,13 @@ def apply_correction( """ kind = kind or factor.get("kind", None) with xr.set_options(keep_attrs=True): + out: xr.DataArray if kind == ADDITIVE: out = x + factor elif kind == MULTIPLICATIVE: out = x * factor else: - raise ValueError + raise ValueError("kind must be `+` or `*`.") return out @@ -309,7 +310,7 @@ def add_cyclic_bounds( return ensure_chunk_size(qmf, **{att: -1}) -def _interp_on_quantiles_1D(newx, oldx, oldy, method, extrap): +def _interp_on_quantiles_1D(newx, oldx, oldy, method, extrap): # noqa: N802 mask_new = np.isnan(newx) mask_old = np.isnan(oldy) | np.isnan(oldx) out = np.full_like(newx, np.NaN, dtype=f"float{oldy.dtype.itemsize * 8}") diff --git a/xclim/testing/diagnostics.py b/xclim/testing/diagnostics.py index 8aede8e47..875e68d1d 100644 --- a/xclim/testing/diagnostics.py +++ b/xclim/testing/diagnostics.py @@ -8,6 +8,8 @@ """ from __future__ import annotations +import warnings + import numpy as np from scipy.stats import gaussian_kde, scoreatpercentile @@ -22,7 +24,7 @@ try: from matplotlib import pyplot as plt except ModuleNotFoundError: - plt = False + warnings.warn("Matplotlib not found, plot-generating functions will not work.") __all__ = ["adapt_freq_graph", "cannon_2015_figure_2", "synth_rainfall"] diff --git a/xclim/testing/helpers.py b/xclim/testing/helpers.py index abce4f21d..94d87fa36 100644 --- a/xclim/testing/helpers.py +++ b/xclim/testing/helpers.py @@ -74,7 +74,7 @@ ] -def generate_atmos(cache_dir: str | Path): +def generate_atmos(cache_dir: Path): """Create the `atmosds` synthetic testing dataset.""" with _open_dataset( "ERA5/daily_surface_cancities_1990-1993.nc", @@ -166,9 +166,11 @@ def populate_testing_data( return -def add_example_file_paths(cache_dir: Path) -> dict[str]: +def add_example_file_paths( + cache_dir: Path, +) -> dict[str, str | list[xr.DataArray]]: """Create a dictionary of relevant datasets to be patched into the xdoctest namespace.""" - ns = dict() + ns: dict = dict() ns["path_to_ensemble_file"] = "EnsembleReduce/TestEnsReduceCriteria.nc" ns["path_to_pr_file"] = "NRCANdaily/nrcan_canada_daily_pr_1990.nc" ns["path_to_sfcWind_file"] = "ERA5/daily_surface_cancities_1990-1993.nc" @@ -223,12 +225,12 @@ def add_example_file_paths(cache_dir: Path) -> dict[str]: def test_timeseries( values, variable, - start="2000-07-01", - units=None, - freq="D", - as_dataset=False, - cftime=False, -): + start: str = "2000-07-01", + units: str | None = None, + freq: str = "D", + as_dataset: bool = False, + cftime: bool = False, +) -> xr.DataArray | xr.Dataset: """Create a generic timeseries object based on pre-defined dictionaries of existing variables.""" if cftime: coords = xr.cftime_range(start, periods=len(values), freq=freq) diff --git a/xclim/testing/utils.py b/xclim/testing/utils.py index 3d3fe16d1..7e3ddc6af 100644 --- a/xclim/testing/utils.py +++ b/xclim/testing/utils.py @@ -80,7 +80,7 @@ def file_md5_checksum(f_name): def get_file( - name: str | os.PathLike | Sequence[str | os.PathLike], + name: str | os.PathLike[str] | Sequence[str | os.PathLike[str]], github_url: str = "https://github.com/Ouranosinc/xclim-testdata", branch: str = "main", cache_dir: Path = _default_cache_dir, @@ -91,7 +91,7 @@ def get_file( Parameters ---------- - name : str | os.PathLike | Sequence[str | os.PathLike] + name : str | os.PathLike[str] | Sequence[str | os.PathLike[str]] Name of the file or list/tuple of names of files containing the dataset(s) including suffixes. github_url : str URL to GitHub repository where the data is stored. @@ -104,10 +104,10 @@ def get_file( ------- Path | list[Path] """ - if isinstance(name, (str, Path)): + if isinstance(name, (str, os.PathLike)): name = [name] - files = list() + files = [] for n in name: fullname = Path(n) suffix = fullname.suffix @@ -291,7 +291,7 @@ def _get( # idea copied from raven that it borrowed from xclim that borrowed it from xarray that was borrowed from Seaborn def open_dataset( - name: str | os.PathLike, + name: str | os.PathLike[str], suffix: str | None = None, dap_url: str | None = None, github_url: str = "https://github.com/Ouranosinc/xclim-testdata", @@ -331,7 +331,7 @@ def open_dataset( -------- xarray.open_dataset """ - if isinstance(name, str): + if isinstance(name, (str, os.PathLike)): name = Path(name) if suffix is None: suffix = ".nc" @@ -538,13 +538,14 @@ def publish_release_notes( if isinstance(file, (Path, os.PathLike)): with Path(file).open("w") as f: print(changes, file=f) - return - print(changes, file=file) + else: + print(changes, file=file) + return None def show_versions( file: os.PathLike | StringIO | TextIO | None = None, - deps: list | None = None, + deps: list[str] | None = None, ) -> str | None: """Print the versions of xclim and its dependencies. @@ -552,19 +553,22 @@ def show_versions( ---------- file : {os.PathLike, StringIO, TextIO}, optional If provided, prints to the given file-like object. Otherwise, returns a string. - deps : list, optional + deps : list of str, optional A list of dependencies to gather and print version information from. Otherwise, prints `xclim` dependencies. Returns ------- str or None """ + dependencies: list[str] if deps is None: - deps = _xclim_deps + dependencies = _xclim_deps + else: + dependencies = deps - dependency_versions = [(d, lambda mod: mod.__version__) for d in deps] + dependency_versions = [(d, lambda mod: mod.__version__) for d in dependencies] - deps_blob = [] + deps_blob: list[tuple[str, str | None]] = [] for modname, ver_f in dependency_versions: try: if modname in sys.modules: @@ -597,5 +601,6 @@ def show_versions( if isinstance(file, (Path, os.PathLike)): with Path(file).open("w") as f: print(message, file=f) - return - print(message, file=file) + else: + print(message, file=file) + return None