diff --git a/.github/get_pypi_info.py b/.github/get_pypi_info.py index 03d2f1ab15..fd7a2c9238 100644 --- a/.github/get_pypi_info.py +++ b/.github/get_pypi_info.py @@ -18,12 +18,14 @@ def get_info(package_name: str = "") -> dict: ::return:: A dict with last_version, url and sha256 """ if package_name == "": - raise ValueError("Package name not provided.") + msg = "Package name not provided." + raise ValueError(msg) url = f"https://pypi.org/pypi/{package_name}/json" print(f"Calling {url}") # noqa: T201 resp = requests.get(url) if resp.status_code != 200: - raise Exception(f"ERROR calling PyPI ({url}) : {resp}") + msg = f"ERROR calling PyPI ({url}) : {resp}" + raise Exception(msg) resp = resp.json() version = resp["info"]["version"] @@ -38,19 +40,19 @@ def get_info(package_name: str = "") -> dict: return {} -def replace_in_file(filepath: str, info: dict): +def replace_in_file(filepath: str, info: dict) -> None: """Replace placeholder in meta.yaml by their values. ::filepath:: Path to meta.yaml, with filename. ::info:: Dict with information to populate. """ - with open(filepath, "rt", encoding="utf-8") as fin: + with open(filepath, encoding="utf-8") as fin: meta = fin.read() # Replace with info from PyPi meta = meta.replace("PYPI_VERSION", info["last_version"]) meta = meta.replace("PYPI_URL", info["url"]) meta = meta.replace("PYPI_SHA256", info["sha256"]) - with open(filepath, "wt", encoding="utf-8") as fout: + with open(filepath, "w", encoding="utf-8") as fout: fout.write(meta) print(f"File {filepath} has been updated with info from PyPi.") # noqa: T201 @@ -75,6 +77,7 @@ def replace_in_file(filepath: str, info: dict): args = parser.parse_args() info = get_info(args.package) print( # noqa: T201 - "Information of the last published PyPi package :", info["last_version"] + "Information of the last published PyPi package :", + info["last_version"], ) replace_in_file(args.filename, info) diff --git a/CHANGELOG.md b/CHANGELOG.md index ade8610231..002cf14714 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +### 41.5.6 [#1185](https://github.com/openfisca/openfisca-core/pull/1185) + +#### Technical changes + +- Remove pre Python 3.9 syntax. + ### 41.5.5 [#1220](https://github.com/openfisca/openfisca-core/pull/1220) #### Technical changes diff --git a/openfisca_core/commons/__init__.py b/openfisca_core/commons/__init__.py index b3b5d8cbb2..807abec778 100644 --- a/openfisca_core/commons/__init__.py +++ b/openfisca_core/commons/__init__.py @@ -52,9 +52,9 @@ # Official Public API -from .formulas import apply_thresholds, concat, switch # noqa: F401 -from .misc import empty_clone, stringify_array # noqa: F401 -from .rates import average_rate, marginal_rate # noqa: F401 +from .formulas import apply_thresholds, concat, switch +from .misc import empty_clone, stringify_array +from .rates import average_rate, marginal_rate __all__ = ["apply_thresholds", "concat", "switch"] __all__ = ["empty_clone", "stringify_array", *__all__] @@ -62,6 +62,6 @@ # Deprecated -from .dummy import Dummy # noqa: F401 +from .dummy import Dummy __all__ = ["Dummy", *__all__] diff --git a/openfisca_core/commons/dummy.py b/openfisca_core/commons/dummy.py index 3788e48705..5135a8f555 100644 --- a/openfisca_core/commons/dummy.py +++ b/openfisca_core/commons/dummy.py @@ -20,4 +20,3 @@ def __init__(self) -> None: "and will be removed in the future.", ] warnings.warn(" ".join(message), DeprecationWarning, stacklevel=2) - pass diff --git a/openfisca_core/commons/formulas.py b/openfisca_core/commons/formulas.py index bce9206938..909c4cd14a 100644 --- a/openfisca_core/commons/formulas.py +++ b/openfisca_core/commons/formulas.py @@ -7,10 +7,10 @@ def apply_thresholds( - input: t.Array[numpy.float_], + input: t.Array[numpy.float64], thresholds: t.ArrayLike[float], choices: t.ArrayLike[float], -) -> t.Array[numpy.float_]: +) -> t.Array[numpy.float64]: """Makes a choice based on an input and thresholds. From a list of ``choices``, this function selects one of these values @@ -38,7 +38,6 @@ def apply_thresholds( array([10, 10, 15, 15, 20]) """ - condlist: list[Union[t.Array[numpy.bool_], bool]] condlist = [input <= threshold for threshold in thresholds] @@ -47,12 +46,9 @@ def apply_thresholds( # must be true to return it. condlist += [True] - assert len(condlist) == len(choices), " ".join( - [ - "'apply_thresholds' must be called with the same number of", - "thresholds than choices, or one more choice.", - ] - ) + assert len(condlist) == len( + choices + ), "'apply_thresholds' must be called with the same number of thresholds than choices, or one more choice." return numpy.select(condlist, choices) @@ -78,7 +74,6 @@ def concat( array(['this1.0', 'that2.5']...) """ - if isinstance(this, numpy.ndarray) and not numpy.issubdtype(this.dtype, numpy.str_): this = this.astype("str") @@ -89,9 +84,9 @@ def concat( def switch( - conditions: t.Array[numpy.float_], + conditions: t.Array[numpy.float64], value_by_condition: Mapping[float, float], -) -> t.Array[numpy.float_]: +) -> t.Array[numpy.float64]: """Mimicks a switch statement. Given an array of conditions, returns an array of the same size, @@ -115,11 +110,10 @@ def switch( array([80, 80, 80, 90]) """ - assert ( len(value_by_condition) > 0 ), "'switch' must be called with at least one value." - condlist = [conditions == condition for condition in value_by_condition.keys()] + condlist = [conditions == condition for condition in value_by_condition] return numpy.select(condlist, tuple(value_by_condition.values())) diff --git a/openfisca_core/commons/misc.py b/openfisca_core/commons/misc.py index 342bbbe5fb..3c9cd5feab 100644 --- a/openfisca_core/commons/misc.py +++ b/openfisca_core/commons/misc.py @@ -30,7 +30,6 @@ def empty_clone(original: T) -> T: True """ - Dummy: object new: T @@ -60,7 +59,7 @@ def stringify_array(array: Optional[t.Array[numpy.generic]]) -> str: >>> stringify_array(None) 'None' - >>> array = numpy.array([10, 20.]) + >>> array = numpy.array([10, 20.0]) >>> stringify_array(array) '[10.0, 20.0]' @@ -73,7 +72,6 @@ def stringify_array(array: Optional[t.Array[numpy.generic]]) -> str: "[, {}, Array[numpy.float_]: +) -> Array[numpy.float64]: """Computes the average rate of a target net income. Given a ``target`` net income, and according to the ``varying`` gross @@ -35,13 +35,12 @@ def average_rate( Examples: >>> target = numpy.array([1, 2, 3]) >>> varying = [2, 2, 2] - >>> trim = [-1, .25] + >>> trim = [-1, 0.25] >>> average_rate(target, varying, trim) array([ nan, 0. , -0.5]) """ - - average_rate: Array[numpy.float_] + average_rate: Array[numpy.float64] average_rate = 1 - target / varying @@ -62,10 +61,10 @@ def average_rate( def marginal_rate( - target: Array[numpy.float_], - varying: Array[numpy.float_], + target: Array[numpy.float64], + varying: Array[numpy.float64], trim: Optional[ArrayLike[float]] = None, -) -> Array[numpy.float_]: +) -> Array[numpy.float64]: """Computes the marginal rate of a target net income. Given a ``target`` net income, and according to the ``varying`` gross @@ -91,13 +90,12 @@ def marginal_rate( Examples: >>> target = numpy.array([1, 2, 3]) >>> varying = numpy.array([1, 2, 4]) - >>> trim = [.25, .75] + >>> trim = [0.25, 0.75] >>> marginal_rate(target, varying, trim) array([nan, 0.5]) """ - - marginal_rate: Array[numpy.float_] + marginal_rate: Array[numpy.float64] marginal_rate = +1 - (target[:-1] - target[1:]) / (varying[:-1] - varying[1:]) diff --git a/openfisca_core/commons/tests/test_dummy.py b/openfisca_core/commons/tests/test_dummy.py index d4ecec3842..4dd13eabab 100644 --- a/openfisca_core/commons/tests/test_dummy.py +++ b/openfisca_core/commons/tests/test_dummy.py @@ -3,8 +3,7 @@ from openfisca_core.commons import Dummy -def test_dummy_deprecation(): +def test_dummy_deprecation() -> None: """Dummy throws a deprecation warning when instantiated.""" - with pytest.warns(DeprecationWarning): assert Dummy() diff --git a/openfisca_core/commons/tests/test_formulas.py b/openfisca_core/commons/tests/test_formulas.py index 82755583e6..91866bd0c0 100644 --- a/openfisca_core/commons/tests/test_formulas.py +++ b/openfisca_core/commons/tests/test_formulas.py @@ -5,9 +5,8 @@ from openfisca_core import commons -def test_apply_thresholds_when_several_inputs(): +def test_apply_thresholds_when_several_inputs() -> None: """Makes a choice for any given input.""" - input_ = numpy.array([4, 5, 6, 7, 8, 9, 10]) thresholds = [5, 7, 9] choices = [10, 15, 20, 25] @@ -17,9 +16,8 @@ def test_apply_thresholds_when_several_inputs(): assert_array_equal(result, [10, 10, 15, 15, 20, 20, 25]) -def test_apply_thresholds_when_too_many_thresholds(): +def test_apply_thresholds_when_too_many_thresholds() -> None: """Raises an AssertionError when thresholds > choices.""" - input_ = numpy.array([6]) thresholds = [5, 7, 9, 11] choices = [10, 15, 20] @@ -28,9 +26,8 @@ def test_apply_thresholds_when_too_many_thresholds(): assert commons.apply_thresholds(input_, thresholds, choices) -def test_apply_thresholds_when_too_many_choices(): +def test_apply_thresholds_when_too_many_choices() -> None: """Raises an AssertionError when thresholds < choices - 1.""" - input_ = numpy.array([6]) thresholds = [5, 7] choices = [10, 15, 20, 25] @@ -39,9 +36,8 @@ def test_apply_thresholds_when_too_many_choices(): assert commons.apply_thresholds(input_, thresholds, choices) -def test_concat_when_this_is_array_not_str(): +def test_concat_when_this_is_array_not_str() -> None: """Casts ``this`` to ``str`` when it is a NumPy array other than string.""" - this = numpy.array([1, 2]) that = numpy.array(["la", "o"]) @@ -50,9 +46,8 @@ def test_concat_when_this_is_array_not_str(): assert_array_equal(result, ["1la", "2o"]) -def test_concat_when_that_is_array_not_str(): +def test_concat_when_that_is_array_not_str() -> None: """Casts ``that`` to ``str`` when it is a NumPy array other than string.""" - this = numpy.array(["ho", "cha"]) that = numpy.array([1, 2]) @@ -61,9 +56,8 @@ def test_concat_when_that_is_array_not_str(): assert_array_equal(result, ["ho1", "cha2"]) -def test_concat_when_args_not_str_array_like(): +def test_concat_when_args_not_str_array_like() -> None: """Raises a TypeError when args are not a string array-like object.""" - this = (1, 2) that = (3, 4) @@ -71,9 +65,8 @@ def test_concat_when_args_not_str_array_like(): commons.concat(this, that) -def test_switch_when_values_are_empty(): +def test_switch_when_values_are_empty() -> None: """Raises an AssertionError when the values are empty.""" - conditions = [1, 1, 1, 2] value_by_condition = {} diff --git a/openfisca_core/commons/tests/test_rates.py b/openfisca_core/commons/tests/test_rates.py index 01565d9527..54e24b8d0f 100644 --- a/openfisca_core/commons/tests/test_rates.py +++ b/openfisca_core/commons/tests/test_rates.py @@ -4,9 +4,8 @@ from openfisca_core import commons -def test_average_rate_when_varying_is_zero(): +def test_average_rate_when_varying_is_zero() -> None: """Yields infinity when the varying gross income crosses zero.""" - target = numpy.array([1, 2, 3]) varying = [0, 0, 0] @@ -15,9 +14,8 @@ def test_average_rate_when_varying_is_zero(): assert_array_equal(result, [-numpy.inf, -numpy.inf, -numpy.inf]) -def test_marginal_rate_when_varying_is_zero(): +def test_marginal_rate_when_varying_is_zero() -> None: """Yields infinity when the varying gross income crosses zero.""" - target = numpy.array([1, 2, 3]) varying = numpy.array([0, 0, 0]) diff --git a/openfisca_core/data_storage/in_memory_storage.py b/openfisca_core/data_storage/in_memory_storage.py index 8fb472046b..0808612ba8 100644 --- a/openfisca_core/data_storage/in_memory_storage.py +++ b/openfisca_core/data_storage/in_memory_storage.py @@ -5,11 +5,9 @@ class InMemoryStorage: - """ - Low-level class responsible for storing and retrieving calculated vectors in memory - """ + """Low-level class responsible for storing and retrieving calculated vectors in memory.""" - def __init__(self, is_eternal=False): + def __init__(self, is_eternal=False) -> None: self._arrays = {} self.is_eternal = is_eternal @@ -23,14 +21,14 @@ def get(self, period): return None return values - def put(self, value, period): + def put(self, value, period) -> None: if self.is_eternal: period = periods.period(DateUnit.ETERNITY) period = periods.period(period) self._arrays[period] = value - def delete(self, period=None): + def delete(self, period=None) -> None: if period is None: self._arrays = {} return @@ -50,16 +48,16 @@ def get_known_periods(self): def get_memory_usage(self): if not self._arrays: - return dict( - nb_arrays=0, - total_nb_bytes=0, - cell_size=numpy.nan, - ) + return { + "nb_arrays": 0, + "total_nb_bytes": 0, + "cell_size": numpy.nan, + } nb_arrays = len(self._arrays) array = next(iter(self._arrays.values())) - return dict( - nb_arrays=nb_arrays, - total_nb_bytes=array.nbytes * nb_arrays, - cell_size=array.itemsize, - ) + return { + "nb_arrays": nb_arrays, + "total_nb_bytes": array.nbytes * nb_arrays, + "cell_size": array.itemsize, + } diff --git a/openfisca_core/data_storage/on_disk_storage.py b/openfisca_core/data_storage/on_disk_storage.py index dbf8a4eb13..9133db2376 100644 --- a/openfisca_core/data_storage/on_disk_storage.py +++ b/openfisca_core/data_storage/on_disk_storage.py @@ -9,11 +9,11 @@ class OnDiskStorage: - """ - Low-level class responsible for storing and retrieving calculated vectors on disk - """ + """Low-level class responsible for storing and retrieving calculated vectors on disk.""" - def __init__(self, storage_dir, is_eternal=False, preserve_storage_dir=False): + def __init__( + self, storage_dir, is_eternal=False, preserve_storage_dir=False + ) -> None: self._files = {} self._enums = {} self.is_eternal = is_eternal @@ -24,8 +24,7 @@ def _decode_file(self, file): enum = self._enums.get(file) if enum is not None: return EnumArray(numpy.load(file), enum) - else: - return numpy.load(file) + return numpy.load(file) def get(self, period): if self.is_eternal: @@ -37,7 +36,7 @@ def get(self, period): return None return self._decode_file(values) - def put(self, value, period): + def put(self, value, period) -> None: if self.is_eternal: period = periods.period(DateUnit.ETERNITY) period = periods.period(period) @@ -50,7 +49,7 @@ def put(self, value, period): numpy.save(path, value) self._files[period] = path - def delete(self, period=None): + def delete(self, period=None) -> None: if period is None: self._files = {} return @@ -69,7 +68,7 @@ def delete(self, period=None): def get_known_periods(self): return self._files.keys() - def restore(self): + def restore(self) -> None: self._files = files = {} # Restore self._files from content of storage_dir. for filename in os.listdir(self.storage_dir): @@ -80,7 +79,7 @@ def restore(self): period = periods.period(filename_core) files[period] = path - def __del__(self): + def __del__(self) -> None: if self.preserve_storage_dir: return shutil.rmtree(self.storage_dir) # Remove the holder temporary files diff --git a/openfisca_core/entities/_core_entity.py b/openfisca_core/entities/_core_entity.py index 9a2707d19d..da3e6ea981 100644 --- a/openfisca_core/entities/_core_entity.py +++ b/openfisca_core/entities/_core_entity.py @@ -29,9 +29,13 @@ class _CoreEntity: @abstractmethod def __init__( - self, key: str, plural: str, label: str, doc: str, *args: object - ) -> None: - ... + self, + key: str, + plural: str, + label: str, + doc: str, + *args: object, + ) -> None: ... def __repr__(self) -> str: return f"{self.__class__.__name__}({self.key})" @@ -47,8 +51,9 @@ def get_variable( ) -> t.Variable | None: """Get a ``variable_name`` from ``variables``.""" if self._tax_benefit_system is None: + msg = "You must set 'tax_benefit_system' before calling this method." raise ValueError( - "You must set 'tax_benefit_system' before calling this method." + msg, ) return self._tax_benefit_system.get_variable(variable_name, check_existence) @@ -75,4 +80,5 @@ def check_variable_defined_for_entity(self, variable_name: str) -> None: def check_role_validity(self, role: object) -> None: """Check if a ``role`` is an instance of Role.""" if role is not None and not isinstance(role, Role): - raise ValueError(f"{role} is not a valid role") + msg = f"{role} is not a valid role" + raise ValueError(msg) diff --git a/openfisca_core/entities/entity.py b/openfisca_core/entities/entity.py index 8194772663..a3fbaddac3 100644 --- a/openfisca_core/entities/entity.py +++ b/openfisca_core/entities/entity.py @@ -5,9 +5,7 @@ class Entity(_CoreEntity): - """ - Represents an entity (e.g. a person, a household, etc.) on which calculations can be run. - """ + """Represents an entity (e.g. a person, a household, etc.) on which calculations can be run.""" def __init__(self, key: str, plural: str, label: str, doc: str) -> None: self.key = t.EntityKey(key) diff --git a/openfisca_core/entities/group_entity.py b/openfisca_core/entities/group_entity.py index d2242983d6..b47c92a525 100644 --- a/openfisca_core/entities/group_entity.py +++ b/openfisca_core/entities/group_entity.py @@ -26,7 +26,7 @@ class GroupEntity(_CoreEntity): containing_entities: The list of keys of group entities whose members are guaranteed to be a superset of this group's entities. - """ # noqa RST301 + """ def __init__( self, @@ -56,7 +56,7 @@ def __init__( role.subroles = (*role.subroles, subrole) role.max = len(role.subroles) self.flattened_roles = tuple( - chain.from_iterable(role.subroles or [role] for role in self.roles) + chain.from_iterable(role.subroles or [role] for role in self.roles), ) self.is_person = False diff --git a/openfisca_core/entities/helpers.py b/openfisca_core/entities/helpers.py index 90b9ffd948..d5ba6cc6a0 100644 --- a/openfisca_core/entities/helpers.py +++ b/openfisca_core/entities/helpers.py @@ -1,7 +1,6 @@ from __future__ import annotations from collections.abc import Iterable, Sequence -from typing import Optional from . import types as t from .entity import Entity @@ -13,7 +12,7 @@ def build_entity( plural: str, label: str, doc: str = "", - roles: Optional[Sequence[t.RoleParams]] = None, + roles: Sequence[t.RoleParams] | None = None, is_person: bool = False, class_override: object | None = None, containing_entities: Sequence[str] = (), @@ -45,17 +44,17 @@ def build_entity( ... "syndicate", ... "syndicates", ... "Banks loaning jointly.", - ... roles = [], - ... containing_entities = (), - ... ) + ... roles=[], + ... containing_entities=(), + ... ) GroupEntity(syndicate) >>> build_entity( ... "company", ... "companies", ... "A small or medium company.", - ... is_person = True, - ... ) + ... is_person=True, + ... ) Entity(company) >>> role = entities.Role({"key": "key"}, object()) @@ -64,26 +63,33 @@ def build_entity( ... "syndicate", ... "syndicates", ... "Banks loaning jointly.", - ... roles = role, - ... ) + ... roles=role, + ... ) Traceback (most recent call last): TypeError: 'Role' object is not iterable """ - if is_person: return Entity(key, plural, label, doc) if roles is not None: return GroupEntity( - key, plural, label, doc, roles, containing_entities=containing_entities + key, + plural, + label, + doc, + roles, + containing_entities=containing_entities, ) raise NotImplementedError def find_role( - roles: Iterable[t.Role], key: t.RoleKey, *, total: int | None = None + roles: Iterable[t.Role], + key: t.RoleKey, + *, + total: int | None = None, ) -> t.Role | None: """Find a Role in a GroupEntity. @@ -141,7 +147,6 @@ def find_role( Role(first_parent) """ - for role in roles: if role.subroles: for subrole in role.subroles: diff --git a/openfisca_core/entities/role.py b/openfisca_core/entities/role.py index d703578160..45193fffc0 100644 --- a/openfisca_core/entities/role.py +++ b/openfisca_core/entities/role.py @@ -85,7 +85,7 @@ def __init__(self, description: Mapping[str, Any], entity: SingleEntity) -> None key: value for key, value in description.items() if key in {"key", "plural", "label", "doc"} - } + }, ) self.entity = entity self.max = description.get("max") @@ -96,7 +96,7 @@ def __repr__(self) -> str: @dataclasses.dataclass(frozen=True) class _Description: - """A Role's description. + r"""A Role's description. Examples: >>> data = { diff --git a/openfisca_core/entities/tests/test_entity.py b/openfisca_core/entities/tests/test_entity.py index 488d271ff5..b3cb813ddc 100644 --- a/openfisca_core/entities/tests/test_entity.py +++ b/openfisca_core/entities/tests/test_entity.py @@ -3,7 +3,6 @@ def test_init_when_doc_indented() -> None: """De-indent the ``doc`` attribute if it is passed at initialisation.""" - key = "\tkey" doc = "\tdoc" entity = entities.Entity(key, "label", "plural", doc) diff --git a/openfisca_core/entities/tests/test_group_entity.py b/openfisca_core/entities/tests/test_group_entity.py index ed55648d71..092c9d3575 100644 --- a/openfisca_core/entities/tests/test_group_entity.py +++ b/openfisca_core/entities/tests/test_group_entity.py @@ -43,7 +43,6 @@ def group_entity(role: Mapping[str, Any]) -> entities.GroupEntity: def test_init_when_doc_indented() -> None: """De-indent the ``doc`` attribute if it is passed at initialisation.""" - key = "\tkey" doc = "\tdoc" group_entity = entities.GroupEntity(key, "label", "plural", doc, ()) @@ -52,18 +51,20 @@ def test_init_when_doc_indented() -> None: def test_group_entity_with_roles( - group_entity: entities.GroupEntity, parent: str, uncle: str + group_entity: entities.GroupEntity, + parent: str, + uncle: str, ) -> None: """Assign a Role for each role-like passed as argument.""" - assert hasattr(group_entity, parent.upper()) assert not hasattr(group_entity, uncle.upper()) def test_group_entity_with_subroles( - group_entity: entities.GroupEntity, first_parent: str, second_parent: str + group_entity: entities.GroupEntity, + first_parent: str, + second_parent: str, ) -> None: """Assign a Role for each subrole-like passed as argument.""" - assert hasattr(group_entity, first_parent.upper()) assert not hasattr(group_entity, second_parent.upper()) diff --git a/openfisca_core/entities/tests/test_role.py b/openfisca_core/entities/tests/test_role.py index 83692e8236..ffb1fdddb8 100644 --- a/openfisca_core/entities/tests/test_role.py +++ b/openfisca_core/entities/tests/test_role.py @@ -3,7 +3,6 @@ def test_init_when_doc_indented() -> None: """De-indent the ``doc`` attribute if it is passed at initialisation.""" - key = "\tkey" doc = "\tdoc" role = entities.Role({"key": key, "doc": doc}, object()) diff --git a/openfisca_core/entities/types.py b/openfisca_core/entities/types.py index 2f9acd0402..38607d5488 100644 --- a/openfisca_core/entities/types.py +++ b/openfisca_core/entities/types.py @@ -26,12 +26,10 @@ class CoreEntity(t.CoreEntity, Protocol): plural: EntityPlural | None -class SingleEntity(t.SingleEntity, Protocol): - ... +class SingleEntity(t.SingleEntity, Protocol): ... -class GroupEntity(t.GroupEntity, Protocol): - ... +class GroupEntity(t.GroupEntity, Protocol): ... class Role(t.Role, Protocol): @@ -50,12 +48,10 @@ class RoleParams(TypedDict, total=False): # Tax-Benefit systems -class TaxBenefitSystem(t.TaxBenefitSystem, Protocol): - ... +class TaxBenefitSystem(t.TaxBenefitSystem, Protocol): ... # Variables -class Variable(t.Variable, Protocol): - ... +class Variable(t.Variable, Protocol): ... diff --git a/openfisca_core/errors/__init__.py b/openfisca_core/errors/__init__.py index 41d4760bee..2c4d438116 100644 --- a/openfisca_core/errors/__init__.py +++ b/openfisca_core/errors/__init__.py @@ -24,18 +24,22 @@ from .cycle_error import CycleError from .empty_argument_error import EmptyArgumentError from .nan_creation_error import NaNCreationError -from .parameter_not_found_error import ParameterNotFoundError -from .parameter_not_found_error import ParameterNotFoundError as ParameterNotFound +from .parameter_not_found_error import ( + ParameterNotFoundError, + ParameterNotFoundError as ParameterNotFound, +) from .parameter_parsing_error import ParameterParsingError from .period_mismatch_error import PeriodMismatchError from .situation_parsing_error import SituationParsingError from .spiral_error import SpiralError -from .variable_name_config_error import VariableNameConflictError from .variable_name_config_error import ( + VariableNameConflictError, VariableNameConflictError as VariableNameConflict, ) -from .variable_not_found_error import VariableNotFoundError -from .variable_not_found_error import VariableNotFoundError as VariableNotFound +from .variable_not_found_error import ( + VariableNotFoundError, + VariableNotFoundError as VariableNotFound, +) __all__ = [ "CycleError", diff --git a/openfisca_core/errors/cycle_error.py b/openfisca_core/errors/cycle_error.py index b4d44b5993..b81cc7b3f9 100644 --- a/openfisca_core/errors/cycle_error.py +++ b/openfisca_core/errors/cycle_error.py @@ -1,4 +1,2 @@ class CycleError(Exception): """Simulation error.""" - - pass diff --git a/openfisca_core/errors/empty_argument_error.py b/openfisca_core/errors/empty_argument_error.py index 0d0205b432..960d8d28c2 100644 --- a/openfisca_core/errors/empty_argument_error.py +++ b/openfisca_core/errors/empty_argument_error.py @@ -16,7 +16,7 @@ def __init__( class_name: str, method_name: str, arg_name: str, - arg_value: typing.Union[typing.List, numpy.ndarray], + arg_value: typing.Union[list, numpy.ndarray], ) -> None: message = [ f"'{class_name}.{method_name}' can't be run with an empty '{arg_name}':\n", diff --git a/openfisca_core/errors/nan_creation_error.py b/openfisca_core/errors/nan_creation_error.py index dfd1b7af7e..373e391517 100644 --- a/openfisca_core/errors/nan_creation_error.py +++ b/openfisca_core/errors/nan_creation_error.py @@ -1,4 +1,2 @@ class NaNCreationError(Exception): """Simulation error.""" - - pass diff --git a/openfisca_core/errors/parameter_not_found_error.py b/openfisca_core/errors/parameter_not_found_error.py index 1a8528f45c..bad33c89f4 100644 --- a/openfisca_core/errors/parameter_not_found_error.py +++ b/openfisca_core/errors/parameter_not_found_error.py @@ -1,21 +1,16 @@ class ParameterNotFoundError(AttributeError): - """ - Exception raised when a parameter is not found in the parameters. - """ + """Exception raised when a parameter is not found in the parameters.""" - def __init__(self, name, instant_str, variable_name=None): - """ - :param name: Name of the parameter + def __init__(self, name, instant_str, variable_name=None) -> None: + """:param name: Name of the parameter :param instant_str: Instant where the parameter does not exist, in the format `YYYY-MM-DD`. :param variable_name: If the parameter was queried during the computation of a variable, name of that variable. """ self.name = name self.instant_str = instant_str self.variable_name = variable_name - message = "The parameter '{}'".format(name) + message = f"The parameter '{name}'" if variable_name is not None: - message += " requested by variable '{}'".format(variable_name) - message += (" was not found in the {} tax and benefit system.").format( - instant_str - ) - super(ParameterNotFoundError, self).__init__(message) + message += f" requested by variable '{variable_name}'" + message += f" was not found in the {instant_str} tax and benefit system." + super().__init__(message) diff --git a/openfisca_core/errors/parameter_parsing_error.py b/openfisca_core/errors/parameter_parsing_error.py index 48b44e3341..7628e42d86 100644 --- a/openfisca_core/errors/parameter_parsing_error.py +++ b/openfisca_core/errors/parameter_parsing_error.py @@ -2,20 +2,17 @@ class ParameterParsingError(Exception): - """ - Exception raised when a parameter cannot be parsed. - """ + """Exception raised when a parameter cannot be parsed.""" - def __init__(self, message, file=None, traceback=None): - """ - :param message: Error message + def __init__(self, message, file=None, traceback=None) -> None: + """:param message: Error message :param file: Parameter file which caused the error (optional) :param traceback: Traceback (optional) """ if file is not None: message = os.linesep.join( - ["Error parsing parameter file '{}':".format(file), message] + [f"Error parsing parameter file '{file}':", message], ) if traceback is not None: message = os.linesep.join([traceback, message]) - super(ParameterParsingError, self).__init__(message) + super().__init__(message) diff --git a/openfisca_core/errors/period_mismatch_error.py b/openfisca_core/errors/period_mismatch_error.py index 2937d11968..fcece9474d 100644 --- a/openfisca_core/errors/period_mismatch_error.py +++ b/openfisca_core/errors/period_mismatch_error.py @@ -1,9 +1,7 @@ class PeriodMismatchError(ValueError): - """ - Exception raised when one tries to set a variable value for a period that doesn't match its definition period - """ + """Exception raised when one tries to set a variable value for a period that doesn't match its definition period.""" - def __init__(self, variable_name: str, period, definition_period, message): + def __init__(self, variable_name: str, period, definition_period, message) -> None: self.variable_name = variable_name self.period = period self.definition_period = definition_period diff --git a/openfisca_core/errors/situation_parsing_error.py b/openfisca_core/errors/situation_parsing_error.py index ff3839d5f7..a5d7ee88d3 100644 --- a/openfisca_core/errors/situation_parsing_error.py +++ b/openfisca_core/errors/situation_parsing_error.py @@ -8,12 +8,13 @@ class SituationParsingError(Exception): - """ - Exception raised when the situation provided as an input for a simulation cannot be parsed - """ + """Exception raised when the situation provided as an input for a simulation cannot be parsed.""" def __init__( - self, path: Iterable[str], message: str, code: int | None = None + self, + path: Iterable[str], + message: str, + code: int | None = None, ) -> None: self.error = {} dpath_path = "/".join([str(item) for item in path]) diff --git a/openfisca_core/errors/spiral_error.py b/openfisca_core/errors/spiral_error.py index 0495439b68..ffa7fe2850 100644 --- a/openfisca_core/errors/spiral_error.py +++ b/openfisca_core/errors/spiral_error.py @@ -1,4 +1,2 @@ class SpiralError(Exception): """Simulation error.""" - - pass diff --git a/openfisca_core/errors/variable_name_config_error.py b/openfisca_core/errors/variable_name_config_error.py index 7a87d7f5c8..fec1c45864 100644 --- a/openfisca_core/errors/variable_name_config_error.py +++ b/openfisca_core/errors/variable_name_config_error.py @@ -1,6 +1,2 @@ class VariableNameConflictError(Exception): - """ - Exception raised when two variables with the same name are added to a tax and benefit system. - """ - - pass + """Exception raised when two variables with the same name are added to a tax and benefit system.""" diff --git a/openfisca_core/errors/variable_not_found_error.py b/openfisca_core/errors/variable_not_found_error.py index ab71239c7d..46ece4b13c 100644 --- a/openfisca_core/errors/variable_not_found_error.py +++ b/openfisca_core/errors/variable_not_found_error.py @@ -2,36 +2,27 @@ class VariableNotFoundError(Exception): - """ - Exception raised when a variable has been queried but is not defined in the TaxBenefitSystem. - """ + """Exception raised when a variable has been queried but is not defined in the TaxBenefitSystem.""" - def __init__(self, variable_name: str, tax_benefit_system): - """ - :param variable_name: Name of the variable that was queried. + def __init__(self, variable_name: str, tax_benefit_system) -> None: + """:param variable_name: Name of the variable that was queried. :param tax_benefit_system: Tax benefits system that does not contain `variable_name` """ country_package_metadata = tax_benefit_system.get_package_metadata() country_package_name = country_package_metadata["name"] country_package_version = country_package_metadata["version"] if country_package_version: - country_package_id = "{}@{}".format( - country_package_name, country_package_version - ) + country_package_id = f"{country_package_name}@{country_package_version}" else: country_package_id = country_package_name message = os.linesep.join( [ - "You tried to calculate or to set a value for variable '{0}', but it was not found in the loaded tax and benefit system ({1}).".format( - variable_name, country_package_id - ), - "Are you sure you spelled '{0}' correctly?".format(variable_name), + f"You tried to calculate or to set a value for variable '{variable_name}', but it was not found in the loaded tax and benefit system ({country_package_id}).", + f"Are you sure you spelled '{variable_name}' correctly?", "If this code used to work and suddenly does not, this is most probably linked to an update of the tax and benefit system.", "Look at its changelog to learn about renames and removals and update your code. If it is an official package,", - "it is probably available on .".format( - country_package_name - ), - ] + f"it is probably available on .", + ], ) self.message = message self.variable_name = variable_name diff --git a/openfisca_core/experimental/memory_config.py b/openfisca_core/experimental/memory_config.py index b5a0af5317..fec38e3a54 100644 --- a/openfisca_core/experimental/memory_config.py +++ b/openfisca_core/experimental/memory_config.py @@ -5,8 +5,11 @@ class MemoryConfig: def __init__( - self, max_memory_occupation, priority_variables=None, variables_to_drop=None - ): + self, + max_memory_occupation, + priority_variables=None, + variables_to_drop=None, + ) -> None: message = [ "Memory configuration is a feature that is still currently under experimentation.", "You are very welcome to use it and send us precious feedback,", @@ -16,7 +19,8 @@ def __init__( self.max_memory_occupation = float(max_memory_occupation) if self.max_memory_occupation > 1: - raise ValueError("max_memory_occupation must be <= 1") + msg = "max_memory_occupation must be <= 1" + raise ValueError(msg) self.max_memory_occupation_pc = self.max_memory_occupation * 100 self.priority_variables = ( set(priority_variables) if priority_variables else set() diff --git a/openfisca_core/holders/helpers.py b/openfisca_core/holders/helpers.py index 0e88964fc7..fcc6563c79 100644 --- a/openfisca_core/holders/helpers.py +++ b/openfisca_core/holders/helpers.py @@ -7,9 +7,8 @@ log = logging.getLogger(__name__) -def set_input_dispatch_by_period(holder, period, array): - """ - This function can be declared as a ``set_input`` attribute of a variable. +def set_input_dispatch_by_period(holder, period, array) -> None: + """This function can be declared as a ``set_input`` attribute of a variable. In this case, the variable will accept inputs on larger periods that its definition period, and the value for the larger period will be applied to all its subperiods. @@ -23,8 +22,9 @@ def set_input_dispatch_by_period(holder, period, array): if holder.variable.definition_period not in ( periods.DateUnit.isoformat + periods.DateUnit.isocalendar ): + msg = "set_input_dispatch_by_period can't be used for eternal variables." raise ValueError( - "set_input_dispatch_by_period can't be used for eternal variables." + msg, ) cached_period_unit = holder.variable.definition_period @@ -43,9 +43,8 @@ def set_input_dispatch_by_period(holder, period, array): sub_period = sub_period.offset(1) -def set_input_divide_by_period(holder, period, array): - """ - This function can be declared as a ``set_input`` attribute of a variable. +def set_input_divide_by_period(holder, period, array) -> None: + """This function can be declared as a ``set_input`` attribute of a variable. In this case, the variable will accept inputs on larger periods that its definition period, and the value for the larger period will be divided between its subperiods. @@ -59,8 +58,9 @@ def set_input_divide_by_period(holder, period, array): if holder.variable.definition_period not in ( periods.DateUnit.isoformat + periods.DateUnit.isocalendar ): + msg = "set_input_divide_by_period can't be used for eternal variables." raise ValueError( - "set_input_divide_by_period can't be used for eternal variables." + msg, ) cached_period_unit = holder.variable.definition_period @@ -87,8 +87,7 @@ def set_input_divide_by_period(holder, period, array): holder._set(sub_period, divided_array) sub_period = sub_period.offset(1) elif not (remaining_array == 0).all(): + msg = f"Inconsistent input: variable {holder.variable.name} has already been set for all months contained in period {period}, and value {array} provided for {period} doesn't match the total ({array - remaining_array}). This error may also be thrown if you try to call set_input twice for the same variable and period." raise ValueError( - "Inconsistent input: variable {0} has already been set for all months contained in period {1}, and value {2} provided for {1} doesn't match the total ({3}). This error may also be thrown if you try to call set_input twice for the same variable and period.".format( - holder.variable.name, period, array, array - remaining_array - ) + msg, ) diff --git a/openfisca_core/holders/holder.py b/openfisca_core/holders/holder.py index 230c916d06..a8ddf3ed3a 100644 --- a/openfisca_core/holders/holder.py +++ b/openfisca_core/holders/holder.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Any, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any import os import warnings @@ -8,21 +9,23 @@ import numpy import psutil -from openfisca_core import commons -from openfisca_core import data_storage as storage -from openfisca_core import errors -from openfisca_core import indexed_enums as enums -from openfisca_core import periods, tools, types +from openfisca_core import ( + commons, + data_storage as storage, + errors, + indexed_enums as enums, + periods, + tools, + types, +) from .memory_usage import MemoryUsage class Holder: - """ - A holder keeps tracks of a variable values after they have been calculated, or set as an input. - """ + """A holder keeps tracks of a variable values after they have been calculated, or set as an input.""" - def __init__(self, variable, population): + def __init__(self, variable, population) -> None: self.population = population self.variable = variable self.simulation = population.simulation @@ -44,9 +47,7 @@ def __init__(self, variable, population): self._do_not_store = True def clone(self, population): - """ - Copy the holder just enough to be able to run a new simulation without modifying the original simulation. - """ + """Copy the holder just enough to be able to run a new simulation without modifying the original simulation.""" new = commons.empty_clone(self) new_dict = new.__dict__ @@ -66,23 +67,22 @@ def create_disk_storage(self, directory=None, preserve=False): if not os.path.isdir(storage_dir): os.mkdir(storage_dir) return storage.OnDiskStorage( - storage_dir, self._eternal, preserve_storage_dir=preserve + storage_dir, + self._eternal, + preserve_storage_dir=preserve, ) - def delete_arrays(self, period=None): - """ - If ``period`` is ``None``, remove all known values of the variable. + def delete_arrays(self, period=None) -> None: + """If ``period`` is ``None``, remove all known values of the variable. If ``period`` is not ``None``, only remove all values for any period included in period (e.g. if period is "2017", values for "2017-01", "2017-07", etc. would be removed) """ - self._memory_storage.delete(period) if self._disk_storage: self._disk_storage.delete(period) def get_array(self, period): - """ - Get the value of the variable for the given period. + """Get the value of the variable for the given period. If the value is not known, return ``None``. """ @@ -93,6 +93,7 @@ def get_array(self, period): return value if self._disk_storage: return self._disk_storage.get(period) + return None def get_memory_usage(self) -> MemoryUsage: """Get data about the virtual memory usage of the Holder. @@ -109,7 +110,7 @@ def get_memory_usage(self) -> MemoryUsage: ... simulations, ... taxbenefitsystems, ... variables, - ... ) + ... ) >>> entity = entities.Entity("", "", "", "") @@ -127,7 +128,7 @@ def get_memory_usage(self) -> MemoryUsage: >>> simulation = simulations.Simulation(tbs, entities) >>> holder.simulation = simulation - >>> pprint(holder.get_memory_usage(), indent = 3) + >>> pprint(holder.get_memory_usage(), indent=3) { 'cell_size': nan, 'dtype': , 'nb_arrays': 0, @@ -135,7 +136,6 @@ def get_memory_usage(self) -> MemoryUsage: 'total_nb_bytes': 0... """ - usage = MemoryUsage( nb_cells_by_array=self.population.count, dtype=self.variable.dtype, @@ -146,30 +146,29 @@ def get_memory_usage(self) -> MemoryUsage: if self.simulation.trace: nb_requests = self.simulation.tracer.get_nb_requests(self.variable.name) usage.update( - dict( - nb_requests=nb_requests, - nb_requests_by_array=nb_requests / float(usage["nb_arrays"]) - if usage["nb_arrays"] > 0 - else numpy.nan, - ) + { + "nb_requests": nb_requests, + "nb_requests_by_array": ( + nb_requests / float(usage["nb_arrays"]) + if usage["nb_arrays"] > 0 + else numpy.nan + ), + }, ) return usage def get_known_periods(self): - """ - Get the list of periods the variable value is known for. - """ - + """Get the list of periods the variable value is known for.""" return list(self._memory_storage.get_known_periods()) + list( - (self._disk_storage.get_known_periods() if self._disk_storage else []) + self._disk_storage.get_known_periods() if self._disk_storage else [], ) def set_input( self, period: types.Period, - array: Union[numpy.ndarray, Sequence[Any]], - ) -> Optional[numpy.ndarray]: + array: numpy.ndarray | Sequence[Any], + ) -> numpy.ndarray | None: """Set a Variable's array of values of a given Period. Args: @@ -187,6 +186,7 @@ def set_input( Examples: >>> from openfisca_core import entities, populations, variables + >>> entity = entities.Entity("", "", "", "") >>> class MyVariable(variables.Variable): @@ -212,7 +212,6 @@ def set_input( https://openfisca.org/doc/coding-the-legislation/35_periods.html#set-input-automatically-process-variable-inputs-defined-for-periods-not-matching-the-definition-period """ - period = periods.period(period) if period.unit == periods.DateUnit.ETERNITY and not self._eternal: @@ -220,7 +219,7 @@ def set_input( [ "Unable to set a value for variable {1} for {0}.", "{1} is only defined for {2}s. Please adapt your input.", - ] + ], ).format( periods.DateUnit.ETERNITY.upper(), self.variable.name, @@ -233,9 +232,7 @@ def set_input( error_message, ) if self.variable.is_neutralized: - warning_message = "You cannot set a value for the variable {}, as it has been neutralized. The value you provided ({}) will be ignored.".format( - self.variable.name, array - ) + warning_message = f"You cannot set a value for the variable {self.variable.name}, as it has been neutralized. The value you provided ({array}) will be ignored." return warnings.warn(warning_message, Warning, stacklevel=2) if self.variable.value_type in (float, int) and isinstance(array, str): array = tools.eval_expression(array) @@ -250,14 +247,9 @@ def _to_array(self, value): # 0-dim arrays are casted to scalar when they interact with float. We don't want that. value = value.reshape(1) if len(value) != self.population.count: + msg = f'Unable to set value "{value}" for variable "{self.variable.name}", as its length is {len(value)} while there are {self.population.count} {self.population.entity.plural} in the simulation.' raise ValueError( - 'Unable to set value "{}" for variable "{}", as its length is {} while there are {} {} in the simulation.'.format( - value, - self.variable.name, - len(value), - self.population.count, - self.population.entity.plural, - ) + msg, ) if self.variable.value_type == enums.Enum: value = self.variable.possible_values.encode(value) @@ -265,20 +257,22 @@ def _to_array(self, value): try: value = value.astype(self.variable.dtype) except ValueError: + msg = f'Unable to set value "{value}" for variable "{self.variable.name}", as the variable dtype "{self.variable.dtype}" does not match the value dtype "{value.dtype}".' raise ValueError( - 'Unable to set value "{}" for variable "{}", as the variable dtype "{}" does not match the value dtype "{}".'.format( - value, self.variable.name, self.variable.dtype, value.dtype - ) + msg, ) return value - def _set(self, period, value): + def _set(self, period, value) -> None: value = self._to_array(value) if not self._eternal: if period is None: - raise ValueError( + msg = ( f"A period must be specified to set values, except for variables with " - f"{periods.DateUnit.ETERNITY.upper()} as as period_definition.", + f"{periods.DateUnit.ETERNITY.upper()} as as period_definition." + ) + raise ValueError( + msg, ) if self.variable.definition_period != period.unit or period.size > 1: name = self.variable.name @@ -292,7 +286,7 @@ def _set(self, period, value): f'Unable to set a value for variable "{name}" for {period_size_adj}-long period "{period}".', f'"{name}" can only be set for one {self.variable.definition_period} at a time. Please adapt your input.', f'If you are the maintainer of "{name}", you can consider adding it a set_input attribute to enable automatic period casting.', - ] + ], ) raise errors.PeriodMismatchError( @@ -314,7 +308,7 @@ def _set(self, period, value): else: self._memory_storage.put(value, period) - def put_in_cache(self, value, period): + def put_in_cache(self, value, period) -> None: if self._do_not_store: return @@ -328,8 +322,5 @@ def put_in_cache(self, value, period): self._set(period, value) def default_array(self): - """ - Return a new array of the appropriate length for the entity, filled with the variable default values. - """ - + """Return a new array of the appropriate length for the entity, filled with the variable default values.""" return self.variable.default_array(self.population.count) diff --git a/openfisca_core/holders/tests/test_helpers.py b/openfisca_core/holders/tests/test_helpers.py index d76040d676..948f25288f 100644 --- a/openfisca_core/holders/tests/test_helpers.py +++ b/openfisca_core/holders/tests/test_helpers.py @@ -35,38 +35,33 @@ def population(people): @pytest.mark.parametrize( - "dispatch_unit, definition_unit, values, expected", + ("dispatch_unit", "definition_unit", "values", "expected"), [ - [DateUnit.YEAR, DateUnit.YEAR, [1.0], [3.0]], - [DateUnit.YEAR, DateUnit.MONTH, [1.0], [36.0]], - [DateUnit.YEAR, DateUnit.DAY, [1.0], [1096.0]], - [DateUnit.YEAR, DateUnit.WEEK, [1.0], [157.0]], - [DateUnit.YEAR, DateUnit.WEEKDAY, [1.0], [1096.0]], - [DateUnit.MONTH, DateUnit.YEAR, [1.0], [1.0]], - [DateUnit.MONTH, DateUnit.MONTH, [1.0], [3.0]], - [DateUnit.MONTH, DateUnit.DAY, [1.0], [90.0]], - [DateUnit.MONTH, DateUnit.WEEK, [1.0], [13.0]], - [DateUnit.MONTH, DateUnit.WEEKDAY, [1.0], [90.0]], - [DateUnit.DAY, DateUnit.YEAR, [1.0], [1.0]], - [DateUnit.DAY, DateUnit.MONTH, [1.0], [1.0]], - [DateUnit.DAY, DateUnit.DAY, [1.0], [3.0]], - [DateUnit.DAY, DateUnit.WEEK, [1.0], [1.0]], - [DateUnit.DAY, DateUnit.WEEKDAY, [1.0], [3.0]], - [DateUnit.WEEK, DateUnit.YEAR, [1.0], [1.0]], - [DateUnit.WEEK, DateUnit.MONTH, [1.0], [1.0]], - [DateUnit.WEEK, DateUnit.DAY, [1.0], [21.0]], - [DateUnit.WEEK, DateUnit.WEEK, [1.0], [3.0]], - [DateUnit.WEEK, DateUnit.WEEKDAY, [1.0], [21.0]], - [DateUnit.WEEK, DateUnit.YEAR, [1.0], [1.0]], - [DateUnit.WEEK, DateUnit.MONTH, [1.0], [1.0]], - [DateUnit.WEEK, DateUnit.DAY, [1.0], [21.0]], - [DateUnit.WEEK, DateUnit.WEEK, [1.0], [3.0]], - [DateUnit.WEEK, DateUnit.WEEKDAY, [1.0], [21.0]], - [DateUnit.WEEKDAY, DateUnit.YEAR, [1.0], [1.0]], - [DateUnit.WEEKDAY, DateUnit.MONTH, [1.0], [1.0]], - [DateUnit.WEEKDAY, DateUnit.DAY, [1.0], [3.0]], - [DateUnit.WEEKDAY, DateUnit.WEEK, [1.0], [1.0]], - [DateUnit.WEEKDAY, DateUnit.WEEKDAY, [1.0], [3.0]], + (DateUnit.YEAR, DateUnit.YEAR, [1.0], [3.0]), + (DateUnit.YEAR, DateUnit.MONTH, [1.0], [36.0]), + (DateUnit.YEAR, DateUnit.DAY, [1.0], [1096.0]), + (DateUnit.YEAR, DateUnit.WEEK, [1.0], [157.0]), + (DateUnit.YEAR, DateUnit.WEEKDAY, [1.0], [1096.0]), + (DateUnit.MONTH, DateUnit.YEAR, [1.0], [1.0]), + (DateUnit.MONTH, DateUnit.MONTH, [1.0], [3.0]), + (DateUnit.MONTH, DateUnit.DAY, [1.0], [90.0]), + (DateUnit.MONTH, DateUnit.WEEK, [1.0], [13.0]), + (DateUnit.MONTH, DateUnit.WEEKDAY, [1.0], [90.0]), + (DateUnit.DAY, DateUnit.YEAR, [1.0], [1.0]), + (DateUnit.DAY, DateUnit.MONTH, [1.0], [1.0]), + (DateUnit.DAY, DateUnit.DAY, [1.0], [3.0]), + (DateUnit.DAY, DateUnit.WEEK, [1.0], [1.0]), + (DateUnit.DAY, DateUnit.WEEKDAY, [1.0], [3.0]), + (DateUnit.WEEK, DateUnit.YEAR, [1.0], [1.0]), + (DateUnit.WEEK, DateUnit.MONTH, [1.0], [1.0]), + (DateUnit.WEEK, DateUnit.DAY, [1.0], [21.0]), + (DateUnit.WEEK, DateUnit.WEEK, [1.0], [3.0]), + (DateUnit.WEEK, DateUnit.WEEKDAY, [1.0], [21.0]), + (DateUnit.WEEKDAY, DateUnit.YEAR, [1.0], [1.0]), + (DateUnit.WEEKDAY, DateUnit.MONTH, [1.0], [1.0]), + (DateUnit.WEEKDAY, DateUnit.DAY, [1.0], [3.0]), + (DateUnit.WEEKDAY, DateUnit.WEEK, [1.0], [1.0]), + (DateUnit.WEEKDAY, DateUnit.WEEKDAY, [1.0], [3.0]), ], ) def test_set_input_dispatch_by_period( @@ -76,7 +71,7 @@ def test_set_input_dispatch_by_period( definition_unit, values, expected, -): +) -> None: Income.definition_period = definition_unit income = Income() holder = Holder(income, population) @@ -90,33 +85,33 @@ def test_set_input_dispatch_by_period( @pytest.mark.parametrize( - "divide_unit, definition_unit, values, expected", + ("divide_unit", "definition_unit", "values", "expected"), [ - [DateUnit.YEAR, DateUnit.YEAR, [3.0], [1.0]], - [DateUnit.YEAR, DateUnit.MONTH, [36.0], [1.0]], - [DateUnit.YEAR, DateUnit.DAY, [1095.0], [1.0]], - [DateUnit.YEAR, DateUnit.WEEK, [157.0], [1.0]], - [DateUnit.YEAR, DateUnit.WEEKDAY, [1095.0], [1.0]], - [DateUnit.MONTH, DateUnit.YEAR, [1.0], [1.0]], - [DateUnit.MONTH, DateUnit.MONTH, [3.0], [1.0]], - [DateUnit.MONTH, DateUnit.DAY, [90.0], [1.0]], - [DateUnit.MONTH, DateUnit.WEEK, [13.0], [1.0]], - [DateUnit.MONTH, DateUnit.WEEKDAY, [90.0], [1.0]], - [DateUnit.DAY, DateUnit.YEAR, [1.0], [1.0]], - [DateUnit.DAY, DateUnit.MONTH, [1.0], [1.0]], - [DateUnit.DAY, DateUnit.DAY, [3.0], [1.0]], - [DateUnit.DAY, DateUnit.WEEK, [1.0], [1.0]], - [DateUnit.DAY, DateUnit.WEEKDAY, [3.0], [1.0]], - [DateUnit.WEEK, DateUnit.YEAR, [1.0], [1.0]], - [DateUnit.WEEK, DateUnit.MONTH, [1.0], [1.0]], - [DateUnit.WEEK, DateUnit.DAY, [21.0], [1.0]], - [DateUnit.WEEK, DateUnit.WEEK, [3.0], [1.0]], - [DateUnit.WEEK, DateUnit.WEEKDAY, [21.0], [1.0]], - [DateUnit.WEEKDAY, DateUnit.YEAR, [1.0], [1.0]], - [DateUnit.WEEKDAY, DateUnit.MONTH, [1.0], [1.0]], - [DateUnit.WEEKDAY, DateUnit.DAY, [3.0], [1.0]], - [DateUnit.WEEKDAY, DateUnit.WEEK, [1.0], [1.0]], - [DateUnit.WEEKDAY, DateUnit.WEEKDAY, [3.0], [1.0]], + (DateUnit.YEAR, DateUnit.YEAR, [3.0], [1.0]), + (DateUnit.YEAR, DateUnit.MONTH, [36.0], [1.0]), + (DateUnit.YEAR, DateUnit.DAY, [1095.0], [1.0]), + (DateUnit.YEAR, DateUnit.WEEK, [157.0], [1.0]), + (DateUnit.YEAR, DateUnit.WEEKDAY, [1095.0], [1.0]), + (DateUnit.MONTH, DateUnit.YEAR, [1.0], [1.0]), + (DateUnit.MONTH, DateUnit.MONTH, [3.0], [1.0]), + (DateUnit.MONTH, DateUnit.DAY, [90.0], [1.0]), + (DateUnit.MONTH, DateUnit.WEEK, [13.0], [1.0]), + (DateUnit.MONTH, DateUnit.WEEKDAY, [90.0], [1.0]), + (DateUnit.DAY, DateUnit.YEAR, [1.0], [1.0]), + (DateUnit.DAY, DateUnit.MONTH, [1.0], [1.0]), + (DateUnit.DAY, DateUnit.DAY, [3.0], [1.0]), + (DateUnit.DAY, DateUnit.WEEK, [1.0], [1.0]), + (DateUnit.DAY, DateUnit.WEEKDAY, [3.0], [1.0]), + (DateUnit.WEEK, DateUnit.YEAR, [1.0], [1.0]), + (DateUnit.WEEK, DateUnit.MONTH, [1.0], [1.0]), + (DateUnit.WEEK, DateUnit.DAY, [21.0], [1.0]), + (DateUnit.WEEK, DateUnit.WEEK, [3.0], [1.0]), + (DateUnit.WEEK, DateUnit.WEEKDAY, [21.0], [1.0]), + (DateUnit.WEEKDAY, DateUnit.YEAR, [1.0], [1.0]), + (DateUnit.WEEKDAY, DateUnit.MONTH, [1.0], [1.0]), + (DateUnit.WEEKDAY, DateUnit.DAY, [3.0], [1.0]), + (DateUnit.WEEKDAY, DateUnit.WEEK, [1.0], [1.0]), + (DateUnit.WEEKDAY, DateUnit.WEEKDAY, [3.0], [1.0]), ], ) def test_set_input_divide_by_period( @@ -126,7 +121,7 @@ def test_set_input_divide_by_period( definition_unit, values, expected, -): +) -> None: Income.definition_period = definition_unit income = Income() holder = Holder(income, population) diff --git a/openfisca_core/indexed_enums/enum.py b/openfisca_core/indexed_enums/enum.py index 7957ced3a2..25b02cee74 100644 --- a/openfisca_core/indexed_enums/enum.py +++ b/openfisca_core/indexed_enums/enum.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Union - import enum import numpy @@ -11,8 +9,7 @@ class Enum(enum.Enum): - """ - Enum based on `enum34 `_, whose items + """Enum based on `enum34 `_, whose items have an index. """ @@ -33,15 +30,9 @@ def __init__(self, name: str) -> None: @classmethod def encode( cls, - array: Union[ - EnumArray, - numpy.int_, - numpy.float_, - numpy.object_, - ], + array: EnumArray | numpy.int_ | numpy.float64 | numpy.object_, ) -> EnumArray: - """ - Encode a string numpy array, an enum item numpy array, or an int numpy + """Encode a string numpy array, an enum item numpy array, or an int numpy array into an :any:`EnumArray`. See :any:`EnumArray.decode` for decoding. @@ -53,7 +44,7 @@ def encode( For instance: - >>> string_identifier_array = asarray(['free_lodger', 'owner']) + >>> string_identifier_array = asarray(["free_lodger", "owner"]) >>> encoded_array = HousingOccupancyStatus.encode(string_identifier_array) >>> encoded_array[0] 2 # Encoded value diff --git a/openfisca_core/indexed_enums/enum_array.py b/openfisca_core/indexed_enums/enum_array.py index 2742719ada..86b55f9f48 100644 --- a/openfisca_core/indexed_enums/enum_array.py +++ b/openfisca_core/indexed_enums/enum_array.py @@ -1,7 +1,7 @@ from __future__ import annotations import typing -from typing import Any, NoReturn, Optional, Type +from typing import Any, NoReturn import numpy @@ -10,8 +10,7 @@ class EnumArray(numpy.ndarray): - """ - NumPy array subclass representing an array of enum items. + """NumPy array subclass representing an array of enum items. EnumArrays are encoded as ``int`` arrays to improve performance """ @@ -22,20 +21,20 @@ class EnumArray(numpy.ndarray): def __new__( cls, input_array: numpy.int_, - possible_values: Optional[Type[Enum]] = None, + possible_values: type[Enum] | None = None, ) -> EnumArray: obj = numpy.asarray(input_array).view(cls) obj.possible_values = possible_values return obj # See previous comment - def __array_finalize__(self, obj: Optional[numpy.int_]) -> None: + def __array_finalize__(self, obj: numpy.int_ | None) -> None: if obj is None: return self.possible_values = getattr(obj, "possible_values", None) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: # When comparing to an item of self.possible_values, use the item index # to speed up the comparison. if other.__class__.__name__ is self.possible_values.__name__: @@ -45,13 +44,16 @@ def __eq__(self, other: Any) -> bool: return self.view(numpy.ndarray) == other - def __ne__(self, other: Any) -> bool: + def __ne__(self, other: object) -> bool: return numpy.logical_not(self == other) def _forbidden_operation(self, other: Any) -> NoReturn: - raise TypeError( + msg = ( "Forbidden operation. The only operations allowed on EnumArrays " - "are '==' and '!='.", + "are '==' and '!='." + ) + raise TypeError( + msg, ) __add__ = _forbidden_operation @@ -64,12 +66,11 @@ def _forbidden_operation(self, other: Any) -> NoReturn: __or__ = _forbidden_operation def decode(self) -> numpy.object_: - """ - Return the array of enum items corresponding to self. + """Return the array of enum items corresponding to self. For instance: - >>> enum_array = household('housing_occupancy_status', period) + >>> enum_array = household("housing_occupancy_status", period) >>> enum_array[0] >>> 2 # Encoded value >>> enum_array.decode()[0] @@ -83,12 +84,11 @@ def decode(self) -> numpy.object_: ) def decode_to_str(self) -> numpy.str_: - """ - Return the array of string identifiers corresponding to self. + """Return the array of string identifiers corresponding to self. For instance: - >>> enum_array = household('housing_occupancy_status', period) + >>> enum_array = household("housing_occupancy_status", period) >>> enum_array[0] >>> 2 # Encoded value >>> enum_array.decode_to_str()[0] @@ -100,7 +100,7 @@ def decode_to_str(self) -> numpy.str_: ) def __repr__(self) -> str: - return f"{self.__class__.__name__}({str(self.decode())})" + return f"{self.__class__.__name__}({self.decode()!s})" def __str__(self) -> str: return str(self.decode_to_str()) diff --git a/openfisca_core/model_api.py b/openfisca_core/model_api.py index 553ee75b34..e36e0d5f76 100644 --- a/openfisca_core/model_api.py +++ b/openfisca_core/model_api.py @@ -1,10 +1,13 @@ from datetime import date -from numpy import logical_not as not_ -from numpy import maximum as max_ -from numpy import minimum as min_ -from numpy import round as round_ -from numpy import select, where +from numpy import ( + logical_not as not_, + maximum as max_, + minimum as min_, + round as round_, + select, + where, +) from openfisca_core.commons import apply_thresholds, concat, switch from openfisca_core.holders import ( diff --git a/openfisca_core/parameters/__init__.py b/openfisca_core/parameters/__init__.py index f64f577fd4..5d742d4611 100644 --- a/openfisca_core/parameters/__init__.py +++ b/openfisca_core/parameters/__init__.py @@ -36,10 +36,11 @@ from .parameter_at_instant import ParameterAtInstant from .parameter_node import ParameterNode from .parameter_node_at_instant import ParameterNodeAtInstant -from .parameter_scale import ParameterScale -from .parameter_scale import ParameterScale as Scale -from .parameter_scale_bracket import ParameterScaleBracket -from .parameter_scale_bracket import ParameterScaleBracket as Bracket +from .parameter_scale import ParameterScale, ParameterScale as Scale +from .parameter_scale_bracket import ( + ParameterScaleBracket, + ParameterScaleBracket as Bracket, +) from .values_history import ValuesHistory from .vectorial_asof_date_parameter_node_at_instant import ( VectorialAsofDateParameterNodeAtInstant, diff --git a/openfisca_core/parameters/at_instant_like.py b/openfisca_core/parameters/at_instant_like.py index 1a1db34beb..19c28e98c2 100644 --- a/openfisca_core/parameters/at_instant_like.py +++ b/openfisca_core/parameters/at_instant_like.py @@ -4,9 +4,7 @@ class AtInstantLike(abc.ABC): - """ - Base class for various types of parameters implementing the at instant protocol. - """ + """Base class for various types of parameters implementing the at instant protocol.""" def __call__(self, instant): return self.get_at_instant(instant) @@ -16,5 +14,4 @@ def get_at_instant(self, instant): return self._get_at_instant(instant) @abc.abstractmethod - def _get_at_instant(self, instant): - ... + def _get_at_instant(self, instant): ... diff --git a/openfisca_core/parameters/config.py b/openfisca_core/parameters/config.py index 1900d0f550..b97462a79d 100644 --- a/openfisca_core/parameters/config.py +++ b/openfisca_core/parameters/config.py @@ -1,5 +1,3 @@ -import typing - import os import warnings @@ -23,7 +21,7 @@ # 'unit' and 'reference' are only listed here for backward compatibility. # It is now recommended to include them in metadata, until a common consensus emerges. -ALLOWED_PARAM_TYPES = (float, int, bool, type(None), typing.List) +ALLOWED_PARAM_TYPES = (float, int, bool, type(None), list) COMMON_KEYS = {"description", "metadata", "unit", "reference", "documentation"} FILE_EXTENSIONS = {".yaml", ".yml"} @@ -39,9 +37,12 @@ def dict_no_duplicate_constructor(loader, node, deep=False): keys = [key.value for key, value in node.value] if len(keys) != len(set(keys)): - duplicate = next((key for key in keys if keys.count(key) > 1)) + duplicate = next(key for key in keys if keys.count(key) > 1) + msg = "" raise yaml.parser.ParserError( - "", node.start_mark, f"Found duplicate key '{duplicate}'" + msg, + node.start_mark, + f"Found duplicate key '{duplicate}'", ) return loader.construct_mapping(node, deep) diff --git a/openfisca_core/parameters/helpers.py b/openfisca_core/parameters/helpers.py index 30af4adcbc..09925bbcdb 100644 --- a/openfisca_core/parameters/helpers.py +++ b/openfisca_core/parameters/helpers.py @@ -10,21 +10,21 @@ def contains_nan(vector): if numpy.issubdtype(vector.dtype, numpy.record) or numpy.issubdtype( - vector.dtype, numpy.void + vector.dtype, + numpy.void, ): - return any([contains_nan(vector[name]) for name in vector.dtype.names]) - else: - return numpy.isnan(vector).any() + return any(contains_nan(vector[name]) for name in vector.dtype.names) + return numpy.isnan(vector).any() def load_parameter_file(file_path, name=""): - """ - Load parameters from a YAML file (or a directory containing YAML files). + """Load parameters from a YAML file (or a directory containing YAML files). :returns: An instance of :class:`.ParameterNode` or :class:`.ParameterScale` or :class:`.Parameter`. """ if not os.path.exists(file_path): - raise ValueError("{} does not exist".format(file_path)) + msg = f"{file_path} does not exist" + raise ValueError(msg) if os.path.isdir(file_path): return parameters.ParameterNode(name, directory_path=file_path) data = _load_yaml_file(file_path) @@ -35,26 +35,29 @@ def _compose_name(path, child_name=None, item_name=None): if not path: return child_name if child_name is not None: - return "{}.{}".format(path, child_name) + return f"{path}.{child_name}" if item_name is not None: - return "{}[{}]".format(path, item_name) + return f"{path}[{item_name}]" + return None def _load_yaml_file(file_path): - with open(file_path, "r") as f: + with open(file_path) as f: try: return config.yaml.load(f, Loader=config.Loader) except (config.yaml.scanner.ScannerError, config.yaml.parser.ParserError): stack_trace = traceback.format_exc() + msg = "Invalid YAML. Check the traceback above for more details." raise ParameterParsingError( - "Invalid YAML. Check the traceback above for more details.", + msg, file_path, stack_trace, ) except Exception: stack_trace = traceback.format_exc() + msg = "Invalid parameter file content. Check the traceback above for more details." raise ParameterParsingError( - "Invalid parameter file content. Check the traceback above for more details.", + msg, file_path, stack_trace, ) @@ -63,32 +66,32 @@ def _load_yaml_file(file_path): def _parse_child(child_name, child, child_path): if "values" in child: return parameters.Parameter(child_name, child, child_path) - elif "brackets" in child: + if "brackets" in child: return parameters.ParameterScale(child_name, child, child_path) - elif isinstance(child, dict) and all( - [periods.INSTANT_PATTERN.match(str(key)) for key in child.keys()] + if isinstance(child, dict) and all( + periods.INSTANT_PATTERN.match(str(key)) for key in child ): return parameters.Parameter(child_name, child, child_path) - else: - return parameters.ParameterNode(child_name, data=child, file_path=child_path) + return parameters.ParameterNode(child_name, data=child, file_path=child_path) -def _set_backward_compatibility_metadata(parameter, data): +def _set_backward_compatibility_metadata(parameter, data) -> None: if data.get("unit") is not None: parameter.metadata["unit"] = data["unit"] if data.get("reference") is not None: parameter.metadata["reference"] = data["reference"] -def _validate_parameter(parameter, data, data_type=None, allowed_keys=None): +def _validate_parameter(parameter, data, data_type=None, allowed_keys=None) -> None: type_map = { dict: "object", list: "array", } if data_type is not None and not isinstance(data, data_type): + msg = f"'{parameter.name}' must be of type {type_map[data_type]}." raise ParameterParsingError( - "'{}' must be of type {}.".format(parameter.name, type_map[data_type]), + msg, parameter.file_path, ) @@ -96,9 +99,8 @@ def _validate_parameter(parameter, data, data_type=None, allowed_keys=None): keys = data.keys() for key in keys: if key not in allowed_keys: + msg = f"Unexpected property '{key}' in '{parameter.name}'. Allowed properties are {list(allowed_keys)}." raise ParameterParsingError( - "Unexpected property '{}' in '{}'. Allowed properties are {}.".format( - key, parameter.name, list(allowed_keys) - ), + msg, parameter.file_path, ) diff --git a/openfisca_core/parameters/parameter.py b/openfisca_core/parameters/parameter.py index 9c9d7e7093..528f54cccd 100644 --- a/openfisca_core/parameters/parameter.py +++ b/openfisca_core/parameters/parameter.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Dict, List, Optional - import copy import os @@ -46,20 +44,22 @@ class Parameter(AtInstantLike): """ - def __init__(self, name: str, data: dict, file_path: Optional[str] = None) -> None: + def __init__(self, name: str, data: dict, file_path: str | None = None) -> None: self.name: str = name - self.file_path: Optional[str] = file_path + self.file_path: str | None = file_path helpers._validate_parameter(self, data, data_type=dict) - self.description: Optional[str] = None - self.metadata: Dict = {} - self.documentation: Optional[str] = None + self.description: str | None = None + self.metadata: dict = {} + self.documentation: str | None = None self.values_history = self # Only for backward compatibility # Normal parameter declaration: the values are declared under the 'values' key: parse the description and metadata. if data.get("values"): # 'unit' and 'reference' are only listed here for backward compatibility helpers._validate_parameter( - self, data, allowed_keys=config.COMMON_KEYS.union({"values"}) + self, + data, + allowed_keys=config.COMMON_KEYS.union({"values"}), ) self.description = data.get("description") @@ -75,16 +75,16 @@ def __init__(self, name: str, data: dict, file_path: Optional[str] = None) -> No values = data instants = sorted( - values.keys(), reverse=True + values.keys(), + reverse=True, ) # sort in reverse chronological order values_list = [] for instant_str in instants: if not periods.INSTANT_PATTERN.match(instant_str): + msg = f"Invalid property '{instant_str}' in '{self.name}'. Properties must be valid YYYY-MM-DD instants, such as 2017-01-15." raise ParameterParsingError( - "Invalid property '{}' in '{}'. Properties must be valid YYYY-MM-DD instants, such as 2017-01-15.".format( - instant_str, self.name - ), + msg, file_path, ) @@ -108,9 +108,9 @@ def __init__(self, name: str, data: dict, file_path: Optional[str] = None) -> No ) values_list.append(value_at_instant) - self.values_list: List[ParameterAtInstant] = values_list + self.values_list: list[ParameterAtInstant] = values_list - def __repr__(self): + def __repr__(self) -> str: return os.linesep.join( [ "{}: {}".format( @@ -118,7 +118,7 @@ def __repr__(self): value.value if value.value is not None else "null", ) for value in self.values_list - ] + ], ) def __eq__(self, other): @@ -134,9 +134,8 @@ def clone(self): ] return clone - def update(self, period=None, start=None, stop=None, value=None): - """ - Change the value for a given period. + def update(self, period=None, start=None, stop=None, value=None) -> None: + """Change the value for a given period. :param period: Period where the value is modified. If set, `start` and `stop` should be `None`. :param start: Start of the period. Instance of `openfisca_core.periods.Instant`. If set, `period` should be `None`. @@ -145,15 +144,17 @@ def update(self, period=None, start=None, stop=None, value=None): """ if period is not None: if start is not None or stop is not None: + msg = "Wrong input for 'update' method: use either 'update(period, value = value)' or 'update(start = start, stop = stop, value = value)'. You cannot both use 'period' and 'start' or 'stop'." raise TypeError( - "Wrong input for 'update' method: use either 'update(period, value = value)' or 'update(start = start, stop = stop, value = value)'. You cannot both use 'period' and 'start' or 'stop'." + msg, ) if isinstance(period, str): period = periods.period(period) start = period.start stop = period.stop if start is None: - raise ValueError("You must provide either a start or a period") + msg = "You must provide either a start or a period" + raise ValueError(msg) start_str = str(start) stop_str = str(stop.offset(1, "day")) if stop else None @@ -172,20 +173,23 @@ def update(self, period=None, start=None, stop=None, value=None): if stop_str: if new_values and (stop_str == new_values[-1].instant_str): pass # such interval is empty + elif i < n: + overlapped_value = old_values[i].value + value_name = helpers._compose_name(self.name, item_name=stop_str) + new_interval = ParameterAtInstant( + value_name, + stop_str, + data={"value": overlapped_value}, + ) + new_values.append(new_interval) else: - if i < n: - overlapped_value = old_values[i].value - value_name = helpers._compose_name(self.name, item_name=stop_str) - new_interval = ParameterAtInstant( - value_name, stop_str, data={"value": overlapped_value} - ) - new_values.append(new_interval) - else: - value_name = helpers._compose_name(self.name, item_name=stop_str) - new_interval = ParameterAtInstant( - value_name, stop_str, data={"value": None} - ) - new_values.append(new_interval) + value_name = helpers._compose_name(self.name, item_name=stop_str) + new_interval = ParameterAtInstant( + value_name, + stop_str, + data={"value": None}, + ) + new_values.append(new_interval) # Insert new interval value_name = helpers._compose_name(self.name, item_name=start_str) diff --git a/openfisca_core/parameters/parameter_at_instant.py b/openfisca_core/parameters/parameter_at_instant.py index b84dc5b2b6..ae525cf829 100644 --- a/openfisca_core/parameters/parameter_at_instant.py +++ b/openfisca_core/parameters/parameter_at_instant.py @@ -1,5 +1,3 @@ -import typing - import copy from openfisca_core import commons @@ -8,23 +6,22 @@ class ParameterAtInstant: - """ - A value of a parameter at a given instant. - """ + """A value of a parameter at a given instant.""" # 'unit' and 'reference' are only listed here for backward compatibility - _allowed_keys = set(["value", "metadata", "unit", "reference"]) + _allowed_keys = {"value", "metadata", "unit", "reference"} - def __init__(self, name, instant_str, data=None, file_path=None, metadata=None): - """ - :param str name: name of the parameter, e.g. "taxes.some_tax.some_param" + def __init__( + self, name, instant_str, data=None, file_path=None, metadata=None + ) -> None: + """:param str name: name of the parameter, e.g. "taxes.some_tax.some_param" :param str instant_str: Date of the value in the format `YYYY-MM-DD`. :param dict data: Data, usually loaded from a YAML file. """ self.name: str = name self.instant_str: str = instant_str self.file_path: str = file_path - self.metadata: typing.Dict = {} + self.metadata: dict = {} # Accept { 2015-01-01: 4000 } if not isinstance(data, dict) and isinstance(data, config.ALLOWED_PARAM_TYPES): @@ -39,21 +36,25 @@ def __init__(self, name, instant_str, data=None, file_path=None, metadata=None): helpers._set_backward_compatibility_metadata(self, data) self.metadata.update(data.get("metadata", {})) - def validate(self, data): + def validate(self, data) -> None: helpers._validate_parameter( - self, data, data_type=dict, allowed_keys=self._allowed_keys + self, + data, + data_type=dict, + allowed_keys=self._allowed_keys, ) try: value = data["value"] except KeyError: + msg = f"Missing 'value' property for {self.name}" raise ParameterParsingError( - "Missing 'value' property for {}".format(self.name), self.file_path + msg, + self.file_path, ) if not isinstance(value, config.ALLOWED_PARAM_TYPES): + msg = f"Value in {self.name} has type {type(value)}, which is not one of the allowed types ({config.ALLOWED_PARAM_TYPES}): {value}" raise ParameterParsingError( - "Value in {} has type {}, which is not one of the allowed types ({}): {}".format( - self.name, type(value), config.ALLOWED_PARAM_TYPES, value - ), + msg, self.file_path, ) @@ -64,8 +65,8 @@ def __eq__(self, other): and (self.value == other.value) ) - def __repr__(self): - return "ParameterAtInstant({})".format({self.instant_str: self.value}) + def __repr__(self) -> str: + return "ParameterAtInstant({self.instant_str: self.value})" def clone(self): clone = commons.empty_clone(self) diff --git a/openfisca_core/parameters/parameter_node.py b/openfisca_core/parameters/parameter_node.py index 6a344a09a9..2be3a9acfd 100644 --- a/openfisca_core/parameters/parameter_node.py +++ b/openfisca_core/parameters/parameter_node.py @@ -14,17 +14,14 @@ class ParameterNode(AtInstantLike): - """ - A node in the legislation `parameter tree `_. - """ + """A node in the legislation `parameter tree `_.""" - _allowed_keys: typing.Optional[ - typing.Iterable[str] - ] = None # By default, no restriction on the keys + _allowed_keys: None | (typing.Iterable[str]) = ( + None # By default, no restriction on the keys + ) - def __init__(self, name="", directory_path=None, data=None, file_path=None): - """ - Instantiate a ParameterNode either from a dict, (using `data`), or from a directory containing YAML files (using `directory_path`). + def __init__(self, name="", directory_path=None, data=None, file_path=None) -> None: + """Instantiate a ParameterNode either from a dict, (using `data`), or from a directory containing YAML files (using `directory_path`). :param str name: Name of the node, eg "taxes.some_tax". :param str directory_path: Directory containing YAML files describing the node. @@ -51,16 +48,20 @@ def __init__(self, name="", directory_path=None, data=None, file_path=None): Instantiate a ParameterNode from a directory containing YAML parameter files: - >>> node = ParameterNode('benefits', directory_path = '/path/to/country_package/parameters/benefits') + >>> node = ParameterNode( + ... "benefits", + ... directory_path="/path/to/country_package/parameters/benefits", + ... ) """ self.name: str = name - self.children: typing.Dict[ - str, typing.Union[ParameterNode, Parameter, parameters.ParameterScale] + self.children: dict[ + str, + ParameterNode | Parameter | parameters.ParameterScale, ] = {} self.description: str = None self.documentation: str = None self.file_path: str = None - self.metadata: typing.Dict = {} + self.metadata: dict = {} if directory_path: self.file_path = directory_path @@ -76,7 +77,9 @@ def __init__(self, name="", directory_path=None, data=None, file_path=None): if child_name == "index": data = helpers._load_yaml_file(child_path) or {} helpers._validate_parameter( - self, data, allowed_keys=config.COMMON_KEYS + self, + data, + allowed_keys=config.COMMON_KEYS, ) self.description = data.get("description") self.documentation = data.get("documentation") @@ -85,7 +88,8 @@ def __init__(self, name="", directory_path=None, data=None, file_path=None): else: child_name_expanded = helpers._compose_name(name, child_name) child = helpers.load_parameter_file( - child_path, child_name_expanded + child_path, + child_name_expanded, ) self.add_child(child_name, child) @@ -93,14 +97,18 @@ def __init__(self, name="", directory_path=None, data=None, file_path=None): child_name = os.path.basename(child_path) child_name_expanded = helpers._compose_name(name, child_name) child = ParameterNode( - child_name_expanded, directory_path=child_path + child_name_expanded, + directory_path=child_path, ) self.add_child(child_name, child) else: self.file_path = file_path helpers._validate_parameter( - self, data, data_type=dict, allowed_keys=self._allowed_keys + self, + data, + data_type=dict, + allowed_keys=self._allowed_keys, ) self.description = data.get("description") self.documentation = data.get("documentation") @@ -115,50 +123,43 @@ def __init__(self, name="", directory_path=None, data=None, file_path=None): child = helpers._parse_child(child_name_expanded, child, file_path) self.add_child(child_name, child) - def merge(self, other): - """ - Merges another ParameterNode into the current node. + def merge(self, other) -> None: + """Merges another ParameterNode into the current node. In case of child name conflict, the other node child will replace the current node child. """ for child_name, child in other.children.items(): self.add_child(child_name, child) - def add_child(self, name, child): - """ - Add a new child to the node. + def add_child(self, name, child) -> None: + """Add a new child to the node. :param name: Name of the child that must be used to access that child. Should not contain anything that could interfere with the operator `.` (dot). :param child: The new child, an instance of :class:`.ParameterScale` or :class:`.Parameter` or :class:`.ParameterNode`. """ if name in self.children: - raise ValueError("{} has already a child named {}".format(self.name, name)) + msg = f"{self.name} has already a child named {name}" + raise ValueError(msg) if not ( - isinstance(child, ParameterNode) - or isinstance(child, Parameter) - or isinstance(child, parameters.ParameterScale) + isinstance(child, (ParameterNode, Parameter, parameters.ParameterScale)) ): + msg = f"child must be of type ParameterNode, Parameter, or Scale. Instead got {type(child)}" raise TypeError( - "child must be of type ParameterNode, Parameter, or Scale. Instead got {}".format( - type(child) - ) + msg, ) self.children[name] = child setattr(self, name, child) - def __repr__(self): - result = os.linesep.join( + def __repr__(self) -> str: + return os.linesep.join( [ os.linesep.join(["{}:", "{}"]).format(name, tools.indent(repr(value))) for name, value in sorted(self.children.items()) - ] + ], ) - return result def get_descendants(self): - """ - Return a generator containing all the parameters and nodes recursively contained in this `ParameterNode` - """ + """Return a generator containing all the parameters and nodes recursively contained in this `ParameterNode`.""" for child in self.children.values(): yield child yield from child.get_descendants() diff --git a/openfisca_core/parameters/parameter_node_at_instant.py b/openfisca_core/parameters/parameter_node_at_instant.py index d98d88a698..b66c0c1ed7 100644 --- a/openfisca_core/parameters/parameter_node_at_instant.py +++ b/openfisca_core/parameters/parameter_node_at_instant.py @@ -1,5 +1,4 @@ import os -import sys import numpy @@ -9,17 +8,13 @@ class ParameterNodeAtInstant: - """ - Parameter node of the legislation, at a given instant. - """ + """Parameter node of the legislation, at a given instant.""" - def __init__(self, name, node, instant_str): - """ - :param name: Name of the node. + def __init__(self, name, node, instant_str) -> None: + """:param name: Name of the node. :param node: Original :any:`ParameterNode` instance. :param instant_str: A date in the format `YYYY-MM-DD`. """ - # The "technical" attributes are hidden, so that the node children can be easily browsed with auto-completion without pollution self._name = name self._instant_str = instant_str @@ -30,7 +25,7 @@ def __init__(self, name, node, instant_str): if child_at_instant is not None: self.add_child(child_name, child_at_instant) - def add_child(self, child_name, child_at_instant): + def add_child(self, child_name, child_at_instant) -> None: self._children[child_name] = child_at_instant setattr(self, child_name, child_at_instant) @@ -45,7 +40,7 @@ def __getitem__(self, key): if numpy.issubdtype(key.dtype, numpy.datetime64): return ( parameters.VectorialAsofDateParameterNodeAtInstant.build_from_node( - self + self, )[key] ) @@ -55,13 +50,10 @@ def __getitem__(self, key): def __iter__(self): return iter(self._children) - def __repr__(self): - result = os.linesep.join( + def __repr__(self) -> str: + return os.linesep.join( [ os.linesep.join(["{}:", "{}"]).format(name, tools.indent(repr(value))) for name, value in self._children.items() - ] + ], ) - if sys.version_info < (3, 0): - return result - return result diff --git a/openfisca_core/parameters/parameter_scale.py b/openfisca_core/parameters/parameter_scale.py index f3d636ed0c..b01b6a372a 100644 --- a/openfisca_core/parameters/parameter_scale.py +++ b/openfisca_core/parameters/parameter_scale.py @@ -1,5 +1,3 @@ -import typing - import copy import os @@ -15,34 +13,33 @@ class ParameterScale(AtInstantLike): - """ - A parameter scale (for instance a marginal scale). - """ + """A parameter scale (for instance a marginal scale).""" # 'unit' and 'reference' are only listed here for backward compatibility _allowed_keys = config.COMMON_KEYS.union({"brackets"}) - def __init__(self, name, data, file_path): - """ - :param name: name of the scale, eg "taxes.some_scale" + def __init__(self, name, data, file_path) -> None: + """:param name: name of the scale, eg "taxes.some_scale" :param data: Data loaded from a YAML file. In case of a reform, the data can also be created dynamically. :param file_path: File the parameter was loaded from. """ self.name: str = name self.file_path: str = file_path helpers._validate_parameter( - self, data, data_type=dict, allowed_keys=self._allowed_keys + self, + data, + data_type=dict, + allowed_keys=self._allowed_keys, ) self.description: str = data.get("description") - self.metadata: typing.Dict = {} + self.metadata: dict = {} helpers._set_backward_compatibility_metadata(self, data) self.metadata.update(data.get("metadata", {})) if not isinstance(data.get("brackets", []), list): + msg = f"Property 'brackets' of scale '{self.name}' must be of type array." raise ParameterParsingError( - "Property 'brackets' of scale '{}' must be of type array.".format( - self.name - ), + msg, self.file_path, ) @@ -50,24 +47,25 @@ def __init__(self, name, data, file_path): for i, bracket_data in enumerate(data.get("brackets", [])): bracket_name = helpers._compose_name(name, item_name=i) bracket = parameters.ParameterScaleBracket( - name=bracket_name, data=bracket_data, file_path=file_path + name=bracket_name, + data=bracket_data, + file_path=file_path, ) brackets.append(bracket) - self.brackets: typing.List[parameters.ParameterScaleBracket] = brackets + self.brackets: list[parameters.ParameterScaleBracket] = brackets def __getitem__(self, key): if isinstance(key, int) and key < len(self.brackets): return self.brackets[key] - else: - raise KeyError(key) + raise KeyError(key) - def __repr__(self): + def __repr__(self) -> str: return os.linesep.join( ["brackets:"] + [ tools.indent("-" + tools.indent(repr(bracket))[1:]) for bracket in self.brackets - ] + ], ) def get_descendants(self): @@ -93,7 +91,7 @@ def _get_at_instant(self, instant): threshold = bracket.threshold scale.add_bracket(threshold, amount) return scale - elif any("amount" in bracket._children for bracket in brackets): + if any("amount" in bracket._children for bracket in brackets): scale = MarginalAmountTaxScale() for bracket in brackets: if "amount" in bracket._children and "threshold" in bracket._children: @@ -101,7 +99,7 @@ def _get_at_instant(self, instant): threshold = bracket.threshold scale.add_bracket(threshold, amount) return scale - elif any("average_rate" in bracket._children for bracket in brackets): + if any("average_rate" in bracket._children for bracket in brackets): scale = LinearAverageRateTaxScale() for bracket in brackets: @@ -113,12 +111,11 @@ def _get_at_instant(self, instant): threshold = bracket.threshold scale.add_bracket(threshold, average_rate) return scale - else: - scale = MarginalRateTaxScale() - - for bracket in brackets: - if "rate" in bracket._children and "threshold" in bracket._children: - rate = bracket.rate - threshold = bracket.threshold - scale.add_bracket(threshold, rate) - return scale + scale = MarginalRateTaxScale() + + for bracket in brackets: + if "rate" in bracket._children and "threshold" in bracket._children: + rate = bracket.rate + threshold = bracket.threshold + scale.add_bracket(threshold, rate) + return scale diff --git a/openfisca_core/parameters/parameter_scale_bracket.py b/openfisca_core/parameters/parameter_scale_bracket.py index 2e3e65e649..b9691ea3ca 100644 --- a/openfisca_core/parameters/parameter_scale_bracket.py +++ b/openfisca_core/parameters/parameter_scale_bracket.py @@ -2,8 +2,6 @@ class ParameterScaleBracket(ParameterNode): - """ - A parameter scale bracket. - """ + """A parameter scale bracket.""" - _allowed_keys = set(["amount", "threshold", "rate", "average_rate"]) + _allowed_keys = {"amount", "threshold", "rate", "average_rate"} diff --git a/openfisca_core/parameters/vectorial_asof_date_parameter_node_at_instant.py b/openfisca_core/parameters/vectorial_asof_date_parameter_node_at_instant.py index e00ce11733..27be1f6946 100644 --- a/openfisca_core/parameters/vectorial_asof_date_parameter_node_at_instant.py +++ b/openfisca_core/parameters/vectorial_asof_date_parameter_node_at_instant.py @@ -7,9 +7,8 @@ class VectorialAsofDateParameterNodeAtInstant(VectorialParameterNodeAtInstant): - """ - Parameter node of the legislation at a given instant which has been vectorized along some date. - Vectorized parameters allow requests such as parameters.housing_benefit[date], where date is a np.datetime64 type vector + """Parameter node of the legislation at a given instant which has been vectorized along some date. + Vectorized parameters allow requests such as parameters.housing_benefit[date], where date is a numpy.datetime64 type vector. """ @staticmethod @@ -19,13 +18,15 @@ def build_from_node(node): # Recursively vectorize the children of the node vectorial_subnodes = tuple( [ - VectorialAsofDateParameterNodeAtInstant.build_from_node( - node[subnode_name] - ).vector - if isinstance(node[subnode_name], ParameterNodeAtInstant) - else node[subnode_name] + ( + VectorialAsofDateParameterNodeAtInstant.build_from_node( + node[subnode_name], + ).vector + if isinstance(node[subnode_name], ParameterNodeAtInstant) + else node[subnode_name] + ) for subnode_name in subnodes_name - ] + ], ) # A vectorial node is a wrapper around a numpy recarray # We first build the recarray @@ -40,7 +41,9 @@ def build_from_node(node): ], ) return VectorialAsofDateParameterNodeAtInstant( - node._name, recarray.view(numpy.recarray), node._instant_str + node._name, + recarray.view(numpy.recarray), + node._instant_str, ) def __getitem__(self, key): @@ -49,12 +52,12 @@ def __getitem__(self, key): key = numpy.array([key], dtype="datetime64[D]") return self.__getattr__(key) # If the key is a vector, e.g. ['1990-11-25', '1983-04-17', '1969-09-09'] - elif isinstance(key, numpy.ndarray): + if isinstance(key, numpy.ndarray): assert numpy.issubdtype(key.dtype, numpy.datetime64) names = list( - self.dtype.names + self.dtype.names, ) # Get all the names of the subnodes, e.g. ['before_X', 'after_X', 'after_Y'] - values = numpy.asarray([value for value in self.vector[0]]) + values = numpy.asarray(list(self.vector[0])) names = [name for name in names if not name.startswith("before")] names = [ numpy.datetime64("-".join(name[len("after_") :].split("_"))) @@ -65,10 +68,14 @@ def __getitem__(self, key): # If the result is not a leaf, wrap the result in a vectorial node. if numpy.issubdtype(result.dtype, numpy.record) or numpy.issubdtype( - result.dtype, numpy.void + result.dtype, + numpy.void, ): return VectorialAsofDateParameterNodeAtInstant( - self._name, result.view(numpy.recarray), self._instant_str + self._name, + result.view(numpy.recarray), + self._instant_str, ) return result + return None diff --git a/openfisca_core/parameters/vectorial_parameter_node_at_instant.py b/openfisca_core/parameters/vectorial_parameter_node_at_instant.py index 0681848cfa..74cd02d378 100644 --- a/openfisca_core/parameters/vectorial_parameter_node_at_instant.py +++ b/openfisca_core/parameters/vectorial_parameter_node_at_instant.py @@ -1,3 +1,5 @@ +from typing import NoReturn + import numpy from openfisca_core import parameters @@ -7,9 +9,8 @@ class VectorialParameterNodeAtInstant: - """ - Parameter node of the legislation at a given instant which has been vectorized. - Vectorized parameters allow requests such as parameters.housing_benefit[zipcode], where zipcode is a vector + """Parameter node of the legislation at a given instant which has been vectorized. + Vectorized parameters allow requests such as parameters.housing_benefit[zipcode], where zipcode is a vector. """ @staticmethod @@ -19,13 +20,15 @@ def build_from_node(node): # Recursively vectorize the children of the node vectorial_subnodes = tuple( [ - VectorialParameterNodeAtInstant.build_from_node( - node[subnode_name] - ).vector - if isinstance(node[subnode_name], parameters.ParameterNodeAtInstant) - else node[subnode_name] + ( + VectorialParameterNodeAtInstant.build_from_node( + node[subnode_name], + ).vector + if isinstance(node[subnode_name], parameters.ParameterNodeAtInstant) + else node[subnode_name] + ) for subnode_name in subnodes_name - ] + ], ) # A vectorial node is a wrapper around a numpy recarray # We first build the recarray @@ -41,45 +44,33 @@ def build_from_node(node): ) return VectorialParameterNodeAtInstant( - node._name, recarray.view(numpy.recarray), node._instant_str + node._name, + recarray.view(numpy.recarray), + node._instant_str, ) @staticmethod - def check_node_vectorisable(node): - """ - Check that a node can be casted to a vectorial node, in order to be able to use fancy indexing. - """ + def check_node_vectorisable(node) -> None: + """Check that a node can be casted to a vectorial node, in order to be able to use fancy indexing.""" MESSAGE_PART_1 = "Cannot use fancy indexing on parameter node '{}', as" MESSAGE_PART_3 = ( "To use fancy indexing on parameter node, its children must be homogenous." ) MESSAGE_PART_4 = "See more at ." - def raise_key_inhomogeneity_error(node_with_key, node_without_key, missing_key): - message = " ".join( - [ - MESSAGE_PART_1, - "'{}' exists, but '{}' doesn't.", - MESSAGE_PART_3, - MESSAGE_PART_4, - ] - ).format( + def raise_key_inhomogeneity_error( + node_with_key, node_without_key, missing_key + ) -> NoReturn: + message = f"{MESSAGE_PART_1} '{{}}' exists, but '{{}}' doesn't. {MESSAGE_PART_3} {MESSAGE_PART_4}".format( node._name, - ".".join([node_with_key, missing_key]), - ".".join([node_without_key, missing_key]), + f"{node_with_key}.{missing_key}", + f"{node_without_key}.{missing_key}", ) raise ValueError(message) - def raise_type_inhomogeneity_error(node_name, non_node_name): - message = " ".join( - [ - MESSAGE_PART_1, - "'{}' is a node, but '{}' is not.", - MESSAGE_PART_3, - MESSAGE_PART_4, - ] - ).format( + def raise_type_inhomogeneity_error(node_name, non_node_name) -> NoReturn: + message = f"{MESSAGE_PART_1} '{{}}' is a node, but '{{}}' is not. {MESSAGE_PART_3} {MESSAGE_PART_4}".format( node._name, node_name, non_node_name, @@ -87,14 +78,8 @@ def raise_type_inhomogeneity_error(node_name, non_node_name): raise ValueError(message) - def raise_not_implemented(node_name, node_type): - message = " ".join( - [ - MESSAGE_PART_1, - "'{}' is a '{}', and fancy indexing has not been implemented yet on this kind of parameters.", - MESSAGE_PART_4, - ] - ).format( + def raise_not_implemented(node_name, node_type) -> NoReturn: + message = f"{MESSAGE_PART_1} '{{}}' is a '{{}}', and fancy indexing has not been implemented yet on this kind of parameters. {MESSAGE_PART_4}".format( node._name, node_name, node_type, @@ -103,14 +88,11 @@ def raise_not_implemented(node_name, node_type): def extract_named_children(node): return { - ".".join([node._name, key]): value - for key, value in node._children.items() + f"{node._name}.{key}": value for key, value in node._children.items() } - def check_nodes_homogeneous(named_nodes): - """ - Check than several nodes (or parameters, or baremes) have the same structure. - """ + def check_nodes_homogeneous(named_nodes) -> None: + """Check than several nodes (or parameters, or baremes) have the same structure.""" names = list(named_nodes.keys()) nodes = list(named_nodes.values()) first_node = nodes[0] @@ -122,11 +104,13 @@ def check_nodes_homogeneous(named_nodes): raise_type_inhomogeneity_error(first_name, name) first_node_keys = first_node._children.keys() node_keys = node._children.keys() - if not first_node_keys == node_keys: + if first_node_keys != node_keys: missing_keys = set(first_node_keys).difference(node_keys) if missing_keys: # If the first_node has a key that node hasn't raise_key_inhomogeneity_error( - first_name, name, missing_keys.pop() + first_name, + name, + missing_keys.pop(), ) else: # If If the node has a key that first_node doesn't have missing_key = ( @@ -135,9 +119,9 @@ def check_nodes_homogeneous(named_nodes): raise_key_inhomogeneity_error(name, first_name, missing_key) children.update(extract_named_children(node)) check_nodes_homogeneous(children) - elif isinstance(first_node, float) or isinstance(first_node, int): + elif isinstance(first_node, (float, int)): for node, name in list(zip(nodes, names))[1:]: - if isinstance(node, int) or isinstance(node, float): + if isinstance(node, (int, float)): pass elif isinstance(node, parameters.ParameterNodeAtInstant): raise_type_inhomogeneity_error(name, first_name) @@ -149,7 +133,7 @@ def check_nodes_homogeneous(named_nodes): check_nodes_homogeneous(extract_named_children(node)) - def __init__(self, name, vector, instant_str): + def __init__(self, name, vector, instant_str) -> None: self.vector = vector self._name = name self._instant_str = instant_str @@ -165,13 +149,14 @@ def __getitem__(self, key): if isinstance(key, str): return self.__getattr__(key) # If the key is a vector, e.g. ['zone_1', 'zone_2', 'zone_1'] - elif isinstance(key, numpy.ndarray): + if isinstance(key, numpy.ndarray): if not numpy.issubdtype(key.dtype, numpy.str_): # In case the key is not a string vector, stringify it if key.dtype == object and issubclass(type(key[0]), Enum): enum = type(key[0]) key = numpy.select( - [key == item for item in enum], [item.name for item in enum] + [key == item for item in enum], + [item.name for item in enum], ) elif isinstance(key, EnumArray): enum = key.possible_values @@ -182,26 +167,33 @@ def __getitem__(self, key): else: key = key.astype("str") names = list( - self.dtype.names + self.dtype.names, ) # Get all the names of the subnodes, e.g. ['zone_1', 'zone_2'] default = numpy.full_like( - self.vector[key[0]], numpy.nan + self.vector[key[0]], + numpy.nan, ) # In case of unexpected key, we will set the corresponding value to NaN. conditions = [key == name for name in names] values = [self.vector[name] for name in names] result = numpy.select(conditions, values, default) if helpers.contains_nan(result): unexpected_key = set(key).difference(self.vector.dtype.names).pop() + msg = f"{self._name}.{unexpected_key}" raise ParameterNotFoundError( - ".".join([self._name, unexpected_key]), self._instant_str + msg, + self._instant_str, ) # If the result is not a leaf, wrap the result in a vectorial node. if numpy.issubdtype(result.dtype, numpy.record) or numpy.issubdtype( - result.dtype, numpy.void + result.dtype, + numpy.void, ): return VectorialParameterNodeAtInstant( - self._name, result.view(numpy.recarray), self._instant_str + self._name, + result.view(numpy.recarray), + self._instant_str, ) return result + return None diff --git a/openfisca_core/periods/_parsers.py b/openfisca_core/periods/_parsers.py index 64b2077831..95a17fb041 100644 --- a/openfisca_core/periods/_parsers.py +++ b/openfisca_core/periods/_parsers.py @@ -29,7 +29,6 @@ def _parse_period(value: str) -> Optional[Period]: Period((, Instant((2022, 1, 16)), 1)) """ - # If it's a complex period, next! if len(value.split(":")) != 1: return None @@ -77,26 +76,22 @@ def _parse_unit(value: str) -> DateUnit: """ - length = len(value.split("-")) isweek = value.find("W") != -1 if length == 1: return DateUnit.YEAR - elif length == 2: + if length == 2: if isweek: return DateUnit.WEEK - else: - return DateUnit.MONTH + return DateUnit.MONTH - elif length == 3: + if length == 3: if isweek: return DateUnit.WEEKDAY - else: - return DateUnit.DAY + return DateUnit.DAY - else: - raise ValueError + raise ValueError diff --git a/openfisca_core/periods/config.py b/openfisca_core/periods/config.py index 17807160e4..26ce30a5aa 100644 --- a/openfisca_core/periods/config.py +++ b/openfisca_core/periods/config.py @@ -12,11 +12,11 @@ # Matches "2015", "2015-01", "2015-01-01" # Does not match "2015-13", "2015-12-32" INSTANT_PATTERN = re.compile( - r"^\d{4}(-(0[1-9]|1[012]))?(-(0[1-9]|1[012])-(0[1-9]|[12][0-9]|3[01]))?$" + r"^\d{4}(-(0[1-9]|1[012]))?(-(0[1-9]|1[012])-(0[1-9]|[12][0-9]|3[01]))?$", ) date_by_instant_cache: dict = {} str_by_instant_cache: dict = {} year_or_month_or_day_re = re.compile( - r"(18|19|20)\d{2}(-(0?[1-9]|1[0-2])(-([0-2]?\d|3[0-1]))?)?$" + r"(18|19|20)\d{2}(-(0?[1-9]|1[0-2])(-([0-2]?\d|3[0-1]))?)?$", ) diff --git a/openfisca_core/periods/date_unit.py b/openfisca_core/periods/date_unit.py index a813211495..61f7fbc66f 100644 --- a/openfisca_core/periods/date_unit.py +++ b/openfisca_core/periods/date_unit.py @@ -7,7 +7,7 @@ class DateUnitMeta(EnumMeta): @property - def isoformat(self) -> tuple[DateUnit, ...]: + def isoformat(cls) -> tuple[DateUnit, ...]: """Creates a :obj:`tuple` of ``key`` with isoformat items. Returns: @@ -24,11 +24,10 @@ def isoformat(self) -> tuple[DateUnit, ...]: False """ - return DateUnit.DAY, DateUnit.MONTH, DateUnit.YEAR @property - def isocalendar(self) -> tuple[DateUnit, ...]: + def isocalendar(cls) -> tuple[DateUnit, ...]: """Creates a :obj:`tuple` of ``key`` with isocalendar items. Returns: @@ -45,7 +44,6 @@ def isocalendar(self) -> tuple[DateUnit, ...]: False """ - return DateUnit.WEEKDAY, DateUnit.WEEK, DateUnit.YEAR diff --git a/openfisca_core/periods/helpers.py b/openfisca_core/periods/helpers.py index 2ce4e0cd35..c1ccc4a3a2 100644 --- a/openfisca_core/periods/helpers.py +++ b/openfisca_core/periods/helpers.py @@ -48,15 +48,15 @@ def instant(instant) -> Optional[Instant]: Instant((2021, 1, 1)) """ - if instant is None: return None if isinstance(instant, Instant): return instant if isinstance(instant, str): if not config.INSTANT_PATTERN.match(instant): + msg = f"'{instant}' is not a valid instant. Instants are described using the 'YYYY-MM-DD' format, for instance '2015-06-15'." raise ValueError( - f"'{instant}' is not a valid instant. Instants are described using the 'YYYY-MM-DD' format, for instance '2015-06-15'." + msg, ) instant = Instant(int(fragment) for fragment in instant.split("-", 2)[:3]) elif isinstance(instant, datetime.date): @@ -93,7 +93,6 @@ def instant_date(instant: Optional[Instant]) -> Optional[datetime.date]: Date(2021, 1, 1) """ - if instant is None: return None @@ -150,7 +149,6 @@ def period(value) -> Period: """ - if isinstance(value, Period): return value @@ -171,7 +169,7 @@ def period(value) -> Period: DateUnit.ETERNITY, instant(datetime.date.min), float("inf"), - ) + ), ) # For example ``2021`` gives @@ -256,13 +254,12 @@ def _raise_error(value: str) -> NoReturn: .", - ] + ], ) raise ValueError(message) @@ -290,7 +287,6 @@ def key_period_size(period: Period) -> str: '300_3' """ - unit, start, size = period return f"{unit_weight(unit)}_{size}" @@ -304,7 +300,6 @@ def unit_weights() -> dict[str, int]: {: 100, ...ETERNITY: 'eternity'>: 400} """ - return { DateUnit.WEEKDAY: 100, DateUnit.WEEK: 200, @@ -323,5 +318,4 @@ def unit_weight(unit: str) -> int: 100 """ - return unit_weights()[unit] diff --git a/openfisca_core/periods/instant_.py b/openfisca_core/periods/instant_.py index 9d0893ba41..5042209492 100644 --- a/openfisca_core/periods/instant_.py +++ b/openfisca_core/periods/instant_.py @@ -78,10 +78,10 @@ class Instant(tuple): """ - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}({super().__repr__()})" - def __str__(self): + def __str__(self) -> str: instant_str = config.str_by_instant_cache.get(self) if instant_str is None: @@ -135,7 +135,6 @@ def offset(self, offset, unit): Instant((2019, 12, 29)) """ - year, month, day = self assert unit in ( @@ -146,52 +145,55 @@ def offset(self, offset, unit): if unit == DateUnit.YEAR: return self.__class__((year, 1, 1)) - elif unit == DateUnit.MONTH: + if unit == DateUnit.MONTH: return self.__class__((year, month, 1)) - elif unit == DateUnit.WEEK: + if unit == DateUnit.WEEK: date = self.date date = date.start_of("week") return self.__class__((date.year, date.month, date.day)) + return None - elif offset == "last-of": + if offset == "last-of": if unit == DateUnit.YEAR: return self.__class__((year, 12, 31)) - elif unit == DateUnit.MONTH: + if unit == DateUnit.MONTH: date = self.date date = date.end_of("month") return self.__class__((date.year, date.month, date.day)) - elif unit == DateUnit.WEEK: + if unit == DateUnit.WEEK: date = self.date date = date.end_of("week") return self.__class__((date.year, date.month, date.day)) - - else: - assert isinstance( - offset, int - ), f"Invalid offset: {offset} of type {type(offset)}" - - if unit == DateUnit.YEAR: - date = self.date - date = date.add(years=offset) - return self.__class__((date.year, date.month, date.day)) - - elif unit == DateUnit.MONTH: - date = self.date - date = date.add(months=offset) - return self.__class__((date.year, date.month, date.day)) - - elif unit == DateUnit.WEEK: - date = self.date - date = date.add(weeks=offset) - return self.__class__((date.year, date.month, date.day)) - - elif unit in (DateUnit.DAY, DateUnit.WEEKDAY): - date = self.date - date = date.add(days=offset) - return self.__class__((date.year, date.month, date.day)) + return None + + assert isinstance( + offset, + int, + ), f"Invalid offset: {offset} of type {type(offset)}" + + if unit == DateUnit.YEAR: + date = self.date + date = date.add(years=offset) + return self.__class__((date.year, date.month, date.day)) + + if unit == DateUnit.MONTH: + date = self.date + date = date.add(months=offset) + return self.__class__((date.year, date.month, date.day)) + + if unit == DateUnit.WEEK: + date = self.date + date = date.add(weeks=offset) + return self.__class__((date.year, date.month, date.day)) + + if unit in (DateUnit.DAY, DateUnit.WEEKDAY): + date = self.date + date = date.add(days=offset) + return self.__class__((date.year, date.month, date.day)) + return None @property def year(self): diff --git a/openfisca_core/periods/period_.py b/openfisca_core/periods/period_.py index 11a7b671b4..0dcf960bbf 100644 --- a/openfisca_core/periods/period_.py +++ b/openfisca_core/periods/period_.py @@ -1,7 +1,6 @@ from __future__ import annotations import typing -from collections.abc import Sequence import calendar import datetime @@ -13,6 +12,8 @@ from .instant_ import Instant if typing.TYPE_CHECKING: + from collections.abc import Sequence + from pendulum.datetime import Date @@ -141,9 +142,8 @@ def __str__(self) -> str: if month == 1: # civil year starting from january return str(f_year) - else: - # rolling year - return f"{DateUnit.YEAR}:{f_year}-{month:02d}" + # rolling year + return f"{DateUnit.YEAR}:{f_year}-{month:02d}" # simple month if unit == DateUnit.MONTH and size == 1: @@ -156,8 +156,7 @@ def __str__(self) -> str: if unit == DateUnit.DAY: if size == 1: return f"{f_year}-{month:02d}-{day:02d}" - else: - return f"{unit}:{f_year}-{month:02d}-{day:02d}:{size}" + return f"{unit}:{f_year}-{month:02d}-{day:02d}:{size}" # 1 week if unit == DateUnit.WEEK and size == 1: @@ -201,7 +200,6 @@ def unit(self) -> str: """ - return self[0] @property @@ -215,7 +213,6 @@ def start(self) -> Instant: Instant((2021, 10, 1)) """ - return self[1] @property @@ -229,7 +226,6 @@ def size(self) -> int: 3 """ - return self[2] @property @@ -249,9 +245,9 @@ def date(self) -> Date: ValueError: "date" is undefined for a period of size > 1: year:2021-10:3. """ - if self.size != 1: - raise ValueError(f'"date" is undefined for a period of size > 1: {self}.') + msg = f'"date" is undefined for a period of size > 1: {self}.' + raise ValueError(msg) return self.start.date @@ -272,11 +268,11 @@ def size_in_years(self) -> int: ValueError: Can't calculate number of years in a month. """ - if self.unit == DateUnit.YEAR: return self.size - raise ValueError(f"Can't calculate number of years in a {self.unit}.") + msg = f"Can't calculate number of years in a {self.unit}." + raise ValueError(msg) @property def size_in_months(self) -> int: @@ -295,14 +291,14 @@ def size_in_months(self) -> int: ValueError: Can't calculate number of months in a day. """ - if self.unit == DateUnit.YEAR: return self.size * 12 if self.unit == DateUnit.MONTH: return self.size - raise ValueError(f"Can't calculate number of months in a {self.unit}.") + msg = f"Can't calculate number of months in a {self.unit}." + raise ValueError(msg) @property def size_in_days(self) -> int: @@ -320,7 +316,6 @@ def size_in_days(self) -> int: 92 """ - if self.unit in (DateUnit.YEAR, DateUnit.MONTH): last_day = self.start.offset(self.size, self.unit).offset(-1, DateUnit.DAY) return (last_day.date - self.start.date).days + 1 @@ -331,7 +326,8 @@ def size_in_days(self) -> int: if self.unit in (DateUnit.DAY, DateUnit.WEEKDAY): return self.size - raise ValueError(f"Can't calculate number of days in a {self.unit}.") + msg = f"Can't calculate number of days in a {self.unit}." + raise ValueError(msg) @property def size_in_weeks(self): @@ -349,7 +345,6 @@ def size_in_weeks(self): 261 """ - if self.unit == DateUnit.YEAR: start = self.start.date cease = start.add(years=self.size) @@ -365,7 +360,8 @@ def size_in_weeks(self): if self.unit == DateUnit.WEEK: return self.size - raise ValueError(f"Can't calculate number of weeks in a {self.unit}.") + msg = f"Can't calculate number of weeks in a {self.unit}." + raise ValueError(msg) @property def size_in_weekdays(self): @@ -383,7 +379,6 @@ def size_in_weekdays(self): 21 """ - if self.unit == DateUnit.YEAR: return self.size_in_weeks * 7 @@ -397,7 +392,8 @@ def size_in_weekdays(self): if self.unit in (DateUnit.DAY, DateUnit.WEEKDAY): return self.size - raise ValueError(f"Can't calculate number of weekdays in a {self.unit}.") + msg = f"Can't calculate number of weekdays in a {self.unit}." + raise ValueError(msg) @property def days(self): @@ -430,7 +426,7 @@ def intersection(self, start, stop): DateUnit.YEAR, intersection_start, intersection_stop.year - intersection_start.year + 1, - ) + ), ) if ( intersection_start.day == 1 @@ -447,14 +443,14 @@ def intersection(self, start, stop): - intersection_start.month + 1 ), - ) + ), ) return self.__class__( ( DateUnit.DAY, intersection_start, (intersection_stop.date - intersection_start.date).days + 1, - ) + ), ) def get_subperiods(self, unit: DateUnit) -> Sequence[Period]: @@ -470,9 +466,9 @@ def get_subperiods(self, unit: DateUnit) -> Sequence[Period]: [Period((, Instant((2021, 1, 1)), 1)),...((2022, 1, 1)), 1))] """ - if helpers.unit_weight(self.unit) < helpers.unit_weight(unit): - raise ValueError(f"Cannot subdivide {self.unit} into {unit}") + msg = f"Cannot subdivide {self.unit} into {unit}" + raise ValueError(msg) if unit == DateUnit.YEAR: return [self.this_year.offset(i, DateUnit.YEAR) for i in range(self.size)] @@ -500,7 +496,8 @@ def get_subperiods(self, unit: DateUnit) -> Sequence[Period]: for i in range(self.size_in_weekdays) ] - raise ValueError(f"Cannot subdivide {self.unit} into {unit}") + msg = f"Cannot subdivide {self.unit} into {unit}" + raise ValueError(msg) def offset(self, offset, unit=None): """Increment (or decrement) the given period with offset units. @@ -524,7 +521,9 @@ def offset(self, offset, unit=None): >>> Period((DateUnit.MONTH, Instant((2021, 1, 1)), 12)).offset(1, DateUnit.DAY) Period((, Instant((2021, 1, 2)), 12)) - >>> Period((DateUnit.MONTH, Instant((2021, 1, 1)), 12)).offset(1, DateUnit.MONTH) + >>> Period((DateUnit.MONTH, Instant((2021, 1, 1)), 12)).offset( + ... 1, DateUnit.MONTH + ... ) Period((, Instant((2021, 2, 1)), 12)) >>> Period((DateUnit.MONTH, Instant((2021, 1, 1)), 12)).offset(1, DateUnit.YEAR) @@ -578,110 +577,157 @@ def offset(self, offset, unit=None): >>> Period((DateUnit.YEAR, Instant((2014, 1, 1)), 1)).offset(-3) Period((, Instant((2011, 1, 1)), 1)) - >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 1)).offset("first-of", DateUnit.MONTH) + >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 1)).offset( + ... "first-of", DateUnit.MONTH + ... ) Period((, Instant((2014, 2, 1)), 1)) - >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 1)).offset("first-of", DateUnit.YEAR) + >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 1)).offset( + ... "first-of", DateUnit.YEAR + ... ) Period((, Instant((2014, 1, 1)), 1)) - >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 4)).offset("first-of", DateUnit.MONTH) + >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 4)).offset( + ... "first-of", DateUnit.MONTH + ... ) Period((, Instant((2014, 2, 1)), 4)) - >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 4)).offset("first-of", DateUnit.YEAR) + >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 4)).offset( + ... "first-of", DateUnit.YEAR + ... ) Period((, Instant((2014, 1, 1)), 4)) >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 1)).offset("first-of") Period((, Instant((2014, 2, 1)), 1)) - >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 1)).offset("first-of", DateUnit.MONTH) + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 1)).offset( + ... "first-of", DateUnit.MONTH + ... ) Period((, Instant((2014, 2, 1)), 1)) - >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 1)).offset("first-of", DateUnit.YEAR) + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 1)).offset( + ... "first-of", DateUnit.YEAR + ... ) Period((, Instant((2014, 1, 1)), 1)) >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 4)).offset("first-of") Period((, Instant((2014, 2, 1)), 4)) - >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 4)).offset("first-of", DateUnit.MONTH) + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 4)).offset( + ... "first-of", DateUnit.MONTH + ... ) Period((, Instant((2014, 2, 1)), 4)) - >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 4)).offset("first-of", DateUnit.YEAR) + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 4)).offset( + ... "first-of", DateUnit.YEAR + ... ) Period((, Instant((2014, 1, 1)), 4)) >>> Period((DateUnit.YEAR, Instant((2014, 1, 30)), 1)).offset("first-of") Period((, Instant((2014, 1, 1)), 1)) - >>> Period((DateUnit.YEAR, Instant((2014, 1, 30)), 1)).offset("first-of", DateUnit.MONTH) + >>> Period((DateUnit.YEAR, Instant((2014, 1, 30)), 1)).offset( + ... "first-of", DateUnit.MONTH + ... ) Period((, Instant((2014, 1, 1)), 1)) - >>> Period((DateUnit.YEAR, Instant((2014, 1, 30)), 1)).offset("first-of", DateUnit.YEAR) + >>> Period((DateUnit.YEAR, Instant((2014, 1, 30)), 1)).offset( + ... "first-of", DateUnit.YEAR + ... ) Period((, Instant((2014, 1, 1)), 1)) >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset("first-of") Period((, Instant((2014, 1, 1)), 1)) - >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset("first-of", DateUnit.MONTH) + >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset( + ... "first-of", DateUnit.MONTH + ... ) Period((, Instant((2014, 2, 1)), 1)) - >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset("first-of", DateUnit.YEAR) + >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset( + ... "first-of", DateUnit.YEAR + ... ) Period((, Instant((2014, 1, 1)), 1)) - >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 1)).offset("last-of", DateUnit.MONTH) + >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 1)).offset( + ... "last-of", DateUnit.MONTH + ... ) Period((, Instant((2014, 2, 28)), 1)) - >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 1)).offset("last-of", DateUnit.YEAR) + >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 1)).offset( + ... "last-of", DateUnit.YEAR + ... ) Period((, Instant((2014, 12, 31)), 1)) - >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 4)).offset("last-of", DateUnit.MONTH) + >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 4)).offset( + ... "last-of", DateUnit.MONTH + ... ) Period((, Instant((2014, 2, 28)), 4)) - >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 4)).offset("last-of", DateUnit.YEAR) + >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 4)).offset( + ... "last-of", DateUnit.YEAR + ... ) Period((, Instant((2014, 12, 31)), 4)) >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 1)).offset("last-of") Period((, Instant((2014, 2, 28)), 1)) - >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 1)).offset("last-of", DateUnit.MONTH) + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 1)).offset( + ... "last-of", DateUnit.MONTH + ... ) Period((, Instant((2014, 2, 28)), 1)) - >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 1)).offset("last-of", DateUnit.YEAR) + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 1)).offset( + ... "last-of", DateUnit.YEAR + ... ) Period((, Instant((2014, 12, 31)), 1)) >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 4)).offset("last-of") Period((, Instant((2014, 2, 28)), 4)) - >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 4)).offset("last-of", DateUnit.MONTH) + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 4)).offset( + ... "last-of", DateUnit.MONTH + ... ) Period((, Instant((2014, 2, 28)), 4)) - >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 4)).offset("last-of", DateUnit.YEAR) + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 4)).offset( + ... "last-of", DateUnit.YEAR + ... ) Period((, Instant((2014, 12, 31)), 4)) >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset("last-of") Period((, Instant((2014, 12, 31)), 1)) - >>> Period((DateUnit.YEAR, Instant((2014, 1, 1)), 1)).offset("last-of", DateUnit.MONTH) + >>> Period((DateUnit.YEAR, Instant((2014, 1, 1)), 1)).offset( + ... "last-of", DateUnit.MONTH + ... ) Period((, Instant((2014, 1, 31)), 1)) - >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset("last-of", DateUnit.YEAR) + >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset( + ... "last-of", DateUnit.YEAR + ... ) Period((, Instant((2014, 12, 31)), 1)) >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset("last-of") Period((, Instant((2014, 12, 31)), 1)) - >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset("last-of", DateUnit.MONTH) + >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset( + ... "last-of", DateUnit.MONTH + ... ) Period((, Instant((2014, 2, 28)), 1)) - >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset("last-of", DateUnit.YEAR) + >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset( + ... "last-of", DateUnit.YEAR + ... ) Period((, Instant((2014, 12, 31)), 1)) """ - return self.__class__( ( self[0], self[1].offset(offset, self[0] if unit is None else unit), self[2], - ) + ), ) def contains(self, other: Period) -> bool: @@ -690,7 +736,6 @@ def contains(self, other: Period) -> bool: For instance, ``period(2015)`` contains ``period(2015-01)``. """ - return self.start <= other.start and self.stop >= other.stop @property @@ -726,31 +771,29 @@ def stop(self) -> Instant: Instant((2012, 3, 1)) """ - unit, start_instant, size = self year, month, day = start_instant if unit == DateUnit.ETERNITY: return Instant((float("inf"), float("inf"), float("inf"))) - elif unit == DateUnit.YEAR: + if unit == DateUnit.YEAR: date = start_instant.date.add(years=size, days=-1) return Instant((date.year, date.month, date.day)) - elif unit == DateUnit.MONTH: + if unit == DateUnit.MONTH: date = start_instant.date.add(months=size, days=-1) return Instant((date.year, date.month, date.day)) - elif unit == DateUnit.WEEK: + if unit == DateUnit.WEEK: date = start_instant.date.add(weeks=size, days=-1) return Instant((date.year, date.month, date.day)) - elif unit in (DateUnit.DAY, DateUnit.WEEKDAY): + if unit in (DateUnit.DAY, DateUnit.WEEKDAY): date = start_instant.date.add(days=size - 1) return Instant((date.year, date.month, date.day)) - else: - raise ValueError + raise ValueError # Reference periods diff --git a/openfisca_core/periods/tests/helpers/test_helpers.py b/openfisca_core/periods/tests/helpers/test_helpers.py index bb409323d1..3cbf078a2e 100644 --- a/openfisca_core/periods/tests/helpers/test_helpers.py +++ b/openfisca_core/periods/tests/helpers/test_helpers.py @@ -7,49 +7,49 @@ @pytest.mark.parametrize( - "arg, expected", + ("arg", "expected"), [ - [None, None], - [Instant((1, 1, 1)), datetime.date(1, 1, 1)], - [Instant((4, 2, 29)), datetime.date(4, 2, 29)], - [(1, 1, 1), datetime.date(1, 1, 1)], + (None, None), + (Instant((1, 1, 1)), datetime.date(1, 1, 1)), + (Instant((4, 2, 29)), datetime.date(4, 2, 29)), + ((1, 1, 1), datetime.date(1, 1, 1)), ], ) -def test_instant_date(arg, expected): +def test_instant_date(arg, expected) -> None: assert periods.instant_date(arg) == expected @pytest.mark.parametrize( - "arg, error", + ("arg", "error"), [ - [Instant((-1, 1, 1)), ValueError], - [Instant((1, -1, 1)), ValueError], - [Instant((1, 13, -1)), ValueError], - [Instant((1, 1, -1)), ValueError], - [Instant((1, 1, 32)), ValueError], - [Instant((1, 2, 29)), ValueError], - [Instant(("1", 1, 1)), TypeError], - [(1,), TypeError], - [(1, 1), TypeError], + (Instant((-1, 1, 1)), ValueError), + (Instant((1, -1, 1)), ValueError), + (Instant((1, 13, -1)), ValueError), + (Instant((1, 1, -1)), ValueError), + (Instant((1, 1, 32)), ValueError), + (Instant((1, 2, 29)), ValueError), + (Instant(("1", 1, 1)), TypeError), + ((1,), TypeError), + ((1, 1), TypeError), ], ) -def test_instant_date_with_an_invalid_argument(arg, error): +def test_instant_date_with_an_invalid_argument(arg, error) -> None: with pytest.raises(error): periods.instant_date(arg) @pytest.mark.parametrize( - "arg, expected", + ("arg", "expected"), [ - [Period((DateUnit.WEEKDAY, Instant((1, 1, 1)), 5)), "100_5"], - [Period((DateUnit.WEEK, Instant((1, 1, 1)), 26)), "200_26"], - [Period((DateUnit.DAY, Instant((1, 1, 1)), 365)), "100_365"], - [Period((DateUnit.MONTH, Instant((1, 1, 1)), 12)), "200_12"], - [Period((DateUnit.YEAR, Instant((1, 1, 1)), 2)), "300_2"], - [Period((DateUnit.ETERNITY, Instant((1, 1, 1)), 1)), "400_1"], - [(DateUnit.DAY, None, 1), "100_1"], - [(DateUnit.MONTH, None, -1000), "200_-1000"], + (Period((DateUnit.WEEKDAY, Instant((1, 1, 1)), 5)), "100_5"), + (Period((DateUnit.WEEK, Instant((1, 1, 1)), 26)), "200_26"), + (Period((DateUnit.DAY, Instant((1, 1, 1)), 365)), "100_365"), + (Period((DateUnit.MONTH, Instant((1, 1, 1)), 12)), "200_12"), + (Period((DateUnit.YEAR, Instant((1, 1, 1)), 2)), "300_2"), + (Period((DateUnit.ETERNITY, Instant((1, 1, 1)), 1)), "400_1"), + ((DateUnit.DAY, None, 1), "100_1"), + ((DateUnit.MONTH, None, -1000), "200_-1000"), ], ) -def test_key_period_size(arg, expected): +def test_key_period_size(arg, expected) -> None: assert periods.key_period_size(arg) == expected diff --git a/openfisca_core/periods/tests/helpers/test_instant.py b/openfisca_core/periods/tests/helpers/test_instant.py index cb74c55ca4..73f37ece6f 100644 --- a/openfisca_core/periods/tests/helpers/test_instant.py +++ b/openfisca_core/periods/tests/helpers/test_instant.py @@ -7,70 +7,70 @@ @pytest.mark.parametrize( - "arg, expected", + ("arg", "expected"), [ - [None, None], - [datetime.date(1, 1, 1), Instant((1, 1, 1))], - [Instant((1, 1, 1)), Instant((1, 1, 1))], - [Period((DateUnit.DAY, Instant((1, 1, 1)), 365)), Instant((1, 1, 1))], - [-1, Instant((-1, 1, 1))], - [0, Instant((0, 1, 1))], - [1, Instant((1, 1, 1))], - [999, Instant((999, 1, 1))], - [1000, Instant((1000, 1, 1))], - ["1000", Instant((1000, 1, 1))], - ["1000-01", Instant((1000, 1, 1))], - ["1000-01-01", Instant((1000, 1, 1))], - [(None,), Instant((None, 1, 1))], - [(None, None), Instant((None, None, 1))], - [(None, None, None), Instant((None, None, None))], - [(datetime.date(1, 1, 1),), Instant((datetime.date(1, 1, 1), 1, 1))], - [(Instant((1, 1, 1)),), Instant((Instant((1, 1, 1)), 1, 1))], - [ + (None, None), + (datetime.date(1, 1, 1), Instant((1, 1, 1))), + (Instant((1, 1, 1)), Instant((1, 1, 1))), + (Period((DateUnit.DAY, Instant((1, 1, 1)), 365)), Instant((1, 1, 1))), + (-1, Instant((-1, 1, 1))), + (0, Instant((0, 1, 1))), + (1, Instant((1, 1, 1))), + (999, Instant((999, 1, 1))), + (1000, Instant((1000, 1, 1))), + ("1000", Instant((1000, 1, 1))), + ("1000-01", Instant((1000, 1, 1))), + ("1000-01-01", Instant((1000, 1, 1))), + ((None,), Instant((None, 1, 1))), + ((None, None), Instant((None, None, 1))), + ((None, None, None), Instant((None, None, None))), + ((datetime.date(1, 1, 1),), Instant((datetime.date(1, 1, 1), 1, 1))), + ((Instant((1, 1, 1)),), Instant((Instant((1, 1, 1)), 1, 1))), + ( (Period((DateUnit.DAY, Instant((1, 1, 1)), 365)),), Instant((Period((DateUnit.DAY, Instant((1, 1, 1)), 365)), 1, 1)), - ], - [(-1,), Instant((-1, 1, 1))], - [(-1, -1), Instant((-1, -1, 1))], - [(-1, -1, -1), Instant((-1, -1, -1))], - [("-1",), Instant(("-1", 1, 1))], - [("-1", "-1"), Instant(("-1", "-1", 1))], - [("-1", "-1", "-1"), Instant(("-1", "-1", "-1"))], - [("1-1",), Instant(("1-1", 1, 1))], - [("1-1-1",), Instant(("1-1-1", 1, 1))], + ), + ((-1,), Instant((-1, 1, 1))), + ((-1, -1), Instant((-1, -1, 1))), + ((-1, -1, -1), Instant((-1, -1, -1))), + (("-1",), Instant(("-1", 1, 1))), + (("-1", "-1"), Instant(("-1", "-1", 1))), + (("-1", "-1", "-1"), Instant(("-1", "-1", "-1"))), + (("1-1",), Instant(("1-1", 1, 1))), + (("1-1-1",), Instant(("1-1-1", 1, 1))), ], ) -def test_instant(arg, expected): +def test_instant(arg, expected) -> None: assert periods.instant(arg) == expected @pytest.mark.parametrize( - "arg, error", + ("arg", "error"), [ - [DateUnit.YEAR, ValueError], - [DateUnit.ETERNITY, ValueError], - ["1000-0", ValueError], - ["1000-0-0", ValueError], - ["1000-1", ValueError], - ["1000-1-1", ValueError], - ["1", ValueError], - ["a", ValueError], - ["year", ValueError], - ["eternity", ValueError], - ["999", ValueError], - ["1:1000-01-01", ValueError], - ["a:1000-01-01", ValueError], - ["year:1000-01-01", ValueError], - ["year:1000-01-01:1", ValueError], - ["year:1000-01-01:3", ValueError], - ["1000-01-01:a", ValueError], - ["1000-01-01:1", ValueError], - [(), AssertionError], - [{}, AssertionError], - ["", ValueError], - [(None, None, None, None), AssertionError], + (DateUnit.YEAR, ValueError), + (DateUnit.ETERNITY, ValueError), + ("1000-0", ValueError), + ("1000-0-0", ValueError), + ("1000-1", ValueError), + ("1000-1-1", ValueError), + ("1", ValueError), + ("a", ValueError), + ("year", ValueError), + ("eternity", ValueError), + ("999", ValueError), + ("1:1000-01-01", ValueError), + ("a:1000-01-01", ValueError), + ("year:1000-01-01", ValueError), + ("year:1000-01-01:1", ValueError), + ("year:1000-01-01:3", ValueError), + ("1000-01-01:a", ValueError), + ("1000-01-01:1", ValueError), + ((), AssertionError), + ({}, AssertionError), + ("", ValueError), + ((None, None, None, None), AssertionError), ], ) -def test_instant_with_an_invalid_argument(arg, error): +def test_instant_with_an_invalid_argument(arg, error) -> None: with pytest.raises(error): periods.instant(arg) diff --git a/openfisca_core/periods/tests/helpers/test_period.py b/openfisca_core/periods/tests/helpers/test_period.py index 7d50abe102..c31e54c2ca 100644 --- a/openfisca_core/periods/tests/helpers/test_period.py +++ b/openfisca_core/periods/tests/helpers/test_period.py @@ -7,128 +7,128 @@ @pytest.mark.parametrize( - "arg, expected", + ("arg", "expected"), [ - ["eternity", Period((DateUnit.ETERNITY, Instant((1, 1, 1)), float("inf")))], - ["ETERNITY", Period((DateUnit.ETERNITY, Instant((1, 1, 1)), float("inf")))], - [ + ("eternity", Period((DateUnit.ETERNITY, Instant((1, 1, 1)), float("inf")))), + ("ETERNITY", Period((DateUnit.ETERNITY, Instant((1, 1, 1)), float("inf")))), + ( DateUnit.ETERNITY, Period((DateUnit.ETERNITY, Instant((1, 1, 1)), float("inf"))), - ], - [datetime.date(1, 1, 1), Period((DateUnit.DAY, Instant((1, 1, 1)), 1))], - [Instant((1, 1, 1)), Period((DateUnit.DAY, Instant((1, 1, 1)), 1))], - [ + ), + (datetime.date(1, 1, 1), Period((DateUnit.DAY, Instant((1, 1, 1)), 1))), + (Instant((1, 1, 1)), Period((DateUnit.DAY, Instant((1, 1, 1)), 1))), + ( Period((DateUnit.DAY, Instant((1, 1, 1)), 365)), Period((DateUnit.DAY, Instant((1, 1, 1)), 365)), - ], - [-1, Period((DateUnit.YEAR, Instant((-1, 1, 1)), 1))], - [0, Period((DateUnit.YEAR, Instant((0, 1, 1)), 1))], - [1, Period((DateUnit.YEAR, Instant((1, 1, 1)), 1))], - [999, Period((DateUnit.YEAR, Instant((999, 1, 1)), 1))], - [1000, Period((DateUnit.YEAR, Instant((1000, 1, 1)), 1))], - ["1001", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))], - ["1001-01", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 1))], - ["1001-01-01", Period((DateUnit.DAY, Instant((1001, 1, 1)), 1))], - ["1004-02-29", Period((DateUnit.DAY, Instant((1004, 2, 29)), 1))], - ["1001-W01", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 1))], - ["1001-W01-1", Period((DateUnit.WEEKDAY, Instant((1000, 12, 29)), 1))], - ["year:1001", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))], - ["year:1001-01", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))], - ["year:1001-01-01", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))], - ["year:1001-W01", Period((DateUnit.YEAR, Instant((1000, 12, 29)), 1))], - ["year:1001-W01-1", Period((DateUnit.YEAR, Instant((1000, 12, 29)), 1))], - ["year:1001:1", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))], - ["year:1001-01:1", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))], - ["year:1001-01-01:1", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))], - ["year:1001-W01:1", Period((DateUnit.YEAR, Instant((1000, 12, 29)), 1))], - ["year:1001-W01-1:1", Period((DateUnit.YEAR, Instant((1000, 12, 29)), 1))], - ["year:1001:3", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 3))], - ["year:1001-01:3", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 3))], - ["year:1001-01-01:3", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 3))], - ["year:1001-W01:3", Period((DateUnit.YEAR, Instant((1000, 12, 29)), 3))], - ["year:1001-W01-1:3", Period((DateUnit.YEAR, Instant((1000, 12, 29)), 3))], - ["month:1001-01", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 1))], - ["month:1001-01-01", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 1))], - ["week:1001-W01", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 1))], - ["week:1001-W01-1", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 1))], - ["month:1001-01:1", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 1))], - ["month:1001-01:3", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 3))], - ["month:1001-01-01:3", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 3))], - ["week:1001-W01:1", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 1))], - ["week:1001-W01:3", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 3))], - ["week:1001-W01-1:3", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 3))], - ["day:1001-01-01", Period((DateUnit.DAY, Instant((1001, 1, 1)), 1))], - ["day:1001-01-01:3", Period((DateUnit.DAY, Instant((1001, 1, 1)), 3))], - ["weekday:1001-W01-1", Period((DateUnit.WEEKDAY, Instant((1000, 12, 29)), 1))], - [ + ), + (-1, Period((DateUnit.YEAR, Instant((-1, 1, 1)), 1))), + (0, Period((DateUnit.YEAR, Instant((0, 1, 1)), 1))), + (1, Period((DateUnit.YEAR, Instant((1, 1, 1)), 1))), + (999, Period((DateUnit.YEAR, Instant((999, 1, 1)), 1))), + (1000, Period((DateUnit.YEAR, Instant((1000, 1, 1)), 1))), + ("1001", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))), + ("1001-01", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 1))), + ("1001-01-01", Period((DateUnit.DAY, Instant((1001, 1, 1)), 1))), + ("1004-02-29", Period((DateUnit.DAY, Instant((1004, 2, 29)), 1))), + ("1001-W01", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 1))), + ("1001-W01-1", Period((DateUnit.WEEKDAY, Instant((1000, 12, 29)), 1))), + ("year:1001", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))), + ("year:1001-01", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))), + ("year:1001-01-01", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))), + ("year:1001-W01", Period((DateUnit.YEAR, Instant((1000, 12, 29)), 1))), + ("year:1001-W01-1", Period((DateUnit.YEAR, Instant((1000, 12, 29)), 1))), + ("year:1001:1", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))), + ("year:1001-01:1", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))), + ("year:1001-01-01:1", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))), + ("year:1001-W01:1", Period((DateUnit.YEAR, Instant((1000, 12, 29)), 1))), + ("year:1001-W01-1:1", Period((DateUnit.YEAR, Instant((1000, 12, 29)), 1))), + ("year:1001:3", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 3))), + ("year:1001-01:3", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 3))), + ("year:1001-01-01:3", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 3))), + ("year:1001-W01:3", Period((DateUnit.YEAR, Instant((1000, 12, 29)), 3))), + ("year:1001-W01-1:3", Period((DateUnit.YEAR, Instant((1000, 12, 29)), 3))), + ("month:1001-01", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 1))), + ("month:1001-01-01", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 1))), + ("week:1001-W01", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 1))), + ("week:1001-W01-1", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 1))), + ("month:1001-01:1", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 1))), + ("month:1001-01:3", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 3))), + ("month:1001-01-01:3", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 3))), + ("week:1001-W01:1", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 1))), + ("week:1001-W01:3", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 3))), + ("week:1001-W01-1:3", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 3))), + ("day:1001-01-01", Period((DateUnit.DAY, Instant((1001, 1, 1)), 1))), + ("day:1001-01-01:3", Period((DateUnit.DAY, Instant((1001, 1, 1)), 3))), + ("weekday:1001-W01-1", Period((DateUnit.WEEKDAY, Instant((1000, 12, 29)), 1))), + ( "weekday:1001-W01-1:3", Period((DateUnit.WEEKDAY, Instant((1000, 12, 29)), 3)), - ], + ), ], ) -def test_period(arg, expected): +def test_period(arg, expected) -> None: assert periods.period(arg) == expected @pytest.mark.parametrize( - "arg, error", + ("arg", "error"), [ - [None, ValueError], - [DateUnit.YEAR, ValueError], - ["1", ValueError], - ["999", ValueError], - ["1000-0", ValueError], - ["1000-13", ValueError], - ["1000-W0", ValueError], - ["1000-W54", ValueError], - ["1000-0-0", ValueError], - ["1000-1-0", ValueError], - ["1000-2-31", ValueError], - ["1000-W0-0", ValueError], - ["1000-W1-0", ValueError], - ["1000-W1-8", ValueError], - ["a", ValueError], - ["year", ValueError], - ["1:1000", ValueError], - ["a:1000", ValueError], - ["month:1000", ValueError], - ["week:1000", ValueError], - ["day:1000-01", ValueError], - ["weekday:1000-W1", ValueError], - ["1000:a", ValueError], - ["1000:1", ValueError], - ["1000-01:1", ValueError], - ["1000-01-01:1", ValueError], - ["1000-W1:1", ValueError], - ["1000-W1-1:1", ValueError], - ["month:1000:1", ValueError], - ["week:1000:1", ValueError], - ["day:1000:1", ValueError], - ["day:1000-01:1", ValueError], - ["weekday:1000:1", ValueError], - ["weekday:1000-W1:1", ValueError], - [(), ValueError], - [{}, ValueError], - ["", ValueError], - [(None,), ValueError], - [(None, None), ValueError], - [(None, None, None), ValueError], - [(None, None, None, None), ValueError], - [(Instant((1, 1, 1)),), ValueError], - [(Period((DateUnit.DAY, Instant((1, 1, 1)), 365)),), ValueError], - [(1,), ValueError], - [(1, 1), ValueError], - [(1, 1, 1), ValueError], - [(-1,), ValueError], - [(-1, -1), ValueError], - [(-1, -1, -1), ValueError], - [("-1",), ValueError], - [("-1", "-1"), ValueError], - [("-1", "-1", "-1"), ValueError], - [("1-1",), ValueError], - [("1-1-1",), ValueError], + (None, ValueError), + (DateUnit.YEAR, ValueError), + ("1", ValueError), + ("999", ValueError), + ("1000-0", ValueError), + ("1000-13", ValueError), + ("1000-W0", ValueError), + ("1000-W54", ValueError), + ("1000-0-0", ValueError), + ("1000-1-0", ValueError), + ("1000-2-31", ValueError), + ("1000-W0-0", ValueError), + ("1000-W1-0", ValueError), + ("1000-W1-8", ValueError), + ("a", ValueError), + ("year", ValueError), + ("1:1000", ValueError), + ("a:1000", ValueError), + ("month:1000", ValueError), + ("week:1000", ValueError), + ("day:1000-01", ValueError), + ("weekday:1000-W1", ValueError), + ("1000:a", ValueError), + ("1000:1", ValueError), + ("1000-01:1", ValueError), + ("1000-01-01:1", ValueError), + ("1000-W1:1", ValueError), + ("1000-W1-1:1", ValueError), + ("month:1000:1", ValueError), + ("week:1000:1", ValueError), + ("day:1000:1", ValueError), + ("day:1000-01:1", ValueError), + ("weekday:1000:1", ValueError), + ("weekday:1000-W1:1", ValueError), + ((), ValueError), + ({}, ValueError), + ("", ValueError), + ((None,), ValueError), + ((None, None), ValueError), + ((None, None, None), ValueError), + ((None, None, None, None), ValueError), + ((Instant((1, 1, 1)),), ValueError), + ((Period((DateUnit.DAY, Instant((1, 1, 1)), 365)),), ValueError), + ((1,), ValueError), + ((1, 1), ValueError), + ((1, 1, 1), ValueError), + ((-1,), ValueError), + ((-1, -1), ValueError), + ((-1, -1, -1), ValueError), + (("-1",), ValueError), + (("-1", "-1"), ValueError), + (("-1", "-1", "-1"), ValueError), + (("1-1",), ValueError), + (("1-1-1",), ValueError), ], ) -def test_period_with_an_invalid_argument(arg, error): +def test_period_with_an_invalid_argument(arg, error) -> None: with pytest.raises(error): periods.period(arg) diff --git a/openfisca_core/periods/tests/test__parsers.py b/openfisca_core/periods/tests/test__parsers.py index 6c88c9cd11..67a2891a32 100644 --- a/openfisca_core/periods/tests/test__parsers.py +++ b/openfisca_core/periods/tests/test__parsers.py @@ -5,65 +5,65 @@ @pytest.mark.parametrize( - "arg, expected", + ("arg", "expected"), [ - ["1001", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))], - ["1001-01", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 1))], - ["1001-12", Period((DateUnit.MONTH, Instant((1001, 12, 1)), 1))], - ["1001-01-01", Period((DateUnit.DAY, Instant((1001, 1, 1)), 1))], - ["1001-W01", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 1))], - ["1001-W52", Period((DateUnit.WEEK, Instant((1001, 12, 21)), 1))], - ["1001-W01-1", Period((DateUnit.WEEKDAY, Instant((1000, 12, 29)), 1))], + ("1001", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))), + ("1001-01", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 1))), + ("1001-12", Period((DateUnit.MONTH, Instant((1001, 12, 1)), 1))), + ("1001-01-01", Period((DateUnit.DAY, Instant((1001, 1, 1)), 1))), + ("1001-W01", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 1))), + ("1001-W52", Period((DateUnit.WEEK, Instant((1001, 12, 21)), 1))), + ("1001-W01-1", Period((DateUnit.WEEKDAY, Instant((1000, 12, 29)), 1))), ], ) -def test__parse_period(arg, expected): +def test__parse_period(arg, expected) -> None: assert _parsers._parse_period(arg) == expected @pytest.mark.parametrize( - "arg, error", + ("arg", "error"), [ - [None, AttributeError], - [{}, AttributeError], - [(), AttributeError], - [[], AttributeError], - [1, AttributeError], - ["", AttributeError], - ["à", ParserError], - ["1", ValueError], - ["-1", ValueError], - ["999", ParserError], - ["1000-0", ParserError], - ["1000-1", ParserError], - ["1000-1-1", ParserError], - ["1000-00", ParserError], - ["1000-13", ParserError], - ["1000-01-00", ParserError], - ["1000-01-99", ParserError], - ["1000-W0", ParserError], - ["1000-W1", ParserError], - ["1000-W99", ParserError], - ["1000-W1-0", ParserError], - ["1000-W1-1", ParserError], - ["1000-W1-99", ParserError], - ["1000-W01-0", ParserError], - ["1000-W01-00", ParserError], + (None, AttributeError), + ({}, AttributeError), + ((), AttributeError), + ([], AttributeError), + (1, AttributeError), + ("", AttributeError), + ("à", ParserError), + ("1", ValueError), + ("-1", ValueError), + ("999", ParserError), + ("1000-0", ParserError), + ("1000-1", ParserError), + ("1000-1-1", ParserError), + ("1000-00", ParserError), + ("1000-13", ParserError), + ("1000-01-00", ParserError), + ("1000-01-99", ParserError), + ("1000-W0", ParserError), + ("1000-W1", ParserError), + ("1000-W99", ParserError), + ("1000-W1-0", ParserError), + ("1000-W1-1", ParserError), + ("1000-W1-99", ParserError), + ("1000-W01-0", ParserError), + ("1000-W01-00", ParserError), ], ) -def test__parse_period_with_invalid_argument(arg, error): +def test__parse_period_with_invalid_argument(arg, error) -> None: with pytest.raises(error): _parsers._parse_period(arg) @pytest.mark.parametrize( - "arg, expected", + ("arg", "expected"), [ - ["2022", DateUnit.YEAR], - ["2022-01", DateUnit.MONTH], - ["2022-01-01", DateUnit.DAY], - ["2022-W01", DateUnit.WEEK], - ["2022-W01-01", DateUnit.WEEKDAY], + ("2022", DateUnit.YEAR), + ("2022-01", DateUnit.MONTH), + ("2022-01-01", DateUnit.DAY), + ("2022-W01", DateUnit.WEEK), + ("2022-W01-01", DateUnit.WEEKDAY), ], ) -def test__parse_unit(arg, expected): +def test__parse_unit(arg, expected) -> None: assert _parsers._parse_unit(arg) == expected diff --git a/openfisca_core/periods/tests/test_instant.py b/openfisca_core/periods/tests/test_instant.py index 21549008f4..e9c73ef6aa 100644 --- a/openfisca_core/periods/tests/test_instant.py +++ b/openfisca_core/periods/tests/test_instant.py @@ -4,29 +4,29 @@ @pytest.mark.parametrize( - "instant, offset, unit, expected", + ("instant", "offset", "unit", "expected"), [ - [Instant((2020, 2, 29)), "first-of", DateUnit.YEAR, Instant((2020, 1, 1))], - [Instant((2020, 2, 29)), "first-of", DateUnit.MONTH, Instant((2020, 2, 1))], - [Instant((2020, 2, 29)), "first-of", DateUnit.WEEK, Instant((2020, 2, 24))], - [Instant((2020, 2, 29)), "first-of", DateUnit.DAY, None], - [Instant((2020, 2, 29)), "first-of", DateUnit.WEEKDAY, None], - [Instant((2020, 2, 29)), "last-of", DateUnit.YEAR, Instant((2020, 12, 31))], - [Instant((2020, 2, 29)), "last-of", DateUnit.MONTH, Instant((2020, 2, 29))], - [Instant((2020, 2, 29)), "last-of", DateUnit.WEEK, Instant((2020, 3, 1))], - [Instant((2020, 2, 29)), "last-of", DateUnit.DAY, None], - [Instant((2020, 2, 29)), "last-of", DateUnit.WEEKDAY, None], - [Instant((2020, 2, 29)), -3, DateUnit.YEAR, Instant((2017, 2, 28))], - [Instant((2020, 2, 29)), -3, DateUnit.MONTH, Instant((2019, 11, 29))], - [Instant((2020, 2, 29)), -3, DateUnit.WEEK, Instant((2020, 2, 8))], - [Instant((2020, 2, 29)), -3, DateUnit.DAY, Instant((2020, 2, 26))], - [Instant((2020, 2, 29)), -3, DateUnit.WEEKDAY, Instant((2020, 2, 26))], - [Instant((2020, 2, 29)), 3, DateUnit.YEAR, Instant((2023, 2, 28))], - [Instant((2020, 2, 29)), 3, DateUnit.MONTH, Instant((2020, 5, 29))], - [Instant((2020, 2, 29)), 3, DateUnit.WEEK, Instant((2020, 3, 21))], - [Instant((2020, 2, 29)), 3, DateUnit.DAY, Instant((2020, 3, 3))], - [Instant((2020, 2, 29)), 3, DateUnit.WEEKDAY, Instant((2020, 3, 3))], + (Instant((2020, 2, 29)), "first-of", DateUnit.YEAR, Instant((2020, 1, 1))), + (Instant((2020, 2, 29)), "first-of", DateUnit.MONTH, Instant((2020, 2, 1))), + (Instant((2020, 2, 29)), "first-of", DateUnit.WEEK, Instant((2020, 2, 24))), + (Instant((2020, 2, 29)), "first-of", DateUnit.DAY, None), + (Instant((2020, 2, 29)), "first-of", DateUnit.WEEKDAY, None), + (Instant((2020, 2, 29)), "last-of", DateUnit.YEAR, Instant((2020, 12, 31))), + (Instant((2020, 2, 29)), "last-of", DateUnit.MONTH, Instant((2020, 2, 29))), + (Instant((2020, 2, 29)), "last-of", DateUnit.WEEK, Instant((2020, 3, 1))), + (Instant((2020, 2, 29)), "last-of", DateUnit.DAY, None), + (Instant((2020, 2, 29)), "last-of", DateUnit.WEEKDAY, None), + (Instant((2020, 2, 29)), -3, DateUnit.YEAR, Instant((2017, 2, 28))), + (Instant((2020, 2, 29)), -3, DateUnit.MONTH, Instant((2019, 11, 29))), + (Instant((2020, 2, 29)), -3, DateUnit.WEEK, Instant((2020, 2, 8))), + (Instant((2020, 2, 29)), -3, DateUnit.DAY, Instant((2020, 2, 26))), + (Instant((2020, 2, 29)), -3, DateUnit.WEEKDAY, Instant((2020, 2, 26))), + (Instant((2020, 2, 29)), 3, DateUnit.YEAR, Instant((2023, 2, 28))), + (Instant((2020, 2, 29)), 3, DateUnit.MONTH, Instant((2020, 5, 29))), + (Instant((2020, 2, 29)), 3, DateUnit.WEEK, Instant((2020, 3, 21))), + (Instant((2020, 2, 29)), 3, DateUnit.DAY, Instant((2020, 3, 3))), + (Instant((2020, 2, 29)), 3, DateUnit.WEEKDAY, Instant((2020, 3, 3))), ], ) -def test_offset(instant, offset, unit, expected): +def test_offset(instant, offset, unit, expected) -> None: assert instant.offset(offset, unit) == expected diff --git a/openfisca_core/periods/tests/test_period.py b/openfisca_core/periods/tests/test_period.py index 6553c4fd9b..9e53bf7d12 100644 --- a/openfisca_core/periods/tests/test_period.py +++ b/openfisca_core/periods/tests/test_period.py @@ -4,278 +4,278 @@ @pytest.mark.parametrize( - "date_unit, instant, size, expected", + ("date_unit", "instant", "size", "expected"), [ - [DateUnit.YEAR, Instant((2022, 1, 1)), 1, "2022"], - [DateUnit.MONTH, Instant((2022, 1, 1)), 12, "2022"], - [DateUnit.YEAR, Instant((2022, 3, 1)), 1, "year:2022-03"], - [DateUnit.MONTH, Instant((2022, 3, 1)), 12, "year:2022-03"], - [DateUnit.YEAR, Instant((2022, 1, 1)), 3, "year:2022:3"], - [DateUnit.YEAR, Instant((2022, 1, 3)), 3, "year:2022:3"], + (DateUnit.YEAR, Instant((2022, 1, 1)), 1, "2022"), + (DateUnit.MONTH, Instant((2022, 1, 1)), 12, "2022"), + (DateUnit.YEAR, Instant((2022, 3, 1)), 1, "year:2022-03"), + (DateUnit.MONTH, Instant((2022, 3, 1)), 12, "year:2022-03"), + (DateUnit.YEAR, Instant((2022, 1, 1)), 3, "year:2022:3"), + (DateUnit.YEAR, Instant((2022, 1, 3)), 3, "year:2022:3"), ], ) -def test_str_with_years(date_unit, instant, size, expected): +def test_str_with_years(date_unit, instant, size, expected) -> None: assert str(Period((date_unit, instant, size))) == expected @pytest.mark.parametrize( - "date_unit, instant, size, expected", + ("date_unit", "instant", "size", "expected"), [ - [DateUnit.MONTH, Instant((2022, 1, 1)), 1, "2022-01"], - [DateUnit.MONTH, Instant((2022, 1, 1)), 3, "month:2022-01:3"], - [DateUnit.MONTH, Instant((2022, 3, 1)), 3, "month:2022-03:3"], + (DateUnit.MONTH, Instant((2022, 1, 1)), 1, "2022-01"), + (DateUnit.MONTH, Instant((2022, 1, 1)), 3, "month:2022-01:3"), + (DateUnit.MONTH, Instant((2022, 3, 1)), 3, "month:2022-03:3"), ], ) -def test_str_with_months(date_unit, instant, size, expected): +def test_str_with_months(date_unit, instant, size, expected) -> None: assert str(Period((date_unit, instant, size))) == expected @pytest.mark.parametrize( - "date_unit, instant, size, expected", + ("date_unit", "instant", "size", "expected"), [ - [DateUnit.DAY, Instant((2022, 1, 1)), 1, "2022-01-01"], - [DateUnit.DAY, Instant((2022, 1, 1)), 3, "day:2022-01-01:3"], - [DateUnit.DAY, Instant((2022, 3, 1)), 3, "day:2022-03-01:3"], + (DateUnit.DAY, Instant((2022, 1, 1)), 1, "2022-01-01"), + (DateUnit.DAY, Instant((2022, 1, 1)), 3, "day:2022-01-01:3"), + (DateUnit.DAY, Instant((2022, 3, 1)), 3, "day:2022-03-01:3"), ], ) -def test_str_with_days(date_unit, instant, size, expected): +def test_str_with_days(date_unit, instant, size, expected) -> None: assert str(Period((date_unit, instant, size))) == expected @pytest.mark.parametrize( - "date_unit, instant, size, expected", + ("date_unit", "instant", "size", "expected"), [ - [DateUnit.WEEK, Instant((2022, 1, 1)), 1, "2021-W52"], - [DateUnit.WEEK, Instant((2022, 1, 1)), 3, "week:2021-W52:3"], - [DateUnit.WEEK, Instant((2022, 3, 1)), 1, "2022-W09"], - [DateUnit.WEEK, Instant((2022, 3, 1)), 3, "week:2022-W09:3"], + (DateUnit.WEEK, Instant((2022, 1, 1)), 1, "2021-W52"), + (DateUnit.WEEK, Instant((2022, 1, 1)), 3, "week:2021-W52:3"), + (DateUnit.WEEK, Instant((2022, 3, 1)), 1, "2022-W09"), + (DateUnit.WEEK, Instant((2022, 3, 1)), 3, "week:2022-W09:3"), ], ) -def test_str_with_weeks(date_unit, instant, size, expected): +def test_str_with_weeks(date_unit, instant, size, expected) -> None: assert str(Period((date_unit, instant, size))) == expected @pytest.mark.parametrize( - "date_unit, instant, size, expected", + ("date_unit", "instant", "size", "expected"), [ - [DateUnit.WEEKDAY, Instant((2022, 1, 1)), 1, "2021-W52-6"], - [DateUnit.WEEKDAY, Instant((2022, 1, 1)), 3, "weekday:2021-W52-6:3"], - [DateUnit.WEEKDAY, Instant((2022, 3, 1)), 1, "2022-W09-2"], - [DateUnit.WEEKDAY, Instant((2022, 3, 1)), 3, "weekday:2022-W09-2:3"], + (DateUnit.WEEKDAY, Instant((2022, 1, 1)), 1, "2021-W52-6"), + (DateUnit.WEEKDAY, Instant((2022, 1, 1)), 3, "weekday:2021-W52-6:3"), + (DateUnit.WEEKDAY, Instant((2022, 3, 1)), 1, "2022-W09-2"), + (DateUnit.WEEKDAY, Instant((2022, 3, 1)), 3, "weekday:2022-W09-2:3"), ], ) -def test_str_with_weekdays(date_unit, instant, size, expected): +def test_str_with_weekdays(date_unit, instant, size, expected) -> None: assert str(Period((date_unit, instant, size))) == expected @pytest.mark.parametrize( - "date_unit, instant, size, expected", + ("date_unit", "instant", "size", "expected"), [ - [DateUnit.YEAR, Instant((2022, 12, 1)), 1, 1], - [DateUnit.YEAR, Instant((2022, 1, 1)), 2, 2], + (DateUnit.YEAR, Instant((2022, 12, 1)), 1, 1), + (DateUnit.YEAR, Instant((2022, 1, 1)), 2, 2), ], ) -def test_size_in_years(date_unit, instant, size, expected): +def test_size_in_years(date_unit, instant, size, expected) -> None: period = Period((date_unit, instant, size)) assert period.size_in_years == expected @pytest.mark.parametrize( - "date_unit, instant, size, expected", + ("date_unit", "instant", "size", "expected"), [ - [DateUnit.YEAR, Instant((2020, 1, 1)), 1, 12], - [DateUnit.YEAR, Instant((2022, 1, 1)), 2, 24], - [DateUnit.MONTH, Instant((2012, 1, 3)), 3, 3], + (DateUnit.YEAR, Instant((2020, 1, 1)), 1, 12), + (DateUnit.YEAR, Instant((2022, 1, 1)), 2, 24), + (DateUnit.MONTH, Instant((2012, 1, 3)), 3, 3), ], ) -def test_size_in_months(date_unit, instant, size, expected): +def test_size_in_months(date_unit, instant, size, expected) -> None: period = Period((date_unit, instant, size)) assert period.size_in_months == expected @pytest.mark.parametrize( - "date_unit, instant, size, expected", + ("date_unit", "instant", "size", "expected"), [ - [DateUnit.YEAR, Instant((2022, 12, 1)), 1, 365], - [DateUnit.YEAR, Instant((2020, 1, 1)), 1, 366], - [DateUnit.YEAR, Instant((2022, 1, 1)), 2, 730], - [DateUnit.MONTH, Instant((2022, 12, 1)), 1, 31], - [DateUnit.MONTH, Instant((2020, 2, 3)), 1, 29], - [DateUnit.MONTH, Instant((2022, 1, 3)), 3, 31 + 28 + 31], - [DateUnit.MONTH, Instant((2012, 1, 3)), 3, 31 + 29 + 31], - [DateUnit.DAY, Instant((2022, 12, 31)), 1, 1], - [DateUnit.DAY, Instant((2022, 12, 31)), 3, 3], - [DateUnit.WEEK, Instant((2022, 12, 31)), 1, 7], - [DateUnit.WEEK, Instant((2022, 12, 31)), 3, 21], - [DateUnit.WEEKDAY, Instant((2022, 12, 31)), 1, 1], - [DateUnit.WEEKDAY, Instant((2022, 12, 31)), 3, 3], + (DateUnit.YEAR, Instant((2022, 12, 1)), 1, 365), + (DateUnit.YEAR, Instant((2020, 1, 1)), 1, 366), + (DateUnit.YEAR, Instant((2022, 1, 1)), 2, 730), + (DateUnit.MONTH, Instant((2022, 12, 1)), 1, 31), + (DateUnit.MONTH, Instant((2020, 2, 3)), 1, 29), + (DateUnit.MONTH, Instant((2022, 1, 3)), 3, 31 + 28 + 31), + (DateUnit.MONTH, Instant((2012, 1, 3)), 3, 31 + 29 + 31), + (DateUnit.DAY, Instant((2022, 12, 31)), 1, 1), + (DateUnit.DAY, Instant((2022, 12, 31)), 3, 3), + (DateUnit.WEEK, Instant((2022, 12, 31)), 1, 7), + (DateUnit.WEEK, Instant((2022, 12, 31)), 3, 21), + (DateUnit.WEEKDAY, Instant((2022, 12, 31)), 1, 1), + (DateUnit.WEEKDAY, Instant((2022, 12, 31)), 3, 3), ], ) -def test_size_in_days(date_unit, instant, size, expected): +def test_size_in_days(date_unit, instant, size, expected) -> None: period = Period((date_unit, instant, size)) assert period.size_in_days == expected assert period.size_in_days == period.days @pytest.mark.parametrize( - "date_unit, instant, size, expected", + ("date_unit", "instant", "size", "expected"), [ - [DateUnit.YEAR, Instant((2022, 12, 1)), 1, 52], - [DateUnit.YEAR, Instant((2020, 1, 1)), 5, 261], - [DateUnit.MONTH, Instant((2022, 12, 1)), 1, 4], - [DateUnit.MONTH, Instant((2020, 2, 3)), 1, 4], - [DateUnit.MONTH, Instant((2022, 1, 3)), 3, 12], - [DateUnit.MONTH, Instant((2012, 1, 3)), 3, 13], - [DateUnit.WEEK, Instant((2022, 12, 31)), 1, 1], - [DateUnit.WEEK, Instant((2022, 12, 31)), 3, 3], + (DateUnit.YEAR, Instant((2022, 12, 1)), 1, 52), + (DateUnit.YEAR, Instant((2020, 1, 1)), 5, 261), + (DateUnit.MONTH, Instant((2022, 12, 1)), 1, 4), + (DateUnit.MONTH, Instant((2020, 2, 3)), 1, 4), + (DateUnit.MONTH, Instant((2022, 1, 3)), 3, 12), + (DateUnit.MONTH, Instant((2012, 1, 3)), 3, 13), + (DateUnit.WEEK, Instant((2022, 12, 31)), 1, 1), + (DateUnit.WEEK, Instant((2022, 12, 31)), 3, 3), ], ) -def test_size_in_weeks(date_unit, instant, size, expected): +def test_size_in_weeks(date_unit, instant, size, expected) -> None: period = Period((date_unit, instant, size)) assert period.size_in_weeks == expected @pytest.mark.parametrize( - "date_unit, instant, size, expected", + ("date_unit", "instant", "size", "expected"), [ - [DateUnit.YEAR, Instant((2022, 12, 1)), 1, 364], - [DateUnit.YEAR, Instant((2020, 1, 1)), 1, 364], - [DateUnit.YEAR, Instant((2022, 1, 1)), 2, 728], - [DateUnit.MONTH, Instant((2022, 12, 1)), 1, 31], - [DateUnit.MONTH, Instant((2020, 2, 3)), 1, 29], - [DateUnit.MONTH, Instant((2022, 1, 3)), 3, 31 + 28 + 31], - [DateUnit.MONTH, Instant((2012, 1, 3)), 3, 31 + 29 + 31], - [DateUnit.DAY, Instant((2022, 12, 31)), 1, 1], - [DateUnit.DAY, Instant((2022, 12, 31)), 3, 3], - [DateUnit.WEEK, Instant((2022, 12, 31)), 1, 7], - [DateUnit.WEEK, Instant((2022, 12, 31)), 3, 21], - [DateUnit.WEEKDAY, Instant((2022, 12, 31)), 1, 1], - [DateUnit.WEEKDAY, Instant((2022, 12, 31)), 3, 3], + (DateUnit.YEAR, Instant((2022, 12, 1)), 1, 364), + (DateUnit.YEAR, Instant((2020, 1, 1)), 1, 364), + (DateUnit.YEAR, Instant((2022, 1, 1)), 2, 728), + (DateUnit.MONTH, Instant((2022, 12, 1)), 1, 31), + (DateUnit.MONTH, Instant((2020, 2, 3)), 1, 29), + (DateUnit.MONTH, Instant((2022, 1, 3)), 3, 31 + 28 + 31), + (DateUnit.MONTH, Instant((2012, 1, 3)), 3, 31 + 29 + 31), + (DateUnit.DAY, Instant((2022, 12, 31)), 1, 1), + (DateUnit.DAY, Instant((2022, 12, 31)), 3, 3), + (DateUnit.WEEK, Instant((2022, 12, 31)), 1, 7), + (DateUnit.WEEK, Instant((2022, 12, 31)), 3, 21), + (DateUnit.WEEKDAY, Instant((2022, 12, 31)), 1, 1), + (DateUnit.WEEKDAY, Instant((2022, 12, 31)), 3, 3), ], ) -def test_size_in_weekdays(date_unit, instant, size, expected): +def test_size_in_weekdays(date_unit, instant, size, expected) -> None: period = Period((date_unit, instant, size)) assert period.size_in_weekdays == expected @pytest.mark.parametrize( - "period_unit, sub_unit, instant, start, cease, count", + ("period_unit", "sub_unit", "instant", "start", "cease", "count"), [ - [ + ( DateUnit.YEAR, DateUnit.YEAR, Instant((2022, 12, 31)), Instant((2022, 1, 1)), Instant((2024, 1, 1)), 3, - ], - [ + ), + ( DateUnit.YEAR, DateUnit.MONTH, Instant((2022, 12, 31)), Instant((2022, 12, 1)), Instant((2025, 11, 1)), 36, - ], - [ + ), + ( DateUnit.YEAR, DateUnit.DAY, Instant((2022, 12, 31)), Instant((2022, 12, 31)), Instant((2025, 12, 30)), 1096, - ], - [ + ), + ( DateUnit.YEAR, DateUnit.WEEK, Instant((2022, 12, 31)), Instant((2022, 12, 26)), Instant((2025, 12, 15)), 156, - ], - [ + ), + ( DateUnit.YEAR, DateUnit.WEEKDAY, Instant((2022, 12, 31)), Instant((2022, 12, 31)), Instant((2025, 12, 26)), 1092, - ], - [ + ), + ( DateUnit.MONTH, DateUnit.MONTH, Instant((2022, 12, 31)), Instant((2022, 12, 1)), Instant((2023, 2, 1)), 3, - ], - [ + ), + ( DateUnit.MONTH, DateUnit.DAY, Instant((2022, 12, 31)), Instant((2022, 12, 31)), Instant((2023, 3, 30)), 90, - ], - [ + ), + ( DateUnit.DAY, DateUnit.DAY, Instant((2022, 12, 31)), Instant((2022, 12, 31)), Instant((2023, 1, 2)), 3, - ], - [ + ), + ( DateUnit.DAY, DateUnit.WEEKDAY, Instant((2022, 12, 31)), Instant((2022, 12, 31)), Instant((2023, 1, 2)), 3, - ], - [ + ), + ( DateUnit.WEEK, DateUnit.DAY, Instant((2022, 12, 31)), Instant((2022, 12, 31)), Instant((2023, 1, 20)), 21, - ], - [ + ), + ( DateUnit.WEEK, DateUnit.WEEK, Instant((2022, 12, 31)), Instant((2022, 12, 26)), Instant((2023, 1, 9)), 3, - ], - [ + ), + ( DateUnit.WEEK, DateUnit.WEEKDAY, Instant((2022, 12, 31)), Instant((2022, 12, 31)), Instant((2023, 1, 20)), 21, - ], - [ + ), + ( DateUnit.WEEKDAY, DateUnit.DAY, Instant((2022, 12, 31)), Instant((2022, 12, 31)), Instant((2023, 1, 2)), 3, - ], - [ + ), + ( DateUnit.WEEKDAY, DateUnit.WEEKDAY, Instant((2022, 12, 31)), Instant((2022, 12, 31)), Instant((2023, 1, 2)), 3, - ], + ), ], ) -def test_subperiods(period_unit, sub_unit, instant, start, cease, count): +def test_subperiods(period_unit, sub_unit, instant, start, cease, count) -> None: period = Period((period_unit, instant, 3)) subperiods = period.get_subperiods(sub_unit) assert len(subperiods) == count diff --git a/openfisca_core/populations/group_population.py b/openfisca_core/populations/group_population.py index d77816face..4e68762f19 100644 --- a/openfisca_core/populations/group_population.py +++ b/openfisca_core/populations/group_population.py @@ -8,7 +8,7 @@ class GroupPopulation(Population): - def __init__(self, entity, members): + def __init__(self, entity, members) -> None: super().__init__(entity) self.members = members self._members_entity_id = None @@ -46,7 +46,7 @@ def members_position(self): return self._members_position @members_position.setter - def members_position(self, members_position): + def members_position(self, members_position) -> None: self._members_position = members_position @property @@ -54,7 +54,7 @@ def members_entity_id(self): return self._members_entity_id @members_entity_id.setter - def members_entity_id(self, members_entity_id): + def members_entity_id(self, members_entity_id) -> None: self._members_entity_id = members_entity_id @property @@ -65,14 +65,13 @@ def members_role(self): return self._members_role @members_role.setter - def members_role(self, members_role: typing.Iterable[entities.Role]): + def members_role(self, members_role: typing.Iterable[entities.Role]) -> None: if members_role is not None: self._members_role = numpy.array(list(members_role)) @property def ordered_members_map(self): - """ - Mask to group the persons by entity + """Mask to group the persons by entity This function only caches the map value, to see what the map is used for, see value_nth_person method. """ if self._ordered_members_map is None: @@ -89,18 +88,19 @@ def get_role(self, role_name): @projectors.projectable def sum(self, array, role=None): - """ - Return the sum of ``array`` for the members of the entity. + """Return the sum of ``array`` for the members of the entity. ``array`` must have the dimension of the number of persons in the simulation If ``role`` is provided, only the entity member with the given role are taken into account. Example: - - >>> salaries = household.members('salary', '2018-01') # e.g. [2000, 1500, 0, 0, 0] + >>> salaries = household.members( + ... "salary", "2018-01" + ... ) # e.g. [2000, 1500, 0, 0, 0] >>> household.sum(salaries) >>> array([3500]) + """ self.entity.check_role_validity(role) self.members.check_array_compatible_with_entity(array) @@ -111,23 +111,23 @@ def sum(self, array, role=None): weights=array[role_filter], minlength=self.count, ) - else: - return numpy.bincount(self.members_entity_id, weights=array) + return numpy.bincount(self.members_entity_id, weights=array) @projectors.projectable def any(self, array, role=None): - """ - Return ``True`` if ``array`` is ``True`` for any members of the entity. + """Return ``True`` if ``array`` is ``True`` for any members of the entity. ``array`` must have the dimension of the number of persons in the simulation If ``role`` is provided, only the entity member with the given role are taken into account. Example: - - >>> salaries = household.members('salary', '2018-01') # e.g. [2000, 1500, 0, 0, 0] + >>> salaries = household.members( + ... "salary", "2018-01" + ... ) # e.g. [2000, 1500, 0, 0, 0] >>> household.any(salaries >= 1800) >>> array([True]) + """ sum_in_entity = self.sum(array, role=role) return sum_in_entity > 0 @@ -141,7 +141,7 @@ def reduce(self, array, reducer, neutral_element, role=None): filtered_array = numpy.where(role_filter, array, neutral_element) result = self.filled_array( - neutral_element + neutral_element, ) # Neutral value that will be returned if no one with the given role exists. # We loop over the positions in the entity @@ -156,87 +156,98 @@ def reduce(self, array, reducer, neutral_element, role=None): @projectors.projectable def all(self, array, role=None): - """ - Return ``True`` if ``array`` is ``True`` for all members of the entity. + """Return ``True`` if ``array`` is ``True`` for all members of the entity. ``array`` must have the dimension of the number of persons in the simulation If ``role`` is provided, only the entity member with the given role are taken into account. Example: - - >>> salaries = household.members('salary', '2018-01') # e.g. [2000, 1500, 0, 0, 0] + >>> salaries = household.members( + ... "salary", "2018-01" + ... ) # e.g. [2000, 1500, 0, 0, 0] >>> household.all(salaries >= 1800) >>> array([False]) + """ return self.reduce( - array, reducer=numpy.logical_and, neutral_element=True, role=role + array, + reducer=numpy.logical_and, + neutral_element=True, + role=role, ) @projectors.projectable def max(self, array, role=None): - """ - Return the maximum value of ``array`` for the entity members. + """Return the maximum value of ``array`` for the entity members. ``array`` must have the dimension of the number of persons in the simulation If ``role`` is provided, only the entity member with the given role are taken into account. Example: - - >>> salaries = household.members('salary', '2018-01') # e.g. [2000, 1500, 0, 0, 0] + >>> salaries = household.members( + ... "salary", "2018-01" + ... ) # e.g. [2000, 1500, 0, 0, 0] >>> household.max(salaries) >>> array([2000]) + """ return self.reduce( - array, reducer=numpy.maximum, neutral_element=-numpy.infty, role=role + array, + reducer=numpy.maximum, + neutral_element=-numpy.inf, + role=role, ) @projectors.projectable def min(self, array, role=None): - """ - Return the minimum value of ``array`` for the entity members. + """Return the minimum value of ``array`` for the entity members. ``array`` must have the dimension of the number of persons in the simulation If ``role`` is provided, only the entity member with the given role are taken into account. Example: - - >>> salaries = household.members('salary', '2018-01') # e.g. [2000, 1500, 0, 0, 0] + >>> salaries = household.members( + ... "salary", "2018-01" + ... ) # e.g. [2000, 1500, 0, 0, 0] >>> household.min(salaries) >>> array([0]) - >>> household.min(salaries, role = Household.PARENT) # Assuming the 1st two persons are parents + >>> household.min( + ... salaries, role=Household.PARENT + ... ) # Assuming the 1st two persons are parents >>> array([1500]) + """ return self.reduce( - array, reducer=numpy.minimum, neutral_element=numpy.infty, role=role + array, + reducer=numpy.minimum, + neutral_element=numpy.inf, + role=role, ) @projectors.projectable def nb_persons(self, role=None): - """ - Returns the number of persons contained in the entity. + """Returns the number of persons contained in the entity. If ``role`` is provided, only the entity member with the given role are taken into account. """ if role: if role.subroles: role_condition = numpy.logical_or.reduce( - [self.members_role == subrole for subrole in role.subroles] + [self.members_role == subrole for subrole in role.subroles], ) else: role_condition = self.members_role == role return self.sum(role_condition) - else: - return numpy.bincount(self.members_entity_id) + return numpy.bincount(self.members_entity_id) # Projection person -> entity @projectors.projectable def value_from_person(self, array, role, default=0): - """ - Get the value of ``array`` for the person with the unique role ``role``. + """Get the value of ``array`` for the person with the unique role ``role``. ``array`` must have the dimension of the number of persons in the simulation @@ -246,10 +257,9 @@ def value_from_person(self, array, role, default=0): """ self.entity.check_role_validity(role) if role.max != 1: + msg = f"You can only use value_from_person with a role that is unique in {self.key}. Role {role.key} is not unique." raise Exception( - "You can only use value_from_person with a role that is unique in {}. Role {} is not unique.".format( - self.key, role.key - ) + msg, ) self.members.check_array_compatible_with_entity(array) members_map = self.ordered_members_map @@ -265,8 +275,7 @@ def value_from_person(self, array, role, default=0): @projectors.projectable def value_nth_person(self, n, array, default=0): - """ - Get the value of array for the person whose position in the entity is n. + """Get the value of array for the person whose position in the entity is n. Note that this position is arbitrary, and that members are not sorted. @@ -301,6 +310,5 @@ def project(self, array, role=None): self.entity.check_role_validity(role) if role is None: return array[self.members_entity_id] - else: - role_condition = self.members.has_role(role) - return numpy.where(role_condition, array[self.members_entity_id], 0) + role_condition = self.members.has_role(role) + return numpy.where(role_condition, array[self.members_entity_id], 0) diff --git a/openfisca_core/populations/population.py b/openfisca_core/populations/population.py index e3ef6b209a..f9eee1c2a2 100644 --- a/openfisca_core/populations/population.py +++ b/openfisca_core/populations/population.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Dict, NamedTuple, Optional, Sequence, Union +from collections.abc import Sequence +from typing import NamedTuple from typing_extensions import TypedDict from openfisca_core.types import Array, Period, Role, Simulation, SingleEntity @@ -15,9 +16,9 @@ class Population: - simulation: Optional[Simulation] + simulation: Simulation | None entity: SingleEntity - _holders: Dict[str, holders.Holder] + _holders: dict[str, holders.Holder] count: int ids: Array[str] @@ -44,22 +45,21 @@ def empty_array(self) -> Array[float]: def filled_array( self, - value: Union[float, bool], - dtype: Optional[numpy.dtype] = None, - ) -> Union[Array[float], Array[bool]]: + value: float | bool, + dtype: numpy.dtype | None = None, + ) -> Array[float] | Array[bool]: return numpy.full(self.count, value, dtype) def __getattr__(self, attribute: str) -> projectors.Projector: - projector: Optional[projectors.Projector] + projector: projectors.Projector | None projector = projectors.get_projector_from_shortcut(self, attribute) if isinstance(projector, projectors.Projector): return projector + msg = f"You tried to use the '{attribute}' of '{self.entity.key}' but that is not a known attribute." raise AttributeError( - "You tried to use the '{}' of '{}' but that is not a known attribute.".format( - attribute, self.entity.key - ) + msg, ) def get_index(self, id: str) -> int: @@ -72,51 +72,48 @@ def check_array_compatible_with_entity( array: Array[float], ) -> None: if self.count == array.size: - return None + return + msg = f"Input {array} is not a valid value for the entity {self.entity.key} (size = {array.size} != {self.count} = count)" raise ValueError( - "Input {} is not a valid value for the entity {} (size = {} != {} = count)".format( - array, self.entity.key, array.size, self.count - ) + msg, ) def check_period_validity( self, variable_name: str, - period: Optional[Union[int, str, Period]], + period: int | str | Period | None, ) -> None: if isinstance(period, (int, str, Period)): - return None + return stack = traceback.extract_stack() filename, line_number, function_name, line_of_code = stack[-3] - raise ValueError( - """ -You requested computation of variable "{}", but you did not specify on which period in "{}:{}": - {} + msg = f""" +You requested computation of variable "{variable_name}", but you did not specify on which period in "{filename}:{line_number}": + {line_of_code} When you request the computation of a variable within a formula, you must always specify the period as the second parameter. The convention is to call this parameter "period". For example: computed_salary = person('salary', period). See more information at . -""".format( - variable_name, filename, line_number, line_of_code - ) +""" + raise ValueError( + msg, ) def __call__( self, variable_name: str, - period: Optional[Union[int, str, Period]] = None, - options: Optional[Sequence[str]] = None, - ) -> Optional[Array[float]]: - """ - Calculate the variable ``variable_name`` for the entity and the period ``period``, using the variable formula if it exists. + period: int | str | Period | None = None, + options: Sequence[str] | None = None, + ) -> Array[float] | None: + """Calculate the variable ``variable_name`` for the entity and the period ``period``, using the variable formula if it exists. Example: - - >>> person('salary', '2017-04') - >>> array([300.]) + >>> person("salary", "2017-04") + >>> array([300.0]) :returns: A numpy array containing the result of the calculation + """ if self.simulation is None: return None @@ -149,11 +146,7 @@ def __call__( ) raise ValueError( - "Options config.ADD and config.DIVIDE are incompatible (trying to compute variable {})".format( - variable_name - ).encode( - "utf-8" - ) + f"Options config.ADD and config.DIVIDE are incompatible (trying to compute variable {variable_name})".encode(), ) # Helpers @@ -169,7 +162,7 @@ def get_holder(self, variable_name: str) -> holders.Holder: def get_memory_usage( self, - variables: Optional[Sequence[str]] = None, + variables: Sequence[str] | None = None, ) -> MemoryUsageByVariable: holders_memory_usage = { variable_name: holder.get_memory_usage() @@ -186,20 +179,18 @@ def get_memory_usage( { "total_nb_bytes": total_memory_usage, "by_variable": holders_memory_usage, - } + }, ) @projectors.projectable - def has_role(self, role: Role) -> Optional[Array[bool]]: - """ - Check if a person has a given role within its `GroupEntity` + def has_role(self, role: Role) -> Array[bool] | None: + """Check if a person has a given role within its `GroupEntity`. Example: - >>> person.has_role(Household.CHILD) >>> array([False]) - """ + """ if self.simulation is None: return None @@ -209,11 +200,10 @@ def has_role(self, role: Role) -> Optional[Array[bool]]: if role.subroles: return numpy.logical_or.reduce( - [group_population.members_role == subrole for subrole in role.subroles] + [group_population.members_role == subrole for subrole in role.subroles], ) - else: - return group_population.members_role == role + return group_population.members_role == role @projectors.projectable def value_from_partner( @@ -221,13 +211,14 @@ def value_from_partner( array: Array[float], entity: projectors.Projector, role: Role, - ) -> Optional[Array[float]]: + ) -> Array[float] | None: self.check_array_compatible_with_entity(array) self.entity.check_role_validity(role) - if not role.subroles or not len(role.subroles) == 2: + if not role.subroles or len(role.subroles) != 2: + msg = "Projection to partner is only implemented for roles having exactly two subroles." raise Exception( - "Projection to partner is only implemented for roles having exactly two subroles." + msg, ) [subrole_1, subrole_2] = role.subroles @@ -246,22 +237,24 @@ def get_rank( criteria: Array[float], condition: bool = True, ) -> Array[int]: - """ - Get the rank of a person within an entity according to a criteria. + """Get the rank of a person within an entity according to a criteria. The person with rank 0 has the minimum value of criteria. If condition is specified, then the persons who don't respect it are not taken into account and their rank is -1. Example: - - >>> age = person('age', period) # e.g [32, 34, 2, 8, 1] + >>> age = person("age", period) # e.g [32, 34, 2, 8, 1] >>> person.get_rank(household, age) >>> [3, 4, 0, 2, 1] - >>> is_child = person.has_role(Household.CHILD) # [False, False, True, True, True] - >>> person.get_rank(household, - age, condition = is_child) # Sort in reverse order so that the eldest child gets the rank 0. + >>> is_child = person.has_role( + ... Household.CHILD + ... ) # [False, False, True, True, True] + >>> person.get_rank( + ... household, -age, condition=is_child + ... ) # Sort in reverse order so that the eldest child gets the rank 0. >>> [-1, -1, 1, 0, 2] - """ + """ # If entity is for instance 'person.household', we get the reference entity 'household' behind the projector entity = ( entity @@ -279,7 +272,7 @@ def get_rank( [ entity.value_nth_person(k, filtered_criteria, default=numpy.inf) for k in range(biggest_entity_size) - ] + ], ).transpose() # We double-argsort all lines of the matrix. @@ -297,9 +290,9 @@ def get_rank( class Calculate(NamedTuple): variable: str period: Period - option: Optional[Sequence[str]] + option: Sequence[str] | None class MemoryUsageByVariable(TypedDict, total=False): - by_variable: Dict[str, holders.MemoryUsage] + by_variable: dict[str, holders.MemoryUsage] total_nb_bytes: int diff --git a/openfisca_core/projectors/entity_to_person_projector.py b/openfisca_core/projectors/entity_to_person_projector.py index ca6245a1f7..392fda08a1 100644 --- a/openfisca_core/projectors/entity_to_person_projector.py +++ b/openfisca_core/projectors/entity_to_person_projector.py @@ -4,7 +4,7 @@ class EntityToPersonProjector(Projector): """For instance person.family.""" - def __init__(self, entity, parent=None): + def __init__(self, entity, parent=None) -> None: self.reference_entity = entity self.parent = parent diff --git a/openfisca_core/projectors/first_person_to_entity_projector.py b/openfisca_core/projectors/first_person_to_entity_projector.py index 4b4e7b7994..d986460cdc 100644 --- a/openfisca_core/projectors/first_person_to_entity_projector.py +++ b/openfisca_core/projectors/first_person_to_entity_projector.py @@ -4,7 +4,7 @@ class FirstPersonToEntityProjector(Projector): """For instance famille.first_person.""" - def __init__(self, entity, parent=None): + def __init__(self, entity, parent=None) -> None: self.target_entity = entity self.reference_entity = entity.members self.parent = parent diff --git a/openfisca_core/projectors/helpers.py b/openfisca_core/projectors/helpers.py index b3b7e6f2d3..4c7712106a 100644 --- a/openfisca_core/projectors/helpers.py +++ b/openfisca_core/projectors/helpers.py @@ -10,8 +10,7 @@ def projectable(function): - """ - Decorator to indicate that when called on a projector, the outcome of the function must be projected. + """Decorator to indicate that when called on a projector, the outcome of the function must be projected. For instance person.household.sum(...) must be projected on person, while it would not make sense for person.household.get_holder. """ function.projectable = True @@ -109,15 +108,15 @@ def get_projector_from_shortcut( <...UniqueRoleToEntityProjector object at ...> """ - entity: SingleEntity | GroupEntity = population.entity if isinstance(entity, entities.Entity): populations: Mapping[ - str, Population | GroupPopulation + str, + Population | GroupPopulation, ] = population.simulation.populations - if shortcut not in populations.keys(): + if shortcut not in populations: return None return projectors.EntityToPersonProjector(populations[shortcut], parent) @@ -133,7 +132,8 @@ def get_projector_from_shortcut( if shortcut in entity.containing_entities: projector: projectors.Projector = getattr( - projectors.FirstPersonToEntityProjector(population, parent), shortcut + projectors.FirstPersonToEntityProjector(population, parent), + shortcut, ) return projector diff --git a/openfisca_core/projectors/projector.py b/openfisca_core/projectors/projector.py index 5ab5f6d958..37881201dc 100644 --- a/openfisca_core/projectors/projector.py +++ b/openfisca_core/projectors/projector.py @@ -7,7 +7,9 @@ class Projector: def __getattr__(self, attribute): projector = helpers.get_projector_from_shortcut( - self.reference_entity, attribute, parent=self + self.reference_entity, + attribute, + parent=self, ) if projector: return projector @@ -30,8 +32,7 @@ def transform_and_bubble_up(self, result): transformed_result = self.transform(result) if self.parent is None: return transformed_result - else: - return self.parent.transform_and_bubble_up(transformed_result) + return self.parent.transform_and_bubble_up(transformed_result) def transform(self, result): return NotImplementedError() diff --git a/openfisca_core/projectors/typing.py b/openfisca_core/projectors/typing.py index 186f90e30c..a49bc96621 100644 --- a/openfisca_core/projectors/typing.py +++ b/openfisca_core/projectors/typing.py @@ -8,25 +8,20 @@ class Population(Protocol): @property - def entity(self) -> SingleEntity: - ... + def entity(self) -> SingleEntity: ... @property - def simulation(self) -> Simulation: - ... + def simulation(self) -> Simulation: ... class GroupPopulation(Protocol): @property - def entity(self) -> GroupEntity: - ... + def entity(self) -> GroupEntity: ... @property - def simulation(self) -> Simulation: - ... + def simulation(self) -> Simulation: ... class Simulation(Protocol): @property - def populations(self) -> Mapping[str, Population | GroupPopulation]: - ... + def populations(self) -> Mapping[str, Population | GroupPopulation]: ... diff --git a/openfisca_core/projectors/unique_role_to_entity_projector.py b/openfisca_core/projectors/unique_role_to_entity_projector.py index fed2f249ca..c565484339 100644 --- a/openfisca_core/projectors/unique_role_to_entity_projector.py +++ b/openfisca_core/projectors/unique_role_to_entity_projector.py @@ -4,7 +4,7 @@ class UniqueRoleToEntityProjector(Projector): """For instance famille.declarant_principal.""" - def __init__(self, entity, role, parent=None): + def __init__(self, entity, role, parent=None) -> None: self.target_entity = entity self.reference_entity = entity.members self.parent = parent diff --git a/openfisca_core/reforms/reform.py b/openfisca_core/reforms/reform.py index 8c179596ed..76e7152334 100644 --- a/openfisca_core/reforms/reform.py +++ b/openfisca_core/reforms/reform.py @@ -7,23 +7,22 @@ class Reform(TaxBenefitSystem): - """A modified TaxBenefitSystem + """A modified TaxBenefitSystem. All reforms must subclass `Reform` and implement a method `apply()`. In this method, the reform can add or replace variables and call `modify_parameters` to modify the parameters of the legislation. - Example: - + Example: >>> from openfisca_core import reforms >>> from openfisca_core.parameters import load_parameter_file >>> >>> def modify_my_parameters(parameters): - >>> # Add new parameters + >>> # Add new parameters >>> new_parameters = load_parameter_file(name='reform_name', file_path='path_to_yaml_file.yaml') >>> parameters.add_child('reform_name', new_parameters) >>> - >>> # Update a value + >>> # Update a value >>> parameters.taxes.some_tax.some_param.update(period=some_period, value=1000.0) >>> >>> return parameters @@ -33,14 +32,13 @@ class Reform(TaxBenefitSystem): >>> self.add_variable(some_variable) >>> self.update_variable(some_other_variable) >>> self.modify_parameters(modifier_function = modify_my_parameters) + """ name = None - def __init__(self, baseline): - """ - :param baseline: Baseline TaxBenefitSystem. - """ + def __init__(self, baseline) -> None: + """:param baseline: Baseline TaxBenefitSystem.""" super().__init__(baseline.entities) self.baseline = baseline self.parameters = baseline.parameters @@ -49,9 +47,8 @@ def __init__(self, baseline): self.decomposition_file_path = baseline.decomposition_file_path self.key = self.__class__.__name__ if not hasattr(self, "apply"): - raise Exception( - "Reform {} must define an `apply` function".format(self.key) - ) + msg = f"Reform {self.key} must define an `apply` function" + raise Exception(msg) self.apply() def __getattr__(self, attribute): @@ -60,12 +57,12 @@ def __getattr__(self, attribute): @property def full_key(self): key = self.key - assert key is not None, "key was not set for reform {} (name: {!r})".format( - self, self.name - ) + assert ( + key is not None + ), f"key was not set for reform {self} (name: {self.name!r})" if self.baseline is not None and hasattr(self.baseline, "key"): baseline_full_key = self.baseline.full_key - key = ".".join([baseline_full_key, key]) + key = f"{baseline_full_key}.{key}" return key def modify_parameters(self, modifier_function): @@ -75,16 +72,15 @@ def modify_parameters(self, modifier_function): Args: modifier_function: A function that takes a :obj:`.ParameterNode` and should return an object of the same type. + """ baseline_parameters = self.baseline.parameters baseline_parameters_copy = copy.deepcopy(baseline_parameters) reform_parameters = modifier_function(baseline_parameters_copy) if not isinstance(reform_parameters, ParameterNode): return ValueError( - "modifier_function {} in module {} must return a ParameterNode".format( - modifier_function.__name__, - modifier_function.__module__, - ) + f"modifier_function {modifier_function.__name__} in module {modifier_function.__module__} must return a ParameterNode", ) self.parameters = reform_parameters self._parameters_at_instant_cache = {} + return None diff --git a/openfisca_core/scripts/__init__.py b/openfisca_core/scripts/__init__.py index 6366c8df15..e9080f2381 100644 --- a/openfisca_core/scripts/__init__.py +++ b/openfisca_core/scripts/__init__.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import importlib import logging import pkgutil @@ -17,7 +15,11 @@ def add_tax_benefit_system_arguments(parser): help='country package to use. If not provided, an automatic detection will be attempted by scanning the python packages installed in your environment which name contains the word "openfisca".', ) parser.add_argument( - "-e", "--extensions", action="store", help="extensions to load", nargs="*" + "-e", + "--extensions", + action="store", + help="extensions to load", + nargs="*", ) parser.add_argument( "-r", @@ -39,18 +41,17 @@ def build_tax_benefit_system(country_package_name, extensions, reforms): message = linesep.join( [ traceback.format_exc(), - "Could not import module `{}`.".format(country_package_name), + f"Could not import module `{country_package_name}`.", "Are you sure it is installed in your environment? If so, look at the stack trace above to determine the origin of this error.", "See more at .", - ] + ], ) raise ImportError(message) if not hasattr(country_package, "CountryTaxBenefitSystem"): + msg = f"`{country_package_name}` does not seem to be a valid Openfisca country package." raise ImportError( - "`{}` does not seem to be a valid Openfisca country package.".format( - country_package_name - ) + msg, ) country_package = importlib.import_module(country_package_name) @@ -82,22 +83,24 @@ def detect_country_package(): message = linesep.join( [ traceback.format_exc(), - "Could not import module `{}`.".format(module_name), + f"Could not import module `{module_name}`.", "Look at the stack trace above to determine the error that stopped installed modules detection.", - ] + ], ) raise ImportError(message) if hasattr(module, "CountryTaxBenefitSystem"): installed_country_packages.append(module_name) if len(installed_country_packages) == 0: + msg = "No country package has been detected on your environment. If your country package is installed but not detected, please use the --country-package option." raise ImportError( - "No country package has been detected on your environment. If your country package is installed but not detected, please use the --country-package option." + msg, ) if len(installed_country_packages) > 1: log.warning( "Several country packages detected : `{}`. Using `{}` by default. To use another package, please use the --country-package option.".format( - ", ".join(installed_country_packages), installed_country_packages[0] - ) + ", ".join(installed_country_packages), + installed_country_packages[0], + ), ) return installed_country_packages[0] diff --git a/openfisca_core/scripts/find_placeholders.py b/openfisca_core/scripts/find_placeholders.py index 37f31f6727..b7b5a81969 100644 --- a/openfisca_core/scripts/find_placeholders.py +++ b/openfisca_core/scripts/find_placeholders.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # flake8: noqa T001 import fnmatch @@ -10,7 +9,7 @@ def find_param_files(input_dir): param_files = [] - for root, dirnames, filenames in os.walk(input_dir): + for root, _dirnames, filenames in os.walk(input_dir): for filename in fnmatch.filter(filenames, "*.xml"): param_files.append(os.path.join(root, filename)) @@ -18,7 +17,7 @@ def find_param_files(input_dir): def find_placeholders(filename_input): - with open(filename_input, "r") as f: + with open(filename_input) as f: xml_content = f.read() xml_parsed = BeautifulSoup(xml_content, "lxml-xml") @@ -29,26 +28,17 @@ def find_placeholders(filename_input): for placeholder in placeholders: parent_list = list(placeholder.parents)[:-1] path = ".".join( - [p.attrs["code"] for p in parent_list if "code" in p.attrs][::-1] + [p.attrs["code"] for p in parent_list if "code" in p.attrs][::-1], ) deb = placeholder.attrs["deb"] output_list.append((deb, path)) - output_list = sorted(output_list, key=lambda x: x[0]) - - return output_list + return sorted(output_list, key=lambda x: x[0]) if __name__ == "__main__": - print( - """find_placeholders.py : Find nodes PLACEHOLDER in xml parameter files -Usage : - python find_placeholders /dir/to/search -""" - ) - assert len(sys.argv) == 2 input_dir = sys.argv[1] @@ -57,9 +47,5 @@ def find_placeholders(filename_input): for filename_input in param_files: output_list = find_placeholders(filename_input) - print("File {}".format(filename_input)) - - for deb, path in output_list: - print("{} {}".format(deb, path)) - - print("\n") + for _deb, _path in output_list: + pass diff --git a/openfisca_core/scripts/measure_numpy_condition_notations.py b/openfisca_core/scripts/measure_numpy_condition_notations.py index f737413bf4..65e48f6e2c 100755 --- a/openfisca_core/scripts/measure_numpy_condition_notations.py +++ b/openfisca_core/scripts/measure_numpy_condition_notations.py @@ -1,16 +1,15 @@ #! /usr/bin/env python -# -*- coding: utf-8 -*- # flake8: noqa T001 -""" -Measure and compare different vectorial condition notations: +"""Measure and compare different vectorial condition notations: - using multiplication notation: (choice == 1) * choice_1_value + (choice == 2) * choice_2_value - using numpy.select: the same than multiplication but more idiomatic like a "switch" control-flow statement -- using numpy.fromiter: iterates in Python over the array and calculates lazily only the required values +- using numpy.fromiter: iterates in Python over the array and calculates lazily only the required values. The aim of this script is to compare the time taken by the calculation of the values """ + import argparse import sys import time @@ -23,10 +22,9 @@ @contextmanager def measure_time(title): - t1 = time.time() + time.time() yield - t2 = time.time() - print("{}\t: {:.8f} seconds elapsed".format(title, t2 - t1)) + time.time() def switch_fromiter(conditions, function_by_condition, dtype): @@ -45,21 +43,21 @@ def get_or_store_value(condition): def switch_select(conditions, value_by_condition): - condlist = [conditions == condition for condition in value_by_condition.keys()] + condlist = [conditions == condition for condition in value_by_condition] return numpy.select(condlist, value_by_condition.values()) -def calculate_choice_1_value(): +def calculate_choice_1_value() -> int: time.sleep(args.calculate_time) return 80 -def calculate_choice_2_value(): +def calculate_choice_2_value() -> int: time.sleep(args.calculate_time) return 90 -def calculate_choice_3_value(): +def calculate_choice_3_value() -> int: time.sleep(args.calculate_time) return 95 @@ -68,32 +66,30 @@ def test_multiplication(choice): choice_1_value = calculate_choice_1_value() choice_2_value = calculate_choice_2_value() choice_3_value = calculate_choice_3_value() - result = ( + return ( (choice == 1) * choice_1_value + (choice == 2) * choice_2_value + (choice == 3) * choice_3_value ) - return result def test_switch_fromiter(choice): - result = switch_fromiter( + return switch_fromiter( choice, { 1: calculate_choice_1_value, 2: calculate_choice_2_value, 3: calculate_choice_3_value, }, - dtype=numpy.int, + dtype=int, ) - return result def test_switch_select(choice): choice_1_value = calculate_choice_1_value() choice_2_value = calculate_choice_2_value() choice_3_value = calculate_choice_2_value() - result = switch_select( + return switch_select( choice, { 1: choice_1_value, @@ -101,10 +97,9 @@ def test_switch_select(choice): 3: choice_3_value, }, ) - return result -def test_all_notations(): +def test_all_notations() -> None: # choice is an array with 1 and 2 items like [2, 1, ..., 1, 2] choice = numpy.random.randint(2, size=args.array_length) + 1 @@ -118,10 +113,13 @@ def test_all_notations(): test_switch_fromiter(choice) -def main(): +def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - "--array-length", default=1000, type=int, help="length of the array" + "--array-length", + default=1000, + type=int, + help="length of the array", ) parser.add_argument( "--calculate-time", @@ -132,7 +130,6 @@ def main(): global args args = parser.parse_args() - print(args) test_all_notations() diff --git a/openfisca_core/scripts/measure_performances.py b/openfisca_core/scripts/measure_performances.py index 89fd47b441..48b99c93f8 100644 --- a/openfisca_core/scripts/measure_performances.py +++ b/openfisca_core/scripts/measure_performances.py @@ -1,15 +1,15 @@ #! /usr/bin/env python -# -*- coding: utf-8 -*- # flake8: noqa T001 """Measure performances of a basic tax-benefit system to compare to other OpenFisca implementations.""" + import argparse import logging import sys import time -import numpy as np +import numpy from numpy.core.defchararray import startswith from openfisca_core import periods, simulations @@ -24,11 +24,9 @@ def timeit(method): def timed(*args, **kwargs): - start_time = time.time() - result = method(*args, **kwargs) + time.time() + return method(*args, **kwargs) # print '%r (%r, %r) %2.9f s' % (method.__name__, args, kw, time.time() - start_time) - print("{:2.6f} s".format(time.time() - start_time)) - return result return timed @@ -106,7 +104,7 @@ def formula(self, simulation, period): if age_en_mois is not None: return age_en_mois // 12 birth = simulation.calculate("birth", period) - return (np.datetime64(period.date) - birth).astype("timedelta64[Y]") + return (numpy.datetime64(period.date) - birth).astype("timedelta64[Y]") class dom_tom(Variable): @@ -117,7 +115,9 @@ class dom_tom(Variable): def formula(self, simulation, period): period = period.start.period(DateUnit.YEAR).offset("first-of") city_code = simulation.calculate("city_code", period) - return np.logical_or(startswith(city_code, "97"), startswith(city_code, "98")) + return numpy.logical_or( + startswith(city_code, "97"), startswith(city_code, "98") + ) class revenu_disponible(Variable): @@ -158,10 +158,10 @@ class salaire_imposable(Variable): entity = Individu label = "Salaire imposable" - def formula(individu, period): + def formula(self, period): period = period.start.period(DateUnit.YEAR).offset("first-of") - dom_tom = individu.famille("dom_tom", period) - salaire_net = individu("salaire_net", period) + dom_tom = self.famille("dom_tom", period) + salaire_net = self("salaire_net", period) return salaire_net * 0.9 - 100 * dom_tom @@ -195,9 +195,10 @@ def formula(self, simulation, period): @timeit -def check_revenu_disponible(year, city_code, expected_revenu_disponible): +def check_revenu_disponible(year, city_code, expected_revenu_disponible) -> None: simulation = simulations.Simulation( - period=periods.period(year), tax_benefit_system=tax_benefit_system + period=periods.period(year), + tax_benefit_system=tax_benefit_system, ) famille = simulation.populations["famille"] famille.count = 3 @@ -206,20 +207,22 @@ def check_revenu_disponible(year, city_code, expected_revenu_disponible): individu = simulation.populations["individu"] individu.count = 6 individu.step_size = 2 - simulation.get_or_new_holder("city_code").array = np.array( - [city_code, city_code, city_code] + simulation.get_or_new_holder("city_code").array = numpy.array( + [city_code, city_code, city_code], ) - famille.members_entity_id = np.array([0, 0, 1, 1, 2, 2]) - simulation.get_or_new_holder("salaire_brut").array = np.array( - [0.0, 0.0, 50000.0, 0.0, 100000.0, 0.0] + famille.members_entity_id = numpy.array([0, 0, 1, 1, 2, 2]) + simulation.get_or_new_holder("salaire_brut").array = numpy.array( + [0.0, 0.0, 50000.0, 0.0, 100000.0, 0.0], ) revenu_disponible = simulation.calculate("revenu_disponible") assert_near( - revenu_disponible, expected_revenu_disponible, absolute_error_margin=0.005 + revenu_disponible, + expected_revenu_disponible, + absolute_error_margin=0.005, ) -def main(): +def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "-v", @@ -231,37 +234,56 @@ def main(): global args args = parser.parse_args() logging.basicConfig( - level=logging.DEBUG if args.verbose else logging.WARNING, stream=sys.stdout + level=logging.DEBUG if args.verbose else logging.WARNING, + stream=sys.stdout, ) - check_revenu_disponible(2009, "75101", np.array([0, 0, 25200, 0, 50400, 0])) + check_revenu_disponible(2009, "75101", numpy.array([0, 0, 25200, 0, 50400, 0])) check_revenu_disponible( - 2010, "75101", np.array([1200, 1200, 25200, 1200, 50400, 1200]) + 2010, + "75101", + numpy.array([1200, 1200, 25200, 1200, 50400, 1200]), ) check_revenu_disponible( - 2011, "75101", np.array([2400, 2400, 25200, 2400, 50400, 2400]) + 2011, + "75101", + numpy.array([2400, 2400, 25200, 2400, 50400, 2400]), ) check_revenu_disponible( - 2012, "75101", np.array([2400, 2400, 25200, 2400, 50400, 2400]) + 2012, + "75101", + numpy.array([2400, 2400, 25200, 2400, 50400, 2400]), ) check_revenu_disponible( - 2013, "75101", np.array([3600, 3600, 25200, 3600, 50400, 3600]) + 2013, + "75101", + numpy.array([3600, 3600, 25200, 3600, 50400, 3600]), ) check_revenu_disponible( - 2009, "97123", np.array([-70.0, -70.0, 25130.0, -70.0, 50330.0, -70.0]) + 2009, + "97123", + numpy.array([-70.0, -70.0, 25130.0, -70.0, 50330.0, -70.0]), ) check_revenu_disponible( - 2010, "97123", np.array([1130.0, 1130.0, 25130.0, 1130.0, 50330.0, 1130.0]) + 2010, + "97123", + numpy.array([1130.0, 1130.0, 25130.0, 1130.0, 50330.0, 1130.0]), ) check_revenu_disponible( - 2011, "98456", np.array([2330.0, 2330.0, 25130.0, 2330.0, 50330.0, 2330.0]) + 2011, + "98456", + numpy.array([2330.0, 2330.0, 25130.0, 2330.0, 50330.0, 2330.0]), ) check_revenu_disponible( - 2012, "98456", np.array([2330.0, 2330.0, 25130.0, 2330.0, 50330.0, 2330.0]) + 2012, + "98456", + numpy.array([2330.0, 2330.0, 25130.0, 2330.0, 50330.0, 2330.0]), ) check_revenu_disponible( - 2013, "98456", np.array([3530.0, 3530.0, 25130.0, 3530.0, 50330.0, 3530.0]) + 2013, + "98456", + numpy.array([3530.0, 3530.0, 25130.0, 3530.0, 50330.0, 3530.0]), ) diff --git a/openfisca_core/scripts/measure_performances_fancy_indexing.py b/openfisca_core/scripts/measure_performances_fancy_indexing.py index 030b1af7aa..7c261e2fe3 100644 --- a/openfisca_core/scripts/measure_performances_fancy_indexing.py +++ b/openfisca_core/scripts/measure_performances_fancy_indexing.py @@ -2,24 +2,25 @@ import timeit -import numpy as np +import numpy from openfisca_france import CountryTaxBenefitSystem tbs = CountryTaxBenefitSystem() N = 200000 al_plaf_acc = tbs.get_parameters_at_instant("2015-01-01").prestations.al_plaf_acc -zone_apl = np.random.choice([1, 2, 3], N) -al_nb_pac = np.random.choice(6, N) -couple = np.random.choice([True, False], N) +zone_apl = numpy.random.choice([1, 2, 3], N) +al_nb_pac = numpy.random.choice(6, N) +couple = numpy.random.choice([True, False], N) formatted_zone = concat( - "plafond_pour_accession_a_la_propriete_zone_", zone_apl + "plafond_pour_accession_a_la_propriete_zone_", + zone_apl, ) # zone_apl returns 1, 2 or 3 but the parameters have a long name def formula_with(): plafonds = al_plaf_acc[formatted_zone] - result = ( + return ( plafonds.personne_isolee_sans_enfant * not_(couple) * (al_nb_pac == 0) + plafonds.menage_seul * couple * (al_nb_pac == 0) + plafonds.menage_ou_isole_avec_1_enfant * (al_nb_pac == 1) @@ -32,8 +33,6 @@ def formula_with(): * (al_nb_pac - 5) ) - return result - def formula_without(): z1 = al_plaf_acc.plafond_pour_accession_a_la_propriete_zone_1 @@ -79,14 +78,12 @@ def formula_without(): if __name__ == "__main__": time_with = timeit.timeit( - "formula_with()", setup="from __main__ import formula_with", number=50 + "formula_with()", + setup="from __main__ import formula_with", + number=50, ) time_without = timeit.timeit( - "formula_without()", setup="from __main__ import formula_without", number=50 - ) - - print("Computing with dynamic legislation computing took {}".format(time_with)) - print( - "Computing without dynamic legislation computing took {}".format(time_without) + "formula_without()", + setup="from __main__ import formula_without", + number=50, ) - print("Ratio: {}".format(time_with / time_without)) diff --git a/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_country_template.py b/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_country_template.py index 3ff9c3d7ac..38538d644a 100644 --- a/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_country_template.py +++ b/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_country_template.py @@ -1,12 +1,11 @@ -# -*- coding: utf-8 -*- - -""" xml_to_yaml_country_template.py : Parse XML parameter files for Country-Template and convert them to YAML files. Comments are NOT transformed. +"""xml_to_yaml_country_template.py : Parse XML parameter files for Country-Template and convert them to YAML files. Comments are NOT transformed. Usage : `python xml_to_yaml_country_template.py output_dir` or just (output is written in a directory called `yaml_parameters`): `python xml_to_yaml_country_template.py` """ + import os import sys @@ -16,10 +15,7 @@ tax_benefit_system = CountryTaxBenefitSystem() -if len(sys.argv) > 1: - target_path = sys.argv[1] -else: - target_path = "yaml_parameters" +target_path = sys.argv[1] if len(sys.argv) > 1 else "yaml_parameters" param_dir = os.path.join(COUNTRY_DIR, "parameters") param_files = [ diff --git a/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_extension_template.py b/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_extension_template.py index 34b7ca430d..0b57c19016 100644 --- a/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_extension_template.py +++ b/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_extension_template.py @@ -1,6 +1,4 @@ -# -*- coding: utf-8 -*- - -""" xml_to_yaml_extension_template.py : Parse XML parameter files for Extension-Template and convert them to YAML files. Comments are NOT transformed. +"""xml_to_yaml_extension_template.py : Parse XML parameter files for Extension-Template and convert them to YAML files. Comments are NOT transformed. Usage : `python xml_to_yaml_extension_template.py output_dir` @@ -15,10 +13,7 @@ from . import xml_to_yaml -if len(sys.argv) > 1: - target_path = sys.argv[1] -else: - target_path = "yaml_parameters" +target_path = sys.argv[1] if len(sys.argv) > 1 else "yaml_parameters" param_dir = os.path.dirname(openfisca_extension_template.__file__) param_files = [ diff --git a/openfisca_core/scripts/migrations/v24_to_25.py b/openfisca_core/scripts/migrations/v24_to_25.py index 1eefd426ad..08bbeddc3b 100644 --- a/openfisca_core/scripts/migrations/v24_to_25.py +++ b/openfisca_core/scripts/migrations/v24_to_25.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # flake8: noqa T001 import argparse @@ -33,21 +32,21 @@ def build_parser(): parser = argparse.ArgumentParser() parser.add_argument( - "path", help="paths (files or directories) of tests to execute", nargs="+" + "path", + help="paths (files or directories) of tests to execute", + nargs="+", ) - parser = add_tax_benefit_system_arguments(parser) + return add_tax_benefit_system_arguments(parser) - return parser - -class Migrator(object): - def __init__(self, tax_benefit_system): +class Migrator: + def __init__(self, tax_benefit_system) -> None: self.tax_benefit_system = tax_benefit_system self.entities_by_plural = { entity.plural: entity for entity in self.tax_benefit_system.entities } - def migrate(self, path): + def migrate(self, path) -> None: if isinstance(path, list): for item in path: self.migrate(item) @@ -65,8 +64,6 @@ def migrate(self, path): return - print("Migrating {}.".format(path)) - with open(path) as yaml_file: tests = yaml.safe_load(yaml_file) if isinstance(tests, CommentedSeq): @@ -107,14 +104,12 @@ def convert_inputs(self, inputs): continue results[entity_plural] = self.convert_entities(entity, entities_description) - results = self.generate_missing_entities(results) - - return results + return self.generate_missing_entities(results) def convert_entities(self, entity, entities_description): return { - entity_description.get("id", "{}_{}".format(entity.key, index)): remove_id( - entity_description + entity_description.get("id", f"{entity.key}_{index}"): remove_id( + entity_description, ) for index, entity_description in enumerate(entities_description) } @@ -127,12 +122,12 @@ def generate_missing_entities(self, inputs): if len(persons) == 1: person_id = next(iter(persons)) inputs[entity.key] = { - entity.roles[0].plural or entity.roles[0].key: [person_id] + entity.roles[0].plural or entity.roles[0].key: [person_id], } else: inputs[entity.plural] = { - "{}_{}".format(entity.key, index): { - entity.roles[0].plural or entity.roles[0].key: [person_id] + f"{entity.key}_{index}": { + entity.roles[0].plural or entity.roles[0].key: [person_id], } for index, person_id in enumerate(persons.keys()) } @@ -143,13 +138,15 @@ def remove_id(input_dict): return {key: value for (key, value) in input_dict.items() if key != "id"} -def main(): +def main() -> None: parser = build_parser() args = parser.parse_args() paths = [os.path.abspath(path) for path in args.path] tax_benefit_system = build_tax_benefit_system( - args.country_package, args.extensions, args.reforms + args.country_package, + args.extensions, + args.reforms, ) Migrator(tax_benefit_system).migrate(paths) diff --git a/openfisca_core/scripts/openfisca_command.py b/openfisca_core/scripts/openfisca_command.py index 3b835e73a3..d82e0aef61 100644 --- a/openfisca_core/scripts/openfisca_command.py +++ b/openfisca_core/scripts/openfisca_command.py @@ -30,7 +30,10 @@ def build_serve_parser(parser): type=int, ) parser.add_argument( - "--tracker-url", action="store", help="tracking service url", type=str + "--tracker-url", + action="store", + help="tracking service url", + type=str, ) parser.add_argument( "--tracker-idsite", @@ -65,7 +68,9 @@ def build_serve_parser(parser): def build_test_parser(parser): parser.add_argument( - "path", help="paths (files or directories) of tests to execute", nargs="+" + "path", + help="paths (files or directories) of tests to execute", + nargs="+", ) parser = add_tax_benefit_system_arguments(parser) parser.add_argument( @@ -156,6 +161,7 @@ def main(): from openfisca_core.scripts.run_test import main return sys.exit(main(parser)) + return None if __name__ == "__main__": diff --git a/openfisca_core/scripts/remove_fuzzy.py b/openfisca_core/scripts/remove_fuzzy.py index 2c06b149b1..a4827aef39 100755 --- a/openfisca_core/scripts/remove_fuzzy.py +++ b/openfisca_core/scripts/remove_fuzzy.py @@ -10,7 +10,7 @@ assert len(sys.argv) == 2 filename = sys.argv[1] -with open(filename, "r") as f: +with open(filename) as f: lines = f.readlines() diff --git a/openfisca_core/scripts/run_test.py b/openfisca_core/scripts/run_test.py index ab292c4165..458dc7e50e 100644 --- a/openfisca_core/scripts/run_test.py +++ b/openfisca_core/scripts/run_test.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import logging import os import sys @@ -8,14 +6,17 @@ from openfisca_core.tools.test_runner import run_tests -def main(parser): +def main(parser) -> None: args = parser.parse_args() logging.basicConfig( - level=logging.DEBUG if args.verbose else logging.WARNING, stream=sys.stdout + level=logging.DEBUG if args.verbose else logging.WARNING, + stream=sys.stdout, ) tax_benefit_system = build_tax_benefit_system( - args.country_package, args.extensions, args.reforms + args.country_package, + args.extensions, + args.reforms, ) options = { diff --git a/openfisca_core/scripts/simulation_generator.py b/openfisca_core/scripts/simulation_generator.py index 4f2c5bfd4f..eca2fa30d1 100644 --- a/openfisca_core/scripts/simulation_generator.py +++ b/openfisca_core/scripts/simulation_generator.py @@ -6,21 +6,22 @@ def make_simulation(tax_benefit_system, nb_persons, nb_groups, **kwargs): - """ - Generate a simulation containing nb_persons persons spread in nb_groups groups. + """Generate a simulation containing nb_persons persons spread in nb_groups groups. Example: - >>> from openfisca_core.scripts.simulation_generator import make_simulation >>> from openfisca_france import CountryTaxBenefitSystem >>> tbs = CountryTaxBenefitSystem() - >>> simulation = make_simulation(tbs, 400, 100) # Create a simulation with 400 persons, spread among 100 families - >>> simulation.calculate('revenu_disponible', 2017) + >>> simulation = make_simulation( + ... tbs, 400, 100 + ... ) # Create a simulation with 400 persons, spread among 100 families + >>> simulation.calculate("revenu_disponible", 2017) + """ simulation = Simulation(tax_benefit_system=tax_benefit_system, **kwargs) simulation.persons.ids = numpy.arange(nb_persons) simulation.persons.count = nb_persons - adults = [0] + sorted(random.sample(range(1, nb_persons), nb_groups - 1)) + adults = [0, *sorted(random.sample(range(1, nb_persons), nb_groups - 1))] members_entity_id = numpy.empty(nb_persons, dtype=int) @@ -50,26 +51,40 @@ def make_simulation(tax_benefit_system, nb_persons, nb_groups, **kwargs): def randomly_init_variable( - simulation, variable_name: str, period, max_value, condition=None -): - """ - Initialise a variable with random values (from 0 to max_value) for the given period. + simulation, + variable_name: str, + period, + max_value, + condition=None, +) -> None: + """Initialise a variable with random values (from 0 to max_value) for the given period. If a condition vector is provided, only set the value of persons or groups for which condition is True. Example: - - >>> from openfisca_core.scripts.simulation_generator import make_simulation, randomly_init_variable + >>> from openfisca_core.scripts.simulation_generator import ( + ... make_simulation, + ... randomly_init_variable, + ... ) >>> from openfisca_france import CountryTaxBenefitSystem >>> tbs = CountryTaxBenefitSystem() - >>> simulation = make_simulation(tbs, 400, 100) # Create a simulation with 400 persons, spread among 100 families - >>> randomly_init_variable(simulation, 'salaire_net', 2017, max_value = 50000, condition = simulation.persons.has_role(simulation.famille.DEMANDEUR)) # Randomly set a salaire_net for all persons between 0 and 50000? - >>> simulation.calculate('revenu_disponible', 2017) + >>> simulation = make_simulation( + ... tbs, 400, 100 + ... ) # Create a simulation with 400 persons, spread among 100 families + >>> randomly_init_variable( + ... simulation, + ... "salaire_net", + ... 2017, + ... max_value=50000, + ... condition=simulation.persons.has_role(simulation.famille.DEMANDEUR), + ... ) # Randomly set a salaire_net for all persons between 0 and 50000? + >>> simulation.calculate("revenu_disponible", 2017) + """ if condition is None: condition = True variable = simulation.tax_benefit_system.get_variable(variable_name) population = simulation.get_variable_population(variable_name) value = (numpy.random.rand(population.count) * max_value * condition).astype( - variable.dtype + variable.dtype, ) simulation.set_input(variable_name, period, value) diff --git a/openfisca_core/simulations/__init__.py b/openfisca_core/simulations/__init__.py index 670b922ebb..9ab10f81a7 100644 --- a/openfisca_core/simulations/__init__.py +++ b/openfisca_core/simulations/__init__.py @@ -21,20 +21,16 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from openfisca_core.errors import ( # noqa: F401 - CycleError, - NaNCreationError, - SpiralError, -) +from openfisca_core.errors import CycleError, NaNCreationError, SpiralError -from .helpers import ( # noqa: F401 +from .helpers import ( calculate_output_add, calculate_output_divide, check_type, transform_to_strict_syntax, ) -from .simulation import Simulation # noqa: F401 -from .simulation_builder import SimulationBuilder # noqa: F401 +from .simulation import Simulation +from .simulation_builder import SimulationBuilder __all__ = [ "CycleError", diff --git a/openfisca_core/simulations/_build_default_simulation.py b/openfisca_core/simulations/_build_default_simulation.py index f99c1d210a..adc7cf4783 100644 --- a/openfisca_core/simulations/_build_default_simulation.py +++ b/openfisca_core/simulations/_build_default_simulation.py @@ -27,9 +27,9 @@ class _BuildDefaultSimulation: >>> count = 1 >>> builder = ( ... _BuildDefaultSimulation(tax_benefit_system, count) - ... .add_count() - ... .add_ids() - ... .add_members_entity_id() + ... .add_count() + ... .add_ids() + ... .add_members_entity_id() ... ) >>> builder.count @@ -84,7 +84,6 @@ def add_count(self) -> Self: 2 """ - for population in self.populations.values(): population.count = self.count @@ -117,7 +116,6 @@ def add_ids(self) -> Self: array([0, 1]) """ - for population in self.populations.values(): population.ids = numpy.array(range(self.count)) @@ -154,7 +152,6 @@ def add_members_entity_id(self) -> Self: array([0, 1]) """ - for population in self.populations.values(): if hasattr(population, "members_entity_id"): population.members_entity_id = numpy.array(range(self.count)) diff --git a/openfisca_core/simulations/_build_from_variables.py b/openfisca_core/simulations/_build_from_variables.py index 60ff6148e7..20f49ce113 100644 --- a/openfisca_core/simulations/_build_from_variables.py +++ b/openfisca_core/simulations/_build_from_variables.py @@ -139,7 +139,6 @@ def add_dated_values(self) -> Self: >>> pack.get_array(period) """ - for variable, value in self.variables.items(): if is_variable_dated(dated_variable := value): for period, dated_value in dated_variable.items(): @@ -197,7 +196,6 @@ def add_undated_values(self) -> Self: array([5000], dtype=int32) """ - for variable, value in self.variables.items(): if not is_variable_dated(undated_value := value): if (period := self.default_period) is None: diff --git a/openfisca_core/simulations/_type_guards.py b/openfisca_core/simulations/_type_guards.py index c34361041a..990248213d 100644 --- a/openfisca_core/simulations/_type_guards.py +++ b/openfisca_core/simulations/_type_guards.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Iterable +from collections.abc import Iterable from typing_extensions import TypeGuard from .typing import ( @@ -17,7 +17,8 @@ def are_entities_fully_specified( - params: Params, items: Iterable[str] + params: Params, + items: Iterable[str], ) -> TypeGuard[FullySpecifiedEntities]: """Check if the params contain fully specified entities. @@ -33,28 +34,32 @@ def are_entities_fully_specified( >>> params = { ... "axes": [ - ... [{"count": 2, "max": 3000, "min": 0, "name": "rent", "period": "2018-11"}] + ... [ + ... { + ... "count": 2, + ... "max": 3000, + ... "min": 0, + ... "name": "rent", + ... "period": "2018-11", + ... } + ... ] ... ], ... "households": { ... "housea": {"parents": ["Alicia", "Javier"]}, ... "houseb": {"parents": ["Tom"]}, - ... }, + ... }, ... "persons": {"Alicia": {"salary": {"2018-11": 0}}, "Javier": {}, "Tom": {}}, ... } >>> are_entities_fully_specified(params, entities) True - >>> params = { - ... "persons": {"Javier": {"salary": {"2018-11": [2000, 3000]}}} - ... } + >>> params = {"persons": {"Javier": {"salary": {"2018-11": [2000, 3000]}}}} >>> are_entities_fully_specified(params, entities) True - >>> params = { - ... "persons": {"Javier": {"salary": {"2018-11": 2000}}} - ... } + >>> params = {"persons": {"Javier": {"salary": {"2018-11": 2000}}}} >>> are_entities_fully_specified(params, entities) True @@ -80,15 +85,15 @@ def are_entities_fully_specified( False """ - if not params: return False - return all(key in items for key in params.keys() if key != "axes") + return all(key in items for key in params if key != "axes") def are_entities_short_form( - params: Params, items: Iterable[str] + params: Params, + items: Iterable[str], ) -> TypeGuard[ImplicitGroupEntities]: """Check if the params contain short form entities. @@ -103,25 +108,23 @@ def are_entities_short_form( >>> entities = {"person", "household"} >>> params = { - ... "persons": {"Javier": { "salary": { "2018-11": 2000}}}, + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, ... "households": {"household": {"parents": ["Javier"]}}, - ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], ... } >>> are_entities_short_form(params, entities) False - >>> params = { - ... "persons": {"Javier": {"salary": {"2018-11": 2000}}} - ... } + >>> params = {"persons": {"Javier": {"salary": {"2018-11": 2000}}}} >>> are_entities_short_form(params, entities) False >>> params = { - ... "persons": {"Javier": { "salary": { "2018-11": 2000}}}, + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, ... "household": {"parents": ["Javier"]}, - ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], ... } >>> are_entities_short_form(params, entities) @@ -129,7 +132,7 @@ def are_entities_short_form( >>> params = { ... "household": {"parents": ["Javier"]}, - ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], ... } >>> are_entities_short_form(params, entities) @@ -161,12 +164,12 @@ def are_entities_short_form( False """ - - return not not set(params).intersection(items) + return bool(set(params).intersection(items)) def are_entities_specified( - params: Params, items: Iterable[str] + params: Params, + items: Iterable[str], ) -> TypeGuard[Variables]: """Check if the params contains entities at all. @@ -181,24 +184,20 @@ def are_entities_specified( >>> variables = {"salary"} >>> params = { - ... "persons": {"Javier": { "salary": { "2018-11": 2000}}}, + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, ... "households": {"household": {"parents": ["Javier"]}}, - ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], ... } >>> are_entities_specified(params, variables) True - >>> params = { - ... "persons": {"Javier": {"salary": {"2018-11": [2000, 3000]}}} - ... } + >>> params = {"persons": {"Javier": {"salary": {"2018-11": [2000, 3000]}}}} >>> are_entities_specified(params, variables) True - >>> params = { - ... "persons": {"Javier": {"salary": {"2018-11": 2000}}} - ... } + >>> params = {"persons": {"Javier": {"salary": {"2018-11": 2000}}}} >>> are_entities_specified(params, variables) True @@ -234,11 +233,10 @@ def are_entities_specified( False """ - if not params: return False - return not any(key in items for key in params.keys()) + return not any(key in items for key in params) def has_axes(params: Params) -> TypeGuard[Axes]: @@ -252,23 +250,20 @@ def has_axes(params: Params) -> TypeGuard[Axes]: Examples: >>> params = { - ... "persons": {"Javier": { "salary": { "2018-11": 2000}}}, + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, ... "households": {"household": {"parents": ["Javier"]}}, - ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], ... } >>> has_axes(params) True - >>> params = { - ... "persons": {"Javier": {"salary": {"2018-11": [2000, 3000]}}} - ... } + >>> params = {"persons": {"Javier": {"salary": {"2018-11": [2000, 3000]}}}} >>> has_axes(params) False """ - return params.get("axes", None) is not None @@ -300,5 +295,4 @@ def is_variable_dated( False """ - return isinstance(variable, dict) diff --git a/openfisca_core/simulations/helpers.py b/openfisca_core/simulations/helpers.py index d5984d88b6..7929c5beda 100644 --- a/openfisca_core/simulations/helpers.py +++ b/openfisca_core/simulations/helpers.py @@ -13,7 +13,7 @@ def calculate_output_divide(simulation, variable_name: str, period): return simulation.calculate_divide(variable_name, period) -def check_type(input, input_type, path=None): +def check_type(input, input_type, path=None) -> None: json_type_map = { dict: "Object", list: "Array", @@ -26,12 +26,13 @@ def check_type(input, input_type, path=None): if not isinstance(input, input_type): raise errors.SituationParsingError( path, - "Invalid type: must be of type '{}'.".format(json_type_map[input_type]), + f"Invalid type: must be of type '{json_type_map[input_type]}'.", ) def check_unexpected_entities( - params: ParamsWithoutAxes, entities: Iterable[str] + params: ParamsWithoutAxes, + entities: Iterable[str], ) -> None: """Check if the input contains entities that are not in the system. @@ -47,21 +48,18 @@ def check_unexpected_entities( >>> params = { ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, - ... "households": {"household": {"parents": ["Javier"]}} + ... "households": {"household": {"parents": ["Javier"]}}, ... } >>> check_unexpected_entities(params, entities) - >>> params = { - ... "dogs": {"Bart": {"damages": {"2018-11": 2000}}} - ... } + >>> params = {"dogs": {"Bart": {"damages": {"2018-11": 2000}}}} >>> check_unexpected_entities(params, entities) Traceback (most recent call last): openfisca_core.errors.situation_parsing_error.SituationParsingError """ - if has_unexpected_entities(params, entities): unexpected_entities = [entity for entity in params if entity not in entities] @@ -90,21 +88,18 @@ def has_unexpected_entities(params: ParamsWithoutAxes, entities: Iterable[str]) >>> params = { ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, - ... "households": {"household": {"parents": ["Javier"]}} + ... "households": {"household": {"parents": ["Javier"]}}, ... } >>> has_unexpected_entities(params, entities) False - >>> params = { - ... "dogs": {"Bart": {"damages": {"2018-11": 2000}}} - ... } + >>> params = {"dogs": {"Bart": {"damages": {"2018-11": 2000}}}} >>> has_unexpected_entities(params, entities) True """ - return any(entity for entity in params if entity not in entities) diff --git a/openfisca_core/simulations/simulation.py b/openfisca_core/simulations/simulation.py index 93becda960..c32fea22af 100644 --- a/openfisca_core/simulations/simulation.py +++ b/openfisca_core/simulations/simulation.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Dict, Mapping, NamedTuple, Optional, Set +from collections.abc import Mapping +from typing import NamedTuple from openfisca_core.types import Population, TaxBenefitSystem, Variable @@ -9,26 +10,29 @@ import numpy -from openfisca_core import commons, errors, indexed_enums, periods, tracers -from openfisca_core import warnings as core_warnings +from openfisca_core import ( + commons, + errors, + indexed_enums, + periods, + tracers, + warnings as core_warnings, +) class Simulation: - """ - Represents a simulation, and handles the calculation logic - """ + """Represents a simulation, and handles the calculation logic.""" tax_benefit_system: TaxBenefitSystem - populations: Dict[str, Population] - invalidated_caches: Set[Cache] + populations: dict[str, Population] + invalidated_caches: set[Cache] def __init__( self, tax_benefit_system: TaxBenefitSystem, populations: Mapping[str, Population], - ): - """ - This constructor is reserved for internal use; see :any:`SimulationBuilder`, + ) -> None: + """This constructor is reserved for internal use; see :any:`SimulationBuilder`, which is the preferred way to obtain a Simulation initialized with a consistent set of Entities. """ @@ -57,37 +61,37 @@ def trace(self): return self._trace @trace.setter - def trace(self, trace): + def trace(self, trace) -> None: self._trace = trace if trace: self.tracer = tracers.FullTracer() else: self.tracer = tracers.SimpleTracer() - def link_to_entities_instances(self): - for _key, entity_instance in self.populations.items(): + def link_to_entities_instances(self) -> None: + for entity_instance in self.populations.values(): entity_instance.simulation = self - def create_shortcuts(self): - for _key, population in self.populations.items(): + def create_shortcuts(self) -> None: + for population in self.populations.values(): # create shortcut simulation.person and simulation.household (for instance) setattr(self, population.entity.key, population) @property def data_storage_dir(self): - """ - Temporary folder used to store intermediate calculation data in case the memory is saturated - """ + """Temporary folder used to store intermediate calculation data in case the memory is saturated.""" if self._data_storage_dir is None: self._data_storage_dir = tempfile.mkdtemp(prefix="openfisca_") message = [ ( - "Intermediate results will be stored on disk in {} in case of memory overflow." - ).format(self._data_storage_dir), + f"Intermediate results will be stored on disk in {self._data_storage_dir} in case of memory overflow." + ), "You should remove this directory once you're done with your simulation.", ] warnings.warn( - " ".join(message), core_warnings.TempfileWarning, stacklevel=2 + " ".join(message), + core_warnings.TempfileWarning, + stacklevel=2, ) return self._data_storage_dir @@ -95,7 +99,6 @@ def data_storage_dir(self): def calculate(self, variable_name: str, period): """Calculate ``variable_name`` for ``period``.""" - if period is not None and not isinstance(period, periods.Period): period = periods.period(period) @@ -111,17 +114,17 @@ def calculate(self, variable_name: str, period): self.purge_cache_of_invalid_values() def _calculate(self, variable_name: str, period: periods.Period): - """ - Calculate the variable ``variable_name`` for the period ``period``, using the variable formula if it exists. + """Calculate the variable ``variable_name`` for the period ``period``, using the variable formula if it exists. :returns: A numpy array containing the result of the calculation """ - variable: Optional[Variable] + variable: Variable | None population = self.get_variable_population(variable_name) holder = population.get_holder(variable_name) variable = self.tax_benefit_system.get_variable( - variable_name, check_existence=True + variable_name, + check_existence=True, ) if variable is None: @@ -153,7 +156,7 @@ def _calculate(self, variable_name: str, period: periods.Period): return array - def purge_cache_of_invalid_values(self): + def purge_cache_of_invalid_values(self) -> None: # We wait for the end of calculate(), signalled by an empty stack, before purging the cache if self.tracer.stack: return @@ -163,10 +166,11 @@ def purge_cache_of_invalid_values(self): self.invalidated_caches = set() def calculate_add(self, variable_name: str, period): - variable: Optional[Variable] + variable: Variable | None variable = self.tax_benefit_system.get_variable( - variable_name, check_existence=True + variable_name, + check_existence=True, ) if variable is None: @@ -177,23 +181,29 @@ def calculate_add(self, variable_name: str, period): # Check that the requested period matches definition_period if periods.unit_weight(variable.definition_period) > periods.unit_weight( - period.unit + period.unit, ): - raise ValueError( + msg = ( f"Unable to compute variable '{variable.name}' for period " f"{period}: '{variable.name}' can only be computed for " f"{variable.definition_period}-long periods. You can use the " f"DIVIDE option to get an estimate of {variable.name}." ) + raise ValueError( + msg, + ) if variable.definition_period not in ( periods.DateUnit.isoformat + periods.DateUnit.isocalendar ): - raise ValueError( + msg = ( f"Unable to ADD constant variable '{variable.name}' over " f"the period {period}: eternal variables can't be summed " "over time." ) + raise ValueError( + msg, + ) return sum( self.calculate(variable_name, sub_period) @@ -201,10 +211,11 @@ def calculate_add(self, variable_name: str, period): ) def calculate_divide(self, variable_name: str, period): - variable: Optional[Variable] + variable: Variable | None variable = self.tax_benefit_system.get_variable( - variable_name, check_existence=True + variable_name, + check_existence=True, ) if variable is None: @@ -218,32 +229,41 @@ def calculate_divide(self, variable_name: str, period): < periods.unit_weight(period.unit) or period.size > 1 ): - raise ValueError( + msg = ( f"Can't calculate variable '{variable.name}' for period " f"{period}: '{variable.name}' can only be computed for " f"{variable.definition_period}-long periods. You can use the " f"ADD option to get an estimate of {variable.name}." ) + raise ValueError( + msg, + ) if variable.definition_period not in ( periods.DateUnit.isoformat + periods.DateUnit.isocalendar ): - raise ValueError( + msg = ( f"Unable to DIVIDE constant variable '{variable.name}' over " f"the period {period}: eternal variables can't be divided " "over time." ) + raise ValueError( + msg, + ) if ( period.unit not in (periods.DateUnit.isoformat + periods.DateUnit.isocalendar) or period.size != 1 ): - raise ValueError( + msg = ( f"Unable to DIVIDE constant variable '{variable.name}' over " f"the period {period}: eternal variables can't be used " "as a denominator to divide a variable over time." ) + raise ValueError( + msg, + ) if variable.definition_period == periods.DateUnit.YEAR: calculation_period = period.this_year @@ -278,14 +298,12 @@ def calculate_divide(self, variable_name: str, period): return self.calculate(variable_name, calculation_period) / denominator def calculate_output(self, variable_name: str, period): - """ - Calculate the value of a variable using the ``calculate_output`` attribute of the variable. - """ - - variable: Optional[Variable] + """Calculate the value of a variable using the ``calculate_output`` attribute of the variable.""" + variable: Variable | None variable = self.tax_benefit_system.get_variable( - variable_name, check_existence=True + variable_name, + check_existence=True, ) if variable is None: @@ -303,10 +321,7 @@ def trace_parameters_at_instant(self, formula_period): ) def _run_formula(self, variable, population, period): - """ - Find the ``variable`` formula for the given ``period`` if it exists, and apply it to ``population``. - """ - + """Find the ``variable`` formula for the given ``period`` if it exists, and apply it to ``population``.""" formula = variable.get_formula(period) if formula is None: return None @@ -323,10 +338,8 @@ def _run_formula(self, variable, population, period): return array - def _check_period_consistency(self, period, variable): - """ - Check that a period matches the variable definition_period - """ + def _check_period_consistency(self, period, variable) -> None: + """Check that a period matches the variable definition_period.""" if variable.definition_period == periods.DateUnit.ETERNITY: return # For variables which values are constant in time, all periods are accepted @@ -334,42 +347,39 @@ def _check_period_consistency(self, period, variable): variable.definition_period == periods.DateUnit.YEAR and period.unit != periods.DateUnit.YEAR ): + msg = f"Unable to compute variable '{variable.name}' for period {period}: '{variable.name}' must be computed for a whole year. You can use the DIVIDE option to get an estimate of {variable.name} by dividing the yearly value by 12, or change the requested period to 'period.this_year'." raise ValueError( - "Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole year. You can use the DIVIDE option to get an estimate of {0} by dividing the yearly value by 12, or change the requested period to 'period.this_year'.".format( - variable.name, period - ) + msg, ) if ( variable.definition_period == periods.DateUnit.MONTH and period.unit != periods.DateUnit.MONTH ): + msg = f"Unable to compute variable '{variable.name}' for period {period}: '{variable.name}' must be computed for a whole month. You can use the ADD option to sum '{variable.name}' over the requested period, or change the requested period to 'period.first_month'." raise ValueError( - "Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole month. You can use the ADD option to sum '{0}' over the requested period, or change the requested period to 'period.first_month'.".format( - variable.name, period - ) + msg, ) if ( variable.definition_period == periods.DateUnit.WEEK and period.unit != periods.DateUnit.WEEK ): + msg = f"Unable to compute variable '{variable.name}' for period {period}: '{variable.name}' must be computed for a whole week. You can use the ADD option to sum '{variable.name}' over the requested period, or change the requested period to 'period.first_week'." raise ValueError( - "Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole week. You can use the ADD option to sum '{0}' over the requested period, or change the requested period to 'period.first_week'.".format( - variable.name, period - ) + msg, ) if period.size != 1: + msg = f"Unable to compute variable '{variable.name}' for period {period}: '{variable.name}' must be computed for a whole {variable.definition_period}. You can use the ADD option to sum '{variable.name}' over the requested period." raise ValueError( - "Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole {2}. You can use the ADD option to sum '{0}' over the requested period.".format( - variable.name, period, variable.definition_period - ) + msg, ) def _cast_formula_result(self, value, variable): if variable.value_type == indexed_enums.Enum and not isinstance( - value, indexed_enums.EnumArray + value, + indexed_enums.EnumArray, ): return variable.possible_values.encode(value) @@ -384,9 +394,8 @@ def _cast_formula_result(self, value, variable): # ----- Handle circular dependencies in a calculation ----- # - def _check_for_cycle(self, variable: str, period): - """ - Raise an exception in the case of a circular definition, where evaluating a variable for + def _check_for_cycle(self, variable: str, period) -> None: + """Raise an exception in the case of a circular definition, where evaluating a variable for a given period loops around to evaluating the same variable/period pair. Also guards, as a heuristic, against "quasicircles", where the evaluation of a variable at a period involves the same variable at a different period. @@ -398,21 +407,20 @@ def _check_for_cycle(self, variable: str, period): if frame["name"] == variable ] if period in previous_periods: + msg = f"Circular definition detected on formula {variable}@{period}" raise errors.CycleError( - "Circular definition detected on formula {}@{}".format(variable, period) + msg, ) spiral = len(previous_periods) >= self.max_spiral_loops if spiral: self.invalidate_spiral_variables(variable) - message = "Quasicircular definition detected on formula {}@{} involving {}".format( - variable, period, self.tracer.stack - ) + message = f"Quasicircular definition detected on formula {variable}@{period} involving {self.tracer.stack}" raise errors.SpiralError(message, variable) - def invalidate_cache_entry(self, variable: str, period): + def invalidate_cache_entry(self, variable: str, period) -> None: self.invalidated_caches.add(Cache(variable, period)) - def invalidate_spiral_variables(self, variable: str): + def invalidate_spiral_variables(self, variable: str) -> None: # Visit the stack, from the bottom (most recent) up; we know that we'll find # the variable implicated in the spiral (max_spiral_loops+1) times; we keep the # intermediate values computed (to avoid impacting performance) but we mark them @@ -428,8 +436,7 @@ def invalidate_spiral_variables(self, variable: str): # ----- Methods to access stored values ----- # def get_array(self, variable_name: str, period): - """ - Return the value of ``variable_name`` for ``period``, if this value is alreay in the cache (if it has been set as an input or previously calculated). + """Return the value of ``variable_name`` for ``period``, if this value is alreay in the cache (if it has been set as an input or previously calculated). Unlike :meth:`.calculate`, this method *does not* trigger calculations and *does not* use any formula. """ @@ -442,10 +449,8 @@ def get_holder(self, variable_name: str): return self.get_variable_population(variable_name).get_holder(variable_name) def get_memory_usage(self, variables=None): - """ - Get data about the virtual memory usage of the simulation - """ - result = dict(total_nb_bytes=0, by_variable={}) + """Get data about the virtual memory usage of the simulation.""" + result = {"total_nb_bytes": 0, "by_variable": {}} for entity in self.populations.values(): entity_memory_usage = entity.get_memory_usage(variables=variables) result["total_nb_bytes"] += entity_memory_usage["total_nb_bytes"] @@ -454,55 +459,52 @@ def get_memory_usage(self, variables=None): # ----- Misc ----- # - def delete_arrays(self, variable, period=None): - """ - Delete a variable's value for a given period + def delete_arrays(self, variable, period=None) -> None: + """Delete a variable's value for a given period. :param variable: the variable to be set :param period: the period for which the value should be deleted Example: - >>> from openfisca_country_template import CountryTaxBenefitSystem >>> simulation = Simulation(CountryTaxBenefitSystem()) - >>> simulation.set_input('age', '2018-04', [12, 14]) - >>> simulation.set_input('age', '2018-05', [13, 14]) - >>> simulation.get_array('age', '2018-05') + >>> simulation.set_input("age", "2018-04", [12, 14]) + >>> simulation.set_input("age", "2018-05", [13, 14]) + >>> simulation.get_array("age", "2018-05") array([13, 14], dtype=int32) - >>> simulation.delete_arrays('age', '2018-05') - >>> simulation.get_array('age', '2018-04') + >>> simulation.delete_arrays("age", "2018-05") + >>> simulation.get_array("age", "2018-04") array([12, 14], dtype=int32) - >>> simulation.get_array('age', '2018-05') is None + >>> simulation.get_array("age", "2018-05") is None True - >>> simulation.set_input('age', '2018-05', [13, 14]) - >>> simulation.delete_arrays('age') - >>> simulation.get_array('age', '2018-04') is None + >>> simulation.set_input("age", "2018-05", [13, 14]) + >>> simulation.delete_arrays("age") + >>> simulation.get_array("age", "2018-04") is None True - >>> simulation.get_array('age', '2018-05') is None + >>> simulation.get_array("age", "2018-05") is None True + """ self.get_holder(variable).delete_arrays(period) def get_known_periods(self, variable): - """ - Get a list variable's known period, i.e. the periods where a value has been initialized and + """Get a list variable's known period, i.e. the periods where a value has been initialized and. :param variable: the variable to be set Example: - >>> from openfisca_country_template import CountryTaxBenefitSystem >>> simulation = Simulation(CountryTaxBenefitSystem()) - >>> simulation.set_input('age', '2018-04', [12, 14]) - >>> simulation.set_input('age', '2018-05', [13, 14]) - >>> simulation.get_known_periods('age') + >>> simulation.set_input("age", "2018-04", [12, 14]) + >>> simulation.set_input("age", "2018-05", [13, 14]) + >>> simulation.get_known_periods("age") [Period((u'month', Instant((2018, 5, 1)), 1)), Period((u'month', Instant((2018, 4, 1)), 1))] + """ return self.get_holder(variable).get_known_periods() - def set_input(self, variable_name: str, period, value): - """ - Set a variable's value for a given period + def set_input(self, variable_name: str, period, value) -> None: + """Set a variable's value for a given period. :param variable: the variable to be set :param value: the input value for the variable @@ -511,16 +513,18 @@ def set_input(self, variable_name: str, period, value): Example: >>> from openfisca_country_template import CountryTaxBenefitSystem >>> simulation = Simulation(CountryTaxBenefitSystem()) - >>> simulation.set_input('age', '2018-04', [12, 14]) - >>> simulation.get_array('age', '2018-04') + >>> simulation.set_input("age", "2018-04", [12, 14]) + >>> simulation.get_array("age", "2018-04") array([12, 14], dtype=int32) If a ``set_input`` property has been set for the variable, this method may accept inputs for periods not matching the ``definition_period`` of the variable. To read more about this, check the `documentation `_. + """ - variable: Optional[Variable] + variable: Variable | None variable = self.tax_benefit_system.get_variable( - variable_name, check_existence=True + variable_name, + check_existence=True, ) if variable is None: @@ -532,10 +536,11 @@ def set_input(self, variable_name: str, period, value): self.get_holder(variable_name).set_input(period, value) def get_variable_population(self, variable_name: str) -> Population: - variable: Optional[Variable] + variable: Variable | None variable = self.tax_benefit_system.get_variable( - variable_name, check_existence=True + variable_name, + check_existence=True, ) if variable is None: @@ -543,7 +548,7 @@ def get_variable_population(self, variable_name: str) -> Population: return self.populations[variable.entity.key] - def get_population(self, plural: Optional[str] = None) -> Optional[Population]: + def get_population(self, plural: str | None = None) -> Population | None: return next( ( population @@ -555,8 +560,8 @@ def get_population(self, plural: Optional[str] = None) -> Optional[Population]: def get_entity( self, - plural: Optional[str] = None, - ) -> Optional[Population]: + plural: str | None = None, + ) -> Population | None: population = self.get_population(plural) return population and population.entity @@ -567,9 +572,7 @@ def describe_entities(self): } def clone(self, debug=False, trace=False): - """ - Copy the simulation just enough to be able to run the copy without modifying the original simulation - """ + """Copy the simulation just enough to be able to run the copy without modifying the original simulation.""" new = commons.empty_clone(self) new_dict = new.__dict__ @@ -585,7 +588,9 @@ def clone(self, debug=False, trace=False): population = self.populations[entity.key].clone(new) new.populations[entity.key] = population setattr( - new, entity.key, population + new, + entity.key, + population, ) # create shortcut simulation.household (for instance) new.debug = debug diff --git a/openfisca_core/simulations/simulation_builder.py b/openfisca_core/simulations/simulation_builder.py index c42d0e4f22..1ebb499239 100644 --- a/openfisca_core/simulations/simulation_builder.py +++ b/openfisca_core/simulations/simulation_builder.py @@ -2,7 +2,7 @@ from collections.abc import Iterable, Sequence from numpy.typing import NDArray as Array -from typing import Dict, List +from typing import NoReturn import copy @@ -39,7 +39,7 @@ class SimulationBuilder: - def __init__(self): + def __init__(self) -> None: self.default_period = ( None # Simulation period used for variables when no period is defined ) @@ -48,26 +48,27 @@ def __init__(self): ) # JSON input - Memory of known input values. Indexed by variable or axis name. - self.input_buffer: Dict[ - variables.Variable.name, Dict[str(periods.period), numpy.array] + self.input_buffer: dict[ + variables.Variable.name, + dict[str(periods.period), numpy.array], ] = {} - self.populations: Dict[entities.Entity.key, populations.Population] = {} + self.populations: dict[entities.Entity.key, populations.Population] = {} # JSON input - Number of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_ids``, including axes. - self.entity_counts: Dict[entities.Entity.plural, int] = {} + self.entity_counts: dict[entities.Entity.plural, int] = {} # JSON input - List of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_counts``. - self.entity_ids: Dict[entities.Entity.plural, List[int]] = {} + self.entity_ids: dict[entities.Entity.plural, list[int]] = {} # Links entities with persons. For each person index in persons ids list, set entity index in entity ids id. E.g.: self.memberships[entity.plural][person_index] = entity_ids.index(instance_id) - self.memberships: Dict[entities.Entity.plural, List[int]] = {} - self.roles: Dict[entities.Entity.plural, List[int]] = {} + self.memberships: dict[entities.Entity.plural, list[int]] = {} + self.roles: dict[entities.Entity.plural, list[int]] = {} - self.variable_entities: Dict[variables.Variable.name, entities.Entity] = {} + self.variable_entities: dict[variables.Variable.name, entities.Entity] = {} self.axes = [[]] - self.axes_entity_counts: Dict[entities.Entity.plural, int] = {} - self.axes_entity_ids: Dict[entities.Entity.plural, List[int]] = {} - self.axes_memberships: Dict[entities.Entity.plural, List[int]] = {} - self.axes_roles: Dict[entities.Entity.plural, List[int]] = {} + self.axes_entity_counts: dict[entities.Entity.plural, int] = {} + self.axes_entity_ids: dict[entities.Entity.plural, list[int]] = {} + self.axes_memberships: dict[entities.Entity.plural, list[int]] = {} + self.axes_roles: dict[entities.Entity.plural, list[int]] = {} def build_from_dict( self, @@ -91,9 +92,9 @@ def build_from_dict( >>> entities = {"person", "household"} >>> params = { - ... "persons": {"Javier": { "salary": { "2018-11": 2000}}}, + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, ... "household": {"parents": ["Javier"]}, - ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], ... } >>> are_entities_short_form(params, entities) @@ -103,13 +104,25 @@ def build_from_dict( >>> params = { ... "axes": [ - ... [{"count": 2, "max": 3000, "min": 0, "name": "rent", "period": "2018-11"}] + ... [ + ... { + ... "count": 2, + ... "max": 3000, + ... "min": 0, + ... "name": "rent", + ... "period": "2018-11", + ... } + ... ] ... ], ... "households": { ... "housea": {"parents": ["Alicia", "Javier"]}, ... "houseb": {"parents": ["Tom"]}, - ... }, - ... "persons": {"Alicia": {"salary": {"2018-11": 0}}, "Javier": {}, "Tom": {}}, + ... }, + ... "persons": { + ... "Alicia": {"salary": {"2018-11": 0}}, + ... "Javier": {}, + ... "Tom": {}, + ... }, ... } >>> are_entities_short_form(params, entities) @@ -121,7 +134,6 @@ def build_from_dict( True """ - #: The plural names of the entities in the tax and benefits system. plural: Iterable[str] = tax_benefit_system.entities_plural() @@ -140,6 +152,7 @@ def build_from_dict( if not are_entities_specified(params := input_dict, variables): return self.build_from_variables(tax_benefit_system, params) + return None def build_from_entities( self, @@ -153,16 +166,15 @@ def build_from_entities( >>> entities = {"person", "household"} >>> params = { - ... "persons": {"Javier": { "salary": { "2018-11": 2000}}}, + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, ... "household": {"parents": ["Javier"]}, - ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], ... } >>> are_entities_short_form(params, entities) True """ - # Create the populations populations = tax_benefit_system.instantiate_entities() @@ -203,9 +215,7 @@ def build_from_entities( if not persons_json: raise errors.SituationParsingError( [person_entity.plural], - "No {0} found. At least one {0} must be defined to run a simulation.".format( - person_entity.key - ), + f"No {person_entity.key} found. At least one {person_entity.key} must be defined to run a simulation.", ) persons_ids = self.add_person_entity(simulation.persons.entity, persons_json) @@ -215,7 +225,10 @@ def build_from_entities( if instances_json is not None: self.add_group_entity( - self.persons_plural, persons_ids, entity_class, instances_json + self.persons_plural, + persons_ids, + entity_class, + instances_json, ) elif axes is not None: @@ -257,7 +270,9 @@ def build_from_entities( return simulation def build_from_variables( - self, tax_benefit_system: TaxBenefitSystem, input_dict: Variables + self, + tax_benefit_system: TaxBenefitSystem, + input_dict: Variables, ) -> Simulation: """Build a simulation from a Python dict ``input_dict`` describing variables values without expliciting entities. @@ -276,18 +291,17 @@ def build_from_variables( SituationParsingError: If the input is not valid. Examples: - >>> params = {'salary': {'2016-10': 12000}} + >>> params = {"salary": {"2016-10": 12000}} >>> are_entities_specified(params, {"salary"}) False - >>> params = {'salary': 12000} + >>> params = {"salary": 12000} >>> are_entities_specified(params, {"salary"}) False """ - return ( _BuildFromVariables(tax_benefit_system, input_dict, self.default_period) .add_dated_values() @@ -297,7 +311,8 @@ def build_from_variables( @staticmethod def build_default_simulation( - tax_benefit_system: TaxBenefitSystem, count: int = 1 + tax_benefit_system: TaxBenefitSystem, + count: int = 1, ) -> Simulation: """Build a default simulation. @@ -307,7 +322,6 @@ def build_default_simulation( - Every person has, in each entity, the first role """ - return ( _BuildDefaultSimulation(tax_benefit_system, count) .add_count() @@ -316,10 +330,10 @@ def build_default_simulation( .simulation ) - def create_entities(self, tax_benefit_system): + def create_entities(self, tax_benefit_system) -> None: self.populations = tax_benefit_system.instantiate_entities() - def declare_person_entity(self, person_singular, persons_ids: Iterable): + def declare_person_entity(self, person_singular, persons_ids: Iterable) -> None: person_instance = self.populations[person_singular] person_instance.ids = numpy.array(list(persons_ids)) person_instance.count = len(person_instance.ids) @@ -336,11 +350,15 @@ def nb_persons(self, entity_singular, role=None): return self.populations[entity_singular].nb_persons(role=role) def join_with_persons( - self, group_population, persons_group_assignment, roles: Iterable[str] - ): + self, + group_population, + persons_group_assignment, + roles: Iterable[str], + ) -> None: # Maps group's identifiers to a 0-based integer range, for indexing into members_roles (see PR#876) group_sorted_indices = numpy.unique( - persons_group_assignment, return_inverse=True + persons_group_assignment, + return_inverse=True, )[1] group_population.members_entity_id = numpy.argsort(group_population.ids)[ group_sorted_indices @@ -350,27 +368,32 @@ def join_with_persons( roles_array = numpy.array(roles) if numpy.issubdtype(roles_array.dtype, numpy.integer): group_population.members_role = numpy.array(flattened_roles)[roles_array] + elif len(flattened_roles) == 0: + group_population.members_role = numpy.int64(0) else: - if len(flattened_roles) == 0: - group_population.members_role = numpy.int64(0) - else: - group_population.members_role = numpy.select( - [roles_array == role.key for role in flattened_roles], - flattened_roles, - ) + group_population.members_role = numpy.select( + [roles_array == role.key for role in flattened_roles], + flattened_roles, + ) def build(self, tax_benefit_system): return Simulation(tax_benefit_system, self.populations) def explicit_singular_entities( - self, tax_benefit_system: TaxBenefitSystem, input_dict: ImplicitGroupEntities + self, + tax_benefit_system: TaxBenefitSystem, + input_dict: ImplicitGroupEntities, ) -> GroupEntities: """Preprocess ``input_dict`` to explicit entities defined using the - single-entity shortcut + single-entity shortcut. Examples: - - >>> params = {'persons': {'Javier': {}, }, 'household': {'parents': ['Javier']}} + >>> params = { + ... "persons": { + ... "Javier": {}, + ... }, + ... "household": {"parents": ["Javier"]}, + ... } >>> are_entities_fully_specified(params, {"persons", "households"}) False @@ -378,7 +401,10 @@ def explicit_singular_entities( >>> are_entities_short_form(params, {"person", "household"}) True - >>> params = {'persons': {'Javier': {}}, 'households': {'household': {'parents': ['Javier']}}} + >>> params = { + ... "persons": {"Javier": {}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... } >>> are_entities_fully_specified(params, {"persons", "households"}) True @@ -387,9 +413,8 @@ def explicit_singular_entities( False """ - singular_keys = set(input_dict).intersection( - tax_benefit_system.entities_by_singular() + tax_benefit_system.entities_by_singular(), ) result = { @@ -405,9 +430,7 @@ def explicit_singular_entities( return result def add_person_entity(self, entity, instances_json): - """ - Add the simulation's instances of the persons entity as described in ``instances_json``. - """ + """Add the simulation's instances of the persons entity as described in ``instances_json``.""" helpers.check_type(instances_json, dict, [entity.plural]) entity_ids = list(map(str, instances_json.keys())) self.persons_plural = entity.plural @@ -421,14 +444,16 @@ def add_person_entity(self, entity, instances_json): return self.get_ids(entity.plural) def add_default_group_entity( - self, persons_ids: list[str], entity: GroupEntity + self, + persons_ids: list[str], + entity: GroupEntity, ) -> None: persons_count = len(persons_ids) roles = list(entity.flattened_roles) self.entity_ids[entity.plural] = persons_ids self.entity_counts[entity.plural] = persons_count self.memberships[entity.plural] = list( - numpy.arange(0, persons_count, dtype=numpy.int32) + numpy.arange(0, persons_count, dtype=numpy.int32), ) self.roles[entity.plural] = [roles[0]] * persons_count @@ -439,9 +464,7 @@ def add_group_entity( entity: GroupEntity, instances_json, ) -> None: - """ - Add all instances of one of the model's entities as described in ``instances_json``. - """ + """Add all instances of one of the model's entities as described in ``instances_json``.""" helpers.check_type(instances_json, dict, [entity.plural]) entity_ids = list(map(str, instances_json.keys())) @@ -464,14 +487,16 @@ def add_group_entity( roles_json = { role.plural or role.key: helpers.transform_to_strict_syntax( - variables_json.pop(role.plural or role.key, []) + variables_json.pop(role.plural or role.key, []), ) for role in entity.roles } for role_id, role_definition in roles_json.items(): helpers.check_type( - role_definition, list, [entity.plural, instance_id, role_id] + role_definition, + list, + [entity.plural, instance_id, role_id], ) for index, person_id in enumerate(role_definition): entity_plural = entity.plural @@ -515,7 +540,7 @@ def add_group_entity( for person_id in persons_to_allocate: person_index = persons_ids.index(person_id) self.memberships[entity.plural][person_index] = entity_ids.index( - person_id + person_id, ) self.roles[entity.plural][person_index] = entity.flattened_roles[0] # Adjust previously computed ids and counts @@ -526,7 +551,7 @@ def add_group_entity( self.roles[entity.plural] = self.roles[entity.plural].tolist() self.memberships[entity.plural] = self.memberships[entity.plural].tolist() - def set_default_period(self, period_str): + def set_default_period(self, period_str) -> None: if period_str: self.default_period = str(periods.period(period_str)) @@ -546,26 +571,24 @@ def check_persons_to_allocate( role_id, persons_to_allocate, index, - ): + ) -> None: helpers.check_type( - person_id, str, [entity_plural, entity_id, role_id, str(index)] + person_id, + str, + [entity_plural, entity_id, role_id, str(index)], ) if person_id not in persons_ids: raise errors.SituationParsingError( [entity_plural, entity_id, role_id], - "Unexpected value: {0}. {0} has been declared in {1} {2}, but has not been declared in {3}.".format( - person_id, entity_id, role_id, persons_plural - ), + f"Unexpected value: {person_id}. {person_id} has been declared in {entity_id} {role_id}, but has not been declared in {persons_plural}.", ) if person_id not in persons_to_allocate: raise errors.SituationParsingError( [entity_plural, entity_id, role_id], - "{} has been declared more than once in {}".format( - person_id, entity_plural - ), + f"{person_id} has been declared more than once in {entity_plural}", ) - def init_variable_values(self, entity, instance_object, instance_id): + def init_variable_values(self, entity, instance_object, instance_id) -> None: for variable_name, variable_values in instance_object.items(): path_in_json = [entity.plural, instance_id, variable_name] try: @@ -592,12 +615,23 @@ def init_variable_values(self, entity, instance_object, instance_id): raise errors.SituationParsingError(path_in_json, e.args[0]) variable = entity.get_variable(variable_name) self.add_variable_value( - entity, variable, instance_index, instance_id, period_str, value + entity, + variable, + instance_index, + instance_id, + period_str, + value, ) def add_variable_value( - self, entity, variable, instance_index, instance_id, period_str, value - ): + self, + entity, + variable, + instance_index, + instance_id, + period_str, + value, + ) -> None: path_in_json = [entity.plural, instance_id, variable.name, period_str] if value is None: @@ -618,7 +652,7 @@ def add_variable_value( self.input_buffer[variable.name][str(periods.period(period_str))] = array - def finalize_variables_init(self, population): + def finalize_variables_init(self, population) -> None: # Due to set_input mechanism, we must bufferize all inputs, then actually set them, # so that the months are set first and the years last. plural_key = population.entity.plural @@ -628,7 +662,7 @@ def finalize_variables_init(self, population): if plural_key in self.memberships: population.members_entity_id = numpy.array(self.get_memberships(plural_key)) population.members_role = numpy.array(self.get_roles(plural_key)) - for variable_name in self.input_buffer.keys(): + for variable_name in self.input_buffer: try: holder = population.get_holder(variable_name) except ValueError: # Wrong entity, we can just ignore that @@ -636,7 +670,7 @@ def finalize_variables_init(self, population): buffer = self.input_buffer[variable_name] unsorted_periods = [ periods.period(period_str) - for period_str in self.input_buffer[variable_name].keys() + for period_str in self.input_buffer[variable_name] ] # We need to handle small periods first for set_input to work sorted_periods = sorted(unsorted_periods, key=periods.key_period_size) @@ -651,21 +685,23 @@ def finalize_variables_init(self, population): if (variable.end is None) or (period_value.start.date <= variable.end): holder.set_input(period_value, array) - def raise_period_mismatch(self, entity, json, e): + def raise_period_mismatch(self, entity, json, e) -> NoReturn: # This error happens when we try to set a variable value for a period that doesn't match its definition period # It is only raised when we consume the buffer. We thus don't know which exact key caused the error. # We do a basic research to find the culprit path culprit_path = next( dpath.util.search( - json, f"*/{e.variable_name}/{str(e.period)}", yielded=True + json, + f"*/{e.variable_name}/{e.period!s}", + yielded=True, ), None, ) if culprit_path: - path = [entity.plural] + culprit_path[0].split("/") + path = [entity.plural, *culprit_path[0].split("/")] else: path = [ - entity.plural + entity.plural, ] # Fallback: if we can't find the culprit, just set the error at the entities level raise errors.SituationParsingError(path, e.message) @@ -682,7 +718,8 @@ def get_ids(self, entity_name: str) -> list[str]: def get_memberships(self, entity_name): # Return empty array for the "persons" entity return self.axes_memberships.get( - entity_name, self.memberships.get(entity_name, []) + entity_name, + self.memberships.get(entity_name, []), ) # Returns the roles of individuals in this entity, including when there is replication along axes @@ -710,7 +747,7 @@ def expand_axes(self) -> None: cell_count *= axis_count # Scale the "prototype" situation, repeating it cell_count times - for entity_name in self.entity_counts.keys(): + for entity_name in self.entity_counts: # Adjust counts self.axes_entity_counts[entity_name] = ( self.get_count(entity_name) * cell_count @@ -718,7 +755,8 @@ def expand_axes(self) -> None: # Adjust ids original_ids: list[str] = self.get_ids(entity_name) * cell_count indices: Array[numpy.int_] = numpy.arange( - 0, cell_count * self.entity_counts[entity_name] + 0, + cell_count * self.entity_counts[entity_name], ) adjusted_ids: list[str] = [ original_id + str(index) @@ -792,7 +830,7 @@ def expand_axes(self) -> None: array = self.get_input(axis_name, str(axis_period)) if array is None: array = variable.default_array( - cell_count * axis_entity_step_size + cell_count * axis_entity_step_size, ) elif array.size == axis_entity_step_size: array = numpy.tile(array, cell_count) diff --git a/openfisca_core/simulations/typing.py b/openfisca_core/simulations/typing.py index 8603d0d811..8091994e53 100644 --- a/openfisca_core/simulations/typing.py +++ b/openfisca_core/simulations/typing.py @@ -10,12 +10,14 @@ import datetime from abc import abstractmethod -from numpy import bool_ as Bool -from numpy import datetime64 as Date -from numpy import float32 as Float -from numpy import int16 as Enum -from numpy import int32 as Int -from numpy import str_ as String +from numpy import ( + bool_ as Bool, + datetime64 as Date, + float32 as Float, + int16 as Enum, + int32 as Int, + str_ as String, +) #: Generic type variables. E = TypeVar("E") @@ -54,7 +56,9 @@ #: Type alias for a simulation dictionary without axes parameters. ParamsWithoutAxes: TypeAlias = Union[ - Variables, ImplicitGroupEntities, FullySpecifiedEntities + Variables, + ImplicitGroupEntities, + FullySpecifiedEntities, ] #: Type alias for a simulation dictionary with axes parameters. diff --git a/openfisca_core/taxbenefitsystems/tax_benefit_system.py b/openfisca_core/taxbenefitsystems/tax_benefit_system.py index 14b607feac..8c48f64715 100644 --- a/openfisca_core/taxbenefitsystems/tax_benefit_system.py +++ b/openfisca_core/taxbenefitsystems/tax_benefit_system.py @@ -1,7 +1,7 @@ from __future__ import annotations -import typing -from typing import Any, Dict, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any from openfisca_core.types import ParameterNodeAtInstant @@ -48,7 +48,7 @@ class TaxBenefitSystem: person_entity: Entity _base_tax_benefit_system = None - _parameters_at_instant_cache: Dict[Instant, ParameterNodeAtInstant] = {} + _parameters_at_instant_cache: dict[Instant, ParameterNodeAtInstant] = {} person_key_plural = None preprocess_parameters = None baseline = None # Baseline tax-benefit system. Used only by reforms. Note: Reforms can be chained. @@ -57,14 +57,17 @@ class TaxBenefitSystem: def __init__(self, entities: Sequence[Entity]) -> None: # TODO: Currently: Don't use a weakref, because they are cleared by Paste (at least) at each call. - self.parameters: Optional[ParameterNode] = None - self.variables: Dict[Any, Any] = {} - self.open_api_config: Dict[Any, Any] = {} + self.parameters: ParameterNode | None = None + self.variables: dict[Any, Any] = {} + self.open_api_config: dict[Any, Any] = {} # Tax benefit systems are mutable, so entities (which need to know about our variables) can't be shared among them if entities is None or len(entities) == 0: - raise Exception("A tax and benefit sytem must have at least an entity.") + msg = "A tax and benefit sytem must have at least an entity." + raise Exception(msg) self.entities = [copy.copy(entity) for entity in entities] - self.person_entity = [entity for entity in self.entities if entity.is_person][0] + self.person_entity = next( + entity for entity in self.entities if entity.is_person + ) self.group_entities = [ entity for entity in self.entities if not entity.is_person ] @@ -78,15 +81,15 @@ def base_tax_benefit_system(self): baseline = self.baseline if baseline is None: return self - self._base_tax_benefit_system = ( - base_tax_benefit_system - ) = baseline.base_tax_benefit_system + self._base_tax_benefit_system = base_tax_benefit_system = ( + baseline.base_tax_benefit_system + ) return base_tax_benefit_system def instantiate_entities(self): person = self.person_entity members = Population(person) - entities: typing.Dict[Entity.key, Entity] = {person.key: members} + entities: dict[Entity.key, Entity] = {person.key: members} for entity in self.group_entities: entities[entity.key] = GroupPopulation(entity, members) @@ -95,8 +98,8 @@ def instantiate_entities(self): # Deprecated method of constructing simulations, to be phased out in favor of SimulationBuilder def new_scenario(self): - class ScenarioAdapter(object): - def __init__(self, tax_benefit_system): + class ScenarioAdapter: + def __init__(self, tax_benefit_system) -> None: self.tax_benefit_system = tax_benefit_system def init_from_attributes(self, **attributes): @@ -110,7 +113,11 @@ def init_from_dict(self, dict): return self def new_simulation( - self, debug=False, opt_out_cache=False, use_baseline=False, trace=False + self, + debug=False, + opt_out_cache=False, + use_baseline=False, + trace=False, ): # Legacy from scenarios, used in reforms tax_benefit_system = self.tax_benefit_system @@ -127,12 +134,14 @@ def new_simulation( period = self.attributes.get("period") builder.set_default_period(period) simulation = builder.build_from_variables( - tax_benefit_system, variables + tax_benefit_system, + variables, ) else: builder.set_default_period(self.period) simulation = builder.build_from_entities( - tax_benefit_system, self.dict + tax_benefit_system, + self.dict, ) simulation.trace = trace @@ -143,7 +152,7 @@ def new_simulation( return ScenarioAdapter(self) - def prefill_cache(self): + def prefill_cache(self) -> None: pass def load_variable(self, variable_class, update=False): @@ -152,10 +161,9 @@ def load_variable(self, variable_class, update=False): # Check if a Variable with the same name is already registered. baseline_variable = self.get_variable(name) if baseline_variable and not update: + msg = f'Variable "{name}" is already defined. Use `update_variable` to replace it.' raise VariableNameConflictError( - 'Variable "{}" is already defined. Use `update_variable` to replace it.'.format( - name - ) + msg, ) variable = variable_class(baseline_variable=baseline_variable) @@ -213,16 +221,14 @@ def update_variable(self, variable: Variable) -> Variable: The added variable. """ - return self.load_variable(variable, update=True) - def add_variables_from_file(self, file_path): - """ - Adds all OpenFisca variables contained in a given file to the tax and benefit system. - """ + def add_variables_from_file(self, file_path) -> None: + """Adds all OpenFisca variables contained in a given file to the tax and benefit system.""" try: source_file_path = file_path.replace( - self.get_package_metadata()["location"], "" + self.get_package_metadata()["location"], + "", ) file_name = os.path.splitext(os.path.basename(file_path))[0] @@ -244,9 +250,9 @@ def add_variables_from_file(self, file_path): spec.loader.exec_module(module) except NameError as e: - logging.error( + logging.exception( str(e) - + ": if this code used to work, this error might be due to a major change in OpenFisca-Core. Checkout the changelog to learn more: " + + ": if this code used to work, this error might be due to a major change in OpenFisca-Core. Checkout the changelog to learn more: ", ) raise potential_variables = [ @@ -269,15 +275,11 @@ def add_variables_from_file(self, file_path): ) self.add_variable(pot_variable) except Exception: - log.error( - 'Unable to load OpenFisca variables from file "{}"'.format(file_path) - ) + log.exception(f'Unable to load OpenFisca variables from file "{file_path}"') raise - def add_variables_from_directory(self, directory): - """ - Recursively explores a directory, and adds all OpenFisca variables found there to the tax and benefit system. - """ + def add_variables_from_directory(self, directory) -> None: + """Recursively explores a directory, and adds all OpenFisca variables found there to the tax and benefit system.""" py_files = glob.glob(os.path.join(directory, "*.py")) for py_file in py_files: self.add_variables_from_file(py_file) @@ -285,18 +287,16 @@ def add_variables_from_directory(self, directory): for subdirectory in subdirectories: self.add_variables_from_directory(subdirectory) - def add_variables(self, *variables): - """ - Adds a list of OpenFisca Variables to the `TaxBenefitSystem`. + def add_variables(self, *variables) -> None: + """Adds a list of OpenFisca Variables to the `TaxBenefitSystem`. See also :any:`add_variable` """ for variable in variables: self.add_variable(variable) - def load_extension(self, extension): - """ - Loads an extension to the tax and benefit system. + def load_extension(self, extension) -> None: + """Loads an extension to the tax and benefit system. :param str extension: The extension to load. Can be an absolute path pointing to an extension directory, or the name of an OpenFisca extension installed as a pip package. @@ -309,12 +309,10 @@ def load_extension(self, extension): message = os.linesep.join( [ traceback.format_exc(), - "Error loading extension: `{}` is neither a directory, nor a package.".format( - extension - ), + f"Error loading extension: `{extension}` is neither a directory, nor a package.", "Are you sure it is installed in your environment? If so, look at the stack trace above to determine the origin of this error.", "See more at .", - ] + ], ) raise ValueError(message) @@ -324,7 +322,7 @@ def load_extension(self, extension): extension_parameters = ParameterNode(directory_path=param_dir) self.parameters.merge(extension_parameters) - def apply_reform(self, reform_path: str) -> "TaxBenefitSystem": + def apply_reform(self, reform_path: str) -> TaxBenefitSystem: """Generates a new tax and benefit system applying a reform to the tax and benefit system. The current tax and benefit system is **not** mutated. @@ -336,8 +334,7 @@ def apply_reform(self, reform_path: str) -> "TaxBenefitSystem": TaxBenefitSystem: A reformed tax and benefit system. Example: - - >>> self.apply_reform('openfisca_france.reforms.inversion_revenus') + >>> self.apply_reform("openfisca_france.reforms.inversion_revenus") """ from openfisca_core.reforms import Reform @@ -345,10 +342,9 @@ def apply_reform(self, reform_path: str) -> "TaxBenefitSystem": try: reform_package, reform_name = reform_path.rsplit(".", 1) except ValueError: + msg = f"`{reform_path}` does not seem to be a path pointing to a reform. A path looks like `some_country_package.reforms.some_reform.`" raise ValueError( - "`{}` does not seem to be a path pointing to a reform. A path looks like `some_country_package.reforms.some_reform.`".format( - reform_path - ) + msg, ) try: reform_module = importlib.import_module(reform_package) @@ -356,19 +352,19 @@ def apply_reform(self, reform_path: str) -> "TaxBenefitSystem": message = os.linesep.join( [ traceback.format_exc(), - "Could not import `{}`.".format(reform_package), + f"Could not import `{reform_package}`.", "Are you sure of this reform module name? If so, look at the stack trace above to determine the origin of this error.", - ] + ], ) raise ValueError(message) reform = getattr(reform_module, reform_name, None) if reform is None: - raise ValueError( - "{} has no attribute {}".format(reform_package, reform_name) - ) + msg = f"{reform_package} has no attribute {reform_name}" + raise ValueError(msg) if not issubclass(reform, Reform): + msg = f"`{reform_path}` does not seem to be a valid Openfisca reform." raise ValueError( - "`{}` does not seem to be a valid Openfisca reform.".format(reform_path) + msg, ) return reform(self) @@ -377,15 +373,14 @@ def get_variable( self, variable_name: str, check_existence: bool = False, - ) -> Optional[Variable]: - """ - Get a variable from the tax and benefit system. + ) -> Variable | None: + """Get a variable from the tax and benefit system. :param variable_name: Name of the requested variable. :param check_existence: If True, raise an error if the requested variable does not exist. """ - variables: Dict[str, Optional[Variable]] = self.variables - variable: Optional[Variable] = variables.get(variable_name) + variables: dict[str, Variable | None] = self.variables + variable: Variable | None = variables.get(variable_name) if isinstance(variable, Variable): return variable @@ -395,25 +390,24 @@ def get_variable( raise VariableNotFoundError(variable_name, self) - def neutralize_variable(self, variable_name: str): - """ - Neutralizes an OpenFisca variable existing in the tax and benefit system. + def neutralize_variable(self, variable_name: str) -> None: + """Neutralizes an OpenFisca variable existing in the tax and benefit system. A neutralized variable always returns its default value when computed. Trying to set inputs for a neutralized variable has no effect except raising a warning. """ self.variables[variable_name] = variables.get_neutralized_variable( - self.get_variable(variable_name) + self.get_variable(variable_name), ) def annualize_variable( self, variable_name: str, - period: Optional[Period] = None, + period: Period | None = None, ) -> None: check: bool - variable: Optional[Variable] + variable: Variable | None annualised_variable: Variable check = bool(period) @@ -426,17 +420,15 @@ def annualize_variable( self.variables[variable_name] = annualised_variable - def load_parameters(self, path_to_yaml_dir): - """ - Loads the legislation parameter for a directory containing YAML parameters files. + def load_parameters(self, path_to_yaml_dir) -> None: + """Loads the legislation parameter for a directory containing YAML parameters files. :param path_to_yaml_dir: Absolute path towards the YAML parameter directory. Example: + >>> self.load_parameters("/path/to/yaml/parameters/dir") - >>> self.load_parameters('/path/to/yaml/parameters/dir') """ - parameters = ParameterNode("", directory_path=path_to_yaml_dir) if self.preprocess_parameters is not None: @@ -450,12 +442,12 @@ def _get_baseline_parameters_at_instant(self, instant): return self.get_parameters_at_instant(instant) return baseline._get_baseline_parameters_at_instant(instant) - @functools.lru_cache() # noqa BO19 + @functools.lru_cache def get_parameters_at_instant( self, - instant: Union[str, int, Period, Instant], - ) -> Optional[ParameterNodeAtInstant]: - """Get the parameters of the legislation at a given instant + instant: str | int | Period | Instant, + ) -> ParameterNodeAtInstant | None: + """Get the parameters of the legislation at a given instant. Args: instant: :obj:`str` formatted "YYYY-MM-DD" or :class:`~openfisca_core.periods.Instant`. @@ -464,8 +456,7 @@ def get_parameters_at_instant( The parameters of the legislation at a given instant. """ - - key: Optional[Instant] + key: Instant | None msg: str if isinstance(instant, Instant): @@ -486,7 +477,7 @@ def get_parameters_at_instant( return self.parameters.get_at_instant(key) - def get_package_metadata(self) -> Dict[str, str]: + def get_package_metadata(self) -> dict[str, str]: """Gets metadata relative to the country package. Returns: @@ -502,7 +493,6 @@ def get_package_metadata(self) -> Dict[str, str]: >>> } """ - # Handle reforms if self.baseline: return self.baseline.get_package_metadata() @@ -515,7 +505,7 @@ def get_package_metadata(self) -> Dict[str, str]: distribution = importlib.metadata.distribution(package_name) source_metadata = distribution.metadata except Exception as e: - log.warn("Unable to load package metadata, exposing default metadata", e) + log.warning("Unable to load package metadata, exposing default metadata", e) source_metadata = { "Name": self.__class__.__name__, "Version": "0.0.0", @@ -526,7 +516,7 @@ def get_package_metadata(self) -> Dict[str, str]: source_file = inspect.getsourcefile(module) location = source_file.split(package_name)[0].rstrip("/") except Exception as e: - log.warn("Unable to load package source folder", e) + log.warning("Unable to load package source folder", e) location = "_unknown_" repository_url = "" @@ -535,7 +525,7 @@ def get_package_metadata(self) -> Dict[str, str]: filter( lambda url: url.startswith("Repository"), source_metadata.get_all("Project-URL"), - ) + ), ).split("Repository, ")[-1] else: # setup.py format repository_url = source_metadata.get("Home-page") @@ -549,8 +539,8 @@ def get_package_metadata(self) -> Dict[str, str]: def get_variables( self, - entity: Optional[Entity] = None, - ) -> Dict[str, Variable]: + entity: Entity | None = None, + ) -> dict[str, Variable]: """Gets all variables contained in a tax and benefit system. Args: @@ -560,16 +550,14 @@ def get_variables( A dictionary, indexed by variable names. """ - if not entity: return self.variables - else: - return { - variable_name: variable - for variable_name, variable in self.variables.items() - # TODO - because entities are copied (see constructor) they can't be compared - if variable.entity.key == entity.key - } + return { + variable_name: variable + for variable_name, variable in self.variables.items() + # TODO - because entities are copied (see constructor) they can't be compared + if variable.entity.key == entity.key + } def clone(self): new = commons.empty_clone(self) diff --git a/openfisca_core/taxscales/abstract_rate_tax_scale.py b/openfisca_core/taxscales/abstract_rate_tax_scale.py index cd04ba872e..84ab4eb913 100644 --- a/openfisca_core/taxscales/abstract_rate_tax_scale.py +++ b/openfisca_core/taxscales/abstract_rate_tax_scale.py @@ -9,18 +9,17 @@ if typing.TYPE_CHECKING: import numpy - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int_, numpy.float64] class AbstractRateTaxScale(RateTaxScaleLike): - """ - Base class for various types of rate-based tax scales: marginal rate, + """Base class for various types of rate-based tax scales: marginal rate, linear average rate... """ def __init__( self, - name: typing.Optional[str] = None, + name: str | None = None, option: typing.Any = None, unit: typing.Any = None, ) -> None: @@ -37,6 +36,7 @@ def calc( tax_base: NumericalArray, right: bool, ) -> typing.NoReturn: + msg = "Method 'calc' is not implemented for " f"{self.__class__.__name__}" raise NotImplementedError( - "Method 'calc' is not implemented for " f"{self.__class__.__name__}", + msg, ) diff --git a/openfisca_core/taxscales/abstract_tax_scale.py b/openfisca_core/taxscales/abstract_tax_scale.py index 43b21f8141..933f36d47d 100644 --- a/openfisca_core/taxscales/abstract_tax_scale.py +++ b/openfisca_core/taxscales/abstract_tax_scale.py @@ -9,18 +9,17 @@ if typing.TYPE_CHECKING: import numpy - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int_, numpy.float64] class AbstractTaxScale(TaxScaleLike): - """ - Base class for various types of tax scales: amount-based tax scales, + """Base class for various types of tax scales: amount-based tax scales, rate-based tax scales... """ def __init__( self, - name: typing.Optional[str] = None, + name: str | None = None, option: typing.Any = None, unit: numpy.int_ = None, ) -> None: @@ -33,8 +32,9 @@ def __init__( super().__init__(name, option, unit) def __repr__(self) -> typing.NoReturn: + msg = "Method '__repr__' is not implemented for " f"{self.__class__.__name__}" raise NotImplementedError( - "Method '__repr__' is not implemented for " f"{self.__class__.__name__}", + msg, ) def calc( @@ -42,11 +42,13 @@ def calc( tax_base: NumericalArray, right: bool, ) -> typing.NoReturn: + msg = "Method 'calc' is not implemented for " f"{self.__class__.__name__}" raise NotImplementedError( - "Method 'calc' is not implemented for " f"{self.__class__.__name__}", + msg, ) def to_dict(self) -> typing.NoReturn: + msg = f"Method 'to_dict' is not implemented for {self.__class__.__name__}" raise NotImplementedError( - f"Method 'to_dict' is not implemented for " f"{self.__class__.__name__}", + msg, ) diff --git a/openfisca_core/taxscales/amount_tax_scale_like.py b/openfisca_core/taxscales/amount_tax_scale_like.py index 865ce3200c..1dc9acf4b3 100644 --- a/openfisca_core/taxscales/amount_tax_scale_like.py +++ b/openfisca_core/taxscales/amount_tax_scale_like.py @@ -10,12 +10,11 @@ class AmountTaxScaleLike(TaxScaleLike, abc.ABC): - """ - Base class for various types of amount-based tax scales: single amount, + """Base class for various types of amount-based tax scales: single amount, marginal amount... """ - amounts: typing.List + amounts: list def __init__( self, @@ -32,8 +31,8 @@ def __repr__(self) -> str: [ f"- threshold: {threshold}{os.linesep} amount: {amount}" for (threshold, amount) in zip(self.thresholds, self.amounts) - ] - ) + ], + ), ) def add_bracket( diff --git a/openfisca_core/taxscales/helpers.py b/openfisca_core/taxscales/helpers.py index 62ee431be9..687db41a3b 100644 --- a/openfisca_core/taxscales/helpers.py +++ b/openfisca_core/taxscales/helpers.py @@ -18,11 +18,9 @@ def combine_tax_scales( node: ParameterNodeAtInstant, combined_tax_scales: TaxScales = None, ) -> TaxScales: - """ - Combine all the MarginalRateTaxScales in the node into a single + """Combine all the MarginalRateTaxScales in the node into a single MarginalRateTaxScale. """ - name = next(iter(node or []), None) if name is None: diff --git a/openfisca_core/taxscales/linear_average_rate_tax_scale.py b/openfisca_core/taxscales/linear_average_rate_tax_scale.py index 60b2053d5d..ec1b22e0c2 100644 --- a/openfisca_core/taxscales/linear_average_rate_tax_scale.py +++ b/openfisca_core/taxscales/linear_average_rate_tax_scale.py @@ -13,7 +13,7 @@ log = logging.getLogger(__name__) if typing.TYPE_CHECKING: - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int_, numpy.float64] class LinearAverageRateTaxScale(RateTaxScaleLike): @@ -21,7 +21,7 @@ def calc( self, tax_base: NumericalArray, right: bool = False, - ) -> numpy.float_: + ) -> numpy.float64: if len(self.rates) == 1: return tax_base * self.rates[0] diff --git a/openfisca_core/taxscales/marginal_amount_tax_scale.py b/openfisca_core/taxscales/marginal_amount_tax_scale.py index fa8a0897f7..ac021351be 100644 --- a/openfisca_core/taxscales/marginal_amount_tax_scale.py +++ b/openfisca_core/taxscales/marginal_amount_tax_scale.py @@ -7,7 +7,7 @@ from .amount_tax_scale_like import AmountTaxScaleLike if typing.TYPE_CHECKING: - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int_, numpy.float64] class MarginalAmountTaxScale(AmountTaxScaleLike): @@ -15,19 +15,20 @@ def calc( self, tax_base: NumericalArray, right: bool = False, - ) -> numpy.float_: - """ - Matches the input amount to a set of brackets and returns the sum of + ) -> numpy.float64: + """Matches the input amount to a set of brackets and returns the sum of cell values from the lowest bracket to the one containing the input. """ base1 = numpy.tile(tax_base, (len(self.thresholds), 1)).T thresholds1 = numpy.tile( - numpy.hstack((self.thresholds, numpy.inf)), (len(tax_base), 1) + numpy.hstack((self.thresholds, numpy.inf)), + (len(tax_base), 1), ) a = numpy.maximum( - numpy.minimum(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], 0 + numpy.minimum(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], + 0, ) return numpy.dot(self.amounts, a.T > 0) diff --git a/openfisca_core/taxscales/marginal_rate_tax_scale.py b/openfisca_core/taxscales/marginal_rate_tax_scale.py index 2604c156e1..c81da8e7e9 100644 --- a/openfisca_core/taxscales/marginal_rate_tax_scale.py +++ b/openfisca_core/taxscales/marginal_rate_tax_scale.py @@ -12,7 +12,7 @@ from .rate_tax_scale_like import RateTaxScaleLike if typing.TYPE_CHECKING: - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int_, numpy.float64] class MarginalRateTaxScale(RateTaxScaleLike): @@ -36,10 +36,9 @@ def calc( self, tax_base: NumericalArray, factor: float = 1.0, - round_base_decimals: typing.Optional[int] = None, - ) -> numpy.float_: - """ - Compute the tax amount for the given tax bases by applying a taxscale. + round_base_decimals: int | None = None, + ) -> numpy.float64: + """Compute the tax amount for the given tax bases by applying a taxscale. :param ndarray tax_base: Array of the tax bases. :param float factor: Factor to apply to the thresholds of the taxscale. @@ -68,30 +67,30 @@ def calc( # # numpy.finfo(float_).eps thresholds1 = numpy.outer( - factor + numpy.finfo(numpy.float_).eps, - numpy.array(self.thresholds + [numpy.inf]), + factor + numpy.finfo(numpy.float64).eps, + numpy.array([*self.thresholds, numpy.inf]), ) if round_base_decimals is not None: - thresholds1 = numpy.round_(thresholds1, round_base_decimals) + thresholds1 = numpy.round(thresholds1, round_base_decimals) a = numpy.maximum( - numpy.minimum(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], 0 + numpy.minimum(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], + 0, ) if round_base_decimals is None: return numpy.dot(self.rates, a.T) - else: - r = numpy.tile(self.rates, (len(tax_base), 1)) - b = numpy.round_(a, round_base_decimals) - return numpy.round_(r * b, round_base_decimals).sum(axis=1) + r = numpy.tile(self.rates, (len(tax_base), 1)) + b = numpy.round(a, round_base_decimals) + return numpy.round(r * b, round_base_decimals).sum(axis=1) def combine_bracket( self, - rate: typing.Union[int, float], + rate: int | float, threshold_low: int = 0, - threshold_high: typing.Union[int, bool] = False, + threshold_high: int | bool = False, ) -> None: # Insert threshold_low and threshold_high without modifying rates if threshold_low not in self.thresholds: @@ -119,10 +118,9 @@ def marginal_rates( self, tax_base: NumericalArray, factor: float = 1.0, - round_base_decimals: typing.Optional[int] = None, - ) -> numpy.float_: - """ - Compute the marginal tax rates relevant for the given tax bases. + round_base_decimals: int | None = None, + ) -> numpy.float64: + """Compute the marginal tax rates relevant for the given tax bases. :param ndarray tax_base: Array of the tax bases. :param float factor: Factor to apply to the thresholds of a tax scale. @@ -152,9 +150,8 @@ def marginal_rates( def rate_from_bracket_indice( self, bracket_indice: numpy.int_, - ) -> numpy.float_: - """ - Compute the relevant tax rates for the given bracket indices. + ) -> numpy.float64: + """Compute the relevant tax rates for the given bracket indices. :param: ndarray bracket_indice: Array of the bracket indices. @@ -173,23 +170,24 @@ def rate_from_bracket_indice( >>> tax_scale.rate_from_bracket_indice(bracket_indice) array([0. , 0.25, 0.1 ]) """ - if bracket_indice.max() > len(self.rates) - 1: - raise IndexError( + msg = ( f"bracket_indice parameter ({bracket_indice}) " f"contains one or more bracket indice which is unavailable " f"inside current {self.__class__.__name__} :\n" f"{self}" ) + raise IndexError( + msg, + ) return numpy.array(self.rates)[bracket_indice] def rate_from_tax_base( self, tax_base: NumericalArray, - ) -> numpy.float_: - """ - Compute the relevant tax rates for the given tax bases. + ) -> numpy.float64: + """Compute the relevant tax rates for the given tax bases. :param: ndarray tax_base: Array of the tax bases. @@ -207,12 +205,10 @@ def rate_from_tax_base( >>> tax_scale.rate_from_tax_base(tax_base) array([0.25, 0. , 0.1 ]) """ - return self.rate_from_bracket_indice(self.bracket_indices(tax_base)) def inverse(self) -> MarginalRateTaxScale: - """ - Returns a new instance of MarginalRateTaxScale. + """Returns a new instance of MarginalRateTaxScale. Invert a taxscale: diff --git a/openfisca_core/taxscales/rate_tax_scale_like.py b/openfisca_core/taxscales/rate_tax_scale_like.py index eb8afd872d..60ea9c20e1 100644 --- a/openfisca_core/taxscales/rate_tax_scale_like.py +++ b/openfisca_core/taxscales/rate_tax_scale_like.py @@ -14,20 +14,19 @@ from .tax_scale_like import TaxScaleLike if typing.TYPE_CHECKING: - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int_, numpy.float64] class RateTaxScaleLike(TaxScaleLike, abc.ABC): - """ - Base class for various types of rate-based tax scales: marginal rate, + """Base class for various types of rate-based tax scales: marginal rate, linear average rate... """ - rates: typing.List + rates: list def __init__( self, - name: typing.Optional[str] = None, + name: str | None = None, option: typing.Any = None, unit: typing.Any = None, ) -> None: @@ -40,14 +39,14 @@ def __repr__(self) -> str: [ f"- threshold: {threshold}{os.linesep} rate: {rate}" for (threshold, rate) in zip(self.thresholds, self.rates) - ] - ) + ], + ), ) def add_bracket( self, - threshold: typing.Union[int, float], - rate: typing.Union[int, float], + threshold: int | float, + rate: int | float, ) -> None: if threshold in self.thresholds: i = self.thresholds.index(threshold) @@ -62,7 +61,7 @@ def multiply_rates( self, factor: float, inplace: bool = True, - new_name: typing.Optional[str] = None, + new_name: str | None = None, ) -> RateTaxScaleLike: if inplace: assert new_name is None @@ -87,9 +86,9 @@ def multiply_rates( def multiply_thresholds( self, factor: float, - decimals: typing.Optional[int] = None, + decimals: int | None = None, inplace: bool = True, - new_name: typing.Optional[str] = None, + new_name: str | None = None, ) -> RateTaxScaleLike: if inplace: assert new_name is None @@ -128,10 +127,9 @@ def bracket_indices( self, tax_base: NumericalArray, factor: float = 1.0, - round_decimals: typing.Optional[int] = None, + round_decimals: int | None = None, ) -> numpy.int_: - """ - Compute the relevant bracket indices for the given tax bases. + """Compute the relevant bracket indices for the given tax bases. :param ndarray tax_base: Array of the tax bases. :param float factor: Factor to apply to the thresholds. @@ -149,7 +147,6 @@ def bracket_indices( >>> tax_scale.bracket_indices(tax_base) [0, 1] """ - if not numpy.size(numpy.array(self.thresholds)): raise EmptyArgumentError( self.__class__.__name__, @@ -177,11 +174,12 @@ def bracket_indices( # # numpy.finfo(float_).eps thresholds1 = numpy.outer( - +factor + numpy.finfo(numpy.float_).eps, numpy.array(self.thresholds) + +factor + numpy.finfo(numpy.float64).eps, + numpy.array(self.thresholds), ) if round_decimals is not None: - thresholds1 = numpy.round_(thresholds1, round_decimals) + thresholds1 = numpy.round(thresholds1, round_decimals) return (base1 - thresholds1 >= 0).sum(axis=1) - 1 @@ -189,8 +187,7 @@ def threshold_from_tax_base( self, tax_base: NumericalArray, ) -> NumericalArray: - """ - Compute the relevant thresholds for the given tax bases. + """Compute the relevant thresholds for the given tax bases. :param: ndarray tax_base: Array of the tax bases. @@ -209,7 +206,6 @@ def threshold_from_tax_base( >>> tax_scale.threshold_from_tax_base(tax_base) array([200, 500, 0]) """ - return numpy.array(self.thresholds)[self.bracket_indices(tax_base)] def to_dict(self) -> dict: diff --git a/openfisca_core/taxscales/single_amount_tax_scale.py b/openfisca_core/taxscales/single_amount_tax_scale.py index 8f8bdc22c9..1a39396398 100644 --- a/openfisca_core/taxscales/single_amount_tax_scale.py +++ b/openfisca_core/taxscales/single_amount_tax_scale.py @@ -7,7 +7,7 @@ from openfisca_core.taxscales import AmountTaxScaleLike if typing.TYPE_CHECKING: - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int_, numpy.float64] class SingleAmountTaxScale(AmountTaxScaleLike): @@ -15,12 +15,11 @@ def calc( self, tax_base: NumericalArray, right: bool = False, - ) -> numpy.float_: - """ - Matches the input amount to a set of brackets and returns the single + ) -> numpy.float64: + """Matches the input amount to a set of brackets and returns the single cell value that fits within that bracket. """ - guarded_thresholds = numpy.array([-numpy.inf] + self.thresholds + [numpy.inf]) + guarded_thresholds = numpy.array([-numpy.inf, *self.thresholds, numpy.inf]) bracket_indices = numpy.digitize( tax_base, @@ -28,6 +27,6 @@ def calc( right=right, ) - guarded_amounts = numpy.array([0] + self.amounts + [0]) + guarded_amounts = numpy.array([0, *self.amounts, 0]) return guarded_amounts[bracket_indices - 1] diff --git a/openfisca_core/taxscales/tax_scale_like.py b/openfisca_core/taxscales/tax_scale_like.py index 2d64e3afeb..683c771127 100644 --- a/openfisca_core/taxscales/tax_scale_like.py +++ b/openfisca_core/taxscales/tax_scale_like.py @@ -5,29 +5,28 @@ import abc import copy -import numpy - from openfisca_core import commons if typing.TYPE_CHECKING: - NumericalArray = typing.Union[numpy.int_, numpy.float_] + import numpy + + NumericalArray = typing.Union[numpy.int_, numpy.float64] class TaxScaleLike(abc.ABC): - """ - Base class for various types of tax scales: amount-based tax scales, + """Base class for various types of tax scales: amount-based tax scales, rate-based tax scales... """ - name: typing.Optional[str] + name: str | None option: typing.Any unit: typing.Any - thresholds: typing.List + thresholds: list @abc.abstractmethod def __init__( self, - name: typing.Optional[str] = None, + name: str | None = None, option: typing.Any = None, unit: typing.Any = None, ) -> None: @@ -37,30 +36,29 @@ def __init__( self.thresholds = [] def __eq__(self, _other: object) -> typing.NoReturn: + msg = "Method '__eq__' is not implemented for " f"{self.__class__.__name__}" raise NotImplementedError( - "Method '__eq__' is not implemented for " f"{self.__class__.__name__}", + msg, ) def __ne__(self, _other: object) -> typing.NoReturn: + msg = "Method '__ne__' is not implemented for " f"{self.__class__.__name__}" raise NotImplementedError( - "Method '__ne__' is not implemented for " f"{self.__class__.__name__}", + msg, ) @abc.abstractmethod - def __repr__(self) -> str: - ... + def __repr__(self) -> str: ... @abc.abstractmethod def calc( self, tax_base: NumericalArray, right: bool, - ) -> numpy.float_: - ... + ) -> numpy.float64: ... @abc.abstractmethod - def to_dict(self) -> dict: - ... + def to_dict(self) -> dict: ... def copy(self) -> typing.Any: new = commons.empty_clone(self) diff --git a/openfisca_core/tools/__init__.py b/openfisca_core/tools/__init__.py index 9c3b1a4962..1416ed1529 100644 --- a/openfisca_core/tools/__init__.py +++ b/openfisca_core/tools/__init__.py @@ -1,6 +1,3 @@ -# -*- coding: utf-8 -*- - - import os import numexpr @@ -15,9 +12,7 @@ def assert_near( message="", relative_error_margin=None, ): - """ - - :param value: Value returned by the test + """:param value: Value returned by the test :param target_value: Value that the test should return to pass :param absolute_error_margin: Absolute error margin authorized :param message: Error message to be displayed if the test fails @@ -26,7 +21,6 @@ def assert_near( Limit : This function cannot be used to assert near periods. """ - import numpy if absolute_error_margin is None and relative_error_margin is None: @@ -48,36 +42,30 @@ def assert_near( if absolute_error_margin is not None: assert ( diff <= absolute_error_margin - ).all(), "{}{} differs from {} with an absolute margin {} > {}".format( - message, value, target_value, diff, absolute_error_margin - ) + ).all(), f"{message}{value} differs from {target_value} with an absolute margin {diff} > {absolute_error_margin}" if relative_error_margin is not None: assert ( diff <= abs(relative_error_margin * target_value) - ).all(), "{}{} differs from {} with a relative margin {} > {}".format( - message, - value, - target_value, - diff, - abs(relative_error_margin * target_value), - ) + ).all(), f"{message}{value} differs from {target_value} with a relative margin {diff} > {abs(relative_error_margin * target_value)}" + return None + return None -def assert_datetime_equals(value, target_value, message=""): - assert (value == target_value).all(), "{}{} differs from {}.".format( - message, value, target_value - ) +def assert_datetime_equals(value, target_value, message="") -> None: + assert ( + value == target_value + ).all(), f"{message}{value} differs from {target_value}." -def assert_enum_equals(value, target_value, message=""): +def assert_enum_equals(value, target_value, message="") -> None: value = value.decode_to_str() - assert (value == target_value).all(), "{}{} differs from {}.".format( - message, value, target_value - ) + assert ( + value == target_value + ).all(), f"{message}{value} differs from {target_value}." def indent(text): - return " {}".format(text.replace(os.linesep, "{} ".format(os.linesep))) + return " {}".format(text.replace(os.linesep, f"{os.linesep} ")) def get_trace_tool_link(scenario, variables, api_url, trace_tool_url): @@ -89,17 +77,16 @@ def get_trace_tool_link(scenario, variables, api_url, trace_tool_url): "scenarios": [scenario_json], "variables": variables, } - url = ( + return ( trace_tool_url + "?" + urllib.urlencode( { "simulation": json.dumps(simulation_json), "api_url": api_url, - } + }, ) ) - return url def eval_expression(expression): diff --git a/openfisca_core/tools/simulation_dumper.py b/openfisca_core/tools/simulation_dumper.py index ab21bd79f0..9b1f5708ad 100644 --- a/openfisca_core/tools/simulation_dumper.py +++ b/openfisca_core/tools/simulation_dumper.py @@ -1,19 +1,14 @@ -# -*- coding: utf-8 -*- - - import os -import numpy as np +import numpy from openfisca_core.data_storage import OnDiskStorage from openfisca_core.periods import DateUnit from openfisca_core.simulations import Simulation -def dump_simulation(simulation, directory): - """ - Write simulation data to directory, so that it can be restored later. - """ +def dump_simulation(simulation, directory) -> None: + """Write simulation data to directory, so that it can be restored later.""" parent_directory = os.path.abspath(os.path.join(directory, os.pardir)) if not os.path.isdir(parent_directory): # To deal with reforms os.mkdir(parent_directory) @@ -21,7 +16,8 @@ def dump_simulation(simulation, directory): os.mkdir(directory) if os.listdir(directory): - raise ValueError("Directory '{}' is not empty".format(directory)) + msg = f"Directory '{directory}' is not empty" + raise ValueError(msg) entities_dump_dir = os.path.join(directory, "__entities__") os.mkdir(entities_dump_dir) @@ -36,11 +32,10 @@ def dump_simulation(simulation, directory): def restore_simulation(directory, tax_benefit_system, **kwargs): - """ - Restore simulation from directory - """ + """Restore simulation from directory.""" simulation = Simulation( - tax_benefit_system, tax_benefit_system.instantiate_entities() + tax_benefit_system, + tax_benefit_system.instantiate_entities(), ) entities_dump_dir = os.path.join(directory, "__entities__") @@ -64,68 +59,74 @@ def restore_simulation(directory, tax_benefit_system, **kwargs): return simulation -def _dump_holder(holder, directory): +def _dump_holder(holder, directory) -> None: disk_storage = holder.create_disk_storage(directory, preserve=True) for period in holder.get_known_periods(): value = holder.get_array(period) disk_storage.put(value, period) -def _dump_entity(population, directory): +def _dump_entity(population, directory) -> None: path = os.path.join(directory, population.entity.key) os.mkdir(path) - np.save(os.path.join(path, "id.npy"), population.ids) + numpy.save(os.path.join(path, "id.npy"), population.ids) if population.entity.is_person: return - np.save(os.path.join(path, "members_position.npy"), population.members_position) - np.save(os.path.join(path, "members_entity_id.npy"), population.members_entity_id) + numpy.save(os.path.join(path, "members_position.npy"), population.members_position) + numpy.save( + os.path.join(path, "members_entity_id.npy"), population.members_entity_id + ) flattened_roles = population.entity.flattened_roles if len(flattened_roles) == 0: - encoded_roles = np.int64(0) + encoded_roles = numpy.int64(0) else: - encoded_roles = np.select( + encoded_roles = numpy.select( [population.members_role == role for role in flattened_roles], [role.key for role in flattened_roles], ) - np.save(os.path.join(path, "members_role.npy"), encoded_roles) + numpy.save(os.path.join(path, "members_role.npy"), encoded_roles) def _restore_entity(population, directory): path = os.path.join(directory, population.entity.key) - population.ids = np.load(os.path.join(path, "id.npy")) + population.ids = numpy.load(os.path.join(path, "id.npy")) if population.entity.is_person: - return + return None - population.members_position = np.load(os.path.join(path, "members_position.npy")) - population.members_entity_id = np.load(os.path.join(path, "members_entity_id.npy")) - encoded_roles = np.load(os.path.join(path, "members_role.npy")) + population.members_position = numpy.load(os.path.join(path, "members_position.npy")) + population.members_entity_id = numpy.load( + os.path.join(path, "members_entity_id.npy") + ) + encoded_roles = numpy.load(os.path.join(path, "members_role.npy")) flattened_roles = population.entity.flattened_roles if len(flattened_roles) == 0: - population.members_role = np.int64(0) + population.members_role = numpy.int64(0) else: - population.members_role = np.select( + population.members_role = numpy.select( [encoded_roles == role.key for role in flattened_roles], - [role for role in flattened_roles], + list(flattened_roles), ) person_count = len(population.members_entity_id) population.count = max(population.members_entity_id) + 1 return person_count -def _restore_holder(simulation, variable, directory): +def _restore_holder(simulation, variable, directory) -> None: storage_dir = os.path.join(directory, variable) is_variable_eternal = ( simulation.tax_benefit_system.get_variable(variable).definition_period == DateUnit.ETERNITY ) disk_storage = OnDiskStorage( - storage_dir, is_eternal=is_variable_eternal, preserve_storage_dir=True + storage_dir, + is_eternal=is_variable_eternal, + preserve_storage_dir=True, ) disk_storage.restore() diff --git a/openfisca_core/tools/test_runner.py b/openfisca_core/tools/test_runner.py index ea77401c6e..fcb5572b79 100644 --- a/openfisca_core/tools/test_runner.py +++ b/openfisca_core/tools/test_runner.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Any, Dict, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any from typing_extensions import Literal, TypedDict from openfisca_core.types import TaxBenefitSystem @@ -23,10 +24,10 @@ class Options(TypedDict, total=False): aggregate: bool - ignore_variables: Optional[Sequence[str]] - max_depth: Optional[int] - name_filter: Optional[str] - only_variables: Optional[Sequence[str]] + ignore_variables: Sequence[str] | None + max_depth: int | None + name_filter: str | None + only_variables: Sequence[str] | None pdb: bool performance_graph: bool performance_tables: bool @@ -35,9 +36,9 @@ class Options(TypedDict, total=False): @dataclasses.dataclass(frozen=True) class ErrorMargin: - __root__: Dict[Union[str, Literal["default"]], Optional[float]] + __root__: dict[str | Literal["default"], float | None] - def __getitem__(self, key: str) -> Optional[float]: + def __getitem__(self, key: str) -> float | None: if key in self.__root__: return self.__root__[key] @@ -49,19 +50,17 @@ class Test: absolute_error_margin: ErrorMargin relative_error_margin: ErrorMargin name: str = "" - input: Dict[str, Union[float, Dict[str, float]]] = dataclasses.field( - default_factory=dict - ) - output: Optional[Dict[str, Union[float, Dict[str, float]]]] = None - period: Optional[str] = None + input: dict[str, float | dict[str, float]] = dataclasses.field(default_factory=dict) + output: dict[str, float | dict[str, float]] | None = None + period: str | None = None reforms: Sequence[str] = dataclasses.field(default_factory=list) - keywords: Optional[Sequence[str]] = None + keywords: Sequence[str] | None = None extensions: Sequence[str] = dataclasses.field(default_factory=list) - description: Optional[str] = None - max_spiral_loops: Optional[int] = None + description: str | None = None + max_spiral_loops: int | None = None -def build_test(params: Dict[str, Any]) -> Test: +def build_test(params: dict[str, Any]) -> Test: for key in ["absolute_error_margin", "relative_error_margin"]: value = params.get(key) @@ -111,14 +110,14 @@ def import_yaml(): yaml, Loader = import_yaml() -_tax_benefit_system_cache: Dict = {} +_tax_benefit_system_cache: dict = {} options: Options = Options() def run_tests( tax_benefit_system: TaxBenefitSystem, - paths: Union[str, Sequence[str]], + paths: str | Sequence[str], options: Options = options, ) -> int: """Runs all the YAML tests contained in a file or a directory. @@ -147,7 +146,6 @@ def run_tests( +-------------------------------+-----------+-------------------------------------------+ """ - argv = [] plugins = [OpenFiscaPlugin(tax_benefit_system, options)] @@ -164,8 +162,8 @@ def run_tests( class YamlFile(pytest.File): - def __init__(self, *, tax_benefit_system, options, **kwargs): - super(YamlFile, self).__init__(**kwargs) + def __init__(self, *, tax_benefit_system, options, **kwargs) -> None: + super().__init__(**kwargs) self.tax_benefit_system = tax_benefit_system self.options = options @@ -177,12 +175,12 @@ def collect(self): [ traceback.format_exc(), f"'{self.path}' is not a valid YAML file. Check the stack trace above for more details.", - ] + ], ) raise ValueError(message) if not isinstance(tests, list): - tests: Sequence[Dict] = [tests] + tests: Sequence[dict] = [tests] for test in tests: if not self.should_ignore(test): @@ -205,19 +203,17 @@ def should_ignore(self, test): class YamlItem(pytest.Item): - """ - Terminal nodes of the test collection tree. - """ + """Terminal nodes of the test collection tree.""" - def __init__(self, *, baseline_tax_benefit_system, test, options, **kwargs): - super(YamlItem, self).__init__(**kwargs) + def __init__(self, *, baseline_tax_benefit_system, test, options, **kwargs) -> None: + super().__init__(**kwargs) self.baseline_tax_benefit_system = baseline_tax_benefit_system self.options = options self.test = build_test(test) self.simulation = None self.tax_benefit_system = None - def runtest(self): + def runtest(self) -> None: self.name = self.test.name if self.test.output is None: @@ -247,10 +243,10 @@ def runtest(self): raise except Exception as e: error_message = os.linesep.join( - [str(e), "", f"Unexpected error raised while parsing '{self.path}'"] + [str(e), "", f"Unexpected error raised while parsing '{self.path}'"], ) raise ValueError(error_message).with_traceback( - sys.exc_info()[2] + sys.exc_info()[2], ) from e # Keep the stack trace from the root error if max_spiral_loops: @@ -268,17 +264,16 @@ def runtest(self): if performance_tables: self.generate_performance_tables(tracer) - def print_computation_log(self, tracer, aggregate, max_depth): - print("Computation log:") # noqa T001 + def print_computation_log(self, tracer, aggregate, max_depth) -> None: tracer.print_computation_log(aggregate, max_depth) - def generate_performance_graph(self, tracer): + def generate_performance_graph(self, tracer) -> None: tracer.generate_performance_graph(".") - def generate_performance_tables(self, tracer): + def generate_performance_tables(self, tracer) -> None: tracer.generate_performance_tables(".") - def check_output(self): + def check_output(self) -> None: output = self.test.output if output is None: @@ -296,18 +291,25 @@ def check_output(self): for variable_name, value in instance_values.items(): entity_index = population.get_index(instance_id) self.check_variable( - variable_name, value, self.test.period, entity_index + variable_name, + value, + self.test.period, + entity_index, ) else: raise VariableNotFound(key, self.tax_benefit_system) def check_variable( - self, variable_name: str, expected_value, period, entity_index=None + self, + variable_name: str, + expected_value, + period, + entity_index=None, ): if self.should_ignore_variable(variable_name): - return + return None - if isinstance(expected_value, Dict): + if isinstance(expected_value, dict): for requested_period, expected_value_at_period in expected_value.items(): self.check_variable( variable_name, @@ -345,9 +347,10 @@ def should_ignore_variable(self, variable_name: str): def repr_failure(self, excinfo): if not isinstance( - excinfo.value, (AssertionError, VariableNotFound, SituationParsingError) + excinfo.value, + (AssertionError, VariableNotFound, SituationParsingError), ): - return super(YamlItem, self).repr_failure(excinfo) + return super().repr_failure(excinfo) message = excinfo.value.args[0] if isinstance(excinfo.value, SituationParsingError): @@ -355,21 +358,20 @@ def repr_failure(self, excinfo): return os.linesep.join( [ - f"{str(self.path)}:", - f" Test '{str(self.name)}':", + f"{self.path!s}:", + f" Test '{self.name!s}':", textwrap.indent(message, " "), - ] + ], ) -class OpenFiscaPlugin(object): - def __init__(self, tax_benefit_system, options): +class OpenFiscaPlugin: + def __init__(self, tax_benefit_system, options) -> None: self.tax_benefit_system = tax_benefit_system self.options = options def pytest_collect_file(self, parent, path): - """ - Called by pytest for all plugins. + """Called by pytest for all plugins. :return: The collector for test methods. """ if path.ext in [".yaml", ".yml"]: @@ -379,6 +381,7 @@ def pytest_collect_file(self, parent, path): tax_benefit_system=self.tax_benefit_system, options=self.options, ) + return None def _get_tax_benefit_system(baseline, reforms, extensions): @@ -396,7 +399,7 @@ def _get_tax_benefit_system(baseline, reforms, extensions): for reform_path in reforms: current_tax_benefit_system = current_tax_benefit_system.apply_reform( - reform_path + reform_path, ) for extension in extensions: diff --git a/openfisca_core/tracers/computation_log.py b/openfisca_core/tracers/computation_log.py index 1013b828f2..6310eb8849 100644 --- a/openfisca_core/tracers/computation_log.py +++ b/openfisca_core/tracers/computation_log.py @@ -1,17 +1,17 @@ from __future__ import annotations import typing -from typing import List, Optional, Union +from typing import Union import numpy from openfisca_core.indexed_enums import EnumArray -from .. import tracers - if typing.TYPE_CHECKING: from numpy.typing import ArrayLike + from openfisca_core import tracers + Array = Union[EnumArray, ArrayLike] @@ -23,7 +23,7 @@ def __init__(self, full_tracer: tracers.FullTracer) -> None: def display( self, - value: Optional[Array], + value: Array | None, ) -> str: if isinstance(value, EnumArray): value = value.decode_to_str() @@ -33,8 +33,8 @@ def display( def lines( self, aggregate: bool = False, - max_depth: Optional[int] = None, - ) -> List[str]: + max_depth: int | None = None, + ) -> list[str]: depth = 1 lines_by_tree = [ @@ -45,8 +45,7 @@ def lines( return self._flatten(lines_by_tree) def print_log(self, aggregate=False, max_depth=None) -> None: - """ - Print the computation log of a simulation. + """Print the computation log of a simulation. If ``aggregate`` is ``False`` (default), print the value of each computed vector. @@ -61,16 +60,16 @@ def print_log(self, aggregate=False, max_depth=None) -> None: If ``max_depth`` is set, for example to ``3``, only print computed vectors up to a depth of ``max_depth``. """ - for line in self.lines(aggregate, max_depth): - print(line) # noqa T001 + for _line in self.lines(aggregate, max_depth): + pass def _get_node_log( self, node: tracers.TraceNode, depth: int, aggregate: bool, - max_depth: Optional[int], - ) -> List[str]: + max_depth: int | None, + ) -> list[str]: if max_depth is not None and depth > max_depth: return [] @@ -88,7 +87,7 @@ def _print_line( depth: int, node: tracers.TraceNode, aggregate: bool, - max_depth: Optional[int], + max_depth: int | None, ) -> str: indent = " " * depth value = node.value @@ -103,7 +102,7 @@ def _print_line( "avg": numpy.mean(value), "max": numpy.max(value), "min": numpy.min(value), - } + }, ) except TypeError: @@ -116,6 +115,6 @@ def _print_line( def _flatten( self, - lists: List[List[str]], - ) -> List[str]: + lists: list[list[str]], + ) -> list[str]: return [item for list_ in lists for item in list_] diff --git a/openfisca_core/tracers/flat_trace.py b/openfisca_core/tracers/flat_trace.py index 25aa75f21d..2090d537b8 100644 --- a/openfisca_core/tracers/flat_trace.py +++ b/openfisca_core/tracers/flat_trace.py @@ -1,18 +1,19 @@ from __future__ import annotations import typing -from typing import Dict, Optional, Union +from typing import Union import numpy -from openfisca_core import tracers from openfisca_core.indexed_enums import EnumArray if typing.TYPE_CHECKING: from numpy.typing import ArrayLike + from openfisca_core import tracers + Array = Union[EnumArray, ArrayLike] - Trace = Dict[str, dict] + Trace = dict[str, dict] class FlatTrace: @@ -39,7 +40,7 @@ def get_trace(self) -> dict: key: node_trace for key, node_trace in self._get_flat_trace(node).items() if key not in trace - } + }, ) return trace @@ -52,13 +53,14 @@ def get_serialized_trace(self) -> dict: def serialize( self, - value: Optional[Array], - ) -> Union[Optional[Array], list]: + value: Array | None, + ) -> Array | None | list: if isinstance(value, EnumArray): value = value.decode_to_str() if isinstance(value, numpy.ndarray) and numpy.issubdtype( - value.dtype, numpy.dtype(bytes) + value.dtype, + numpy.dtype(bytes), ): value = value.astype(numpy.dtype(str)) @@ -73,7 +75,7 @@ def _get_flat_trace( ) -> Trace: key = self.key(node) - node_trace = { + return { key: { "dependencies": [self.key(child) for child in node.children], "parameters": { @@ -85,5 +87,3 @@ def _get_flat_trace( "formula_time": node.formula_time(), }, } - - return node_trace diff --git a/openfisca_core/tracers/full_tracer.py b/openfisca_core/tracers/full_tracer.py index 085607e125..9fa94d5ab5 100644 --- a/openfisca_core/tracers/full_tracer.py +++ b/openfisca_core/tracers/full_tracer.py @@ -1,24 +1,25 @@ from __future__ import annotations import typing -from typing import Dict, Iterator, List, Optional, Union +from typing import Union import time -from .. import tracers +from openfisca_core import tracers if typing.TYPE_CHECKING: + from collections.abc import Iterator from numpy.typing import ArrayLike from openfisca_core.periods import Period - Stack = List[Dict[str, Union[str, Period]]] + Stack = list[dict[str, Union[str, Period]]] class FullTracer: _simple_tracer: tracers.SimpleTracer _trees: list - _current_node: Optional[tracers.TraceNode] + _current_node: tracers.TraceNode | None def __init__(self) -> None: self._simple_tracer = tracers.SimpleTracer() @@ -66,7 +67,7 @@ def record_parameter_access( def _record_start_time( self, - time_in_s: Optional[float] = None, + time_in_s: float | None = None, ) -> None: if time_in_s is None: time_in_s = self._get_time_in_sec() @@ -85,7 +86,7 @@ def record_calculation_end(self) -> None: def _record_end_time( self, - time_in_s: Optional[float] = None, + time_in_s: float | None = None, ) -> None: if time_in_s is None: time_in_s = self._get_time_in_sec() @@ -102,7 +103,7 @@ def stack(self) -> Stack: return self._simple_tracer.stack @property - def trees(self) -> List[tracers.TraceNode]: + def trees(self) -> list[tracers.TraceNode]: return self._trees @property @@ -120,7 +121,7 @@ def flat_trace(self) -> tracers.FlatTrace: def _get_time_in_sec(self) -> float: return time.time_ns() / (10**9) - def print_computation_log(self, aggregate=False, max_depth=None): + def print_computation_log(self, aggregate=False, max_depth=None) -> None: self.computation_log.print_log(aggregate, max_depth) def generate_performance_graph(self, dir_path: str) -> None: diff --git a/openfisca_core/tracers/performance_log.py b/openfisca_core/tracers/performance_log.py index 89917dc50e..f69a3dd3a2 100644 --- a/openfisca_core/tracers/performance_log.py +++ b/openfisca_core/tracers/performance_log.py @@ -8,12 +8,12 @@ import json import os -from .. import tracers +from openfisca_core import tracers if typing.TYPE_CHECKING: - Trace = typing.Dict[str, dict] - Calculation = typing.Tuple[str, dict] - SortedTrace = typing.List[Calculation] + Trace = dict[str, dict] + Calculation = tuple[str, dict] + SortedTrace = list[Calculation] class PerformanceLog: @@ -54,7 +54,7 @@ def generate_performance_tables(self, dir_path: str) -> None: aggregated_csv_rows = [ {"name": key, **aggregated_time} for key, aggregated_time in self.aggregate_calculation_times( - flat_trace + flat_trace, ).items() ] @@ -66,7 +66,7 @@ def generate_performance_tables(self, dir_path: str) -> None: def aggregate_calculation_times( self, flat_trace: Trace, - ) -> typing.Dict[str, dict]: + ) -> dict[str, dict]: def _aggregate_calculations(calculations: list) -> dict: calculation_count = len(calculations) @@ -83,10 +83,10 @@ def _aggregate_calculations(calculations: list) -> dict: "calculation_time": tracers.TraceNode.round(calculation_time), "formula_time": tracers.TraceNode.round(formula_time), "avg_calculation_time": tracers.TraceNode.round( - calculation_time / calculation_count + calculation_time / calculation_count, ), "avg_formula_time": tracers.TraceNode.round( - formula_time / calculation_count + formula_time / calculation_count, ), } @@ -98,7 +98,8 @@ def _groupby(calculation: Calculation) -> str: return { variable_name: _aggregate_calculations(list(calculations)) for variable_name, calculations in itertools.groupby( - all_calculations, _groupby + all_calculations, + _groupby, ) } @@ -122,7 +123,7 @@ def _json_tree(self, tree: tracers.TraceNode) -> dict: "children": children, } - def _write_csv(self, path: str, rows: typing.List[dict]) -> None: + def _write_csv(self, path: str, rows: list[dict]) -> None: fieldnames = list(rows[0].keys()) with open(path, "w") as csv_file: diff --git a/openfisca_core/tracers/simple_tracer.py b/openfisca_core/tracers/simple_tracer.py index 27bcad2e8c..84328730ef 100644 --- a/openfisca_core/tracers/simple_tracer.py +++ b/openfisca_core/tracers/simple_tracer.py @@ -1,14 +1,14 @@ from __future__ import annotations import typing -from typing import Dict, List, Union +from typing import Union if typing.TYPE_CHECKING: from numpy.typing import ArrayLike from openfisca_core.periods import Period - Stack = List[Dict[str, Union[str, Period]]] + Stack = list[dict[str, Union[str, Period]]] class SimpleTracer: @@ -23,7 +23,7 @@ def record_calculation_start(self, variable: str, period: Period | int) -> None: def record_calculation_result(self, value: ArrayLike) -> None: pass # ignore calculation result - def record_parameter_access(self, parameter: str, period, value): + def record_parameter_access(self, parameter: str, period, value) -> None: pass def record_calculation_end(self) -> None: diff --git a/openfisca_core/tracers/trace_node.py b/openfisca_core/tracers/trace_node.py index 4e0cceae0a..ff55a5714f 100644 --- a/openfisca_core/tracers/trace_node.py +++ b/openfisca_core/tracers/trace_node.py @@ -18,10 +18,10 @@ class TraceNode: name: str period: Period - parent: typing.Optional[TraceNode] = None - children: typing.List[TraceNode] = dataclasses.field(default_factory=list) - parameters: typing.List[TraceNode] = dataclasses.field(default_factory=list) - value: typing.Optional[Array] = None + parent: TraceNode | None = None + children: list[TraceNode] = dataclasses.field(default_factory=list) + parameters: list[TraceNode] = dataclasses.field(default_factory=list) + value: Array | None = None start: float = 0 end: float = 0 diff --git a/openfisca_core/tracers/tracing_parameter_node_at_instant.py b/openfisca_core/tracers/tracing_parameter_node_at_instant.py index f618f59e97..074c24221d 100644 --- a/openfisca_core/tracers/tracing_parameter_node_at_instant.py +++ b/openfisca_core/tracers/tracing_parameter_node_at_instant.py @@ -7,8 +7,6 @@ from openfisca_core import parameters -from .. import tracers - ParameterNode = Union[ parameters.VectorialParameterNodeAtInstant, parameters.ParameterNodeAtInstant, @@ -17,6 +15,8 @@ if typing.TYPE_CHECKING: from numpy.typing import ArrayLike + from openfisca_core import tracers + Child = Union[ParameterNode, ArrayLike] @@ -32,7 +32,7 @@ def __init__( def __getattr__( self, key: str, - ) -> Union[TracingParameterNodeAtInstant, Child]: + ) -> TracingParameterNodeAtInstant | Child: child = getattr(self.parameter_node_at_instant, key) return self.get_traced_child(child, key) @@ -44,16 +44,16 @@ def __iter__(self): def __getitem__( self, - key: Union[str, ArrayLike], - ) -> Union[TracingParameterNodeAtInstant, Child]: + key: str | ArrayLike, + ) -> TracingParameterNodeAtInstant | Child: child = self.parameter_node_at_instant[key] return self.get_traced_child(child, key) def get_traced_child( self, child: Child, - key: Union[str, ArrayLike], - ) -> Union[TracingParameterNodeAtInstant, Child]: + key: str | ArrayLike, + ) -> TracingParameterNodeAtInstant | Child: period = self.parameter_node_at_instant._instant_str if isinstance( @@ -75,9 +75,9 @@ def get_traced_child( name = self.parameter_node_at_instant._name else: - name = ".".join([self.parameter_node_at_instant._name, key]) + name = f"{self.parameter_node_at_instant._name}.{key}" - if isinstance(child, (numpy.ndarray,) + parameters.ALLOWED_PARAM_TYPES): + if isinstance(child, (numpy.ndarray, *parameters.ALLOWED_PARAM_TYPES)): self.tracer.record_parameter_access(name, period, child) return child diff --git a/openfisca_core/types.py b/openfisca_core/types.py index b34a555434..16a1f0e90e 100644 --- a/openfisca_core/types.py +++ b/openfisca_core/types.py @@ -35,28 +35,23 @@ class CoreEntity(Protocol): plural: Any @abc.abstractmethod - def check_role_validity(self, role: Any) -> None: - ... + def check_role_validity(self, role: Any) -> None: ... @abc.abstractmethod - def check_variable_defined_for_entity(self, variable_name: Any) -> None: - ... + def check_variable_defined_for_entity(self, variable_name: Any) -> None: ... @abc.abstractmethod def get_variable( self, variable_name: Any, check_existence: Any = ..., - ) -> Any | None: - ... + ) -> Any | None: ... -class SingleEntity(CoreEntity, Protocol): - ... +class SingleEntity(CoreEntity, Protocol): ... -class GroupEntity(CoreEntity, Protocol): - ... +class GroupEntity(CoreEntity, Protocol): ... class Role(Protocol): @@ -65,8 +60,7 @@ class Role(Protocol): subroles: Any @property - def key(self) -> str: - ... + def key(self) -> str: ... # Holders @@ -74,40 +68,34 @@ def key(self) -> str: class Holder(Protocol): @abc.abstractmethod - def clone(self, population: Any) -> Holder: - ... + def clone(self, population: Any) -> Holder: ... @abc.abstractmethod - def get_memory_usage(self) -> Any: - ... + def get_memory_usage(self) -> Any: ... # Parameters @typing_extensions.runtime_checkable -class ParameterNodeAtInstant(Protocol): - ... +class ParameterNodeAtInstant(Protocol): ... # Periods -class Instant(Protocol): - ... +class Instant(Protocol): ... @typing_extensions.runtime_checkable class Period(Protocol): @property @abc.abstractmethod - def start(self) -> Any: - ... + def start(self) -> Any: ... @property @abc.abstractmethod - def unit(self) -> Any: - ... + def unit(self) -> Any: ... # Populations @@ -117,8 +105,7 @@ class Population(Protocol): entity: Any @abc.abstractmethod - def get_holder(self, variable_name: Any) -> Any: - ... + def get_holder(self, variable_name: Any) -> Any: ... # Simulations @@ -126,20 +113,16 @@ def get_holder(self, variable_name: Any) -> Any: class Simulation(Protocol): @abc.abstractmethod - def calculate(self, variable_name: Any, period: Any) -> Any: - ... + def calculate(self, variable_name: Any, period: Any) -> Any: ... @abc.abstractmethod - def calculate_add(self, variable_name: Any, period: Any) -> Any: - ... + def calculate_add(self, variable_name: Any, period: Any) -> Any: ... @abc.abstractmethod - def calculate_divide(self, variable_name: Any, period: Any) -> Any: - ... + def calculate_divide(self, variable_name: Any, period: Any) -> Any: ... @abc.abstractmethod - def get_population(self, plural: Any | None) -> Any: - ... + def get_population(self, plural: Any | None) -> Any: ... # Tax-Benefit systems @@ -171,11 +154,9 @@ def __call__( population: Population, instant: Instant, params: Params, - ) -> Array[Any]: - ... + ) -> Array[Any]: ... class Params(Protocol): @abc.abstractmethod - def __call__(self, instant: Instant) -> ParameterNodeAtInstant: - ... + def __call__(self, instant: Instant) -> ParameterNodeAtInstant: ... diff --git a/openfisca_core/variables/helpers.py b/openfisca_core/variables/helpers.py index ce1eede9fc..5038a78240 100644 --- a/openfisca_core/variables/helpers.py +++ b/openfisca_core/variables/helpers.py @@ -1,19 +1,16 @@ from __future__ import annotations -from typing import Optional - import sortedcontainers +from openfisca_core import variables from openfisca_core.periods import Period -from .. import variables - def get_annualized_variable( - variable: variables.Variable, annualization_period: Optional[Period] = None + variable: variables.Variable, + annualization_period: Period | None = None, ) -> variables.Variable: - """ - Returns a clone of ``variable`` that is annualized for the period ``annualization_period``. + """Returns a clone of ``variable`` that is annualized for the period ``annualization_period``. When annualized, a variable's formula is only called for a January calculation, and the results for other months are assumed to be identical. """ @@ -34,23 +31,24 @@ def annual_formula(population, period, parameters): { key: make_annual_formula(formula, annualization_period) for key, formula in variable.formulas.items() - } + }, ) return new_variable def get_neutralized_variable(variable): - """ - Return a new neutralized variable (to be used by reforms). + """Return a new neutralized variable (to be used by reforms). A neutralized variable always returns its default value, and does not cache anything. """ result = variable.clone() result.is_neutralized = True result.label = ( - "[Neutralized]" - if variable.label is None - else "[Neutralized] {}".format(variable.label), + ( + "[Neutralized]" + if variable.label is None + else f"[Neutralized] {variable.label}" + ), ) return result diff --git a/openfisca_core/variables/tests/test_definition_period.py b/openfisca_core/variables/tests/test_definition_period.py index 7938aaeaef..8ef9bfaa87 100644 --- a/openfisca_core/variables/tests/test_definition_period.py +++ b/openfisca_core/variables/tests/test_definition_period.py @@ -13,31 +13,31 @@ class TestVariable(Variable): return TestVariable -def test_weekday_variable(variable): +def test_weekday_variable(variable) -> None: variable.definition_period = periods.WEEKDAY assert variable() -def test_week_variable(variable): +def test_week_variable(variable) -> None: variable.definition_period = periods.WEEK assert variable() -def test_day_variable(variable): +def test_day_variable(variable) -> None: variable.definition_period = periods.DAY assert variable() -def test_month_variable(variable): +def test_month_variable(variable) -> None: variable.definition_period = periods.MONTH assert variable() -def test_year_variable(variable): +def test_year_variable(variable) -> None: variable.definition_period = periods.YEAR assert variable() -def test_eternity_variable(variable): +def test_eternity_variable(variable) -> None: variable.definition_period = periods.ETERNITY assert variable() diff --git a/openfisca_core/variables/variable.py b/openfisca_core/variables/variable.py index e70c0d05d9..77411c32bb 100644 --- a/openfisca_core/variables/variable.py +++ b/openfisca_core/variables/variable.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Optional, Union +from typing import NoReturn from openfisca_core.types import Formula, Instant @@ -20,8 +20,7 @@ class Variable: - """ - A `variable `_ of the legislation. + """A `variable `_ of the legislation. Main attributes: @@ -102,7 +101,7 @@ class Variable: __name__: str - def __init__(self, baseline_variable=None): + def __init__(self, baseline_variable=None) -> None: self.name = self.__class__.__name__ attr = { name: value @@ -111,21 +110,30 @@ def __init__(self, baseline_variable=None): } self.baseline_variable = baseline_variable self.value_type = self.set( - attr, "value_type", required=True, allowed_values=config.VALUE_TYPES.keys() + attr, + "value_type", + required=True, + allowed_values=config.VALUE_TYPES.keys(), ) self.dtype = config.VALUE_TYPES[self.value_type]["dtype"] self.json_type = config.VALUE_TYPES[self.value_type]["json_type"] if self.value_type == Enum: self.possible_values = self.set( - attr, "possible_values", required=True, setter=self.set_possible_values + attr, + "possible_values", + required=True, + setter=self.set_possible_values, ) if self.value_type == str: self.max_length = self.set(attr, "max_length", allowed_type=int) if self.max_length: - self.dtype = "|S{}".format(self.max_length) + self.dtype = f"|S{self.max_length}" if self.value_type == Enum: self.default_value = self.set( - attr, "default_value", allowed_type=self.possible_values, required=True + attr, + "default_value", + allowed_type=self.possible_values, + required=True, ) else: self.default_value = self.set( @@ -136,7 +144,10 @@ def __init__(self, baseline_variable=None): ) self.entity = self.set(attr, "entity", required=True, setter=self.set_entity) self.definition_period = self.set( - attr, "definition_period", required=True, allowed_values=DateUnit + attr, + "definition_period", + required=True, + allowed_values=DateUnit, ) self.label = self.set(attr, "label", allowed_type=str, setter=self.set_label) self.end = self.set(attr, "end", allowed_type=str, setter=self.set_end) @@ -144,11 +155,14 @@ def __init__(self, baseline_variable=None): self.cerfa_field = self.set(attr, "cerfa_field", allowed_type=(str, dict)) self.unit = self.set(attr, "unit", allowed_type=str) self.documentation = self.set( - attr, "documentation", allowed_type=str, setter=self.set_documentation + attr, + "documentation", + allowed_type=str, + setter=self.set_documentation, ) self.set_input = self.set_set_input(attr.pop("set_input", None)) self.calculate_output = self.set_calculate_output( - attr.pop("calculate_output", None) + attr.pop("calculate_output", None), ) self.is_period_size_independent = self.set( attr, @@ -163,15 +177,18 @@ def __init__(self, baseline_variable=None): ) formulas_attr, unexpected_attrs = helpers._partition( - attr, lambda name, value: name.startswith(config.FORMULA_NAME_PREFIX) + attr, + lambda name, value: name.startswith(config.FORMULA_NAME_PREFIX), ) self.formulas = self.set_formulas(formulas_attr) if unexpected_attrs: + msg = 'Unexpected attributes in definition of variable "{}": {!r}'.format( + self.name, + ", ".join(sorted(unexpected_attrs.keys())), + ) raise ValueError( - 'Unexpected attributes in definition of variable "{}": {!r}'.format( - self.name, ", ".join(sorted(unexpected_attrs.keys())) - ) + msg, ) self.is_neutralized = False @@ -192,16 +209,14 @@ def set( if value is None and self.baseline_variable: return getattr(self.baseline_variable, attribute_name) if required and value is None: + msg = f"Missing attribute '{attribute_name}' in definition of variable '{self.name}'." raise ValueError( - "Missing attribute '{}' in definition of variable '{}'.".format( - attribute_name, self.name - ) + msg, ) if allowed_values is not None and value not in allowed_values: + msg = f"Invalid value '{value}' for attribute '{attribute_name}' in variable '{self.name}'. Allowed values are '{allowed_values}'." raise ValueError( - "Invalid value '{}' for attribute '{}' in variable '{}'. Allowed values are '{}'.".format( - value, attribute_name, self.name, allowed_values - ) + msg, ) if ( allowed_type is not None @@ -211,10 +226,9 @@ def set( if allowed_type == float and isinstance(value, int): value = float(value) else: + msg = f"Invalid value '{value}' for attribute '{attribute_name}' in variable '{self.name}'. Must be of type '{allowed_type}'." raise ValueError( - "Invalid value '{}' for attribute '{}' in variable '{}'. Must be of type '{}'.".format( - value, attribute_name, self.name, allowed_type - ) + msg, ) if setter is not None: value = setter(value) @@ -224,35 +238,38 @@ def set( def set_entity(self, entity): if not isinstance(entity, (Entity, GroupEntity)): - raise ValueError( + msg = ( f"Invalid value '{entity}' for attribute 'entity' in variable " f"'{self.name}'. Must be an instance of Entity or GroupEntity." ) + raise ValueError( + msg, + ) return entity def set_possible_values(self, possible_values): if not issubclass(possible_values, Enum): + msg = f"Invalid value '{possible_values}' for attribute 'possible_values' in variable '{self.name}'. Must be a subclass of {Enum}." raise ValueError( - "Invalid value '{}' for attribute 'possible_values' in variable '{}'. Must be a subclass of {}.".format( - possible_values, self.name, Enum - ) + msg, ) return possible_values def set_label(self, label): if label: return label + return None def set_end(self, end): if end: try: return datetime.datetime.strptime(end, "%Y-%m-%d").date() except ValueError: + msg = f"Incorrect 'end' attribute format in '{self.name}'. 'YYYY-MM-DD' expected where YYYY, MM and DD are year, month and day. Found: {end}" raise ValueError( - "Incorrect 'end' attribute format in '{}'. 'YYYY-MM-DD' expected where YYYY, MM and DD are year, month and day. Found: {}".format( - self.name, end - ) + msg, ) + return None def set_reference(self, reference): if reference: @@ -263,18 +280,16 @@ def set_reference(self, reference): elif isinstance(reference, tuple): reference = list(reference) else: + msg = f"The reference of the variable {self.name} is a {type(reference)} instead of a String or a List of Strings." raise TypeError( - "The reference of the variable {} is a {} instead of a String or a List of Strings.".format( - self.name, type(reference) - ) + msg, ) for element in reference: if not isinstance(element, str): + msg = f"The reference of the variable {self.name} is a {type(reference)} instead of a String or a List of Strings." raise TypeError( - "The reference of the variable {} is a {} instead of a String or a List of Strings.".format( - self.name, type(reference) - ) + msg, ) return reference @@ -282,6 +297,7 @@ def set_reference(self, reference): def set_documentation(self, documentation): if documentation: return textwrap.dedent(documentation) + return None def set_set_input(self, set_input): if not set_input and self.baseline_variable: @@ -299,10 +315,9 @@ def set_formulas(self, formulas_attr): starting_date = self.parse_formula_name(formula_name) if self.end is not None and starting_date > self.end: + msg = f'You declared that "{self.name}" ends on "{self.end}", but you wrote a formula to calculate it from "{starting_date}" ({formula_name}). The "end" attribute of a variable must be posterior to the start dates of all its formulas.' raise ValueError( - 'You declared that "{}" ends on "{}", but you wrote a formula to calculate it from "{}" ({}). The "end" attribute of a variable must be posterior to the start dates of all its formulas.'.format( - self.name, self.end, starting_date, formula_name - ) + msg, ) formulas[str(starting_date)] = formula @@ -316,14 +331,13 @@ def set_formulas(self, formulas_attr): for baseline_start_date, baseline_formula in self.baseline_variable.formulas.items() if first_reform_formula_date is None or baseline_start_date < first_reform_formula_date - } + }, ) return formulas def parse_formula_name(self, attribute_name): - """ - Returns the starting date of a formula based on its name. + """Returns the starting date of a formula based on its name. Valid dated name formats are : 'formula', 'formula_YYYY', 'formula_YYYY_MM' and 'formula_YYYY_MM_DD' where YYYY, MM and DD are a year, month and day. @@ -333,11 +347,10 @@ def parse_formula_name(self, attribute_name): - `formula_YYYY_MM` is `YYYY-MM-01` """ - def raise_error(): + def raise_error() -> NoReturn: + msg = f'Unrecognized formula name in variable "{self.name}". Expecting "formula_YYYY" or "formula_YYYY_MM" or "formula_YYYY_MM_DD where YYYY, MM and DD are year, month and day. Found: "{attribute_name}".' raise ValueError( - 'Unrecognized formula name in variable "{}". Expecting "formula_YYYY" or "formula_YYYY_MM" or "formula_YYYY_MM_DD where YYYY, MM and DD are year, month and day. Found: "{}".'.format( - self.name, attribute_name - ) + msg, ) if attribute_name == config.FORMULA_NAME_PREFIX: @@ -349,7 +362,7 @@ def raise_error(): if not match: raise_error() date_str = "-".join( - [match.group(1), match.group(2) or "01", match.group(3) or "01"] + [match.group(1), match.group(2) or "01", match.group(3) or "01"], ) try: @@ -360,9 +373,7 @@ def raise_error(): # ----- Methods ----- # def is_input_variable(self): - """ - Returns True if the variable is an input variable. - """ + """Returns True if the variable is an input variable.""" return len(self.formulas) == 0 @classmethod @@ -374,8 +385,8 @@ def get_introspection_data(cls): def get_formula( self, - period: Union[Instant, Period, str, int] = None, - ) -> Optional[Formula]: + period: Instant | Period | str | int = None, + ) -> Formula | None: """Returns the formula to compute the variable at the given period. If no period is given and the variable has several formulas, the method @@ -388,14 +399,15 @@ def get_formula( Formula used to compute the variable. """ - - instant: Optional[Instant] + instant: Instant | None if not self.formulas: return None if period is None: - return self.formulas.peekitem(index=0)[ + return self.formulas.peekitem( + index=0, + )[ 1 ] # peekitem gets the 1st key-value tuple (the oldest start_date and formula). Return the formula. @@ -422,8 +434,7 @@ def get_formula( return None def clone(self): - clone = self.__class__() - return clone + return self.__class__() def check_set_value(self, value): if self.value_type == Enum and isinstance(value, str): @@ -431,39 +442,33 @@ def check_set_value(self, value): value = self.possible_values[value].index except KeyError: possible_values = [item.name for item in self.possible_values] + msg = "'{}' is not a known value for '{}'. Possible values are ['{}'].".format( + value, + self.name, + "', '".join(possible_values), + ) raise ValueError( - "'{}' is not a known value for '{}'. Possible values are ['{}'].".format( - value, self.name, "', '".join(possible_values) - ) + msg, ) if self.value_type in (float, int) and isinstance(value, str): try: value = tools.eval_expression(value) except SyntaxError: + msg = f"I couldn't understand '{value}' as a value for '{self.name}'" raise ValueError( - "I couldn't understand '{}' as a value for '{}'".format( - value, self.name - ) + msg, ) try: value = numpy.array([value], dtype=self.dtype)[0] except (TypeError, ValueError): if self.value_type == datetime.date: - error_message = "Can't deal with date: '{}'.".format(value) + error_message = f"Can't deal with date: '{value}'." else: - error_message = ( - "Can't deal with value: expected type {}, received '{}'.".format( - self.json_type, value - ) - ) + error_message = f"Can't deal with value: expected type {self.json_type}, received '{value}'." raise ValueError(error_message) except OverflowError: - error_message = ( - "Can't deal with value: '{}', it's too large for type '{}'.".format( - value, self.json_type - ) - ) + error_message = f"Can't deal with value: '{value}', it's too large for type '{self.json_type}'." raise ValueError(error_message) return value diff --git a/openfisca_core/warnings/libyaml_warning.py b/openfisca_core/warnings/libyaml_warning.py index 7bbf1a5610..7ea797b667 100644 --- a/openfisca_core/warnings/libyaml_warning.py +++ b/openfisca_core/warnings/libyaml_warning.py @@ -1,6 +1,2 @@ class LibYAMLWarning(UserWarning): - """ - Custom warning for LibYAML not installed. - """ - - pass + """Custom warning for LibYAML not installed.""" diff --git a/openfisca_core/warnings/memory_warning.py b/openfisca_core/warnings/memory_warning.py index ef4bcf28af..23e82bf3e0 100644 --- a/openfisca_core/warnings/memory_warning.py +++ b/openfisca_core/warnings/memory_warning.py @@ -1,6 +1,2 @@ class MemoryConfigWarning(UserWarning): - """ - Custom warning for MemoryConfig. - """ - - pass + """Custom warning for MemoryConfig.""" diff --git a/openfisca_core/warnings/tempfile_warning.py b/openfisca_core/warnings/tempfile_warning.py index 433cf54772..9f4aad3820 100644 --- a/openfisca_core/warnings/tempfile_warning.py +++ b/openfisca_core/warnings/tempfile_warning.py @@ -1,6 +1,2 @@ class TempfileWarning(UserWarning): - """ - Custom warning when using a tempfile on disk. - """ - - pass + """Custom warning when using a tempfile on disk.""" diff --git a/openfisca_web_api/app.py b/openfisca_web_api/app.py index b4682b17e7..a76f255a0c 100644 --- a/openfisca_web_api/app.py +++ b/openfisca_web_api/app.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import logging import os import traceback @@ -31,7 +29,7 @@ def init_tracker(url, idsite, tracker_token): "You chose to activate the `tracker` module. ", "Tracking data will be sent to: " + url, "For more information, see .", - ] + ], ) log.info(info) return tracker @@ -42,9 +40,9 @@ def init_tracker(url, idsite, tracker_token): traceback.format_exc(), "You chose to activate the `tracker` module, but it is not installed.", "For more information, see .", - ] + ], ) - log.warn(message) + log.warning(message) def create_app( @@ -77,6 +75,7 @@ def create_app( def before_request(): if request.path != "/" and request.path.endswith("/"): return redirect(request.path[:-1]) + return None @app.route("/") def get_root(): @@ -84,8 +83,8 @@ def get_root(): jsonify( { "welcome": welcome_message - or DEFAULT_WELCOME_MESSAGE.format(request.host_url) - } + or DEFAULT_WELCOME_MESSAGE.format(request.host_url), + }, ), 300, ) @@ -95,7 +94,7 @@ def get_parameters(): parameters = { parameter["id"]: { "description": parameter["description"], - "href": "{}parameter/{}".format(request.host_url, name), + "href": f"{request.host_url}parameter/{name}", } for name, parameter in data["parameters"].items() if parameter.get("subparams") @@ -120,7 +119,7 @@ def get_variables(): variables = { name: { "description": variable["description"], - "href": "{}variable/{}".format(request.host_url, name), + "href": f"{request.host_url}variable/{name}", } for name, variable in data["variables"].items() } @@ -146,15 +145,15 @@ def get_spec(): return jsonify( { **data["openAPI_spec"], - **{"servers": [{"url": url}]}, - } + "servers": [{"url": url}], + }, ) - def handle_invalid_json(error): + def handle_invalid_json(error) -> None: json_response = jsonify( { - "error": "Invalid JSON: {}".format(error.args[0]), - } + "error": f"Invalid JSON: {error.args[0]}", + }, ) abort(make_response(json_response, 400)) @@ -173,7 +172,7 @@ def calculate(): make_response( jsonify({"error": "'" + e[1] + "' is not a valid ASCII value."}), 400, - ) + ), ) return jsonify(result) @@ -194,7 +193,7 @@ def apply_headers(response): { "Country-Package": data["country_package_metadata"]["name"], "Country-Package-Version": data["country_package_metadata"]["version"], - } + }, ) return response diff --git a/openfisca_web_api/errors.py b/openfisca_web_api/errors.py index ba804a7b08..ac93ebd833 100644 --- a/openfisca_web_api/errors.py +++ b/openfisca_web_api/errors.py @@ -1,13 +1,12 @@ -# -*- coding: utf-8 -*- +from typing import NoReturn import logging log = logging.getLogger("gunicorn.error") -def handle_import_error(error): +def handle_import_error(error) -> NoReturn: + msg = f"OpenFisca is missing some dependencies to run the Web API: '{error}'. To install them, run `pip install openfisca_core[web-api]`." raise ImportError( - "OpenFisca is missing some dependencies to run the Web API: '{}'. To install them, run `pip install openfisca_core[web-api]`.".format( - error - ) + msg, ) diff --git a/openfisca_web_api/handlers.py b/openfisca_web_api/handlers.py index a336a490b0..2f6fc4403a 100644 --- a/openfisca_web_api/handlers.py +++ b/openfisca_web_api/handlers.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import dpath.util from openfisca_core.indexed_enums import Enum @@ -7,12 +5,13 @@ def calculate(tax_benefit_system, input_data: dict) -> dict: - """ - Returns the input_data where the None values are replaced by the calculated values. - """ + """Returns the input_data where the None values are replaced by the calculated values.""" simulation = SimulationBuilder().build_from_entities(tax_benefit_system, input_data) requested_computations = dpath.util.search( - input_data, "*/*/*/*", afilter=lambda t: t is None, yielded=True + input_data, + "*/*/*/*", + afilter=lambda t: t is None, + yielded=True, ) computation_results: dict = {} for computation in requested_computations: @@ -29,7 +28,7 @@ def calculate(tax_benefit_system, input_data: dict) -> dict: entity_result = result.decode()[entity_index].name elif variable.value_type == float: entity_result = float( - str(result[entity_index]) + str(result[entity_index]), ) # To turn the float32 into a regular float without adding confusing extra decimals. There must be a better way. elif variable.value_type == str: entity_result = str(result[entity_index]) @@ -40,27 +39,26 @@ def calculate(tax_benefit_system, input_data: dict) -> dict: # See https://github.com/dpath-maintainers/dpath-python/issues/160 if computation_results == {}: computation_results = { - entity_plural: {entity_id: {variable_name: {period: entity_result}}} + entity_plural: {entity_id: {variable_name: {period: entity_result}}}, } - else: - if entity_plural in computation_results: - if entity_id in computation_results[entity_plural]: - if variable_name in computation_results[entity_plural][entity_id]: - computation_results[entity_plural][entity_id][variable_name][ - period - ] = entity_result - else: - computation_results[entity_plural][entity_id][variable_name] = { - period: entity_result - } + elif entity_plural in computation_results: + if entity_id in computation_results[entity_plural]: + if variable_name in computation_results[entity_plural][entity_id]: + computation_results[entity_plural][entity_id][variable_name][ + period + ] = entity_result else: - computation_results[entity_plural][entity_id] = { - variable_name: {period: entity_result} + computation_results[entity_plural][entity_id][variable_name] = { + period: entity_result, } else: - computation_results[entity_plural] = { - entity_id: {variable_name: {period: entity_result}} + computation_results[entity_plural][entity_id] = { + variable_name: {period: entity_result}, } + else: + computation_results[entity_plural] = { + entity_id: {variable_name: {period: entity_result}}, + } dpath.util.merge(input_data, computation_results) return input_data @@ -72,12 +70,15 @@ def trace(tax_benefit_system, input_data): requested_calculations = [] requested_computations = dpath.util.search( - input_data, "*/*/*/*", afilter=lambda t: t is None, yielded=True + input_data, + "*/*/*/*", + afilter=lambda t: t is None, + yielded=True, ) for computation in requested_computations: path = computation[0] entity_plural, entity_id, variable_name, period = path.split("/") - requested_calculations.append(f"{variable_name}<{str(period)}>") + requested_calculations.append(f"{variable_name}<{period!s}>") simulation.calculate(variable_name, period) trace = simulation.tracer.get_serialized_flat_trace() diff --git a/openfisca_web_api/loader/__init__.py b/openfisca_web_api/loader/__init__.py index fcea068e21..8d9318d9ae 100644 --- a/openfisca_web_api/loader/__init__.py +++ b/openfisca_web_api/loader/__init__.py @@ -1,6 +1,3 @@ -# -*- coding: utf-8 -*- - - from openfisca_web_api.loader.entities import build_entities from openfisca_web_api.loader.parameters import build_parameters from openfisca_web_api.loader.spec import build_openAPI_specification diff --git a/openfisca_web_api/loader/entities.py b/openfisca_web_api/loader/entities.py index 683537aa0e..98ce4e6fb9 100644 --- a/openfisca_web_api/loader/entities.py +++ b/openfisca_web_api/loader/entities.py @@ -1,11 +1,5 @@ -# -*- coding: utf-8 -*- - - def build_entities(tax_benefit_system): - entities = { - entity.key: build_entity(entity) for entity in tax_benefit_system.entities - } - return entities + return {entity.key: build_entity(entity) for entity in tax_benefit_system.entities} def build_entity(entity): diff --git a/openfisca_web_api/loader/parameters.py b/openfisca_web_api/loader/parameters.py index 8841f7ebe8..193f12915f 100644 --- a/openfisca_web_api/loader/parameters.py +++ b/openfisca_web_api/loader/parameters.py @@ -1,4 +1,5 @@ -# -*- coding: utf-8 -*- +import functools +import operator from openfisca_core.parameters import Parameter, ParameterNode, Scale @@ -24,8 +25,7 @@ def get_value(date, values): if candidates: return candidates[0][1] - else: - return None + return None def build_api_scale(scale, value_key_name): @@ -39,13 +39,14 @@ def build_api_scale(scale, value_key_name): ] dates = set( - sum( + functools.reduce( + operator.iadd, [ list(bracket["thresholds"].keys()) + list(bracket["values"].keys()) for bracket in brackets ], [], - ) + ), ) # flatten the dates and remove duplicates # We iterate on all dates as we need to build the whole scale for each of them @@ -87,7 +88,8 @@ def build_api_parameter(parameter, country_package_metadata): } if parameter.file_path: api_parameter["source"] = build_source_url( - parameter.file_path, country_package_metadata + parameter.file_path, + country_package_metadata, ) if isinstance(parameter, Parameter): if parameter.documentation: @@ -113,7 +115,8 @@ def build_api_parameter(parameter, country_package_metadata): def build_parameters(tax_benefit_system, country_package_metadata): return { parameter.name.replace(".", "/"): build_api_parameter( - parameter, country_package_metadata + parameter, + country_package_metadata, ) for parameter in tax_benefit_system.parameters.get_descendants() } diff --git a/openfisca_web_api/loader/spec.py b/openfisca_web_api/loader/spec.py index 335317d2fe..4a163bd91f 100644 --- a/openfisca_web_api/loader/spec.py +++ b/openfisca_web_api/loader/spec.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import os from copy import deepcopy @@ -10,13 +8,15 @@ from openfisca_web_api import handlers OPEN_API_CONFIG_FILE = os.path.join( - os.path.dirname(os.path.abspath(__file__)), os.path.pardir, "openAPI.yml" + os.path.dirname(os.path.abspath(__file__)), + os.path.pardir, + "openAPI.yml", ) def build_openAPI_specification(api_data): tax_benefit_system = api_data["tax_benefit_system"] - file = open(OPEN_API_CONFIG_FILE, "r") + file = open(OPEN_API_CONFIG_FILE) spec = yaml.safe_load(file) country_package_name = api_data["country_package_metadata"]["name"].title() country_package_version = api_data["country_package_metadata"]["version"] @@ -29,21 +29,24 @@ def build_openAPI_specification(api_data): spec, "info/description", spec["info"]["description"].replace( - "{COUNTRY_PACKAGE_NAME}", country_package_name + "{COUNTRY_PACKAGE_NAME}", + country_package_name, ), ) dpath.util.new( spec, "info/version", spec["info"]["version"].replace( - "{COUNTRY_PACKAGE_VERSION}", country_package_version + "{COUNTRY_PACKAGE_VERSION}", + country_package_version, ), ) for entity in tax_benefit_system.entities: name = entity.key.title() spec["components"]["schemas"][name] = get_entity_json_schema( - entity, tax_benefit_system + entity, + tax_benefit_system, ) situation_schema = get_situation_json_schema(tax_benefit_system) @@ -79,7 +82,9 @@ def build_openAPI_specification(api_data): if tax_benefit_system.open_api_config.get("simulation_example"): simulation_example = tax_benefit_system.open_api_config["simulation_example"] dpath.util.new( - spec, "components/schemas/SituationInput/example", simulation_example + spec, + "components/schemas/SituationInput/example", + simulation_example, ) dpath.util.new( spec, @@ -92,9 +97,7 @@ def build_openAPI_specification(api_data): handlers.trace(tax_benefit_system, simulation_example), ) else: - message = "No simulation example has been defined for this tax and benefit system. If you are the maintainer of {}, you can define an example by following this documentation: https://openfisca.org/doc/openfisca-web-api/config-openapi.html".format( - country_package_name - ) + message = f"No simulation example has been defined for this tax and benefit system. If you are the maintainer of {country_package_name}, you can define an example by following this documentation: https://openfisca.org/doc/openfisca-web-api/config-openapi.html" dpath.util.new(spec, "components/schemas/SituationInput/example", message) dpath.util.new(spec, "components/schemas/SituationOutput/example", message) dpath.util.new(spec, "components/schemas/Trace/example", message) @@ -122,32 +125,31 @@ def get_entity_json_schema(entity, tax_benefit_system): "properties": { variable_name: get_variable_json_schema(variable) for variable_name, variable in tax_benefit_system.get_variables( - entity + entity, ).items() }, "additionalProperties": False, } - else: - properties = {} - properties.update( - { - role.plural or role.key: {"type": "array", "items": {"type": "string"}} - for role in entity.roles - } - ) - properties.update( - { - variable_name: get_variable_json_schema(variable) - for variable_name, variable in tax_benefit_system.get_variables( - entity - ).items() - } - ) - return { - "type": "object", - "properties": properties, - "additionalProperties": False, - } + properties = {} + properties.update( + { + role.plural or role.key: {"type": "array", "items": {"type": "string"}} + for role in entity.roles + }, + ) + properties.update( + { + variable_name: get_variable_json_schema(variable) + for variable_name, variable in tax_benefit_system.get_variables( + entity, + ).items() + }, + ) + return { + "type": "object", + "properties": properties, + "additionalProperties": False, + } def get_situation_json_schema(tax_benefit_system): @@ -158,7 +160,7 @@ def get_situation_json_schema(tax_benefit_system): entity.plural: { "type": "object", "additionalProperties": { - "$ref": "#/components/schemas/{}".format(entity.key.title()) + "$ref": f"#/components/schemas/{entity.key.title()}", }, } for entity in tax_benefit_system.entities diff --git a/openfisca_web_api/loader/tax_benefit_system.py b/openfisca_web_api/loader/tax_benefit_system.py index 3cbd0edb81..358f960501 100644 --- a/openfisca_web_api/loader/tax_benefit_system.py +++ b/openfisca_web_api/loader/tax_benefit_system.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import importlib import logging import traceback @@ -15,15 +13,15 @@ def build_tax_benefit_system(country_package_name): message = linesep.join( [ traceback.format_exc(), - "Could not import module `{}`.".format(country_package_name), + f"Could not import module `{country_package_name}`.", "Are you sure it is installed in your environment? If so, look at the stack trace above to determine the origin of this error.", "See more at .", linesep, - ] + ], ) raise ValueError(message) try: return country_package.CountryTaxBenefitSystem() except NameError: # Gunicorn swallows NameErrors. Force printing the stack trace. - log.error(traceback.format_exc()) + log.exception(traceback.format_exc()) raise diff --git a/openfisca_web_api/loader/variables.py b/openfisca_web_api/loader/variables.py index f9b6e05887..6730dc0811 100644 --- a/openfisca_web_api/loader/variables.py +++ b/openfisca_web_api/loader/variables.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import datetime import inspect import textwrap @@ -26,7 +24,10 @@ def get_default_value(variable): def build_source_url( - country_package_metadata, source_file_path, start_line_number, source_code + country_package_metadata, + source_file_path, + start_line_number, + source_code, ): nb_lines = source_code.count("\n") return "{}/blob/{}{}#L{}-L{}".format( @@ -45,7 +46,10 @@ def build_formula(formula, country_package_metadata, source_file_path): api_formula = { "source": build_source_url( - country_package_metadata, source_file_path, start_line_number, source_code + country_package_metadata, + source_file_path, + start_line_number, + source_code, ), "content": source_code, } @@ -80,7 +84,10 @@ def build_variable(variable, country_package_metadata): if source_code: result["source"] = build_source_url( - country_package_metadata, source_file_path, start_line_number, source_code + country_package_metadata, + source_file_path, + start_line_number, + source_code, ) if variable.documentation: @@ -91,7 +98,9 @@ def build_variable(variable, country_package_metadata): if len(variable.formulas) > 0: result["formulas"] = build_formulas( - variable.formulas, country_package_metadata, source_file_path + variable.formulas, + country_package_metadata, + source_file_path, ) if variable.end: diff --git a/openfisca_web_api/scripts/serve.py b/openfisca_web_api/scripts/serve.py index ea594d9d6c..6ba89f440a 100644 --- a/openfisca_web_api/scripts/serve.py +++ b/openfisca_web_api/scripts/serve.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import logging import sys @@ -33,7 +31,7 @@ def read_user_configuration(default_configuration, command_line_parser): if args.configuration_file: file_configuration = {} - with open(args.configuration_file, "r") as file: + with open(args.configuration_file) as file: exec(file.read(), {}, file_configuration) # Configuration file overloads default configuration @@ -43,7 +41,8 @@ def read_user_configuration(default_configuration, command_line_parser): gunicorn_parser = config.Config().parser() configuration = update(configuration, vars(args)) configuration = update( - configuration, vars(gunicorn_parser.parse_args(unknown_args)) + configuration, + vars(gunicorn_parser.parse_args(unknown_args)), ) if configuration["args"]: command_line_parser.print_help() @@ -59,17 +58,17 @@ def update(configuration, new_options): configuration[key] = value if key == "port": configuration["bind"] = configuration["bind"][:-4] + str( - configuration["port"] + configuration["port"], ) return configuration class OpenFiscaWebAPIApplication(BaseApplication): - def __init__(self, options): + def __init__(self, options) -> None: self.options = options - super(OpenFiscaWebAPIApplication, self).__init__() + super().__init__() - def load_config(self): + def load_config(self) -> None: for key, value in self.options.items(): if key in self.cfg.settings: self.cfg.set(key.lower(), value) @@ -89,10 +88,10 @@ def load(self): ) -def main(parser): +def main(parser) -> None: configuration = { "port": DEFAULT_PORT, - "bind": "{}:{}".format(HOST, DEFAULT_PORT), + "bind": f"{HOST}:{DEFAULT_PORT}", "workers": DEFAULT_WORKERS_NUMBER, "timeout": DEFAULT_TIMEOUT, } diff --git a/setup.cfg b/setup.cfg index cc850c06a1..596ce99153 100644 --- a/setup.cfg +++ b/setup.cfg @@ -11,11 +11,15 @@ convention = google docstring_style = google extend-ignore = D -ignore = E203, E501, F405, RST301, W503 +ignore = B019, E203, E501, F405, E701, E704, RST212, RST301, W503 in-place = true include-in-doctest = openfisca_core/commons openfisca_core/entities openfisca_core/holders openfisca_core/periods openfisca_core/projectors max-line-length = 88 -per-file-ignores = */types.py:D101,D102,E704, */test_*.py:D101,D102,D103, */__init__.py:F401 +per-file-ignores = + */types.py:D101,D102,E704 + */test_*.py:D101,D102,D103 + */__init__.py:F401 + */__init__.pyi:E302,E704 rst-directives = attribute, deprecated, seealso, versionadded, versionchanged rst-roles = any, attr, class, exc, func, meth, mod, obj strictness = short @@ -33,6 +37,7 @@ score = no [isort] case_sensitive = true +combine_as_imports = true force_alphabetical_sort_within_sections = false group_by_package = true honor_noqa = true @@ -41,6 +46,7 @@ known_first_party = openfisca_core known_openfisca = openfisca_country_template, openfisca_extension_template known_typing = *collections.abc*, *typing*, *typing_extensions* known_types = *types* +multi_line_output = 3 profile = black py_version = 39 sections = FUTURE, TYPING, TYPES, STDLIB, THIRDPARTY, OPENFISCA, FIRSTPARTY, LOCALFOLDER @@ -74,6 +80,7 @@ follow_imports = skip ignore_missing_imports = true implicit_reexport = false install_types = true +mypy_path = stubs non_interactive = true plugins = numpy.typing.mypy_plugin pretty = true diff --git a/setup.py b/setup.py index cca107bee8..2f1c7089bf 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ "numpy >=1.24.2, <1.25", "pendulum >=2.1.2, <3.0.0", "psutil >=5.9.4, <6.0", - "pytest >=7.2.2, <8.0", + "pytest >=8.3.3, <9.0", "sortedcontainers >=2.4.0, <3.0", "typing_extensions >=4.5.0, <5.0", "StrEnum >=0.4.8, <0.5.0", # 3.11.x backport @@ -47,27 +47,30 @@ ] dev_requirements = [ - "black >=23.1.0, <24.0", - "coverage >=6.5.0, <7.0", + "black >=24.8.0, <25.0", + "coverage >=7.6.1, <8.0", "darglint >=1.8.1, <2.0", - "flake8 >=6.0.0, <7.0.0", - "flake8-bugbear >=23.3.23, <24.0", + "flake8 >=7.1.1, <8.0.0", + "flake8-bugbear >=24.8.19, <25.0", "flake8-docstrings >=1.7.0, <2.0", "flake8-print >=5.0.0, <6.0", "flake8-rst-docstrings >=0.3.0, <0.4.0", - "idna >=3.4, <4.0", - "isort >=5.12.0, <6.0", - "mypy >=1.1.1, <2.0", + "idna >=3.10, <4.0", + "isort >=5.13.2, <6.0", + "mypy >=1.11.2, <2.0", "openapi-spec-validator >=0.7.1, <0.8.0", - "pycodestyle >=2.10.0, <3.0", - "pylint >=2.17.1, <3.0", + "pylint >=3.3.1, <4.0", "pylint-per-file-ignores >=1.3.2, <2.0", - "xdoctest >=1.1.1, <2.0", -] + api_requirements + "pyright >=1.1.381, <2.0", + "ruff >=0.6.7, <1.0", + "ruff-lsp >=0.0.57, <1.0", + "xdoctest >=1.2.0, <2.0", + *api_requirements, +] setup( name="OpenFisca-Core", - version="41.5.5", + version="41.5.6", author="OpenFisca Team", author_email="contact@openfisca.org", classifiers=[ @@ -104,7 +107,7 @@ "dev": dev_requirements, "ci": [ "build >=0.10.0, <0.11.0", - "coveralls >=3.3.1, <4.0", + "coveralls >=4.0.1, <5.0", "twine >=5.1.1, <6.0", "wheel >=0.40.0, <0.41.0", ], diff --git a/tests/core/parameter_validation/test_parameter_clone.py b/tests/core/parameter_validation/test_parameter_clone.py index 1c74d861a3..6c77b4bb0b 100644 --- a/tests/core/parameter_validation/test_parameter_clone.py +++ b/tests/core/parameter_validation/test_parameter_clone.py @@ -6,7 +6,7 @@ year = 2016 -def test_clone(): +def test_clone() -> None: path = os.path.join(BASE_DIR, "filesystem_hierarchy") parameters = ParameterNode("", directory_path=path) parameters_at_instant = parameters("2016-01-01") @@ -19,7 +19,7 @@ def test_clone(): assert id(clone.node1.param) != id(parameters.node1.param) -def test_clone_parameter(tax_benefit_system): +def test_clone_parameter(tax_benefit_system) -> None: param = tax_benefit_system.parameters.taxes.income_tax_rate clone = param.clone() @@ -30,7 +30,7 @@ def test_clone_parameter(tax_benefit_system): assert clone.values_list == param.values_list -def test_clone_parameter_node(tax_benefit_system): +def test_clone_parameter_node(tax_benefit_system) -> None: node = tax_benefit_system.parameters.taxes clone = node.clone() @@ -39,7 +39,7 @@ def test_clone_parameter_node(tax_benefit_system): assert clone.children["income_tax_rate"] is not node.children["income_tax_rate"] -def test_clone_scale(tax_benefit_system): +def test_clone_scale(tax_benefit_system) -> None: scale = tax_benefit_system.parameters.taxes.social_security_contribution clone = scale.clone() @@ -47,7 +47,7 @@ def test_clone_scale(tax_benefit_system): assert clone.brackets[0].rate is not scale.brackets[0].rate -def test_deep_edit(tax_benefit_system): +def test_deep_edit(tax_benefit_system) -> None: parameters = tax_benefit_system.parameters clone = parameters.clone() diff --git a/tests/core/parameter_validation/test_parameter_validation.py b/tests/core/parameter_validation/test_parameter_validation.py index f4b8a82d50..d3419312d2 100644 --- a/tests/core/parameter_validation/test_parameter_validation.py +++ b/tests/core/parameter_validation/test_parameter_validation.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import os import pytest @@ -14,7 +12,7 @@ year = 2016 -def check_fails_with_message(file_name, keywords): +def check_fails_with_message(file_name, keywords) -> None: path = os.path.join(BASE_DIR, file_name) + ".yaml" try: load_parameter_file(path, file_name) @@ -65,24 +63,24 @@ def check_fails_with_message(file_name, keywords): ("duplicate_key", {"duplicate"}), ], ) -def test_parsing_errors(test): +def test_parsing_errors(test) -> None: with pytest.raises(ParameterParsingError): check_fails_with_message(*test) -def test_array_type(): +def test_array_type() -> None: path = os.path.join(BASE_DIR, "array_type.yaml") load_parameter_file(path, "array_type") -def test_filesystem_hierarchy(): +def test_filesystem_hierarchy() -> None: path = os.path.join(BASE_DIR, "filesystem_hierarchy") parameters = ParameterNode("", directory_path=path) parameters_at_instant = parameters("2016-01-01") assert parameters_at_instant.node1.param == 1.0 -def test_yaml_hierarchy(): +def test_yaml_hierarchy() -> None: path = os.path.join(BASE_DIR, "yaml_hierarchy") parameters = ParameterNode("", directory_path=path) parameters_at_instant = parameters("2016-01-01") diff --git a/tests/core/parameters_date_indexing/test_date_indexing.py b/tests/core/parameters_date_indexing/test_date_indexing.py index 05bb770823..cefec26648 100644 --- a/tests/core/parameters_date_indexing/test_date_indexing.py +++ b/tests/core/parameters_date_indexing/test_date_indexing.py @@ -16,10 +16,11 @@ def get_message(error): return error.args[0] -def test_on_leaf(): +def test_on_leaf() -> None: parameter_at_instant = parameters.full_rate_required_duration("1995-01-01") birthdate = numpy.array( - ["1930-01-01", "1935-01-01", "1940-01-01", "1945-01-01"], dtype="datetime64[D]" + ["1930-01-01", "1935-01-01", "1940-01-01", "1945-01-01"], + dtype="datetime64[D]", ) assert_near( parameter_at_instant.contribution_quarters_required_by_birthdate[birthdate], @@ -27,9 +28,10 @@ def test_on_leaf(): ) -def test_on_node(): +def test_on_node() -> None: birthdate = numpy.array( - ["1950-01-01", "1953-01-01", "1956-01-01", "1959-01-01"], dtype="datetime64[D]" + ["1950-01-01", "1953-01-01", "1956-01-01", "1959-01-01"], + dtype="datetime64[D]", ) parameter_at_instant = parameters.full_rate_age("2012-03-01") node = parameter_at_instant.full_rate_age_by_birthdate[birthdate] diff --git a/tests/core/parameters_fancy_indexing/test_fancy_indexing.py b/tests/core/parameters_fancy_indexing/test_fancy_indexing.py index 73a4ccb323..b7e7cf4e45 100644 --- a/tests/core/parameters_fancy_indexing/test_fancy_indexing.py +++ b/tests/core/parameters_fancy_indexing/test_fancy_indexing.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - import os import re -import numpy as np +import numpy import pytest from openfisca_core.indexed_enums import Enum @@ -21,27 +19,27 @@ def get_message(error): return error.args[0] -def test_on_leaf(): - zone = np.asarray(["z1", "z2", "z2", "z1"]) +def test_on_leaf() -> None: + zone = numpy.asarray(["z1", "z2", "z2", "z1"]) assert_near(P.single.owner[zone], [100, 200, 200, 100]) -def test_on_node(): - housing_occupancy_status = np.asarray(["owner", "owner", "tenant", "tenant"]) +def test_on_node() -> None: + housing_occupancy_status = numpy.asarray(["owner", "owner", "tenant", "tenant"]) node = P.single[housing_occupancy_status] assert_near(node.z1, [100, 100, 300, 300]) assert_near(node["z1"], [100, 100, 300, 300]) -def test_double_fancy_indexing(): - zone = np.asarray(["z1", "z2", "z2", "z1"]) - housing_occupancy_status = np.asarray(["owner", "owner", "tenant", "tenant"]) +def test_double_fancy_indexing() -> None: + zone = numpy.asarray(["z1", "z2", "z2", "z1"]) + housing_occupancy_status = numpy.asarray(["owner", "owner", "tenant", "tenant"]) assert_near(P.single[housing_occupancy_status][zone], [100, 200, 400, 300]) -def test_double_fancy_indexing_on_node(): - family_status = np.asarray(["single", "couple", "single", "couple"]) - housing_occupancy_status = np.asarray(["owner", "owner", "tenant", "tenant"]) +def test_double_fancy_indexing_on_node() -> None: + family_status = numpy.asarray(["single", "couple", "single", "couple"]) + housing_occupancy_status = numpy.asarray(["owner", "owner", "tenant", "tenant"]) node = P[family_status][housing_occupancy_status] assert_near(node.z1, [100, 500, 300, 700]) assert_near(node["z1"], [100, 500, 300, 700]) @@ -49,28 +47,37 @@ def test_double_fancy_indexing_on_node(): assert_near(node["z2"], [200, 600, 400, 800]) -def test_triple_fancy_indexing(): - family_status = np.asarray( - ["single", "single", "single", "single", "couple", "couple", "couple", "couple"] +def test_triple_fancy_indexing() -> None: + family_status = numpy.asarray( + [ + "single", + "single", + "single", + "single", + "couple", + "couple", + "couple", + "couple", + ], ) - housing_occupancy_status = np.asarray( - ["owner", "owner", "tenant", "tenant", "owner", "owner", "tenant", "tenant"] + housing_occupancy_status = numpy.asarray( + ["owner", "owner", "tenant", "tenant", "owner", "owner", "tenant", "tenant"], ) - zone = np.asarray(["z1", "z2", "z1", "z2", "z1", "z2", "z1", "z2"]) + zone = numpy.asarray(["z1", "z2", "z1", "z2", "z1", "z2", "z1", "z2"]) assert_near( P[family_status][housing_occupancy_status][zone], [100, 200, 300, 400, 500, 600, 700, 800], ) -def test_wrong_key(): - zone = np.asarray(["z1", "z2", "z2", "toto"]) +def test_wrong_key() -> None: + zone = numpy.asarray(["z1", "z2", "z2", "toto"]) with pytest.raises(ParameterNotFound) as e: P.single.owner[zone] assert "'rate.single.owner.toto' was not found" in get_message(e.value) -def test_inhomogenous(): +def test_inhomogenous() -> None: parameters = ParameterNode(directory_path=LOCAL_DIR) parameters.rate.couple.owner.add_child( "toto", @@ -79,20 +86,20 @@ def test_inhomogenous(): { "values": { "2015-01-01": {"value": 1000}, - } + }, }, ), ) P = parameters.rate("2015-01-01") - housing_occupancy_status = np.asarray(["owner", "owner", "tenant", "tenant"]) + housing_occupancy_status = numpy.asarray(["owner", "owner", "tenant", "tenant"]) with pytest.raises(ValueError) as error: P.couple[housing_occupancy_status] assert "'rate.couple.owner.toto' exists" in get_message(error.value) assert "'rate.couple.tenant.toto' doesn't" in get_message(error.value) -def test_inhomogenous_2(): +def test_inhomogenous_2() -> None: parameters = ParameterNode(directory_path=LOCAL_DIR) parameters.rate.couple.tenant.add_child( "toto", @@ -101,20 +108,20 @@ def test_inhomogenous_2(): { "values": { "2015-01-01": {"value": 1000}, - } + }, }, ), ) P = parameters.rate("2015-01-01") - housing_occupancy_status = np.asarray(["owner", "owner", "tenant", "tenant"]) + housing_occupancy_status = numpy.asarray(["owner", "owner", "tenant", "tenant"]) with pytest.raises(ValueError) as e: P.couple[housing_occupancy_status] assert "'rate.couple.tenant.toto' exists" in get_message(e.value) assert "'rate.couple.owner.toto' doesn't" in get_message(e.value) -def test_inhomogenous_3(): +def test_inhomogenous_3() -> None: parameters = ParameterNode(directory_path=LOCAL_DIR) parameters.rate.couple.tenant.add_child( "z4", @@ -125,14 +132,14 @@ def test_inhomogenous_3(): "values": { "2015-01-01": {"value": 550}, "2016-01-01": {"value": 600}, - } - } + }, + }, }, ), ) P = parameters.rate("2015-01-01") - zone = np.asarray(["z1", "z2", "z2", "z1"]) + zone = numpy.asarray(["z1", "z2", "z2", "z1"]) with pytest.raises(ValueError) as e: P.couple.tenant[zone] assert "'rate.couple.tenant.z4' is a node" in get_message(e.value) @@ -142,28 +149,29 @@ def test_inhomogenous_3(): P_2 = parameters.local_tax("2015-01-01") -def test_with_properties_starting_by_number(): - city_code = np.asarray(["75012", "75007", "75015"]) +def test_with_properties_starting_by_number() -> None: + city_code = numpy.asarray(["75012", "75007", "75015"]) assert_near(P_2[city_code], [100, 300, 200]) P_3 = parameters.bareme("2015-01-01") -def test_with_bareme(): - city_code = np.asarray(["75012", "75007", "75015"]) +def test_with_bareme() -> None: + city_code = numpy.asarray(["75012", "75007", "75015"]) with pytest.raises(NotImplementedError) as e: P_3[city_code] assert re.findall( - r"'bareme.7501\d' is a 'MarginalRateTaxScale'", get_message(e.value) + r"'bareme.7501\d' is a 'MarginalRateTaxScale'", + get_message(e.value), ) assert "has not been implemented" in get_message(e.value) -def test_with_enum(): +def test_with_enum() -> None: class TypesZone(Enum): z1 = "Zone 1" z2 = "Zone 2" - zone = np.asarray([TypesZone.z1, TypesZone.z2, TypesZone.z2, TypesZone.z1]) + zone = numpy.asarray([TypesZone.z1, TypesZone.z2, TypesZone.z2, TypesZone.z1]) assert_near(P.single.owner[zone], [100, 200, 200, 100]) diff --git a/tests/core/tax_scales/test_abstract_rate_tax_scale.py b/tests/core/tax_scales/test_abstract_rate_tax_scale.py index ad755075be..c966aa30f3 100644 --- a/tests/core/tax_scales/test_abstract_rate_tax_scale.py +++ b/tests/core/tax_scales/test_abstract_rate_tax_scale.py @@ -3,7 +3,7 @@ from openfisca_core import taxscales -def test_abstract_tax_scale(): +def test_abstract_tax_scale() -> None: with pytest.warns(DeprecationWarning): result = taxscales.AbstractRateTaxScale() assert isinstance(result, taxscales.AbstractRateTaxScale) diff --git a/tests/core/tax_scales/test_abstract_tax_scale.py b/tests/core/tax_scales/test_abstract_tax_scale.py index f1bfc4e4af..aad04d58ed 100644 --- a/tests/core/tax_scales/test_abstract_tax_scale.py +++ b/tests/core/tax_scales/test_abstract_tax_scale.py @@ -3,7 +3,7 @@ from openfisca_core import taxscales -def test_abstract_tax_scale(): +def test_abstract_tax_scale() -> None: with pytest.warns(DeprecationWarning): result = taxscales.AbstractTaxScale() assert isinstance(result, taxscales.AbstractTaxScale) diff --git a/tests/core/tax_scales/test_linear_average_rate_tax_scale.py b/tests/core/tax_scales/test_linear_average_rate_tax_scale.py index 18e6bd5a4a..6205d6de9b 100644 --- a/tests/core/tax_scales/test_linear_average_rate_tax_scale.py +++ b/tests/core/tax_scales/test_linear_average_rate_tax_scale.py @@ -4,7 +4,7 @@ from openfisca_core import taxscales, tools -def test_bracket_indices(): +def test_bracket_indices() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) @@ -16,7 +16,7 @@ def test_bracket_indices(): tools.assert_near(result, [0, 0, 0, 1, 1, 2]) -def test_bracket_indices_with_factor(): +def test_bracket_indices_with_factor() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) @@ -28,7 +28,7 @@ def test_bracket_indices_with_factor(): tools.assert_near(result, [0, 0, 0, 0, 1, 1]) -def test_bracket_indices_with_round_decimals(): +def test_bracket_indices_with_round_decimals() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) @@ -40,7 +40,7 @@ def test_bracket_indices_with_round_decimals(): tools.assert_near(result, [0, 0, 1, 1, 2, 2]) -def test_bracket_indices_without_tax_base(): +def test_bracket_indices_without_tax_base() -> None: tax_base = numpy.array([]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) @@ -51,7 +51,7 @@ def test_bracket_indices_without_tax_base(): tax_scale.bracket_indices(tax_base) -def test_bracket_indices_without_brackets(): +def test_bracket_indices_without_brackets() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() @@ -59,7 +59,7 @@ def test_bracket_indices_without_brackets(): tax_scale.bracket_indices(tax_base) -def test_to_dict(): +def test_to_dict() -> None: tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) tax_scale.add_bracket(100, 0.1) @@ -69,7 +69,7 @@ def test_to_dict(): assert result == {"0": 0.0, "100": 0.1} -def test_to_marginal(): +def test_to_marginal() -> None: tax_base = numpy.array([1, 1.5, 2, 2.5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) diff --git a/tests/core/tax_scales/test_marginal_amount_tax_scale.py b/tests/core/tax_scales/test_marginal_amount_tax_scale.py index e00a8371c4..0a3275c901 100644 --- a/tests/core/tax_scales/test_marginal_amount_tax_scale.py +++ b/tests/core/tax_scales/test_marginal_amount_tax_scale.py @@ -15,12 +15,12 @@ def data(): "amount": { "2017-10-01": {"value": 6}, }, - } + }, ], } -def test_calc(): +def test_calc() -> None: tax_base = array([1, 8, 10]) tax_scale = taxscales.MarginalAmountTaxScale() tax_scale.add_bracket(6, 0.23) @@ -32,7 +32,7 @@ def test_calc(): # TODO: move, as we're testing Scale, not MarginalAmountTaxScale -def test_dispatch_scale_type_on_creation(data): +def test_dispatch_scale_type_on_creation(data) -> None: scale = parameters.Scale("amount_scale", data, "") first_jan = periods.Instant((2017, 11, 1)) diff --git a/tests/core/tax_scales/test_marginal_rate_tax_scale.py b/tests/core/tax_scales/test_marginal_rate_tax_scale.py index 3ed4a3f12f..7696e95fc4 100644 --- a/tests/core/tax_scales/test_marginal_rate_tax_scale.py +++ b/tests/core/tax_scales/test_marginal_rate_tax_scale.py @@ -4,7 +4,7 @@ from openfisca_core import taxscales, tools -def test_bracket_indices(): +def test_bracket_indices() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) @@ -16,7 +16,7 @@ def test_bracket_indices(): tools.assert_near(result, [0, 0, 0, 1, 1, 2]) -def test_bracket_indices_with_factor(): +def test_bracket_indices_with_factor() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) @@ -28,7 +28,7 @@ def test_bracket_indices_with_factor(): tools.assert_near(result, [0, 0, 0, 0, 1, 1]) -def test_bracket_indices_with_round_decimals(): +def test_bracket_indices_with_round_decimals() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) @@ -40,7 +40,7 @@ def test_bracket_indices_with_round_decimals(): tools.assert_near(result, [0, 0, 1, 1, 2, 2]) -def test_bracket_indices_without_tax_base(): +def test_bracket_indices_without_tax_base() -> None: tax_base = numpy.array([]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) @@ -51,7 +51,7 @@ def test_bracket_indices_without_tax_base(): tax_scale.bracket_indices(tax_base) -def test_bracket_indices_without_brackets(): +def test_bracket_indices_without_brackets() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() @@ -59,7 +59,7 @@ def test_bracket_indices_without_brackets(): tax_scale.bracket_indices(tax_base) -def test_to_dict(): +def test_to_dict() -> None: tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) tax_scale.add_bracket(100, 0.1) @@ -69,7 +69,7 @@ def test_to_dict(): assert result == {"0": 0.0, "100": 0.1} -def test_calc(): +def test_calc() -> None: tax_base = numpy.array([1, 1.5, 2, 2.5, 3.0, 4.0]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -86,7 +86,7 @@ def test_calc(): ) -def test_calc_without_round(): +def test_calc_without_round() -> None: tax_base = numpy.array([200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -101,7 +101,7 @@ def test_calc_without_round(): ) -def test_calc_when_round_is_1(): +def test_calc_when_round_is_1() -> None: tax_base = numpy.array([200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -116,7 +116,7 @@ def test_calc_when_round_is_1(): ) -def test_calc_when_round_is_2(): +def test_calc_when_round_is_2() -> None: tax_base = numpy.array([200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -131,7 +131,7 @@ def test_calc_when_round_is_2(): ) -def test_calc_when_round_is_3(): +def test_calc_when_round_is_3() -> None: tax_base = numpy.array([200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -146,7 +146,7 @@ def test_calc_when_round_is_3(): ) -def test_marginal_rates(): +def test_marginal_rates() -> None: tax_base = numpy.array([0, 10, 50, 125, 250]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -158,7 +158,7 @@ def test_marginal_rates(): tools.assert_near(result, [0, 0, 0, 0.1, 0.2]) -def test_inverse(): +def test_inverse() -> None: gross_tax_base = numpy.array([1, 2, 3, 4, 5, 6]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -171,7 +171,7 @@ def test_inverse(): tools.assert_near(result.calc(net_tax_base), gross_tax_base, 1e-15) -def test_scale_tax_scales(): +def test_scale_tax_scales() -> None: tax_base = numpy.array([1, 2, 3]) tax_base_scale = 12.345 scaled_tax_base = tax_base * tax_base_scale @@ -185,7 +185,7 @@ def test_scale_tax_scales(): tools.assert_near(result.thresholds, scaled_tax_base) -def test_inverse_scaled_marginal_tax_scales(): +def test_inverse_scaled_marginal_tax_scales() -> None: gross_tax_base = numpy.array([1, 2, 3, 4, 5, 6]) gross_tax_base_scale = 12.345 scaled_gross_tax_base = gross_tax_base * gross_tax_base_scale @@ -195,7 +195,7 @@ def test_inverse_scaled_marginal_tax_scales(): tax_scale.add_bracket(3, 0.05) scaled_tax_scale = tax_scale.scale_tax_scales(gross_tax_base_scale) scaled_net_tax_base = +scaled_gross_tax_base - scaled_tax_scale.calc( - scaled_gross_tax_base + scaled_gross_tax_base, ) result = scaled_tax_scale.inverse() @@ -203,7 +203,7 @@ def test_inverse_scaled_marginal_tax_scales(): tools.assert_near(result.calc(scaled_net_tax_base), scaled_gross_tax_base, 1e-13) -def test_to_average(): +def test_to_average() -> None: tax_base = numpy.array([1, 1.5, 2, 2.5]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -222,7 +222,7 @@ def test_to_average(): ) -def test_rate_from_bracket_indice(): +def test_rate_from_bracket_indice() -> None: tax_base = numpy.array([0, 1_000, 1_500, 50_000]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -236,7 +236,7 @@ def test_rate_from_bracket_indice(): assert (result == numpy.array([0.0, 0.1, 0.1, 0.4])).all() -def test_rate_from_tax_base(): +def test_rate_from_tax_base() -> None: tax_base = numpy.array([0, 3_000, 15_500, 500_000]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) diff --git a/tests/core/tax_scales/test_rate_tax_scale_like.py b/tests/core/tax_scales/test_rate_tax_scale_like.py index 075fc802d2..9f5bc61286 100644 --- a/tests/core/tax_scales/test_rate_tax_scale_like.py +++ b/tests/core/tax_scales/test_rate_tax_scale_like.py @@ -3,7 +3,7 @@ from openfisca_core import taxscales -def test_threshold_from_tax_base(): +def test_threshold_from_tax_base() -> None: tax_base = numpy.array([0, 33_000, 500, 400_000]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) diff --git a/tests/core/tax_scales/test_single_amount_tax_scale.py b/tests/core/tax_scales/test_single_amount_tax_scale.py index ffcd32e092..2b384f6374 100644 --- a/tests/core/tax_scales/test_single_amount_tax_scale.py +++ b/tests/core/tax_scales/test_single_amount_tax_scale.py @@ -19,12 +19,12 @@ def data(): "amount": { "2017-10-01": {"value": 6}, }, - } + }, ], } -def test_calc(): +def test_calc() -> None: tax_base = numpy.array([1, 8, 10]) tax_scale = taxscales.SingleAmountTaxScale() tax_scale.add_bracket(6, 0.23) @@ -35,7 +35,7 @@ def test_calc(): tools.assert_near(result, [0, 0.23, 0.29]) -def test_to_dict(): +def test_to_dict() -> None: tax_scale = taxscales.SingleAmountTaxScale() tax_scale.add_bracket(6, 0.23) tax_scale.add_bracket(9, 0.29) @@ -46,7 +46,7 @@ def test_to_dict(): # TODO: move, as we're testing Scale, not SingleAmountTaxScale -def test_assign_thresholds_on_creation(data): +def test_assign_thresholds_on_creation(data) -> None: scale = parameters.Scale("amount_scale", data, "") first_jan = periods.Instant((2017, 11, 1)) scale_at_instant = scale.get_at_instant(first_jan) @@ -57,7 +57,7 @@ def test_assign_thresholds_on_creation(data): # TODO: move, as we're testing Scale, not SingleAmountTaxScale -def test_assign_amounts_on_creation(data): +def test_assign_amounts_on_creation(data) -> None: scale = parameters.Scale("amount_scale", data, "") first_jan = periods.Instant((2017, 11, 1)) scale_at_instant = scale.get_at_instant(first_jan) @@ -68,7 +68,7 @@ def test_assign_amounts_on_creation(data): # TODO: move, as we're testing Scale, not SingleAmountTaxScale -def test_dispatch_scale_type_on_creation(data): +def test_dispatch_scale_type_on_creation(data) -> None: scale = parameters.Scale("amount_scale", data, "") first_jan = periods.Instant((2017, 11, 1)) diff --git a/tests/core/tax_scales/test_tax_scales_commons.py b/tests/core/tax_scales/test_tax_scales_commons.py index e4426cd49d..544e5a07fe 100644 --- a/tests/core/tax_scales/test_tax_scales_commons.py +++ b/tests/core/tax_scales/test_tax_scales_commons.py @@ -12,19 +12,19 @@ def node(): "brackets": [ {"rate": {"2015-01-01": 0.05}, "threshold": {"2015-01-01": 0}}, {"rate": {"2015-01-01": 0.10}, "threshold": {"2015-01-01": 2000}}, - ] + ], }, "retirement": { "brackets": [ {"rate": {"2015-01-01": 0.02}, "threshold": {"2015-01-01": 0}}, {"rate": {"2015-01-01": 0.04}, "threshold": {"2015-01-01": 3000}}, - ] + ], }, }, )(2015) -def test_combine_tax_scales(node): +def test_combine_tax_scales(node) -> None: result = taxscales.combine_tax_scales(node) tools.assert_near(result.thresholds, [0, 2000, 3000]) diff --git a/tests/core/test_axes.py b/tests/core/test_axes.py index 799439e9c4..11590daf51 100644 --- a/tests/core/test_axes.py +++ b/tests/core/test_axes.py @@ -7,46 +7,47 @@ # With periods -def test_add_axis_without_period(persons): +def test_add_axis_without_period(persons) -> None: simulation_builder = SimulationBuilder() simulation_builder.set_default_period("2018-11") simulation_builder.add_person_entity(persons, {"Alicia": {}}) simulation_builder.register_variable("salary", persons) simulation_builder.add_parallel_axis( - {"count": 3, "name": "salary", "min": 0, "max": 3000} + {"count": 3, "name": "salary", "min": 0, "max": 3000}, ) simulation_builder.expand_axes() assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( - [0, 1500, 3000] + [0, 1500, 3000], ) # With variables -def test_add_axis_on_a_non_existing_variable(persons): +def test_add_axis_on_a_non_existing_variable(persons) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, {"Alicia": {}}) simulation_builder.add_parallel_axis( - {"count": 3, "name": "ubi", "min": 0, "max": 3000, "period": "2018-11"} + {"count": 3, "name": "ubi", "min": 0, "max": 3000, "period": "2018-11"}, ) with pytest.raises(KeyError): simulation_builder.expand_axes() -def test_add_axis_on_an_existing_variable_with_input(persons): +def test_add_axis_on_an_existing_variable_with_input(persons) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity( - persons, {"Alicia": {"salary": {"2018-11": 1000}}} + persons, + {"Alicia": {"salary": {"2018-11": 1000}}}, ) simulation_builder.register_variable("salary", persons) simulation_builder.add_parallel_axis( - {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"} + {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, ) simulation_builder.expand_axes() assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( - [0, 1500, 3000] + [0, 1500, 3000], ) assert simulation_builder.get_count("persons") == 3 assert simulation_builder.get_ids("persons") == ["Alicia0", "Alicia1", "Alicia2"] @@ -55,46 +56,46 @@ def test_add_axis_on_an_existing_variable_with_input(persons): # With entities -def test_add_axis_on_persons(persons): +def test_add_axis_on_persons(persons) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, {"Alicia": {}}) simulation_builder.register_variable("salary", persons) simulation_builder.add_parallel_axis( - {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"} + {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, ) simulation_builder.expand_axes() assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( - [0, 1500, 3000] + [0, 1500, 3000], ) assert simulation_builder.get_count("persons") == 3 assert simulation_builder.get_ids("persons") == ["Alicia0", "Alicia1", "Alicia2"] -def test_add_two_axes(persons): +def test_add_two_axes(persons) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, {"Alicia": {}}) simulation_builder.register_variable("salary", persons) simulation_builder.add_parallel_axis( - {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"} + {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, ) simulation_builder.add_parallel_axis( - {"count": 3, "name": "pension", "min": 0, "max": 2000, "period": "2018-11"} + {"count": 3, "name": "pension", "min": 0, "max": 2000, "period": "2018-11"}, ) simulation_builder.expand_axes() assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( - [0, 1500, 3000] + [0, 1500, 3000], ) assert simulation_builder.get_input("pension", "2018-11") == pytest.approx( - [0, 1000, 2000] + [0, 1000, 2000], ) -def test_add_axis_with_group(persons): +def test_add_axis_with_group(persons) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, {"Alicia": {}, "Javier": {}}) simulation_builder.register_variable("salary", persons) simulation_builder.add_parallel_axis( - {"count": 2, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"} + {"count": 2, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, ) simulation_builder.add_parallel_axis( { @@ -104,7 +105,7 @@ def test_add_axis_with_group(persons): "max": 3000, "period": "2018-11", "index": 1, - } + }, ) simulation_builder.expand_axes() assert simulation_builder.get_count("persons") == 4 @@ -115,16 +116,16 @@ def test_add_axis_with_group(persons): "Javier3", ] assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( - [0, 0, 3000, 3000] + [0, 0, 3000, 3000], ) -def test_add_axis_with_group_int_period(persons): +def test_add_axis_with_group_int_period(persons) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, {"Alicia": {}, "Javier": {}}) simulation_builder.register_variable("salary", persons) simulation_builder.add_parallel_axis( - {"count": 2, "name": "salary", "min": 0, "max": 3000, "period": 2018} + {"count": 2, "name": "salary", "min": 0, "max": 3000, "period": 2018}, ) simulation_builder.add_parallel_axis( { @@ -134,18 +135,19 @@ def test_add_axis_with_group_int_period(persons): "max": 3000, "period": 2018, "index": 1, - } + }, ) simulation_builder.expand_axes() assert simulation_builder.get_input("salary", "2018") == pytest.approx( - [0, 0, 3000, 3000] + [0, 0, 3000, 3000], ) -def test_add_axis_on_households(persons, households): +def test_add_axis_on_households(persons, households) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity( - persons, {"Alicia": {}, "Javier": {}, "Tom": {}} + persons, + {"Alicia": {}, "Javier": {}, "Tom": {}}, ) simulation_builder.add_group_entity( "persons", @@ -158,7 +160,7 @@ def test_add_axis_on_households(persons, households): ) simulation_builder.register_variable("rent", households) simulation_builder.add_parallel_axis( - {"count": 2, "name": "rent", "min": 0, "max": 3000, "period": "2018-11"} + {"count": 2, "name": "rent", "min": 0, "max": 3000, "period": "2018-11"}, ) simulation_builder.expand_axes() assert simulation_builder.get_count("households") == 4 @@ -169,14 +171,15 @@ def test_add_axis_on_households(persons, households): "houseb3", ] assert simulation_builder.get_input("rent", "2018-11") == pytest.approx( - [0, 0, 3000, 0] + [0, 0, 3000, 0], ) -def test_axis_on_group_expands_persons(persons, households): +def test_axis_on_group_expands_persons(persons, households) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity( - persons, {"Alicia": {}, "Javier": {}, "Tom": {}} + persons, + {"Alicia": {}, "Javier": {}, "Tom": {}}, ) simulation_builder.add_group_entity( "persons", @@ -189,16 +192,17 @@ def test_axis_on_group_expands_persons(persons, households): ) simulation_builder.register_variable("rent", households) simulation_builder.add_parallel_axis( - {"count": 2, "name": "rent", "min": 0, "max": 3000, "period": "2018-11"} + {"count": 2, "name": "rent", "min": 0, "max": 3000, "period": "2018-11"}, ) simulation_builder.expand_axes() assert simulation_builder.get_count("persons") == 6 -def test_add_axis_distributes_roles(persons, households): +def test_add_axis_distributes_roles(persons, households) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity( - persons, {"Alicia": {}, "Javier": {}, "Tom": {}} + persons, + {"Alicia": {}, "Javier": {}, "Tom": {}}, ) simulation_builder.add_group_entity( "persons", @@ -211,7 +215,7 @@ def test_add_axis_distributes_roles(persons, households): ) simulation_builder.register_variable("rent", households) simulation_builder.add_parallel_axis( - {"count": 2, "name": "rent", "min": 0, "max": 3000, "period": "2018-11"} + {"count": 2, "name": "rent", "min": 0, "max": 3000, "period": "2018-11"}, ) simulation_builder.expand_axes() assert [role.key for role in simulation_builder.get_roles("households")] == [ @@ -224,10 +228,11 @@ def test_add_axis_distributes_roles(persons, households): ] -def test_add_axis_on_persons_distributes_roles(persons, households): +def test_add_axis_on_persons_distributes_roles(persons, households) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity( - persons, {"Alicia": {}, "Javier": {}, "Tom": {}} + persons, + {"Alicia": {}, "Javier": {}, "Tom": {}}, ) simulation_builder.add_group_entity( "persons", @@ -240,7 +245,7 @@ def test_add_axis_on_persons_distributes_roles(persons, households): ) simulation_builder.register_variable("salary", persons) simulation_builder.add_parallel_axis( - {"count": 2, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"} + {"count": 2, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, ) simulation_builder.expand_axes() assert [role.key for role in simulation_builder.get_roles("households")] == [ @@ -253,10 +258,11 @@ def test_add_axis_on_persons_distributes_roles(persons, households): ] -def test_add_axis_distributes_memberships(persons, households): +def test_add_axis_distributes_memberships(persons, households) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity( - persons, {"Alicia": {}, "Javier": {}, "Tom": {}} + persons, + {"Alicia": {}, "Javier": {}, "Tom": {}}, ) simulation_builder.add_group_entity( "persons", @@ -269,33 +275,33 @@ def test_add_axis_distributes_memberships(persons, households): ) simulation_builder.register_variable("rent", households) simulation_builder.add_parallel_axis( - {"count": 2, "name": "rent", "min": 0, "max": 3000, "period": "2018-11"} + {"count": 2, "name": "rent", "min": 0, "max": 3000, "period": "2018-11"}, ) simulation_builder.expand_axes() assert simulation_builder.get_memberships("households") == [0, 1, 1, 2, 3, 3] -def test_add_perpendicular_axes(persons): +def test_add_perpendicular_axes(persons) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, {"Alicia": {}}) simulation_builder.register_variable("salary", persons) simulation_builder.register_variable("pension", persons) simulation_builder.add_parallel_axis( - {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"} + {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, ) simulation_builder.add_perpendicular_axis( - {"count": 2, "name": "pension", "min": 0, "max": 2000, "period": "2018-11"} + {"count": 2, "name": "pension", "min": 0, "max": 2000, "period": "2018-11"}, ) simulation_builder.expand_axes() assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( - [0, 1500, 3000, 0, 1500, 3000] + [0, 1500, 3000, 0, 1500, 3000], ) assert simulation_builder.get_input("pension", "2018-11") == pytest.approx( - [0, 0, 0, 2000, 2000, 2000] + [0, 0, 0, 2000, 2000, 2000], ) -def test_add_perpendicular_axis_on_an_existing_variable_with_input(persons): +def test_add_perpendicular_axis_on_an_existing_variable_with_input(persons) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_person_entity( persons, @@ -309,24 +315,24 @@ def test_add_perpendicular_axis_on_an_existing_variable_with_input(persons): simulation_builder.register_variable("salary", persons) simulation_builder.register_variable("pension", persons) simulation_builder.add_parallel_axis( - {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"} + {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, ) simulation_builder.add_perpendicular_axis( - {"count": 2, "name": "pension", "min": 0, "max": 2000, "period": "2018-11"} + {"count": 2, "name": "pension", "min": 0, "max": 2000, "period": "2018-11"}, ) simulation_builder.expand_axes() assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( - [0, 1500, 3000, 0, 1500, 3000] + [0, 1500, 3000, 0, 1500, 3000], ) assert simulation_builder.get_input("pension", "2018-11") == pytest.approx( - [0, 0, 0, 2000, 2000, 2000] + [0, 0, 0, 2000, 2000, 2000], ) # Integration tests -def test_simulation_with_axes(tax_benefit_system): +def test_simulation_with_axes(tax_benefit_system) -> None: input_yaml = """ persons: Alicia: {salary: {2018-11: 0}} @@ -348,7 +354,7 @@ def test_simulation_with_axes(tax_benefit_system): data = test_runner.yaml.safe_load(input_yaml) simulation = SimulationBuilder().build_from_dict(tax_benefit_system, data) assert simulation.get_array("salary", "2018-11") == pytest.approx( - [0, 0, 0, 0, 0, 0] + [0, 0, 0, 0, 0, 0], ) assert simulation.get_array("rent", "2018-11") == pytest.approx([0, 0, 3000, 0]) @@ -356,7 +362,7 @@ def test_simulation_with_axes(tax_benefit_system): # Test for missing group entities with build_from_entities() -def test_simulation_with_axes_missing_entities(tax_benefit_system): +def test_simulation_with_axes_missing_entities(tax_benefit_system) -> None: input_yaml = """ persons: Alicia: {salary: {2018-11: 0}} diff --git a/tests/core/test_calculate_output.py b/tests/core/test_calculate_output.py index ecf59b5f7d..54d868ba92 100644 --- a/tests/core/test_calculate_output.py +++ b/tests/core/test_calculate_output.py @@ -29,7 +29,7 @@ class variable_with_calculate_output_divide(Variable): @pytest.fixture(scope="module", autouse=True) -def add_variables_to_tax_benefit_system(tax_benefit_system): +def add_variables_to_tax_benefit_system(tax_benefit_system) -> None: tax_benefit_system.add_variables( simple_variable, variable_with_calculate_output_add, @@ -40,25 +40,27 @@ def add_variables_to_tax_benefit_system(tax_benefit_system): @pytest.fixture def simulation(tax_benefit_system): return SimulationBuilder().build_from_entities( - tax_benefit_system, situation_examples.single + tax_benefit_system, + situation_examples.single, ) -def test_calculate_output_default(simulation): +def test_calculate_output_default(simulation) -> None: with pytest.raises(ValueError): simulation.calculate_output("simple_variable", 2017) -def test_calculate_output_add(simulation): +def test_calculate_output_add(simulation) -> None: simulation.set_input("variable_with_calculate_output_add", "2017-01", [10]) simulation.set_input("variable_with_calculate_output_add", "2017-05", [20]) simulation.set_input("variable_with_calculate_output_add", "2017-12", [70]) tools.assert_near( - simulation.calculate_output("variable_with_calculate_output_add", 2017), 100 + simulation.calculate_output("variable_with_calculate_output_add", 2017), + 100, ) -def test_calculate_output_divide(simulation): +def test_calculate_output_divide(simulation) -> None: simulation.set_input("variable_with_calculate_output_divide", 2017, [12000]) tools.assert_near( simulation.calculate_output("variable_with_calculate_output_divide", "2017-06"), diff --git a/tests/core/test_countries.py b/tests/core/test_countries.py index 8263ac3c44..d206a8cb35 100644 --- a/tests/core/test_countries.py +++ b/tests/core/test_countries.py @@ -10,19 +10,19 @@ @pytest.mark.parametrize("simulation", [({"salary": 2000}, PERIOD)], indirect=True) -def test_input_variable(simulation): +def test_input_variable(simulation) -> None: result = simulation.calculate("salary", PERIOD) tools.assert_near(result, [2000], absolute_error_margin=0.01) @pytest.mark.parametrize("simulation", [({"salary": 2000}, PERIOD)], indirect=True) -def test_basic_calculation(simulation): +def test_basic_calculation(simulation) -> None: result = simulation.calculate("income_tax", PERIOD) tools.assert_near(result, [300], absolute_error_margin=0.01) @pytest.mark.parametrize("simulation", [({"salary": 24000}, PERIOD)], indirect=True) -def test_calculate_add(simulation): +def test_calculate_add(simulation) -> None: result = simulation.calculate_add("income_tax", PERIOD) tools.assert_near(result, [3600], absolute_error_margin=0.01) @@ -32,26 +32,26 @@ def test_calculate_add(simulation): [({"accommodation_size": 100, "housing_occupancy_status": "tenant"}, PERIOD)], indirect=True, ) -def test_calculate_divide(simulation): +def test_calculate_divide(simulation) -> None: result = simulation.calculate_divide("housing_tax", PERIOD) tools.assert_near(result, [1000 / 12.0], absolute_error_margin=0.01) @pytest.mark.parametrize("simulation", [({"salary": 20000}, PERIOD)], indirect=True) -def test_bareme(simulation): +def test_bareme(simulation) -> None: result = simulation.calculate("social_security_contribution", PERIOD) expected = [0.02 * 6000 + 0.06 * 6400 + 0.12 * 7600] tools.assert_near(result, expected, absolute_error_margin=0.01) @pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect=True) -def test_non_existing_variable(simulation): +def test_non_existing_variable(simulation) -> None: with pytest.raises(VariableNotFoundError): simulation.calculate("non_existent_variable", PERIOD) @pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect=True) -def test_calculate_variable_with_wrong_definition_period(simulation): +def test_calculate_variable_with_wrong_definition_period(simulation) -> None: year = str(PERIOD.this_year) with pytest.raises(ValueError) as error: @@ -67,7 +67,7 @@ def test_calculate_variable_with_wrong_definition_period(simulation): @pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect=True) -def test_divide_option_with_complex_period(simulation): +def test_divide_option_with_complex_period(simulation) -> None: quarter = PERIOD.last_3_months with pytest.raises(ValueError) as error: @@ -82,7 +82,7 @@ def test_divide_option_with_complex_period(simulation): ), f"Expected '{word}' in error message '{error_message}'" -def test_input_with_wrong_period(tax_benefit_system): +def test_input_with_wrong_period(tax_benefit_system) -> None: year = str(PERIOD.this_year) variables = {"basic_income": {year: 12000}} simulation_builder = SimulationBuilder() @@ -92,7 +92,7 @@ def test_input_with_wrong_period(tax_benefit_system): simulation_builder.build_from_variables(tax_benefit_system, variables) -def test_variable_with_reference(make_simulation, isolated_tax_benefit_system): +def test_variable_with_reference(make_simulation, isolated_tax_benefit_system) -> None: variables = {"salary": 4000} simulation = make_simulation(isolated_tax_benefit_system, variables, PERIOD) @@ -103,8 +103,8 @@ def test_variable_with_reference(make_simulation, isolated_tax_benefit_system): class disposable_income(Variable): definition_period = DateUnit.MONTH - def formula(household, period): - return household.empty_array() + def formula(self, period): + return self.empty_array() isolated_tax_benefit_system.update_variable(disposable_income) simulation = make_simulation(isolated_tax_benefit_system, variables, PERIOD) @@ -114,13 +114,13 @@ def formula(household, period): assert result == 0 -def test_variable_name_conflict(tax_benefit_system): +def test_variable_name_conflict(tax_benefit_system) -> None: class disposable_income(Variable): reference = "disposable_income" definition_period = DateUnit.MONTH - def formula(household, period): - return household.empty_array() + def formula(self, period): + return self.empty_array() with pytest.raises(VariableNameConflictError): tax_benefit_system.add_variable(disposable_income) diff --git a/tests/core/test_cycles.py b/tests/core/test_cycles.py index 14886532c6..acb08c6424 100644 --- a/tests/core/test_cycles.py +++ b/tests/core/test_cycles.py @@ -25,8 +25,8 @@ class variable1(Variable): entity = entities.Person definition_period = DateUnit.MONTH - def formula(person, period): - return person("variable2", period) + def formula(self, period): + return self("variable2", period) class variable2(Variable): @@ -34,8 +34,8 @@ class variable2(Variable): entity = entities.Person definition_period = DateUnit.MONTH - def formula(person, period): - return person("variable1", period) + def formula(self, period): + return self("variable1", period) # 3 <--> 4 with a period offset @@ -44,8 +44,8 @@ class variable3(Variable): entity = entities.Person definition_period = DateUnit.MONTH - def formula(person, period): - return person("variable4", period.last_month) + def formula(self, period): + return self("variable4", period.last_month) class variable4(Variable): @@ -53,8 +53,8 @@ class variable4(Variable): entity = entities.Person definition_period = DateUnit.MONTH - def formula(person, period): - return person("variable3", period) + def formula(self, period): + return self("variable3", period) # 5 -f-> 6 with a period offset @@ -64,8 +64,8 @@ class variable5(Variable): entity = entities.Person definition_period = DateUnit.MONTH - def formula(person, period): - variable6 = person("variable6", period.last_month) + def formula(self, period): + variable6 = self("variable6", period.last_month) return 5 + variable6 @@ -74,8 +74,8 @@ class variable6(Variable): entity = entities.Person definition_period = DateUnit.MONTH - def formula(person, period): - variable5 = person("variable5", period) + def formula(self, period): + variable5 = self("variable5", period) return 6 + variable5 @@ -84,8 +84,8 @@ class variable7(Variable): entity = entities.Person definition_period = DateUnit.MONTH - def formula(person, period): - variable5 = person("variable5", period) + def formula(self, period): + variable5 = self("variable5", period) return 7 + variable5 @@ -95,15 +95,14 @@ class cotisation(Variable): entity = entities.Person definition_period = DateUnit.MONTH - def formula(person, period): + def formula(self, period): if period.start.month == 12: - return 2 * person("cotisation", period.last_month) - else: - return person.empty_array() + 1 + return 2 * self("cotisation", period.last_month) + return self.empty_array() + 1 @pytest.fixture(scope="module", autouse=True) -def add_variables_to_tax_benefit_system(tax_benefit_system): +def add_variables_to_tax_benefit_system(tax_benefit_system) -> None: tax_benefit_system.add_variables( variable1, variable2, @@ -116,34 +115,35 @@ def add_variables_to_tax_benefit_system(tax_benefit_system): ) -def test_pure_cycle(simulation, reference_period): +def test_pure_cycle(simulation, reference_period) -> None: with pytest.raises(CycleError): simulation.calculate("variable1", period=reference_period) -def test_spirals_result_in_default_value(simulation, reference_period): +def test_spirals_result_in_default_value(simulation, reference_period) -> None: variable3 = simulation.calculate("variable3", period=reference_period) tools.assert_near(variable3, [0]) -def test_spiral_heuristic(simulation, reference_period): +def test_spiral_heuristic(simulation, reference_period) -> None: variable5 = simulation.calculate("variable5", period=reference_period) variable6 = simulation.calculate("variable6", period=reference_period) variable6_last_month = simulation.calculate( - "variable6", reference_period.last_month + "variable6", + reference_period.last_month, ) tools.assert_near(variable5, [11]) tools.assert_near(variable6, [11]) tools.assert_near(variable6_last_month, [11]) -def test_spiral_cache(simulation, reference_period): +def test_spiral_cache(simulation, reference_period) -> None: simulation.calculate("variable7", period=reference_period) cached_variable7 = simulation.get_holder("variable7").get_array(reference_period) assert cached_variable7 is not None -def test_cotisation_1_level(simulation, reference_period): +def test_cotisation_1_level(simulation, reference_period) -> None: month = reference_period.last_month cotisation = simulation.calculate("cotisation", period=month) tools.assert_near(cotisation, [0]) diff --git a/tests/core/test_dump_restore.py b/tests/core/test_dump_restore.py index b03c55a831..c84044165c 100644 --- a/tests/core/test_dump_restore.py +++ b/tests/core/test_dump_restore.py @@ -9,10 +9,11 @@ from openfisca_core.tools import simulation_dumper -def test_dump(tax_benefit_system): +def test_dump(tax_benefit_system) -> None: directory = tempfile.mkdtemp(prefix="openfisca_") simulation = SimulationBuilder().build_from_entities( - tax_benefit_system, situation_examples.couple + tax_benefit_system, + situation_examples.couple, ) calculated_value = simulation.calculate("disposable_income", "2018-01") simulation_dumper.dump_simulation(simulation, directory) @@ -26,13 +27,16 @@ def test_dump(tax_benefit_system): testing.assert_array_equal(simulation.household.ids, simulation_2.household.ids) testing.assert_array_equal(simulation.household.count, simulation_2.household.count) testing.assert_array_equal( - simulation.household.members_position, simulation_2.household.members_position + simulation.household.members_position, + simulation_2.household.members_position, ) testing.assert_array_equal( - simulation.household.members_entity_id, simulation_2.household.members_entity_id + simulation.household.members_entity_id, + simulation_2.household.members_entity_id, ) testing.assert_array_equal( - simulation.household.members_role, simulation_2.household.members_role + simulation.household.members_role, + simulation_2.household.members_role, ) # Check calculated values are in cache diff --git a/tests/core/test_entities.py b/tests/core/test_entities.py index 1b7b646311..aba17dc4dc 100644 --- a/tests/core/test_entities.py +++ b/tests/core/test_entities.py @@ -34,7 +34,7 @@ def new_simulation(tax_benefit_system, test_case, period=MONTH): return simulation_builder.build_from_entities(tax_benefit_system, test_case) -def test_role_index_and_positions(tax_benefit_system): +def test_role_index_and_positions(tax_benefit_system) -> None: simulation = new_simulation(tax_benefit_system, TEST_CASE) tools.assert_near(simulation.household.members_entity_id, [0, 0, 0, 0, 1, 1]) assert ( @@ -46,7 +46,7 @@ def test_role_index_and_positions(tax_benefit_system): assert simulation.household.ids == ["h1", "h2"] -def test_entity_structure_with_constructor(tax_benefit_system): +def test_entity_structure_with_constructor(tax_benefit_system) -> None: simulation_yaml = """ persons: bill: {} @@ -68,7 +68,8 @@ def test_entity_structure_with_constructor(tax_benefit_system): """ simulation = SimulationBuilder().build_from_dict( - tax_benefit_system, test_runner.yaml.safe_load(simulation_yaml) + tax_benefit_system, + test_runner.yaml.safe_load(simulation_yaml), ) household = simulation.household @@ -81,7 +82,7 @@ def test_entity_structure_with_constructor(tax_benefit_system): tools.assert_near(household.members_position, [0, 1, 0, 2, 3]) -def test_entity_variables_with_constructor(tax_benefit_system): +def test_entity_variables_with_constructor(tax_benefit_system) -> None: simulation_yaml = """ persons: bill: {} @@ -107,13 +108,14 @@ def test_entity_variables_with_constructor(tax_benefit_system): """ simulation = SimulationBuilder().build_from_dict( - tax_benefit_system, test_runner.yaml.safe_load(simulation_yaml) + tax_benefit_system, + test_runner.yaml.safe_load(simulation_yaml), ) household = simulation.household tools.assert_near(household("rent", "2017-06"), [800, 600]) -def test_person_variable_with_constructor(tax_benefit_system): +def test_person_variable_with_constructor(tax_benefit_system) -> None: simulation_yaml = """ persons: bill: @@ -142,14 +144,15 @@ def test_person_variable_with_constructor(tax_benefit_system): """ simulation = SimulationBuilder().build_from_dict( - tax_benefit_system, test_runner.yaml.safe_load(simulation_yaml) + tax_benefit_system, + test_runner.yaml.safe_load(simulation_yaml), ) person = simulation.person tools.assert_near(person("salary", "2017-11"), [1500, 0, 3000, 0, 0]) tools.assert_near(person("salary", "2017-12"), [2000, 0, 4000, 0, 0]) -def test_set_input_with_constructor(tax_benefit_system): +def test_set_input_with_constructor(tax_benefit_system) -> None: simulation_yaml = """ persons: bill: @@ -183,34 +186,38 @@ def test_set_input_with_constructor(tax_benefit_system): """ simulation = SimulationBuilder().build_from_dict( - tax_benefit_system, test_runner.yaml.safe_load(simulation_yaml) + tax_benefit_system, + test_runner.yaml.safe_load(simulation_yaml), ) person = simulation.person tools.assert_near(person("salary", "2017-12"), [2000, 0, 4000, 0, 0]) tools.assert_near(person("salary", "2017-10"), [2000, 3000, 1600, 0, 0]) -def test_has_role(tax_benefit_system): +def test_has_role(tax_benefit_system) -> None: simulation = new_simulation(tax_benefit_system, TEST_CASE) individu = simulation.persons tools.assert_near(individu.has_role(CHILD), [False, False, True, True, False, True]) -def test_has_role_with_subrole(tax_benefit_system): +def test_has_role_with_subrole(tax_benefit_system) -> None: simulation = new_simulation(tax_benefit_system, TEST_CASE) individu = simulation.persons tools.assert_near( - individu.has_role(PARENT), [True, True, False, False, True, False] + individu.has_role(PARENT), + [True, True, False, False, True, False], ) tools.assert_near( - individu.has_role(FIRST_PARENT), [True, False, False, False, True, False] + individu.has_role(FIRST_PARENT), + [True, False, False, False, True, False], ) tools.assert_near( - individu.has_role(SECOND_PARENT), [False, True, False, False, False, False] + individu.has_role(SECOND_PARENT), + [False, True, False, False, False, False], ) -def test_project(tax_benefit_system): +def test_project(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE) test_case["households"]["h1"]["housing_tax"] = 20000 @@ -226,7 +233,7 @@ def test_project(tax_benefit_system): tools.assert_near(housing_tax_projected_on_parents, [20000, 20000, 0, 0, 0, 0]) -def test_implicit_projection(tax_benefit_system): +def test_implicit_projection(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE) test_case["households"]["h1"]["housing_tax"] = 20000 @@ -237,7 +244,7 @@ def test_implicit_projection(tax_benefit_system): tools.assert_near(housing_tax, [20000, 20000, 20000, 20000, 0, 0]) -def test_sum(tax_benefit_system): +def test_sum(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE) test_case["persons"]["ind0"]["salary"] = 1000 test_case["persons"]["ind1"]["salary"] = 1500 @@ -257,7 +264,7 @@ def test_sum(tax_benefit_system): tools.assert_near(total_salary_parents_by_household, [2500, 3000]) -def test_any(tax_benefit_system): +def test_any(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE_AGES) simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household @@ -272,7 +279,7 @@ def test_any(tax_benefit_system): tools.assert_near(has_household_CHILD_with_age_sup_18, [False, True]) -def test_all(tax_benefit_system): +def test_all(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE_AGES) simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household @@ -287,7 +294,7 @@ def test_all(tax_benefit_system): tools.assert_near(all_parents_age_sup_18, [True, True]) -def test_max(tax_benefit_system): +def test_max(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE_AGES) simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household @@ -301,7 +308,7 @@ def test_max(tax_benefit_system): tools.assert_near(age_max_child, [9, 20]) -def test_min(tax_benefit_system): +def test_min(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE_AGES) simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household @@ -315,7 +322,7 @@ def test_min(tax_benefit_system): tools.assert_near(age_min_parents, [37, 54]) -def test_value_nth_person(tax_benefit_system): +def test_value_nth_person(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE_AGES) simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household @@ -334,7 +341,7 @@ def test_value_nth_person(tax_benefit_system): tools.assert_near(result3, [9, -1]) -def test_rank(tax_benefit_system): +def test_rank(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE_AGES) simulation = new_simulation(tax_benefit_system, test_case) person = simulation.person @@ -344,12 +351,14 @@ def test_rank(tax_benefit_system): tools.assert_near(rank, [3, 2, 0, 1, 1, 0]) rank_in_siblings = person.get_rank( - person.household, -age, condition=person.has_role(entities.Household.CHILD) + person.household, + -age, + condition=person.has_role(entities.Household.CHILD), ) tools.assert_near(rank_in_siblings, [-1, -1, 1, 0, -1, 0]) -def test_partner(tax_benefit_system): +def test_partner(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE) test_case["persons"]["ind0"]["salary"] = 1000 test_case["persons"]["ind1"]["salary"] = 1500 @@ -366,7 +375,7 @@ def test_partner(tax_benefit_system): tools.assert_near(salary_second_parent, [1500, 1000, 0, 0, 0, 0]) -def test_value_from_first_person(tax_benefit_system): +def test_value_from_first_person(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE) test_case["persons"]["ind0"]["salary"] = 1000 test_case["persons"]["ind1"]["salary"] = 1500 @@ -382,9 +391,10 @@ def test_value_from_first_person(tax_benefit_system): tools.assert_near(salary_first_person, [1000, 3000]) -def test_projectors_methods(tax_benefit_system): +def test_projectors_methods(tax_benefit_system) -> None: simulation = SimulationBuilder().build_from_dict( - tax_benefit_system, situation_examples.couple + tax_benefit_system, + situation_examples.couple, ) household = simulation.household person = simulation.person @@ -403,7 +413,7 @@ def test_projectors_methods(tax_benefit_system): ) # Must be of a person dimension -def test_sum_following_bug_ipp_1(tax_benefit_system): +def test_sum_following_bug_ipp_1(tax_benefit_system) -> None: test_case = { "persons": {"ind0": {}, "ind1": {}, "ind2": {}, "ind3": {}}, "households": { @@ -425,7 +435,7 @@ def test_sum_following_bug_ipp_1(tax_benefit_system): tools.assert_near(nb_eligibles_by_household, [0, 2]) -def test_sum_following_bug_ipp_2(tax_benefit_system): +def test_sum_following_bug_ipp_2(tax_benefit_system) -> None: test_case = { "persons": {"ind0": {}, "ind1": {}, "ind2": {}, "ind3": {}}, "households": { @@ -447,7 +457,7 @@ def test_sum_following_bug_ipp_2(tax_benefit_system): tools.assert_near(nb_eligibles_by_household, [2, 0]) -def test_get_memory_usage(tax_benefit_system): +def test_get_memory_usage(tax_benefit_system) -> None: test_case = deepcopy(situation_examples.single) test_case["persons"]["Alicia"]["salary"] = {"2017-01": 0} simulation = SimulationBuilder().build_from_dict(tax_benefit_system, test_case) @@ -457,7 +467,7 @@ def test_get_memory_usage(tax_benefit_system): assert len(memory_usage["by_variable"]) == 1 -def test_unordered_persons(tax_benefit_system): +def test_unordered_persons(tax_benefit_system) -> None: test_case = { "persons": { "ind4": {}, @@ -527,11 +537,14 @@ def test_unordered_persons(tax_benefit_system): # Projection entity -> persons tools.assert_near( - household.project(accommodation_size), [60, 160, 160, 160, 60, 160] + household.project(accommodation_size), + [60, 160, 160, 160, 60, 160], ) tools.assert_near( - household.project(accommodation_size, role=PARENT), [60, 0, 160, 0, 0, 160] + household.project(accommodation_size, role=PARENT), + [60, 0, 160, 0, 0, 160], ) tools.assert_near( - household.project(accommodation_size, role=CHILD), [0, 160, 0, 160, 60, 0] + household.project(accommodation_size, role=CHILD), + [0, 160, 0, 160, 60, 0], ) diff --git a/tests/core/test_extensions.py b/tests/core/test_extensions.py index 2bb2689b15..4854815ac3 100644 --- a/tests/core/test_extensions.py +++ b/tests/core/test_extensions.py @@ -1,7 +1,7 @@ import pytest -def test_load_extension(tax_benefit_system): +def test_load_extension(tax_benefit_system) -> None: tbs = tax_benefit_system.clone() assert tbs.get_variable("local_town_child_allowance") is None @@ -11,7 +11,7 @@ def test_load_extension(tax_benefit_system): assert tax_benefit_system.get_variable("local_town_child_allowance") is None -def test_access_to_parameters(tax_benefit_system): +def test_access_to_parameters(tax_benefit_system) -> None: tbs = tax_benefit_system.clone() tbs.load_extension("openfisca_extension_template") @@ -19,6 +19,8 @@ def test_access_to_parameters(tax_benefit_system): assert tbs.parameters.local_town.child_allowance.amount("2016-01") == 100.0 -def test_failure_to_load_extension_when_directory_doesnt_exist(tax_benefit_system): +def test_failure_to_load_extension_when_directory_doesnt_exist( + tax_benefit_system, +) -> None: with pytest.raises(ValueError): tax_benefit_system.load_extension("/this/is/not/a/real/path") diff --git a/tests/core/test_formulas.py b/tests/core/test_formulas.py index c8a5379801..32e6fd35e7 100644 --- a/tests/core/test_formulas.py +++ b/tests/core/test_formulas.py @@ -21,10 +21,9 @@ class uses_multiplication(Variable): label = "Variable with formula that uses multiplication" definition_period = DateUnit.MONTH - def formula(person, period): - choice = person("choice", period) - result = (choice == 1) * 80 + (choice == 2) * 90 - return result + def formula(self, period): + choice = self("choice", period) + return (choice == 1) * 80 + (choice == 2) * 90 class returns_scalar(Variable): @@ -33,7 +32,7 @@ class returns_scalar(Variable): label = "Variable with formula that returns a scalar value" definition_period = DateUnit.MONTH - def formula(person, period): + def formula(self, period) -> int: return 666 @@ -43,27 +42,29 @@ class uses_switch(Variable): label = "Variable with formula that uses switch" definition_period = DateUnit.MONTH - def formula(person, period): - choice = person("choice", period) - result = commons.switch( + def formula(self, period): + choice = self("choice", period) + return commons.switch( choice, { 1: 80, 2: 90, }, ) - return result @fixture(scope="module", autouse=True) -def add_variables_to_tax_benefit_system(tax_benefit_system): +def add_variables_to_tax_benefit_system(tax_benefit_system) -> None: tax_benefit_system.add_variables( - choice, uses_multiplication, uses_switch, returns_scalar + choice, + uses_multiplication, + uses_switch, + returns_scalar, ) @fixture -def month(): +def month() -> str: return "2013-01" @@ -72,35 +73,36 @@ def simulation(tax_benefit_system, month): simulation_builder = SimulationBuilder() simulation_builder.default_period = month simulation = simulation_builder.build_from_variables( - tax_benefit_system, {"choice": numpy.random.randint(2, size=1000) + 1} + tax_benefit_system, + {"choice": numpy.random.randint(2, size=1000) + 1}, ) simulation.debug = True return simulation -def test_switch(simulation, month): +def test_switch(simulation, month) -> None: uses_switch = simulation.calculate("uses_switch", period=month) assert isinstance(uses_switch, numpy.ndarray) -def test_multiplication(simulation, month): +def test_multiplication(simulation, month) -> None: uses_multiplication = simulation.calculate("uses_multiplication", period=month) assert isinstance(uses_multiplication, numpy.ndarray) -def test_broadcast_scalar(simulation, month): +def test_broadcast_scalar(simulation, month) -> None: array_value = simulation.calculate("returns_scalar", period=month) assert isinstance(array_value, numpy.ndarray) assert array_value == approx(numpy.repeat(666, 1000)) -def test_compare_multiplication_and_switch(simulation, month): +def test_compare_multiplication_and_switch(simulation, month) -> None: uses_multiplication = simulation.calculate("uses_multiplication", period=month) uses_switch = simulation.calculate("uses_switch", period=month) assert numpy.all(uses_switch == uses_multiplication) -def test_group_encapsulation(): +def test_group_encapsulation() -> None: """Projects a calculation to all members of an entity. When a household contains more than one family @@ -128,7 +130,7 @@ def test_group_encapsulation(): "key": "member", "plural": "members", "label": "Member", - } + }, ], ) household_entity = build_entity( @@ -140,7 +142,7 @@ def test_group_encapsulation(): "key": "member", "plural": "members", "label": "Member", - } + }, ], ) @@ -158,8 +160,8 @@ class projected_family_level_variable(Variable): entity = family_entity definition_period = DateUnit.ETERNITY - def formula(family, period): - return family.household("household_level_variable", period) + def formula(self, period): + return self.household("household_level_variable", period) system.add_variables(household_level_variable, projected_family_level_variable) @@ -175,7 +177,7 @@ def formula(family, period): "household1": { "members": ["person1", "person2", "person3"], "household_level_variable": {"eternity": 5}, - } + }, }, }, ) diff --git a/tests/core/test_holders.py b/tests/core/test_holders.py index 088ca15935..c72d053ad6 100644 --- a/tests/core/test_holders.py +++ b/tests/core/test_holders.py @@ -15,51 +15,56 @@ @pytest.fixture def single(tax_benefit_system): return SimulationBuilder().build_from_entities( - tax_benefit_system, situation_examples.single + tax_benefit_system, + situation_examples.single, ) @pytest.fixture def couple(tax_benefit_system): return SimulationBuilder().build_from_entities( - tax_benefit_system, situation_examples.couple + tax_benefit_system, + situation_examples.couple, ) period = periods.period("2017-12") -def test_set_input_enum_string(couple): +def test_set_input_enum_string(couple) -> None: simulation = couple status_occupancy = numpy.asarray(["free_lodger"]) simulation.household.get_holder("housing_occupancy_status").set_input( - period, status_occupancy + period, + status_occupancy, ) result = simulation.calculate("housing_occupancy_status", period) assert result == housing.HousingOccupancyStatus.free_lodger -def test_set_input_enum_int(couple): +def test_set_input_enum_int(couple) -> None: simulation = couple status_occupancy = numpy.asarray([2], dtype=numpy.int16) simulation.household.get_holder("housing_occupancy_status").set_input( - period, status_occupancy + period, + status_occupancy, ) result = simulation.calculate("housing_occupancy_status", period) assert result == housing.HousingOccupancyStatus.free_lodger -def test_set_input_enum_item(couple): +def test_set_input_enum_item(couple) -> None: simulation = couple status_occupancy = numpy.asarray([housing.HousingOccupancyStatus.free_lodger]) simulation.household.get_holder("housing_occupancy_status").set_input( - period, status_occupancy + period, + status_occupancy, ) result = simulation.calculate("housing_occupancy_status", period) assert result == housing.HousingOccupancyStatus.free_lodger -def test_yearly_input_month_variable(couple): +def test_yearly_input_month_variable(couple) -> None: with pytest.raises(PeriodMismatchError) as error: couple.set_input("rent", 2019, 3000) assert ( @@ -68,7 +73,7 @@ def test_yearly_input_month_variable(couple): ) -def test_3_months_input_month_variable(couple): +def test_3_months_input_month_variable(couple) -> None: with pytest.raises(PeriodMismatchError) as error: couple.set_input("rent", "month:2019-01:3", 3000) assert ( @@ -77,7 +82,7 @@ def test_3_months_input_month_variable(couple): ) -def test_month_input_year_variable(couple): +def test_month_input_year_variable(couple) -> None: with pytest.raises(PeriodMismatchError) as error: couple.set_input("housing_tax", "2019-01", 3000) assert ( @@ -86,23 +91,24 @@ def test_month_input_year_variable(couple): ) -def test_enum_dtype(couple): +def test_enum_dtype(couple) -> None: simulation = couple status_occupancy = numpy.asarray([2], dtype=numpy.int16) simulation.household.get_holder("housing_occupancy_status").set_input( - period, status_occupancy + period, + status_occupancy, ) result = simulation.calculate("housing_occupancy_status", period) assert result.dtype.kind is not None -def test_permanent_variable_empty(single): +def test_permanent_variable_empty(single) -> None: simulation = single holder = simulation.person.get_holder("birth") assert holder.get_array(None) is None -def test_permanent_variable_filled(single): +def test_permanent_variable_filled(single) -> None: simulation = single holder = simulation.person.get_holder("birth") value = numpy.asarray(["1980-01-01"], dtype=holder.variable.dtype) @@ -112,7 +118,7 @@ def test_permanent_variable_filled(single): assert holder.get_array("2016-01") == value -def test_delete_arrays(single): +def test_delete_arrays(single) -> None: simulation = single salary_holder = simulation.person.get_holder("salary") salary_holder.set_input(periods.period(2017), numpy.asarray([30000])) @@ -131,7 +137,7 @@ def test_delete_arrays(single): assert simulation.person("salary", "2018-01") == 1250 -def test_get_memory_usage(single): +def test_get_memory_usage(single) -> None: simulation = single salary_holder = simulation.person.get_holder("salary") memory_usage = salary_holder.get_memory_usage() @@ -145,7 +151,7 @@ def test_get_memory_usage(single): assert memory_usage["total_nb_bytes"] == 4 * 12 * 1 -def test_get_memory_usage_with_trace(single): +def test_get_memory_usage_with_trace(single) -> None: simulation = single simulation.trace = True salary_holder = simulation.person.get_holder("salary") @@ -159,24 +165,24 @@ def test_get_memory_usage_with_trace(single): assert memory_usage["nb_requests_by_array"] == 1.25 # 15 calculations / 12 arrays -def test_set_input_dispatch_by_period(single): +def test_set_input_dispatch_by_period(single) -> None: simulation = single variable = simulation.tax_benefit_system.get_variable("housing_occupancy_status") entity = simulation.household holder = Holder(variable, entity) holders.set_input_dispatch_by_period(holder, periods.period(2019), "owner") assert holder.get_array("2019-01") == holder.get_array( - "2019-12" + "2019-12", ) # Check the feature assert holder.get_array("2019-01") is holder.get_array( - "2019-12" + "2019-12", ) # Check that the vectors are the same in memory, to avoid duplication force_storage_on_disk = MemoryConfig(max_memory_occupation=0) -def test_delete_arrays_on_disk(single): +def test_delete_arrays_on_disk(single) -> None: simulation = single simulation.memory_config = force_storage_on_disk salary_holder = simulation.person.get_holder("salary") @@ -190,7 +196,7 @@ def test_delete_arrays_on_disk(single): assert simulation.person("salary", "2018-01") == 1250 -def test_cache_disk(couple): +def test_cache_disk(couple) -> None: simulation = couple simulation.memory_config = force_storage_on_disk month = periods.period("2017-01") @@ -201,7 +207,7 @@ def test_cache_disk(couple): tools.assert_near(data, stored_data) -def test_known_periods(couple): +def test_known_periods(couple) -> None: simulation = couple simulation.memory_config = force_storage_on_disk month = periods.period("2017-01") @@ -214,20 +220,22 @@ def test_known_periods(couple): assert sorted(holder.get_known_periods()), [month == month_2] -def test_cache_enum_on_disk(single): +def test_cache_enum_on_disk(single) -> None: simulation = single simulation.memory_config = force_storage_on_disk month = periods.period("2017-01") simulation.calculate("housing_occupancy_status", month) # First calculation housing_occupancy_status = simulation.calculate( - "housing_occupancy_status", month + "housing_occupancy_status", + month, ) # Read from cache assert housing_occupancy_status == housing.HousingOccupancyStatus.tenant -def test_set_not_cached_variable(single): +def test_set_not_cached_variable(single) -> None: dont_cache_variable = MemoryConfig( - max_memory_occupation=1, variables_to_drop=["salary"] + max_memory_occupation=1, + variables_to_drop=["salary"], ) simulation = single simulation.memory_config = dont_cache_variable @@ -237,7 +245,7 @@ def test_set_not_cached_variable(single): assert simulation.calculate("salary", "2015-01") == array -def test_set_input_float_to_int(single): +def test_set_input_float_to_int(single) -> None: simulation = single age = numpy.asarray([50.6]) simulation.person.get_holder("age").set_input(period, age) diff --git a/tests/core/test_opt_out_cache.py b/tests/core/test_opt_out_cache.py index e9fe3a2469..2f61da2898 100644 --- a/tests/core/test_opt_out_cache.py +++ b/tests/core/test_opt_out_cache.py @@ -22,8 +22,8 @@ class intermediate(Variable): label = "Intermediate result that don't need to be cached" definition_period = DateUnit.MONTH - def formula(person, period): - return person("input", period) + def formula(self, period): + return self("input", period) class output(Variable): @@ -32,29 +32,29 @@ class output(Variable): label = "Output variable" definition_period = DateUnit.MONTH - def formula(person, period): - return person("intermediate", period) + def formula(self, period): + return self("intermediate", period) @pytest.fixture(scope="module", autouse=True) -def add_variables_to_tax_benefit_system(tax_benefit_system): +def add_variables_to_tax_benefit_system(tax_benefit_system) -> None: tax_benefit_system.add_variables(input, intermediate, output) @pytest.fixture(scope="module", autouse=True) -def add_variables_to_cache_blakclist(tax_benefit_system): - tax_benefit_system.cache_blacklist = set(["intermediate"]) +def add_variables_to_cache_blakclist(tax_benefit_system) -> None: + tax_benefit_system.cache_blacklist = {"intermediate"} @pytest.mark.parametrize("simulation", [({"input": 1}, PERIOD)], indirect=True) -def test_without_cache_opt_out(simulation): +def test_without_cache_opt_out(simulation) -> None: simulation.calculate("output", period=PERIOD) intermediate_cache = simulation.persons.get_holder("intermediate") assert intermediate_cache.get_array(PERIOD) is not None @pytest.mark.parametrize("simulation", [({"input": 1}, PERIOD)], indirect=True) -def test_with_cache_opt_out(simulation): +def test_with_cache_opt_out(simulation) -> None: simulation.debug = True simulation.opt_out_cache = True simulation.calculate("output", period=PERIOD) @@ -63,7 +63,7 @@ def test_with_cache_opt_out(simulation): @pytest.mark.parametrize("simulation", [({"input": 1}, PERIOD)], indirect=True) -def test_with_no_blacklist(simulation): +def test_with_no_blacklist(simulation) -> None: simulation.calculate("output", period=PERIOD) intermediate_cache = simulation.persons.get_holder("intermediate") assert intermediate_cache.get_array(PERIOD) is not None diff --git a/tests/core/test_parameters.py b/tests/core/test_parameters.py index 13e2874787..7fe63a8180 100644 --- a/tests/core/test_parameters.py +++ b/tests/core/test_parameters.py @@ -10,18 +10,19 @@ ) -def test_get_at_instant(tax_benefit_system): +def test_get_at_instant(tax_benefit_system) -> None: parameters = tax_benefit_system.parameters assert isinstance(parameters, ParameterNode), parameters parameters_at_instant = parameters("2016-01-01") assert isinstance( - parameters_at_instant, ParameterNodeAtInstant + parameters_at_instant, + ParameterNodeAtInstant, ), parameters_at_instant assert parameters_at_instant.taxes.income_tax_rate == 0.15 assert parameters_at_instant.benefits.basic_income == 600 -def test_param_values(tax_benefit_system): +def test_param_values(tax_benefit_system) -> None: dated_values = { "2015-01-01": 0.15, "2014-01-01": 0.14, @@ -36,47 +37,47 @@ def test_param_values(tax_benefit_system): ) -def test_param_before_it_is_defined(tax_benefit_system): +def test_param_before_it_is_defined(tax_benefit_system) -> None: with pytest.raises(ParameterNotFound): tax_benefit_system.get_parameters_at_instant("1997-12-31").taxes.income_tax_rate # The placeholder should have no effect on the parameter computation -def test_param_with_placeholder(tax_benefit_system): +def test_param_with_placeholder(tax_benefit_system) -> None: assert ( tax_benefit_system.get_parameters_at_instant("2018-01-01").taxes.income_tax_rate == 0.15 ) -def test_stopped_parameter_before_end_value(tax_benefit_system): +def test_stopped_parameter_before_end_value(tax_benefit_system) -> None: assert ( tax_benefit_system.get_parameters_at_instant( - "2011-12-31" + "2011-12-31", ).benefits.housing_allowance == 0.25 ) -def test_stopped_parameter_after_end_value(tax_benefit_system): +def test_stopped_parameter_after_end_value(tax_benefit_system) -> None: with pytest.raises(ParameterNotFound): tax_benefit_system.get_parameters_at_instant( - "2016-12-01" + "2016-12-01", ).benefits.housing_allowance -def test_parameter_for_period(tax_benefit_system): +def test_parameter_for_period(tax_benefit_system) -> None: income_tax_rate = tax_benefit_system.parameters.taxes.income_tax_rate assert income_tax_rate("2015") == income_tax_rate("2015-01-01") -def test_wrong_value(tax_benefit_system): +def test_wrong_value(tax_benefit_system) -> None: income_tax_rate = tax_benefit_system.parameters.taxes.income_tax_rate with pytest.raises(ValueError): income_tax_rate("test") -def test_parameter_repr(tax_benefit_system): +def test_parameter_repr(tax_benefit_system) -> None: parameters = tax_benefit_system.parameters tf = tempfile.NamedTemporaryFile(delete=False) tf.write(repr(parameters).encode("utf-8")) @@ -85,7 +86,7 @@ def test_parameter_repr(tax_benefit_system): assert repr(parameters) == repr(tf_parameters) -def test_parameters_metadata(tax_benefit_system): +def test_parameters_metadata(tax_benefit_system) -> None: parameter = tax_benefit_system.parameters.benefits.basic_income assert ( parameter.metadata["reference"] == "https://law.gov.example/basic-income/amount" @@ -101,7 +102,7 @@ def test_parameters_metadata(tax_benefit_system): assert scale.metadata["rate_unit"] == "/1" -def test_parameter_node_metadata(tax_benefit_system): +def test_parameter_node_metadata(tax_benefit_system) -> None: parameter = tax_benefit_system.parameters.benefits assert parameter.description == "Social benefits" @@ -109,7 +110,7 @@ def test_parameter_node_metadata(tax_benefit_system): assert parameter_2.description == "Housing tax" -def test_parameter_documentation(tax_benefit_system): +def test_parameter_documentation(tax_benefit_system) -> None: parameter = tax_benefit_system.parameters.benefits.housing_allowance assert ( parameter.documentation @@ -117,16 +118,16 @@ def test_parameter_documentation(tax_benefit_system): ) -def test_get_descendants(tax_benefit_system): +def test_get_descendants(tax_benefit_system) -> None: all_parameters = { parameter.name for parameter in tax_benefit_system.parameters.get_descendants() } assert all_parameters.issuperset( - {"taxes", "taxes.housing_tax", "taxes.housing_tax.minimal_amount"} + {"taxes", "taxes.housing_tax", "taxes.housing_tax.minimal_amount"}, ) -def test_name(): +def test_name() -> None: parameter_data = { "description": "Parameter indexed by a numeric key", "2010": {"values": {"2006-01-01": 0.0075}}, diff --git a/tests/core/test_projectors.py b/tests/core/test_projectors.py index 27391711c3..c62e49d3a7 100644 --- a/tests/core/test_projectors.py +++ b/tests/core/test_projectors.py @@ -1,4 +1,4 @@ -import numpy as np +import numpy from openfisca_core.entities import build_entity from openfisca_core.indexed_enums import Enum @@ -8,9 +8,8 @@ from openfisca_core.variables import Variable -def test_shortcut_to_containing_entity_provided(): - """ - Tests that, when an entity provides a containing entity, +def test_shortcut_to_containing_entity_provided() -> None: + """Tests that, when an entity provides a containing entity, the shortcut to that containing entity is provided. """ person_entity = build_entity( @@ -29,7 +28,7 @@ def test_shortcut_to_containing_entity_provided(): "key": "member", "plural": "members", "label": "Member", - } + }, ], ) household_entity = build_entity( @@ -41,7 +40,7 @@ def test_shortcut_to_containing_entity_provided(): "key": "member", "plural": "members", "label": "Member", - } + }, ], ) @@ -52,9 +51,8 @@ def test_shortcut_to_containing_entity_provided(): assert simulation.populations["family"].household.entity.key == "household" -def test_shortcut_to_containing_entity_not_provided(): - """ - Tests that, when an entity doesn't provide a containing +def test_shortcut_to_containing_entity_not_provided() -> None: + """Tests that, when an entity doesn't provide a containing entity, the shortcut to that containing entity is not provided. """ person_entity = build_entity( @@ -73,7 +71,7 @@ def test_shortcut_to_containing_entity_not_provided(): "key": "member", "plural": "members", "label": "Member", - } + }, ], ) household_entity = build_entity( @@ -85,7 +83,7 @@ def test_shortcut_to_containing_entity_not_provided(): "key": "member", "plural": "members", "label": "Member", - } + }, ], ) @@ -95,17 +93,15 @@ def test_shortcut_to_containing_entity_not_provided(): simulation = SimulationBuilder().build_from_dict(system, {}) try: simulation.populations["family"].household - raise AssertionError() + raise AssertionError except AttributeError: pass -def test_enum_projects_downwards(): - """ - Test that an Enum-type household-level variable projects +def test_enum_projects_downwards() -> None: + """Test that an Enum-type household-level variable projects values onto its members correctly. """ - person = build_entity( key="person", plural="people", @@ -121,7 +117,7 @@ def test_enum_projects_downwards(): "key": "member", "plural": "members", "label": "Member", - } + }, ], ) @@ -147,8 +143,8 @@ class projected_enum_variable(Variable): entity = person definition_period = DateUnit.ETERNITY - def formula(person, period): - return person.household("household_enum_variable", period) + def formula(self, period): + return self.household("household_enum_variable", period) system.add_variables(household_enum_variable, projected_enum_variable) @@ -160,23 +156,21 @@ def formula(person, period): "household1": { "members": ["person1", "person2", "person3"], "household_enum_variable": {"eternity": "SECOND_OPTION"}, - } + }, }, }, ) assert ( simulation.calculate("projected_enum_variable", "2021-01-01").decode_to_str() - == np.array(["SECOND_OPTION"] * 3) + == numpy.array(["SECOND_OPTION"] * 3) ).all() -def test_enum_projects_upwards(): - """ - Test that an Enum-type person-level variable projects +def test_enum_projects_upwards() -> None: + """Test that an Enum-type person-level variable projects values onto its household (from the first person) correctly. """ - person = build_entity( key="person", plural="people", @@ -192,7 +186,7 @@ def test_enum_projects_upwards(): "key": "member", "plural": "members", "label": "Member", - } + }, ], ) @@ -211,9 +205,9 @@ class household_projected_variable(Variable): entity = household definition_period = DateUnit.ETERNITY - def formula(household, period): - return household.value_from_first_person( - household.members("person_enum_variable", period) + def formula(self, period): + return self.value_from_first_person( + self.members("person_enum_variable", period), ) class person_enum_variable(Variable): @@ -236,25 +230,24 @@ class person_enum_variable(Variable): "households": { "household1": { "members": ["person1", "person2", "person3"], - } + }, }, }, ) assert ( simulation.calculate( - "household_projected_variable", "2021-01-01" + "household_projected_variable", + "2021-01-01", ).decode_to_str() - == np.array(["SECOND_OPTION"]) + == numpy.array(["SECOND_OPTION"]) ).all() -def test_enum_projects_between_containing_groups(): - """ - Test that an Enum-type person-level variable projects +def test_enum_projects_between_containing_groups() -> None: + """Test that an Enum-type person-level variable projects values onto its household (from the first person) correctly. """ - person_entity = build_entity( key="person", plural="people", @@ -271,7 +264,7 @@ def test_enum_projects_between_containing_groups(): "key": "member", "plural": "members", "label": "Member", - } + }, ], ) household_entity = build_entity( @@ -283,7 +276,7 @@ def test_enum_projects_between_containing_groups(): "key": "member", "plural": "members", "label": "Member", - } + }, ], ) @@ -309,16 +302,16 @@ class projected_family_level_variable(Variable): entity = family_entity definition_period = DateUnit.ETERNITY - def formula(family, period): - return family.household("household_level_variable", period) + def formula(self, period): + return self.household("household_level_variable", period) class decoded_projected_family_level_variable(Variable): value_type = str entity = family_entity definition_period = DateUnit.ETERNITY - def formula(family, period): - return family.household("household_level_variable", period).decode_to_str() + def formula(self, period): + return self.household("household_level_variable", period).decode_to_str() system.add_variables( household_level_variable, @@ -338,18 +331,19 @@ def formula(family, period): "household1": { "members": ["person1", "person2", "person3"], "household_level_variable": {"eternity": "SECOND_OPTION"}, - } + }, }, }, ) assert ( simulation.calculate( - "projected_family_level_variable", "2021-01-01" + "projected_family_level_variable", + "2021-01-01", ).decode_to_str() - == np.array(["SECOND_OPTION"]) + == numpy.array(["SECOND_OPTION"]) ).all() assert ( simulation.calculate("decoded_projected_family_level_variable", "2021-01-01") - == np.array(["SECOND_OPTION"]) + == numpy.array(["SECOND_OPTION"]) ).all() diff --git a/tests/core/test_reforms.py b/tests/core/test_reforms.py index 0c17bb1169..1f31bcde2a 100644 --- a/tests/core/test_reforms.py +++ b/tests/core/test_reforms.py @@ -21,16 +21,16 @@ class goes_to_school(Variable): class WithBasicIncomeNeutralized(Reform): - def apply(self): + def apply(self) -> None: self.neutralize_variable("basic_income") @pytest.fixture(scope="module", autouse=True) -def add_variables_to_tax_benefit_system(tax_benefit_system): +def add_variables_to_tax_benefit_system(tax_benefit_system) -> None: tax_benefit_system.add_variables(goes_to_school) -def test_formula_neutralization(make_simulation, tax_benefit_system): +def test_formula_neutralization(make_simulation, tax_benefit_system) -> None: reform = WithBasicIncomeNeutralized(tax_benefit_system) period = "2017-01" @@ -48,16 +48,18 @@ def test_formula_neutralization(make_simulation, tax_benefit_system): basic_income_reform = reform_simulation.calculate("basic_income", period="2013-01") assert_near(basic_income_reform, 0, absolute_error_margin=0) disposable_income_reform = reform_simulation.calculate( - "disposable_income", period=period + "disposable_income", + period=period, ) assert_near(disposable_income_reform, 0) def test_neutralization_variable_with_default_value( - make_simulation, tax_benefit_system -): + make_simulation, + tax_benefit_system, +) -> None: class test_goes_to_school_neutralization(Reform): - def apply(self): + def apply(self) -> None: self.neutralize_variable("goes_to_school") reform = test_goes_to_school_neutralization(tax_benefit_system) @@ -69,7 +71,7 @@ def apply(self): assert_near(goes_to_school, [True], absolute_error_margin=0) -def test_neutralization_optimization(make_simulation, tax_benefit_system): +def test_neutralization_optimization(make_simulation, tax_benefit_system) -> None: reform = WithBasicIncomeNeutralized(tax_benefit_system) period = "2017-01" @@ -84,9 +86,9 @@ def test_neutralization_optimization(make_simulation, tax_benefit_system): assert basic_income_holder.get_known_periods() == [] -def test_input_variable_neutralization(make_simulation, tax_benefit_system): +def test_input_variable_neutralization(make_simulation, tax_benefit_system) -> None: class test_salary_neutralization(Reform): - def apply(self): + def apply(self) -> None: self.neutralize_variable("salary") reform = test_salary_neutralization(tax_benefit_system) @@ -107,21 +109,24 @@ def apply(self): [0, 0], ) disposable_income_reform = reform_simulation.calculate( - "disposable_income", period=period + "disposable_income", + period=period, ) assert_near(disposable_income_reform, [600, 600]) -def test_permanent_variable_neutralization(make_simulation, tax_benefit_system): +def test_permanent_variable_neutralization(make_simulation, tax_benefit_system) -> None: class test_date_naissance_neutralization(Reform): - def apply(self): + def apply(self) -> None: self.neutralize_variable("birth") reform = test_date_naissance_neutralization(tax_benefit_system) period = "2017-01" simulation = make_simulation( - reform.base_tax_benefit_system, {"birth": "1980-01-01"}, period + reform.base_tax_benefit_system, + {"birth": "1980-01-01"}, + period, ) with warnings.catch_warnings(record=True) as raised_warnings: reform_simulation = make_simulation(reform, {"birth": "1980-01-01"}, period) @@ -133,25 +138,35 @@ def apply(self): assert str(reform_simulation.calculate("birth", None)[0]) == "1970-01-01" -def test_update_items(): +def test_update_items() -> None: def check_update_items( - description, value_history, start_instant, stop_instant, value, expected_items - ): + description, + value_history, + start_instant, + stop_instant, + value, + expected_items, + ) -> None: value_history.update( - period=None, start=start_instant, stop=stop_instant, value=value + period=None, + start=start_instant, + stop=stop_instant, + value=value, ) assert value_history == expected_items check_update_items( "Replace an item by a new item", ValuesHistory( - "dummy_name", {"2013-01-01": {"value": 0.0}, "2014-01-01": {"value": None}} + "dummy_name", + {"2013-01-01": {"value": 0.0}, "2014-01-01": {"value": None}}, ), periods.period(2013).start, periods.period(2013).stop, 1.0, ValuesHistory( - "dummy_name", {"2013-01-01": {"value": 1.0}, "2014-01-01": {"value": None}} + "dummy_name", + {"2013-01-01": {"value": 1.0}, "2014-01-01": {"value": None}}, ), ) check_update_items( @@ -179,7 +194,8 @@ def check_update_items( check_update_items( "Open the stop instant to the future", ValuesHistory( - "dummy_name", {"2013-01-01": {"value": 0.0}, "2014-01-01": {"value": None}} + "dummy_name", + {"2013-01-01": {"value": 0.0}, "2014-01-01": {"value": None}}, ), periods.period(2013).start, None, # stop instant @@ -189,7 +205,8 @@ def check_update_items( check_update_items( "Insert a new item in the middle of an existing item", ValuesHistory( - "dummy_name", {"2010-01-01": {"value": 0.0}, "2014-01-01": {"value": None}} + "dummy_name", + {"2010-01-01": {"value": 0.0}, "2014-01-01": {"value": None}}, ), periods.period(2011).start, periods.period(2011).stop, @@ -250,7 +267,8 @@ def check_update_items( None, # stop instant 1.0, ValuesHistory( - "dummy_name", {"2006-01-01": {"value": 0.055}, "2014-01-01": {"value": 1.0}} + "dummy_name", + {"2006-01-01": {"value": 0.055}, "2014-01-01": {"value": 1.0}}, ), ) check_update_items( @@ -314,18 +332,18 @@ def check_update_items( ) -def test_add_variable(make_simulation, tax_benefit_system): +def test_add_variable(make_simulation, tax_benefit_system) -> None: class new_variable(Variable): value_type = int label = "Nouvelle variable introduite par la réforme" entity = Household definition_period = DateUnit.MONTH - def formula(household, period): - return household.empty_array() + 10 + def formula(self, period): + return self.empty_array() + 10 class test_add_variable(Reform): - def apply(self): + def apply(self) -> None: self.add_variable(new_variable) reform = test_add_variable(tax_benefit_system) @@ -337,21 +355,21 @@ def apply(self): assert_near(new_variable1, 10, absolute_error_margin=0) -def test_add_dated_variable(make_simulation, tax_benefit_system): +def test_add_dated_variable(make_simulation, tax_benefit_system) -> None: class new_dated_variable(Variable): value_type = int label = "Nouvelle variable introduite par la réforme" entity = Household definition_period = DateUnit.MONTH - def formula_2010_01_01(household, period): - return household.empty_array() + 10 + def formula_2010_01_01(self, period): + return self.empty_array() + 10 - def formula_2011_01_01(household, period): - return household.empty_array() + 15 + def formula_2011_01_01(self, period): + return self.empty_array() + 15 class test_add_variable(Reform): - def apply(self): + def apply(self) -> None: self.add_variable(new_dated_variable) reform = test_add_variable(tax_benefit_system) @@ -359,20 +377,21 @@ def apply(self): reform_simulation = make_simulation(reform, {}, "2013-01") reform_simulation.debug = True new_dated_variable1 = reform_simulation.calculate( - "new_dated_variable", period="2013-01" + "new_dated_variable", + period="2013-01", ) assert_near(new_dated_variable1, 15, absolute_error_margin=0) -def test_update_variable(make_simulation, tax_benefit_system): +def test_update_variable(make_simulation, tax_benefit_system) -> None: class disposable_income(Variable): definition_period = DateUnit.MONTH - def formula_2018(household, period): - return household.empty_array() + 10 + def formula_2018(self, period): + return self.empty_array() + 10 class test_update_variable(Reform): - def apply(self): + def apply(self) -> None: self.update_variable(disposable_income) reform = test_update_variable(tax_benefit_system) @@ -390,29 +409,31 @@ def apply(self): reform_simulation = make_simulation(reform, {}, 2018) disposable_income1 = reform_simulation.calculate( - "disposable_income", period="2018-01" + "disposable_income", + period="2018-01", ) assert_near(disposable_income1, 10, absolute_error_margin=0) disposable_income2 = reform_simulation.calculate( - "disposable_income", period="2017-01" + "disposable_income", + period="2017-01", ) # Before 2018, the former formula is used assert disposable_income2 > 100 -def test_replace_variable(tax_benefit_system): +def test_replace_variable(tax_benefit_system) -> None: class disposable_income(Variable): definition_period = DateUnit.MONTH entity = Person label = "Disposable income" value_type = float - def formula_2018(household, period): - return household.empty_array() + 10 + def formula_2018(self, period): + return self.empty_array() + 10 class test_update_variable(Reform): - def apply(self): + def apply(self) -> None: self.replace_variable(disposable_income) reform = test_update_variable(tax_benefit_system) @@ -421,7 +442,7 @@ def apply(self): assert disposable_income_reform.get_formula("2017") is None -def test_wrong_reform(tax_benefit_system): +def test_wrong_reform(tax_benefit_system) -> None: class wrong_reform(Reform): # A Reform must implement an `apply` method pass @@ -430,7 +451,7 @@ class wrong_reform(Reform): wrong_reform(tax_benefit_system) -def test_modify_parameters(tax_benefit_system): +def test_modify_parameters(tax_benefit_system) -> None: def modify_parameters(reference_parameters): reform_parameters_subtree = ParameterNode( "new_node", @@ -439,7 +460,7 @@ def modify_parameters(reference_parameters): "values": { "2000-01-01": {"value": True}, "2015-01-01": {"value": None}, - } + }, }, }, ) @@ -447,7 +468,7 @@ def modify_parameters(reference_parameters): return reference_parameters class test_modify_parameters(Reform): - def apply(self): + def apply(self) -> None: self.modify_parameters(modifier_function=modify_parameters) reform = test_modify_parameters(tax_benefit_system) @@ -460,7 +481,7 @@ def apply(self): assert parameters_at_instant.new_node.new_param is True -def test_attributes_conservation(tax_benefit_system): +def test_attributes_conservation(tax_benefit_system) -> None: class some_variable(Variable): value_type = int entity = Person @@ -475,7 +496,7 @@ class reform(Reform): class some_variable(Variable): default_value = 10 - def apply(self): + def apply(self) -> None: self.update_variable(some_variable) reformed_tbs = reform(tax_benefit_system) @@ -489,9 +510,9 @@ def apply(self): assert reform_variable.calculate_output == baseline_variable.calculate_output -def test_formulas_removal(tax_benefit_system): +def test_formulas_removal(tax_benefit_system) -> None: class reform(Reform): - def apply(self): + def apply(self) -> None: class basic_income(Variable): pass diff --git a/tests/core/test_simulation_builder.py b/tests/core/test_simulation_builder.py index d1dc0cde75..b905b29b84 100644 --- a/tests/core/test_simulation_builder.py +++ b/tests/core/test_simulation_builder.py @@ -1,4 +1,4 @@ -from typing import Iterable +from collections.abc import Iterable import datetime @@ -23,7 +23,7 @@ class intvar(Variable): value_type = int entity = persons - def __init__(self): + def __init__(self) -> None: super().__init__() return intvar() @@ -36,7 +36,7 @@ class datevar(Variable): value_type = datetime.date entity = persons - def __init__(self): + def __init__(self) -> None: super().__init__() return datevar() @@ -54,15 +54,16 @@ class TestEnum(Variable): possible_values = Enum("foo", "bar") name = "enum" - def __init__(self): + def __init__(self) -> None: pass return TestEnum() -def test_build_default_simulation(tax_benefit_system): +def test_build_default_simulation(tax_benefit_system) -> None: one_person_simulation = SimulationBuilder().build_default_simulation( - tax_benefit_system, 1 + tax_benefit_system, + 1, ) assert one_person_simulation.persons.count == 1 assert one_person_simulation.household.count == 1 @@ -72,7 +73,8 @@ def test_build_default_simulation(tax_benefit_system): ) several_persons_simulation = SimulationBuilder().build_default_simulation( - tax_benefit_system, 4 + tax_benefit_system, + 4, ) assert several_persons_simulation.persons.count == 4 assert several_persons_simulation.household.count == 4 @@ -85,7 +87,7 @@ def test_build_default_simulation(tax_benefit_system): ).all() -def test_explicit_singular_entities(tax_benefit_system): +def test_explicit_singular_entities(tax_benefit_system) -> None: assert SimulationBuilder().explicit_singular_entities( tax_benefit_system, {"persons": {"Javier": {}}, "household": {"parents": ["Javier"]}}, @@ -95,7 +97,7 @@ def test_explicit_singular_entities(tax_benefit_system): } -def test_add_person_entity(persons): +def test_add_person_entity(persons) -> None: persons_json = {"Alicia": {"salary": {}}, "Javier": {}} simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, persons_json) @@ -103,7 +105,7 @@ def test_add_person_entity(persons): assert simulation_builder.get_ids("persons") == ["Alicia", "Javier"] -def test_numeric_ids(persons): +def test_numeric_ids(persons) -> None: persons_json = {1: {"salary": {}}, 2: {}} simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, persons_json) @@ -111,14 +113,14 @@ def test_numeric_ids(persons): assert simulation_builder.get_ids("persons") == ["1", "2"] -def test_add_person_entity_with_values(persons): +def test_add_person_entity_with_values(persons) -> None: persons_json = {"Alicia": {"salary": {"2018-11": 3000}}, "Javier": {}} simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, persons_json) tools.assert_near(simulation_builder.get_input("salary", "2018-11"), [3000, 0]) -def test_add_person_values_with_default_period(persons): +def test_add_person_values_with_default_period(persons) -> None: persons_json = {"Alicia": {"salary": 3000}, "Javier": {}} simulation_builder = SimulationBuilder() simulation_builder.set_default_period("2018-11") @@ -126,7 +128,7 @@ def test_add_person_values_with_default_period(persons): tools.assert_near(simulation_builder.get_input("salary", "2018-11"), [3000, 0]) -def test_add_person_values_with_default_period_old_syntax(persons): +def test_add_person_values_with_default_period_old_syntax(persons) -> None: persons_json = {"Alicia": {"salary": 3000}, "Javier": {}} simulation_builder = SimulationBuilder() simulation_builder.set_default_period("month:2018-11") @@ -134,7 +136,7 @@ def test_add_person_values_with_default_period_old_syntax(persons): tools.assert_near(simulation_builder.get_input("salary", "2018-11"), [3000, 0]) -def test_add_group_entity(households): +def test_add_group_entity(households) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_group_entity( "persons", @@ -156,7 +158,7 @@ def test_add_group_entity(households): ] -def test_add_group_entity_loose_syntax(households): +def test_add_group_entity_loose_syntax(households) -> None: simulation_builder = SimulationBuilder() simulation_builder.add_group_entity( "persons", @@ -178,71 +180,91 @@ def test_add_group_entity_loose_syntax(households): ] -def test_add_variable_value(persons): +def test_add_variable_value(persons) -> None: salary = persons.get_variable("salary") instance_index = 0 simulation_builder = SimulationBuilder() simulation_builder.entity_counts["persons"] = 1 simulation_builder.add_variable_value( - persons, salary, instance_index, "Alicia", "2018-11", 3000 + persons, + salary, + instance_index, + "Alicia", + "2018-11", + 3000, ) input_array = simulation_builder.get_input("salary", "2018-11") assert input_array[instance_index] == pytest.approx(3000) -def test_add_variable_value_as_expression(persons): +def test_add_variable_value_as_expression(persons) -> None: salary = persons.get_variable("salary") instance_index = 0 simulation_builder = SimulationBuilder() simulation_builder.entity_counts["persons"] = 1 simulation_builder.add_variable_value( - persons, salary, instance_index, "Alicia", "2018-11", "3 * 1000" + persons, + salary, + instance_index, + "Alicia", + "2018-11", + "3 * 1000", ) input_array = simulation_builder.get_input("salary", "2018-11") assert input_array[instance_index] == pytest.approx(3000) -def test_fail_on_wrong_data(persons): +def test_fail_on_wrong_data(persons) -> None: salary = persons.get_variable("salary") instance_index = 0 simulation_builder = SimulationBuilder() simulation_builder.entity_counts["persons"] = 1 with pytest.raises(SituationParsingError) as excinfo: simulation_builder.add_variable_value( - persons, salary, instance_index, "Alicia", "2018-11", "alicia" + persons, + salary, + instance_index, + "Alicia", + "2018-11", + "alicia", ) assert excinfo.value.error == { "persons": { "Alicia": { "salary": { - "2018-11": "Can't deal with value: expected type number, received 'alicia'." - } - } - } + "2018-11": "Can't deal with value: expected type number, received 'alicia'.", + }, + }, + }, } -def test_fail_on_ill_formed_expression(persons): +def test_fail_on_ill_formed_expression(persons) -> None: salary = persons.get_variable("salary") instance_index = 0 simulation_builder = SimulationBuilder() simulation_builder.entity_counts["persons"] = 1 with pytest.raises(SituationParsingError) as excinfo: simulation_builder.add_variable_value( - persons, salary, instance_index, "Alicia", "2018-11", "2 * / 1000" + persons, + salary, + instance_index, + "Alicia", + "2018-11", + "2 * / 1000", ) assert excinfo.value.error == { "persons": { "Alicia": { "salary": { - "2018-11": "I couldn't understand '2 * / 1000' as a value for 'salary'" - } - } - } + "2018-11": "I couldn't understand '2 * / 1000' as a value for 'salary'", + }, + }, + }, } -def test_fail_on_integer_overflow(persons, int_variable): +def test_fail_on_integer_overflow(persons, int_variable) -> None: instance_index = 0 simulation_builder = SimulationBuilder() simulation_builder.entity_counts["persons"] = 1 @@ -259,39 +281,49 @@ def test_fail_on_integer_overflow(persons, int_variable): "persons": { "Alicia": { "intvar": { - "2018-11": "Can't deal with value: '9223372036854775808', it's too large for type 'integer'." - } - } - } + "2018-11": "Can't deal with value: '9223372036854775808', it's too large for type 'integer'.", + }, + }, + }, } -def test_fail_on_date_parsing(persons, date_variable): +def test_fail_on_date_parsing(persons, date_variable) -> None: instance_index = 0 simulation_builder = SimulationBuilder() simulation_builder.entity_counts["persons"] = 1 with pytest.raises(SituationParsingError) as excinfo: simulation_builder.add_variable_value( - persons, date_variable, instance_index, "Alicia", "2018-11", "2019-02-30" + persons, + date_variable, + instance_index, + "Alicia", + "2018-11", + "2019-02-30", ) assert excinfo.value.error == { "persons": { - "Alicia": {"datevar": {"2018-11": "Can't deal with date: '2019-02-30'."}} - } + "Alicia": {"datevar": {"2018-11": "Can't deal with date: '2019-02-30'."}}, + }, } -def test_add_unknown_enum_variable_value(persons, enum_variable): +def test_add_unknown_enum_variable_value(persons, enum_variable) -> None: instance_index = 0 simulation_builder = SimulationBuilder() simulation_builder.entity_counts["persons"] = 1 with pytest.raises(SituationParsingError): simulation_builder.add_variable_value( - persons, enum_variable, instance_index, "Alicia", "2018-11", "baz" + persons, + enum_variable, + instance_index, + "Alicia", + "2018-11", + "baz", ) -def test_finalize_person_entity(persons): +def test_finalize_person_entity(persons) -> None: persons_json = {"Alicia": {"salary": {"2018-11": 3000}}, "Javier": {}} simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, persons_json) @@ -302,7 +334,7 @@ def test_finalize_person_entity(persons): assert population.ids == ["Alicia", "Javier"] -def test_canonicalize_period_keys(persons): +def test_canonicalize_period_keys(persons) -> None: persons_json = {"Alicia": {"salary": {"year:2018-01": 100}}} simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, persons_json) @@ -311,9 +343,10 @@ def test_canonicalize_period_keys(persons): tools.assert_near(population.get_holder("salary").get_array("2018-12"), [100]) -def test_finalize_households(tax_benefit_system): +def test_finalize_households(tax_benefit_system) -> None: simulation = Simulation( - tax_benefit_system, tax_benefit_system.instantiate_entities() + tax_benefit_system, + tax_benefit_system.instantiate_entities(), ) simulation_builder = SimulationBuilder() simulation_builder.add_group_entity( @@ -333,7 +366,7 @@ def test_finalize_households(tax_benefit_system): ) -def test_check_persons_to_allocate(): +def test_check_persons_to_allocate() -> None: entity_plural = "familles" persons_plural = "individus" person_id = "Alicia" @@ -354,7 +387,7 @@ def test_check_persons_to_allocate(): ) -def test_allocate_undeclared_person(): +def test_allocate_undeclared_person() -> None: entity_plural = "familles" persons_plural = "individus" person_id = "Alicia" @@ -377,13 +410,13 @@ def test_allocate_undeclared_person(): assert exception.value.error == { "familles": { "famille1": { - "parents": "Unexpected value: Alicia. Alicia has been declared in famille1 parents, but has not been declared in individus." - } - } + "parents": "Unexpected value: Alicia. Alicia has been declared in famille1 parents, but has not been declared in individus.", + }, + }, } -def test_allocate_person_twice(): +def test_allocate_person_twice() -> None: entity_plural = "familles" persons_plural = "individus" person_id = "Alicia" @@ -406,37 +439,39 @@ def test_allocate_person_twice(): assert exception.value.error == { "familles": { "famille1": { - "parents": "Alicia has been declared more than once in familles" - } - } + "parents": "Alicia has been declared more than once in familles", + }, + }, } -def test_one_person_without_household(tax_benefit_system): +def test_one_person_without_household(tax_benefit_system) -> None: simulation_dict = {"persons": {"Alicia": {}}} simulation = SimulationBuilder().build_from_dict( - tax_benefit_system, simulation_dict + tax_benefit_system, + simulation_dict, ) assert simulation.household.count == 1 parents_in_households = simulation.household.nb_persons( - role=entities.Household.PARENT + role=entities.Household.PARENT, ) assert parents_in_households.tolist() == [ - 1 + 1, ] # household member default role is first_parent -def test_some_person_without_household(tax_benefit_system): +def test_some_person_without_household(tax_benefit_system) -> None: input_yaml = """ persons: {'Alicia': {}, 'Bob': {}} household: {'parents': ['Alicia']} """ simulation = SimulationBuilder().build_from_dict( - tax_benefit_system, test_runner.yaml.safe_load(input_yaml) + tax_benefit_system, + test_runner.yaml.safe_load(input_yaml), ) assert simulation.household.count == 2 parents_in_households = simulation.household.nb_persons( - role=entities.Household.PARENT + role=entities.Household.PARENT, ) assert parents_in_households.tolist() == [ 1, @@ -444,7 +479,7 @@ def test_some_person_without_household(tax_benefit_system): ] # household member default role is first_parent -def test_nb_persons_in_households(tax_benefit_system): +def test_nb_persons_in_households(tax_benefit_system) -> None: persons_ids: Iterable = [2, 0, 1, 4, 3] households_ids: Iterable = ["c", "a", "b"] persons_households: Iterable = ["c", "a", "a", "b", "a"] @@ -454,7 +489,9 @@ def test_nb_persons_in_households(tax_benefit_system): simulation_builder.declare_person_entity("person", persons_ids) household_instance = simulation_builder.declare_entity("household", households_ids) simulation_builder.join_with_persons( - household_instance, persons_households, ["first_parent"] * 5 + household_instance, + persons_households, + ["first_parent"] * 5, ) persons_in_households = simulation_builder.nb_persons("household") @@ -462,7 +499,7 @@ def test_nb_persons_in_households(tax_benefit_system): assert persons_in_households.tolist() == [1, 3, 1] -def test_nb_persons_no_role(tax_benefit_system): +def test_nb_persons_no_role(tax_benefit_system) -> None: persons_ids: Iterable = [2, 0, 1, 4, 3] households_ids: Iterable = ["c", "a", "b"] persons_households: Iterable = ["c", "a", "a", "b", "a"] @@ -473,10 +510,12 @@ def test_nb_persons_no_role(tax_benefit_system): household_instance = simulation_builder.declare_entity("household", households_ids) simulation_builder.join_with_persons( - household_instance, persons_households, ["first_parent"] * 5 + household_instance, + persons_households, + ["first_parent"] * 5, ) parents_in_households = household_instance.nb_persons( - role=entities.Household.PARENT + role=entities.Household.PARENT, ) assert parents_in_households.tolist() == [ @@ -486,7 +525,7 @@ def test_nb_persons_no_role(tax_benefit_system): ] # household member default role is first_parent -def test_nb_persons_by_role(tax_benefit_system): +def test_nb_persons_by_role(tax_benefit_system) -> None: persons_ids: Iterable = [2, 0, 1, 4, 3] households_ids: Iterable = ["c", "a", "b"] persons_households: Iterable = ["c", "a", "a", "b", "a"] @@ -504,16 +543,18 @@ def test_nb_persons_by_role(tax_benefit_system): household_instance = simulation_builder.declare_entity("household", households_ids) simulation_builder.join_with_persons( - household_instance, persons_households, persons_households_roles + household_instance, + persons_households, + persons_households_roles, ) parents_in_households = household_instance.nb_persons( - role=entities.Household.FIRST_PARENT + role=entities.Household.FIRST_PARENT, ) assert parents_in_households.tolist() == [0, 1, 1] -def test_integral_roles(tax_benefit_system): +def test_integral_roles(tax_benefit_system) -> None: persons_ids: Iterable = [2, 0, 1, 4, 3] households_ids: Iterable = ["c", "a", "b"] persons_households: Iterable = ["c", "a", "a", "b", "a"] @@ -526,10 +567,12 @@ def test_integral_roles(tax_benefit_system): household_instance = simulation_builder.declare_entity("household", households_ids) simulation_builder.join_with_persons( - household_instance, persons_households, persons_households_roles + household_instance, + persons_households, + persons_households_roles, ) parents_in_households = household_instance.nb_persons( - role=entities.Household.FIRST_PARENT + role=entities.Household.FIRST_PARENT, ) assert parents_in_households.tolist() == [0, 1, 1] @@ -538,7 +581,7 @@ def test_integral_roles(tax_benefit_system): # Test Intégration -def test_from_person_variable_to_group(tax_benefit_system): +def test_from_person_variable_to_group(tax_benefit_system) -> None: persons_ids: Iterable = [2, 0, 1, 4, 3] households_ids: Iterable = ["c", "a", "b"] @@ -555,7 +598,9 @@ def test_from_person_variable_to_group(tax_benefit_system): household_instance = simulation_builder.declare_entity("household", households_ids) simulation_builder.join_with_persons( - household_instance, persons_households, ["first_parent"] * 5 + household_instance, + persons_households, + ["first_parent"] * 5, ) simulation = simulation_builder.build(tax_benefit_system) @@ -567,14 +612,15 @@ def test_from_person_variable_to_group(tax_benefit_system): assert total_taxes / simulation.calculate("rent", period) == pytest.approx(1) -def test_simulation(tax_benefit_system): +def test_simulation(tax_benefit_system) -> None: input_yaml = """ salary: 2016-10: 12000 """ simulation = SimulationBuilder().build_from_dict( - tax_benefit_system, test_runner.yaml.safe_load(input_yaml) + tax_benefit_system, + test_runner.yaml.safe_load(input_yaml), ) assert simulation.get_array("salary", "2016-10") == 12000 @@ -582,14 +628,15 @@ def test_simulation(tax_benefit_system): simulation.calculate("total_taxes", "2016-10") -def test_vectorial_input(tax_benefit_system): +def test_vectorial_input(tax_benefit_system) -> None: input_yaml = """ salary: 2016-10: [12000, 20000] """ simulation = SimulationBuilder().build_from_dict( - tax_benefit_system, test_runner.yaml.safe_load(input_yaml) + tax_benefit_system, + test_runner.yaml.safe_load(input_yaml), ) tools.assert_near(simulation.get_array("salary", "2016-10"), [12000, 20000]) @@ -597,15 +644,16 @@ def test_vectorial_input(tax_benefit_system): simulation.calculate("total_taxes", "2016-10") -def test_fully_specified_entities(tax_benefit_system): +def test_fully_specified_entities(tax_benefit_system) -> None: simulation = SimulationBuilder().build_from_dict( - tax_benefit_system, situation_examples.couple + tax_benefit_system, + situation_examples.couple, ) assert simulation.household.count == 1 assert simulation.persons.count == 2 -def test_single_entity_shortcut(tax_benefit_system): +def test_single_entity_shortcut(tax_benefit_system) -> None: input_yaml = """ persons: Alicia: {} @@ -615,12 +663,13 @@ def test_single_entity_shortcut(tax_benefit_system): """ simulation = SimulationBuilder().build_from_dict( - tax_benefit_system, test_runner.yaml.safe_load(input_yaml) + tax_benefit_system, + test_runner.yaml.safe_load(input_yaml), ) assert simulation.household.count == 1 -def test_order_preserved(tax_benefit_system): +def test_order_preserved(tax_benefit_system) -> None: input_yaml = """ persons: Javier: {} @@ -638,7 +687,7 @@ def test_order_preserved(tax_benefit_system): assert simulation.persons.ids == ["Javier", "Alicia", "Sarah", "Tom"] -def test_inconsistent_input(tax_benefit_system): +def test_inconsistent_input(tax_benefit_system) -> None: input_yaml = """ salary: 2016-10: [12000, 20000] @@ -647,6 +696,7 @@ def test_inconsistent_input(tax_benefit_system): """ with pytest.raises(ValueError) as error: SimulationBuilder().build_from_dict( - tax_benefit_system, test_runner.yaml.safe_load(input_yaml) + tax_benefit_system, + test_runner.yaml.safe_load(input_yaml), ) assert "its length is 3 while there are 2" in error.value.args[0] diff --git a/tests/core/test_simulations.py b/tests/core/test_simulations.py index 18050b6bc5..7f4897e776 100644 --- a/tests/core/test_simulations.py +++ b/tests/core/test_simulations.py @@ -6,7 +6,7 @@ from openfisca_core.simulations import SimulationBuilder -def test_calculate_full_tracer(tax_benefit_system): +def test_calculate_full_tracer(tax_benefit_system) -> None: simulation = SimulationBuilder().build_default_simulation(tax_benefit_system) simulation.trace = True simulation.calculate("income_tax", "2017-01") @@ -27,12 +27,12 @@ def test_calculate_full_tracer(tax_benefit_system): assert income_tax_node.parameters[0].value == 0.15 -def test_get_entity_not_found(tax_benefit_system): +def test_get_entity_not_found(tax_benefit_system) -> None: simulation = SimulationBuilder().build_default_simulation(tax_benefit_system) assert simulation.get_entity(plural="no_such_entities") is None -def test_clone(tax_benefit_system): +def test_clone(tax_benefit_system) -> None: simulation = SimulationBuilder().build_from_entities( tax_benefit_system, { @@ -59,7 +59,7 @@ def test_clone(tax_benefit_system): assert salary_holder_clone.population == simulation_clone.persons -def test_get_memory_usage(tax_benefit_system): +def test_get_memory_usage(tax_benefit_system) -> None: simulation = SimulationBuilder().build_from_entities(tax_benefit_system, single) simulation.calculate("disposable_income", "2017-01") memory_usage = simulation.get_memory_usage(variables=["salary"]) @@ -67,7 +67,7 @@ def test_get_memory_usage(tax_benefit_system): assert len(memory_usage["by_variable"]) == 1 -def test_invalidate_cache_when_spiral_error_detected(tax_benefit_system): +def test_invalidate_cache_when_spiral_error_detected(tax_benefit_system) -> None: simulation = SimulationBuilder().build_default_simulation(tax_benefit_system) tracer = simulation.tracer diff --git a/tests/core/test_tracers.py b/tests/core/test_tracers.py index 41acf68bc4..178b957ec4 100644 --- a/tests/core/test_tracers.py +++ b/tests/core/test_tracers.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import csv import json import os @@ -21,34 +19,33 @@ from .parameters_fancy_indexing.test_fancy_indexing import parameters -class TestException(Exception): - ... +class TestException(Exception): ... class StubSimulation(Simulation): - def __init__(self): + def __init__(self) -> None: self.exception = None self.max_spiral_loops = 1 - def _calculate(self, variable, period): + def _calculate(self, variable, period) -> None: if self.exception: raise self.exception - def invalidate_cache_entry(self, variable, period): + def invalidate_cache_entry(self, variable, period) -> None: pass - def purge_cache_of_invalid_values(self): + def purge_cache_of_invalid_values(self) -> None: pass class MockTracer: - def record_calculation_start(self, variable, period): + def record_calculation_start(self, variable, period) -> None: self.calculation_start_recorded = True - def record_calculation_result(self, value): + def record_calculation_result(self, value) -> None: self.recorded_result = True - def record_calculation_end(self): + def record_calculation_end(self) -> None: self.calculation_end_recorded = True @@ -58,7 +55,7 @@ def tracer(): @mark.parametrize("tracer", [SimpleTracer(), FullTracer()]) -def test_stack_one_level(tracer): +def test_stack_one_level(tracer) -> None: tracer.record_calculation_start("a", 2017) assert len(tracer.stack) == 1 @@ -70,7 +67,7 @@ def test_stack_one_level(tracer): @mark.parametrize("tracer", [SimpleTracer(), FullTracer()]) -def test_stack_two_levels(tracer): +def test_stack_two_levels(tracer) -> None: tracer.record_calculation_start("a", 2017) tracer.record_calculation_start("b", 2017) @@ -87,7 +84,7 @@ def test_stack_two_levels(tracer): @mark.parametrize("tracer", [SimpleTracer(), FullTracer()]) -def test_tracer_contract(tracer): +def test_tracer_contract(tracer) -> None: simulation = StubSimulation() simulation.tracer = MockTracer() @@ -97,7 +94,7 @@ def test_tracer_contract(tracer): assert simulation.tracer.calculation_end_recorded -def test_exception_robustness(): +def test_exception_robustness() -> None: simulation = StubSimulation() simulation.tracer = MockTracer() simulation.exception = TestException(":-o") @@ -110,7 +107,7 @@ def test_exception_robustness(): @mark.parametrize("tracer", [SimpleTracer(), FullTracer()]) -def test_cycle_error(tracer): +def test_cycle_error(tracer) -> None: simulation = StubSimulation() simulation.tracer = tracer @@ -131,7 +128,7 @@ def test_cycle_error(tracer): @mark.parametrize("tracer", [SimpleTracer(), FullTracer()]) -def test_spiral_error(tracer): +def test_spiral_error(tracer) -> None: simulation = StubSimulation() simulation.tracer = tracer @@ -150,7 +147,7 @@ def test_spiral_error(tracer): ] -def test_full_tracer_one_calculation(tracer): +def test_full_tracer_one_calculation(tracer) -> None: tracer._enter_calculation("a", 2017) tracer._exit_calculation() @@ -161,7 +158,7 @@ def test_full_tracer_one_calculation(tracer): assert tracer.trees[0].children == [] -def test_full_tracer_2_branches(tracer): +def test_full_tracer_2_branches(tracer) -> None: tracer._enter_calculation("a", 2017) tracer._enter_calculation("b", 2017) tracer._exit_calculation() @@ -173,7 +170,7 @@ def test_full_tracer_2_branches(tracer): assert len(tracer.trees[0].children) == 2 -def test_full_tracer_2_trees(tracer): +def test_full_tracer_2_trees(tracer) -> None: tracer._enter_calculation("b", 2017) tracer._exit_calculation() tracer._enter_calculation("c", 2017) @@ -182,7 +179,7 @@ def test_full_tracer_2_trees(tracer): assert len(tracer.trees) == 2 -def test_full_tracer_3_generations(tracer): +def test_full_tracer_3_generations(tracer) -> None: tracer._enter_calculation("a", 2017) tracer._enter_calculation("b", 2017) tracer._enter_calculation("c", 2017) @@ -195,14 +192,14 @@ def test_full_tracer_3_generations(tracer): assert len(tracer.trees[0].children[0].children) == 1 -def test_full_tracer_variable_nb_requests(tracer): +def test_full_tracer_variable_nb_requests(tracer) -> None: tracer._enter_calculation("a", "2017-01") tracer._enter_calculation("a", "2017-02") assert tracer.get_nb_requests("a") == 2 -def test_simulation_calls_record_calculation_result(): +def test_simulation_calls_record_calculation_result() -> None: simulation = StubSimulation() simulation.tracer = MockTracer() @@ -211,7 +208,7 @@ def test_simulation_calls_record_calculation_result(): assert simulation.tracer.recorded_result -def test_record_calculation_result(tracer): +def test_record_calculation_result(tracer) -> None: tracer._enter_calculation("a", 2017) tracer.record_calculation_result(numpy.asarray(100)) tracer._exit_calculation() @@ -219,7 +216,7 @@ def test_record_calculation_result(tracer): assert tracer.trees[0].value == 100 -def test_flat_trace(tracer): +def test_flat_trace(tracer) -> None: tracer._enter_calculation("a", 2019) tracer._enter_calculation("b", 2019) tracer._exit_calculation() @@ -232,7 +229,7 @@ def test_flat_trace(tracer): assert trace["b<2019>"]["dependencies"] == [] -def test_flat_trace_serialize_vectorial_values(tracer): +def test_flat_trace_serialize_vectorial_values(tracer) -> None: tracer._enter_calculation("a", 2019) tracer.record_parameter_access("x.y.z", 2019, numpy.asarray([100, 200, 300])) tracer.record_calculation_result(numpy.asarray([10, 20, 30])) @@ -244,7 +241,7 @@ def test_flat_trace_serialize_vectorial_values(tracer): assert json.dumps(trace["a<2019>"]["parameters"]["x.y.z<2019>"]) -def test_flat_trace_with_parameter(tracer): +def test_flat_trace_with_parameter(tracer) -> None: tracer._enter_calculation("a", 2019) tracer.record_parameter_access("p", "2019-01-01", 100) tracer._exit_calculation() @@ -255,7 +252,7 @@ def test_flat_trace_with_parameter(tracer): assert trace["a<2019>"]["parameters"] == {"p<2019-01-01>": 100} -def test_flat_trace_with_cache(tracer): +def test_flat_trace_with_cache(tracer) -> None: tracer._enter_calculation("a", 2019) tracer._enter_calculation("b", 2019) tracer._enter_calculation("c", 2019) @@ -270,7 +267,7 @@ def test_flat_trace_with_cache(tracer): assert trace["b<2019>"]["dependencies"] == ["c<2019>"] -def test_calculation_time(): +def test_calculation_time() -> None: tracer = FullTracer() tracer._enter_calculation("a", 2019) @@ -322,7 +319,7 @@ def tracer_calc_time(): return tracer -def test_calculation_time_with_depth(tracer_calc_time): +def test_calculation_time_with_depth(tracer_calc_time) -> None: tracer = tracer_calc_time performance_json = tracer.performance_log._json() simulation_grand_children = performance_json["children"][0]["children"] @@ -331,7 +328,7 @@ def test_calculation_time_with_depth(tracer_calc_time): assert simulation_grand_children[0]["value"] == 700 -def test_flat_trace_calc_time(tracer_calc_time): +def test_flat_trace_calc_time(tracer_calc_time) -> None: tracer = tracer_calc_time flat_trace = tracer.get_flat_trace() @@ -343,11 +340,11 @@ def test_flat_trace_calc_time(tracer_calc_time): assert flat_trace["c<2019>"]["formula_time"] == 100 -def test_generate_performance_table(tracer_calc_time, tmpdir): +def test_generate_performance_table(tracer_calc_time, tmpdir) -> None: tracer = tracer_calc_time tracer.generate_performance_tables(tmpdir) - with open(os.path.join(tmpdir, "performance_table.csv"), "r") as csv_file: + with open(os.path.join(tmpdir, "performance_table.csv")) as csv_file: csv_reader = csv.DictReader(csv_file) csv_rows = list(csv_reader) @@ -358,9 +355,7 @@ def test_generate_performance_table(tracer_calc_time, tmpdir): assert float(a_row["calculation_time"]) == 1000 assert float(a_row["formula_time"]) == 190 - with open( - os.path.join(tmpdir, "aggregated_performance_table.csv"), "r" - ) as csv_file: + with open(os.path.join(tmpdir, "aggregated_performance_table.csv")) as csv_file: aggregated_csv_reader = csv.DictReader(csv_file) aggregated_csv_rows = list(aggregated_csv_reader) @@ -372,10 +367,10 @@ def test_generate_performance_table(tracer_calc_time, tmpdir): assert float(a_row["formula_time"]) == 190 + 200 -def test_get_aggregated_calculation_times(tracer_calc_time): +def test_get_aggregated_calculation_times(tracer_calc_time) -> None: perf_log = tracer_calc_time.performance_log aggregated_calculation_times = perf_log.aggregate_calculation_times( - tracer_calc_time.get_flat_trace() + tracer_calc_time.get_flat_trace(), ) assert aggregated_calculation_times["a"]["calculation_time"] == 1000 + 200 @@ -384,7 +379,7 @@ def test_get_aggregated_calculation_times(tracer_calc_time): assert aggregated_calculation_times["a"]["avg_formula_time"] == (190 + 200) / 2 -def test_rounding(): +def test_rounding() -> None: node_a = TraceNode("a", 2017) node_a.start = 1.23456789 node_a.end = node_a.start + 1.23456789e-03 @@ -401,7 +396,7 @@ def test_rounding(): ) # The rounding should not prevent from calculating a precise formula_time -def test_variable_stats(tracer): +def test_variable_stats(tracer) -> None: tracer._enter_calculation("A", 2017) tracer._enter_calculation("B", 2017) tracer._enter_calculation("B", 2017) @@ -412,7 +407,7 @@ def test_variable_stats(tracer): assert tracer.get_nb_requests("C") == 0 -def test_log_format(tracer): +def test_log_format(tracer) -> None: tracer._enter_calculation("A", 2017) tracer._enter_calculation("B", 2017) tracer.record_calculation_result(numpy.asarray([1])) @@ -425,7 +420,7 @@ def test_log_format(tracer): assert lines[1] == " B<2017> >> [1]" -def test_log_format_forest(tracer): +def test_log_format_forest(tracer) -> None: tracer._enter_calculation("A", 2017) tracer.record_calculation_result(numpy.asarray([1])) tracer._exit_calculation() @@ -438,7 +433,7 @@ def test_log_format_forest(tracer): assert lines[1] == " B<2017> >> [2]" -def test_log_aggregate(tracer): +def test_log_aggregate(tracer) -> None: tracer._enter_calculation("A", 2017) tracer.record_calculation_result(numpy.asarray([1])) tracer._exit_calculation() @@ -447,10 +442,10 @@ def test_log_aggregate(tracer): assert lines[0] == " A<2017> >> {'avg': 1.0, 'max': 1, 'min': 1}" -def test_log_aggregate_with_enum(tracer): +def test_log_aggregate_with_enum(tracer) -> None: tracer._enter_calculation("A", 2017) tracer.record_calculation_result( - HousingOccupancyStatus.encode(numpy.repeat("tenant", 100)) + HousingOccupancyStatus.encode(numpy.repeat("tenant", 100)), ) tracer._exit_calculation() lines = tracer.computation_log.lines(aggregate=True) @@ -461,7 +456,7 @@ def test_log_aggregate_with_enum(tracer): ) -def test_log_aggregate_with_strings(tracer): +def test_log_aggregate_with_strings(tracer) -> None: tracer._enter_calculation("A", 2017) tracer.record_calculation_result(numpy.repeat("foo", 100)) tracer._exit_calculation() @@ -470,7 +465,7 @@ def test_log_aggregate_with_strings(tracer): assert lines[0] == " A<2017> >> {'avg': '?', 'max': '?', 'min': '?'}" -def test_log_max_depth(tracer): +def test_log_max_depth(tracer) -> None: tracer._enter_calculation("A", 2017) tracer._enter_calculation("B", 2017) tracer._enter_calculation("C", 2017) @@ -489,10 +484,10 @@ def test_log_max_depth(tracer): assert len(tracer.computation_log.lines(max_depth=0)) == 0 -def test_no_wrapping(tracer): +def test_no_wrapping(tracer) -> None: tracer._enter_calculation("A", 2017) tracer.record_calculation_result( - HousingOccupancyStatus.encode(numpy.repeat("tenant", 100)) + HousingOccupancyStatus.encode(numpy.repeat("tenant", 100)), ) tracer._exit_calculation() lines = tracer.computation_log.lines() @@ -501,10 +496,10 @@ def test_no_wrapping(tracer): assert "\n" not in lines[0] -def test_trace_enums(tracer): +def test_trace_enums(tracer) -> None: tracer._enter_calculation("A", 2017) tracer.record_calculation_result( - HousingOccupancyStatus.encode(numpy.array(["tenant"])) + HousingOccupancyStatus.encode(numpy.array(["tenant"])), ) tracer._exit_calculation() lines = tracer.computation_log.lines() @@ -518,7 +513,7 @@ def test_trace_enums(tracer): family_status = numpy.asarray(["single", "couple", "single", "couple"]) -def check_tracing_params(accessor, param_key): +def check_tracing_params(accessor, param_key) -> None: tracer = FullTracer() tracer._enter_calculation("A", "2015-01") @@ -556,11 +551,11 @@ def check_tracing_params(accessor, param_key): ), # triple ], ) -def test_parameters(test): +def test_parameters(test) -> None: check_tracing_params(*test) -def test_browse_trace(): +def test_browse_trace() -> None: tracer = FullTracer() tracer._enter_calculation("B", 2017) diff --git a/tests/core/test_yaml.py b/tests/core/test_yaml.py index 4673665fcb..1672ea3453 100644 --- a/tests/core/test_yaml.py +++ b/tests/core/test_yaml.py @@ -19,82 +19,83 @@ def run_yaml_test(tax_benefit_system, path, options=None): if options is None: options = {} - result = run_tests(tax_benefit_system, yaml_path, options) - return result + return run_tests(tax_benefit_system, yaml_path, options) -def test_success(tax_benefit_system): +def test_success(tax_benefit_system) -> None: assert run_yaml_test(tax_benefit_system, "test_success.yml") == EXIT_OK -def test_fail(tax_benefit_system): +def test_fail(tax_benefit_system) -> None: assert run_yaml_test(tax_benefit_system, "test_failure.yaml") == EXIT_TESTSFAILED -def test_relative_error_margin_success(tax_benefit_system): +def test_relative_error_margin_success(tax_benefit_system) -> None: assert ( run_yaml_test(tax_benefit_system, "test_relative_error_margin.yaml") == EXIT_OK ) -def test_relative_error_margin_fail(tax_benefit_system): +def test_relative_error_margin_fail(tax_benefit_system) -> None: assert ( run_yaml_test(tax_benefit_system, "failing_test_relative_error_margin.yaml") == EXIT_TESTSFAILED ) -def test_absolute_error_margin_success(tax_benefit_system): +def test_absolute_error_margin_success(tax_benefit_system) -> None: assert ( run_yaml_test(tax_benefit_system, "test_absolute_error_margin.yaml") == EXIT_OK ) -def test_absolute_error_margin_fail(tax_benefit_system): +def test_absolute_error_margin_fail(tax_benefit_system) -> None: assert ( run_yaml_test(tax_benefit_system, "failing_test_absolute_error_margin.yaml") == EXIT_TESTSFAILED ) -def test_run_tests_from_directory(tax_benefit_system): +def test_run_tests_from_directory(tax_benefit_system) -> None: dir_path = os.path.join(yaml_tests_dir, "directory") assert run_yaml_test(tax_benefit_system, dir_path) == EXIT_OK -def test_with_reform(tax_benefit_system): +def test_with_reform(tax_benefit_system) -> None: assert run_yaml_test(tax_benefit_system, "test_with_reform.yaml") == EXIT_OK -def test_with_extension(tax_benefit_system): +def test_with_extension(tax_benefit_system) -> None: assert run_yaml_test(tax_benefit_system, "test_with_extension.yaml") == EXIT_OK -def test_with_anchors(tax_benefit_system): +def test_with_anchors(tax_benefit_system) -> None: assert run_yaml_test(tax_benefit_system, "test_with_anchors.yaml") == EXIT_OK -def test_run_tests_from_directory_fail(tax_benefit_system): +def test_run_tests_from_directory_fail(tax_benefit_system) -> None: assert run_yaml_test(tax_benefit_system, yaml_tests_dir) == EXIT_TESTSFAILED -def test_name_filter(tax_benefit_system): +def test_name_filter(tax_benefit_system) -> None: assert ( run_yaml_test( - tax_benefit_system, yaml_tests_dir, options={"name_filter": "success"} + tax_benefit_system, + yaml_tests_dir, + options={"name_filter": "success"}, ) == EXIT_OK ) -def test_shell_script(): +def test_shell_script() -> None: yaml_path = os.path.join(yaml_tests_dir, "test_success.yml") command = ["openfisca", "test", yaml_path, "-c", "openfisca_country_template"] with open(os.devnull, "wb") as devnull: subprocess.check_call(command, stdout=devnull, stderr=devnull) -def test_failing_shell_script(): +def test_failing_shell_script() -> None: yaml_path = os.path.join(yaml_tests_dir, "test_failure.yaml") command = ["openfisca", "test", yaml_path, "-c", "openfisca_dummy_country"] with open(os.devnull, "wb") as devnull: @@ -102,7 +103,7 @@ def test_failing_shell_script(): subprocess.check_call(command, stdout=devnull, stderr=devnull) -def test_shell_script_with_reform(): +def test_shell_script_with_reform() -> None: yaml_path = os.path.join(yaml_tests_dir, "test_with_reform_2.yaml") command = [ "openfisca", @@ -117,7 +118,7 @@ def test_shell_script_with_reform(): subprocess.check_call(command, stdout=devnull, stderr=devnull) -def test_shell_script_with_extension(): +def test_shell_script_with_extension() -> None: tests_dir = os.path.join(openfisca_extension_template.__path__[0], "tests") command = [ "openfisca", diff --git a/tests/core/tools/test_assert_near.py b/tests/core/tools/test_assert_near.py index 0d540a49e8..c351be0f9c 100644 --- a/tests/core/tools/test_assert_near.py +++ b/tests/core/tools/test_assert_near.py @@ -3,11 +3,11 @@ from openfisca_core.tools import assert_near -def test_date(): +def test_date() -> None: assert_near(numpy.array("2012-03-24", dtype="datetime64[D]"), "2012-03-24") -def test_enum(tax_benefit_system): +def test_enum(tax_benefit_system) -> None: possible_values = tax_benefit_system.variables[ "housing_occupancy_status" ].possible_values @@ -16,7 +16,7 @@ def test_enum(tax_benefit_system): assert_near(value, expected_value) -def test_enum_2(tax_benefit_system): +def test_enum_2(tax_benefit_system) -> None: possible_values = tax_benefit_system.variables[ "housing_occupancy_status" ].possible_values diff --git a/tests/core/tools/test_runner/test_yaml_runner.py b/tests/core/tools/test_runner/test_yaml_runner.py index bf04ade9bb..6a02d14cef 100644 --- a/tests/core/tools/test_runner/test_yaml_runner.py +++ b/tests/core/tools/test_runner/test_yaml_runner.py @@ -1,5 +1,3 @@ -from typing import List - import os import numpy @@ -14,7 +12,7 @@ class TaxBenefitSystem: - def __init__(self): + def __init__(self) -> None: self.variables = {"salary": TestVariable()} self.person_entity = Entity("person", "persons", None, "") self.person_entity.set_tax_benefit_system(self) @@ -25,7 +23,7 @@ def get_package_metadata(self): def apply_reform(self, path): return Reform(self) - def load_extension(self, extension): + def load_extension(self, extension) -> None: pass def entities_by_singular(self): @@ -45,27 +43,27 @@ def clone(self): class Reform(TaxBenefitSystem): - def __init__(self, baseline): + def __init__(self, baseline) -> None: self.baseline = baseline class Simulation: - def __init__(self): + def __init__(self) -> None: self.populations = {"person": None} - def get_population(self, plural=None): + def get_population(self, plural=None) -> None: return None class TestFile(YamlFile): - def __init__(self): + def __init__(self) -> None: self.config = None self.session = None self._nodeid = "testname" class TestItem(YamlItem): - def __init__(self, test): + def __init__(self, test) -> None: super().__init__("", TestFile(), TaxBenefitSystem(), test, {}) self.tax_benefit_system = self.baseline_tax_benefit_system @@ -76,7 +74,7 @@ class TestVariable(Variable): definition_period = DateUnit.ETERNITY value_type = float - def __init__(self): + def __init__(self) -> None: self.end = None self.entity = Entity("person", "persons", None, "") self.is_neutralized = False @@ -85,7 +83,7 @@ def __init__(self): @pytest.mark.skip(reason="Deprecated node constructor") -def test_variable_not_found(): +def test_variable_not_found() -> None: test = {"output": {"unknown_variable": 0}} with pytest.raises(errors.VariableNotFoundError) as excinfo: test_item = TestItem(test) @@ -93,7 +91,7 @@ def test_variable_not_found(): assert excinfo.value.variable_name == "unknown_variable" -def test_tax_benefit_systems_with_reform_cache(): +def test_tax_benefit_systems_with_reform_cache() -> None: baseline = TaxBenefitSystem() ab_tax_benefit_system = _get_tax_benefit_system(baseline, "ab", []) @@ -101,7 +99,7 @@ def test_tax_benefit_systems_with_reform_cache(): assert ab_tax_benefit_system != ba_tax_benefit_system -def test_reforms_formats(): +def test_reforms_formats() -> None: baseline = TaxBenefitSystem() lonely_reform_tbs = _get_tax_benefit_system(baseline, "lonely_reform", []) @@ -109,7 +107,7 @@ def test_reforms_formats(): assert lonely_reform_tbs == list_lonely_reform_tbs -def test_reforms_order(): +def test_reforms_order() -> None: baseline = TaxBenefitSystem() abba_tax_benefit_system = _get_tax_benefit_system(baseline, ["ab", "ba"], []) @@ -119,7 +117,7 @@ def test_reforms_order(): ) # keep reforms order in cache -def test_tax_benefit_systems_with_extensions_cache(): +def test_tax_benefit_systems_with_extensions_cache() -> None: baseline = TaxBenefitSystem() xy_tax_benefit_system = _get_tax_benefit_system(baseline, [], "xy") @@ -127,17 +125,19 @@ def test_tax_benefit_systems_with_extensions_cache(): assert xy_tax_benefit_system != yx_tax_benefit_system -def test_extensions_formats(): +def test_extensions_formats() -> None: baseline = TaxBenefitSystem() lonely_extension_tbs = _get_tax_benefit_system(baseline, [], "lonely_extension") list_lonely_extension_tbs = _get_tax_benefit_system( - baseline, [], ["lonely_extension"] + baseline, + [], + ["lonely_extension"], ) assert lonely_extension_tbs == list_lonely_extension_tbs -def test_extensions_order(): +def test_extensions_order() -> None: baseline = TaxBenefitSystem() xy_tax_benefit_system = _get_tax_benefit_system(baseline, [], ["x", "y"]) @@ -148,7 +148,7 @@ def test_extensions_order(): @pytest.mark.skip(reason="Deprecated node constructor") -def test_performance_graph_option_output(): +def test_performance_graph_option_output() -> None: test = { "input": {"salary": {"2017-01": 2000}}, "output": {"salary": {"2017-01": 2000}}, @@ -170,7 +170,7 @@ def test_performance_graph_option_output(): @pytest.mark.skip(reason="Deprecated node constructor") -def test_performance_tables_option_output(): +def test_performance_tables_option_output() -> None: test = { "input": {"salary": {"2017-01": 2000}}, "output": {"salary": {"2017-01": 2000}}, @@ -191,7 +191,7 @@ def test_performance_tables_option_output(): clean_performance_files(paths) -def clean_performance_files(paths: List[str]): +def clean_performance_files(paths: list[str]) -> None: for path in paths: if os.path.isfile(path): os.remove(path) diff --git a/tests/core/variables/test_annualize.py b/tests/core/variables/test_annualize.py index 7bf85d9a46..58ea1372dd 100644 --- a/tests/core/variables/test_annualize.py +++ b/tests/core/variables/test_annualize.py @@ -1,4 +1,4 @@ -import numpy as np +import numpy from pytest import fixture from openfisca_country_template.entities import Person @@ -17,9 +17,9 @@ class monthly_variable(Variable): entity = Person definition_period = DateUnit.MONTH - def formula(person, period, parameters): + def formula(self, period, parameters): variable.calculation_count += 1 - return np.asarray([100]) + return numpy.asarray([100]) variable = monthly_variable() variable.calculation_count = calculation_count @@ -30,17 +30,16 @@ def formula(person, period, parameters): class PopulationMock: # Simulate a population for whom a variable has already been put in cache for January. - def __init__(self, variable): + def __init__(self, variable) -> None: self.variable = variable def __call__(self, variable_name: str, period): if period.start.month == 1: - return np.asarray([100]) - else: - return self.variable.get_formula(period)(self, period, None) + return numpy.asarray([100]) + return self.variable.get_formula(period)(self, period, None) -def test_without_annualize(monthly_variable): +def test_without_annualize(monthly_variable) -> None: period = periods.period(2019) person = PopulationMock(monthly_variable) @@ -54,7 +53,7 @@ def test_without_annualize(monthly_variable): assert yearly_sum == 1200 -def test_with_annualize(monthly_variable): +def test_with_annualize(monthly_variable) -> None: period = periods.period(2019) annualized_variable = get_annualized_variable(monthly_variable) @@ -69,10 +68,11 @@ def test_with_annualize(monthly_variable): assert yearly_sum == 100 * 12 -def test_with_partial_annualize(monthly_variable): +def test_with_partial_annualize(monthly_variable) -> None: period = periods.period("year:2018:2") annualized_variable = get_annualized_variable( - monthly_variable, periods.period(2018) + monthly_variable, + periods.period(2018), ) person = PopulationMock(annualized_variable) diff --git a/tests/core/variables/test_definition_period.py b/tests/core/variables/test_definition_period.py index 7938aaeaef..8ef9bfaa87 100644 --- a/tests/core/variables/test_definition_period.py +++ b/tests/core/variables/test_definition_period.py @@ -13,31 +13,31 @@ class TestVariable(Variable): return TestVariable -def test_weekday_variable(variable): +def test_weekday_variable(variable) -> None: variable.definition_period = periods.WEEKDAY assert variable() -def test_week_variable(variable): +def test_week_variable(variable) -> None: variable.definition_period = periods.WEEK assert variable() -def test_day_variable(variable): +def test_day_variable(variable) -> None: variable.definition_period = periods.DAY assert variable() -def test_month_variable(variable): +def test_month_variable(variable) -> None: variable.definition_period = periods.MONTH assert variable() -def test_year_variable(variable): +def test_year_variable(variable) -> None: variable.definition_period = periods.YEAR assert variable() -def test_eternity_variable(variable): +def test_eternity_variable(variable) -> None: variable.definition_period = periods.ETERNITY assert variable() diff --git a/tests/core/variables/test_variables.py b/tests/core/variables/test_variables.py index 3b2790bae7..d5d85a70d9 100644 --- a/tests/core/variables/test_variables.py +++ b/tests/core/variables/test_variables.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import datetime from pytest import fixture, mark, raises @@ -26,14 +24,16 @@ @fixture def couple(): return SimulationBuilder().build_from_entities( - tax_benefit_system, openfisca_country_template.situation_examples.couple + tax_benefit_system, + openfisca_country_template.situation_examples.couple, ) @fixture def simulation(): return SimulationBuilder().build_from_entities( - tax_benefit_system, openfisca_country_template.situation_examples.single + tax_benefit_system, + openfisca_country_template.situation_examples.single, ) @@ -41,16 +41,17 @@ def vectorize(individu, number): return individu.filled_array(number) -def check_error_at_add_variable(tax_benefit_system, variable, error_message_prefix): +def check_error_at_add_variable( + tax_benefit_system, variable, error_message_prefix +) -> None: try: tax_benefit_system.add_variable(variable) except ValueError as e: message = get_message(e) if not message or not message.startswith(error_message_prefix): + msg = f'Incorrect error message. Was expecting something starting by "{error_message_prefix}". Got: "{message}"' raise AssertionError( - 'Incorrect error message. Was expecting something starting by "{}". Got: "{}"'.format( - error_message_prefix, message - ) + msg, ) @@ -71,11 +72,11 @@ class variable__no_date(Variable): label = "Variable without date." -def test_before_add__variable__no_date(): +def test_before_add__variable__no_date() -> None: assert tax_benefit_system.variables.get("variable__no_date") is None -def test_variable__no_date(): +def test_variable__no_date() -> None: tax_benefit_system.add_variable(variable__no_date) variable = tax_benefit_system.variables["variable__no_date"] assert variable.end is None @@ -93,14 +94,14 @@ class variable__strange_end_attribute(Variable): end = "1989-00-00" -def test_variable__strange_end_attribute(): +def test_variable__strange_end_attribute() -> None: try: tax_benefit_system.add_variable(variable__strange_end_attribute) except ValueError as e: message = get_message(e) assert message.startswith( - "Incorrect 'end' attribute format in 'variable__strange_end_attribute'." + "Incorrect 'end' attribute format in 'variable__strange_end_attribute'.", ) # Check that Error at variable adding prevents it from registration in the taxbenefitsystem. @@ -121,12 +122,12 @@ class variable__end_attribute(Variable): tax_benefit_system.add_variable(variable__end_attribute) -def test_variable__end_attribute(): +def test_variable__end_attribute() -> None: variable = tax_benefit_system.variables["variable__end_attribute"] assert variable.end == datetime.date(1989, 12, 31) -def test_variable__end_attribute_set_input(simulation): +def test_variable__end_attribute_set_input(simulation) -> None: month_before_end = "1989-01" month_after_end = "1990-01" simulation.set_input("variable__end_attribute", month_before_end, 10) @@ -145,21 +146,21 @@ class end_attribute__one_simple_formula(Variable): label = "Variable with end attribute, one formula without date." end = "1989-12-31" - def formula(individu, period): - return vectorize(individu, 100) + def formula(self, period): + return vectorize(self, 100) tax_benefit_system.add_variable(end_attribute__one_simple_formula) -def test_formulas_attributes_single_formula(): +def test_formulas_attributes_single_formula() -> None: formulas = tax_benefit_system.variables[ "end_attribute__one_simple_formula" ].formulas assert formulas["0001-01-01"] is not None -def test_call__end_attribute__one_simple_formula(simulation): +def test_call__end_attribute__one_simple_formula(simulation) -> None: month = "1979-12" assert simulation.calculate("end_attribute__one_simple_formula", month) == 100 @@ -170,7 +171,7 @@ def test_call__end_attribute__one_simple_formula(simulation): assert simulation.calculate("end_attribute__one_simple_formula", month) == 0 -def test_dates__end_attribute__one_simple_formula(): +def test_dates__end_attribute__one_simple_formula() -> None: variable = tax_benefit_system.variables["end_attribute__one_simple_formula"] assert variable.end == datetime.date(1989, 12, 31) @@ -190,11 +191,11 @@ class no_end_attribute__one_formula__strange_name(Variable): definition_period = DateUnit.MONTH label = "Variable without end attribute, one stangely named formula." - def formula_2015_toto(individu, period): - return vectorize(individu, 100) + def formula_2015_toto(self, period): + return vectorize(self, 100) -def test_add__no_end_attribute__one_formula__strange_name(): +def test_add__no_end_attribute__one_formula__strange_name() -> None: check_error_at_add_variable( tax_benefit_system, no_end_attribute__one_formula__strange_name, @@ -211,14 +212,14 @@ class no_end_attribute__one_formula__start(Variable): definition_period = DateUnit.MONTH label = "Variable without end attribute, one dated formula." - def formula_2000_01_01(individu, period): - return vectorize(individu, 100) + def formula_2000_01_01(self, period): + return vectorize(self, 100) tax_benefit_system.add_variable(no_end_attribute__one_formula__start) -def test_call__no_end_attribute__one_formula__start(simulation): +def test_call__no_end_attribute__one_formula__start(simulation) -> None: month = "1999-12" assert simulation.calculate("no_end_attribute__one_formula__start", month) == 0 @@ -229,7 +230,7 @@ def test_call__no_end_attribute__one_formula__start(simulation): assert simulation.calculate("no_end_attribute__one_formula__start", month) == 100 -def test_dates__no_end_attribute__one_formula__start(): +def test_dates__no_end_attribute__one_formula__start() -> None: variable = tax_benefit_system.variables["no_end_attribute__one_formula__start"] assert variable.end is None @@ -245,15 +246,15 @@ class no_end_attribute__one_formula__eternity(Variable): ) # For this entity, this variable shouldn't evolve through time label = "Variable without end attribute, one dated formula." - def formula_2000_01_01(individu, period): - return vectorize(individu, 100) + def formula_2000_01_01(self, period): + return vectorize(self, 100) tax_benefit_system.add_variable(no_end_attribute__one_formula__eternity) @mark.xfail() -def test_call__no_end_attribute__one_formula__eternity(simulation): +def test_call__no_end_attribute__one_formula__eternity(simulation) -> None: month = "1999-12" assert simulation.calculate("no_end_attribute__one_formula__eternity", month) == 0 @@ -262,12 +263,12 @@ def test_call__no_end_attribute__one_formula__eternity(simulation): assert simulation.calculate("no_end_attribute__one_formula__eternity", month) == 100 -def test_call__no_end_attribute__one_formula__eternity_before(simulation): +def test_call__no_end_attribute__one_formula__eternity_before(simulation) -> None: month = "1999-12" assert simulation.calculate("no_end_attribute__one_formula__eternity", month) == 0 -def test_call__no_end_attribute__one_formula__eternity_after(simulation): +def test_call__no_end_attribute__one_formula__eternity_after(simulation) -> None: month = "2000-01" assert simulation.calculate("no_end_attribute__one_formula__eternity", month) == 100 @@ -281,17 +282,17 @@ class no_end_attribute__formulas__start_formats(Variable): definition_period = DateUnit.MONTH label = "Variable without end attribute, multiple dated formulas." - def formula_2000(individu, period): - return vectorize(individu, 100) + def formula_2000(self, period): + return vectorize(self, 100) - def formula_2010_01(individu, period): - return vectorize(individu, 200) + def formula_2010_01(self, period): + return vectorize(self, 200) tax_benefit_system.add_variable(no_end_attribute__formulas__start_formats) -def test_formulas_attributes_dated_formulas(): +def test_formulas_attributes_dated_formulas() -> None: formulas = tax_benefit_system.variables[ "no_end_attribute__formulas__start_formats" ].formulas @@ -300,7 +301,7 @@ def test_formulas_attributes_dated_formulas(): assert formulas["2010-01-01"] is not None -def test_get_formulas(): +def test_get_formulas() -> None: variable = tax_benefit_system.variables["no_end_attribute__formulas__start_formats"] formula_2000 = variable.formulas["2000-01-01"] formula_2010 = variable.formulas["2010-01-01"] @@ -313,7 +314,7 @@ def test_get_formulas(): assert variable.get_formula("2010-01-01") == formula_2010 -def test_call__no_end_attribute__formulas__start_formats(simulation): +def test_call__no_end_attribute__formulas__start_formats(simulation) -> None: month = "1999-12" assert simulation.calculate("no_end_attribute__formulas__start_formats", month) == 0 @@ -342,14 +343,14 @@ class no_attribute__formulas__different_names__dates_overlap(Variable): definition_period = DateUnit.MONTH label = "Variable, no end attribute, multiple dated formulas with different names but same dates." - def formula_2000(individu, period): - return vectorize(individu, 100) + def formula_2000(self, period): + return vectorize(self, 100) - def formula_2000_01_01(individu, period): - return vectorize(individu, 200) + def formula_2000_01_01(self, period): + return vectorize(self, 200) -def test_add__no_attribute__formulas__different_names__dates_overlap(): +def test_add__no_attribute__formulas__different_names__dates_overlap() -> None: # Variable isn't registered in the taxbenefitsystem check_error_at_add_variable( tax_benefit_system, @@ -367,21 +368,22 @@ class no_attribute__formulas__different_names__no_overlap(Variable): definition_period = DateUnit.MONTH label = "Variable, no end attribute, multiple dated formulas with different names and no date overlap." - def formula_2000_01_01(individu, period): - return vectorize(individu, 100) + def formula_2000_01_01(self, period): + return vectorize(self, 100) - def formula_2010_01_01(individu, period): - return vectorize(individu, 200) + def formula_2010_01_01(self, period): + return vectorize(self, 200) tax_benefit_system.add_variable(no_attribute__formulas__different_names__no_overlap) -def test_call__no_attribute__formulas__different_names__no_overlap(simulation): +def test_call__no_attribute__formulas__different_names__no_overlap(simulation) -> None: month = "2009-12" assert ( simulation.calculate( - "no_attribute__formulas__different_names__no_overlap", month + "no_attribute__formulas__different_names__no_overlap", + month, ) == 100 ) @@ -389,7 +391,8 @@ def test_call__no_attribute__formulas__different_names__no_overlap(simulation): month = "2015-05" assert ( simulation.calculate( - "no_attribute__formulas__different_names__no_overlap", month + "no_attribute__formulas__different_names__no_overlap", + month, ) == 200 ) @@ -408,14 +411,14 @@ class end_attribute__one_formula__start(Variable): label = "Variable with end attribute, one dated formula." end = "2001-12-31" - def formula_2000_01_01(individu, period): - return vectorize(individu, 100) + def formula_2000_01_01(self, period): + return vectorize(self, 100) tax_benefit_system.add_variable(end_attribute__one_formula__start) -def test_call__end_attribute__one_formula__start(simulation): +def test_call__end_attribute__one_formula__start(simulation) -> None: month = "1980-01" assert simulation.calculate("end_attribute__one_formula__start", month) == 0 @@ -436,11 +439,11 @@ class stop_attribute_before__one_formula__start(Variable): label = "Variable with stop attribute only coming before formula start." end = "1990-01-01" - def formula_2000_01_01(individu, period): - return vectorize(individu, 0) + def formula_2000_01_01(self, period): + return vectorize(self, 0) -def test_add__stop_attribute_before__one_formula__start(): +def test_add__stop_attribute_before__one_formula__start() -> None: check_error_at_add_variable( tax_benefit_system, stop_attribute_before__one_formula__start, @@ -460,14 +463,14 @@ class end_attribute_restrictive__one_formula(Variable): ) end = "2001-01-01" - def formula_2001_01_01(individu, period): - return vectorize(individu, 100) + def formula_2001_01_01(self, period): + return vectorize(self, 100) tax_benefit_system.add_variable(end_attribute_restrictive__one_formula) -def test_call__end_attribute_restrictive__one_formula(simulation): +def test_call__end_attribute_restrictive__one_formula(simulation) -> None: month = "2000-12" assert simulation.calculate("end_attribute_restrictive__one_formula", month) == 0 @@ -488,20 +491,20 @@ class end_attribute__formulas__different_names(Variable): label = "Variable with end attribute, multiple dated formulas with different names." end = "2010-12-31" - def formula_2000_01_01(individu, period): - return vectorize(individu, 100) + def formula_2000_01_01(self, period): + return vectorize(self, 100) - def formula_2005_01_01(individu, period): - return vectorize(individu, 200) + def formula_2005_01_01(self, period): + return vectorize(self, 200) - def formula_2010_01_01(individu, period): - return vectorize(individu, 300) + def formula_2010_01_01(self, period): + return vectorize(self, 300) tax_benefit_system.add_variable(end_attribute__formulas__different_names) -def test_call__end_attribute__formulas__different_names(simulation): +def test_call__end_attribute__formulas__different_names(simulation) -> None: month = "2000-01" assert ( simulation.calculate("end_attribute__formulas__different_names", month) == 100 @@ -518,20 +521,22 @@ def test_call__end_attribute__formulas__different_names(simulation): ) -def test_get_formula(simulation): +def test_get_formula(simulation) -> None: person = simulation.person disposable_income_formula = tax_benefit_system.get_variable( - "disposable_income" + "disposable_income", ).get_formula() disposable_income = person("disposable_income", "2017-01") disposable_income_2 = disposable_income_formula( - person, "2017-01", None + person, + "2017-01", + None, ) # No need for parameters here assert_near(disposable_income, disposable_income_2) -def test_unexpected_attr(): +def test_unexpected_attr() -> None: class variable_with_strange_attr(Variable): value_type = int entity = Person diff --git a/tests/fixtures/appclient.py b/tests/fixtures/appclient.py index 5edcfc2c98..692747d393 100644 --- a/tests/fixtures/appclient.py +++ b/tests/fixtures/appclient.py @@ -15,8 +15,10 @@ def test_client(tax_benefit_system): from openfisca_country_template import entities from openfisca_core import periods from openfisca_core.variables import Variable + ... + class new_variable(Variable): value_type = float entity = entities.Person @@ -24,11 +26,11 @@ class new_variable(Variable): label = "New variable" reference = "https://law.gov.example/new_variable" # Always use the most official source + tax_benefit_system.add_variable(new_variable) flask_app = app.create_app(tax_benefit_system) """ - # Create the test API client flask_app = app.create_app(tax_benefit_system) return flask_app.test_client() diff --git a/tests/fixtures/entities.py b/tests/fixtures/entities.py index 4d103f10d3..6670a68da1 100644 --- a/tests/fixtures/entities.py +++ b/tests/fixtures/entities.py @@ -7,25 +7,29 @@ class TestEntity(Entity): def get_variable( - self, variable_name: str, check_existence: bool = False + self, + variable_name: str, + check_existence: bool = False, ) -> TestVariable: result = TestVariable(self) result.name = variable_name return result - def check_variable_defined_for_entity(self, variable_name: str): + def check_variable_defined_for_entity(self, variable_name: str) -> bool: return True class TestGroupEntity(GroupEntity): def get_variable( - self, variable_name: str, check_existence: bool = False + self, + variable_name: str, + check_existence: bool = False, ) -> TestVariable: result = TestVariable(self) result.name = variable_name return result - def check_variable_defined_for_entity(self, variable_name: str): + def check_variable_defined_for_entity(self, variable_name: str) -> bool: return True diff --git a/tests/fixtures/extensions.py b/tests/fixtures/extensions.py index 631dcbc0d7..bc4e85fe72 100644 --- a/tests/fixtures/extensions.py +++ b/tests/fixtures/extensions.py @@ -3,16 +3,16 @@ import pytest -@pytest.fixture() -def test_country_package_name(): +@pytest.fixture +def test_country_package_name() -> str: return "openfisca_country_template" -@pytest.fixture() -def test_extension_package_name(): +@pytest.fixture +def test_extension_package_name() -> str: return "openfisca_extension_template" -@pytest.fixture() +@pytest.fixture def distribution(test_country_package_name): return metadata.distribution(test_country_package_name) diff --git a/tests/fixtures/simulations.py b/tests/fixtures/simulations.py index 6df19ba27f..53120b60d9 100644 --- a/tests/fixtures/simulations.py +++ b/tests/fixtures/simulations.py @@ -24,6 +24,4 @@ def make_simulation(): def _simulation(simulation_builder, tax_benefit_system, variables, period): simulation_builder.set_default_period(period) - simulation = simulation_builder.build_from_variables(tax_benefit_system, variables) - - return simulation + return simulation_builder.build_from_variables(tax_benefit_system, variables) diff --git a/tests/fixtures/variables.py b/tests/fixtures/variables.py index aab7cda58d..2deccf5891 100644 --- a/tests/fixtures/variables.py +++ b/tests/fixtures/variables.py @@ -6,6 +6,6 @@ class TestVariable(Variable): definition_period = DateUnit.ETERNITY value_type = float - def __init__(self, entity): + def __init__(self, entity) -> None: self.__class__.entity = entity super().__init__() diff --git a/tests/web_api/case_with_extension/test_extensions.py b/tests/web_api/case_with_extension/test_extensions.py index be5ee6bf24..2c688232f8 100644 --- a/tests/web_api/case_with_extension/test_extensions.py +++ b/tests/web_api/case_with_extension/test_extensions.py @@ -6,7 +6,7 @@ from openfisca_web_api.app import create_app -@pytest.fixture() +@pytest.fixture def tax_benefit_system(test_country_package_name, test_extension_package_name): return build_tax_benefit_system( test_country_package_name, @@ -15,25 +15,25 @@ def tax_benefit_system(test_country_package_name, test_extension_package_name): ) -@pytest.fixture() +@pytest.fixture def extended_subject(tax_benefit_system): return create_app(tax_benefit_system).test_client() -def test_return_code(extended_subject): +def test_return_code(extended_subject) -> None: parameters_response = extended_subject.get("/parameters") assert parameters_response.status_code == client.OK -def test_return_code_existing_parameter(extended_subject): +def test_return_code_existing_parameter(extended_subject) -> None: extension_parameter_response = extended_subject.get( - "/parameter/local_town.child_allowance.amount" + "/parameter/local_town.child_allowance.amount", ) assert extension_parameter_response.status_code == client.OK -def test_return_code_existing_variable(extended_subject): +def test_return_code_existing_variable(extended_subject) -> None: extension_variable_response = extended_subject.get( - "/variable/local_town_child_allowance" + "/variable/local_town_child_allowance", ) assert extension_variable_response.status_code == client.OK diff --git a/tests/web_api/case_with_reform/test_reforms.py b/tests/web_api/case_with_reform/test_reforms.py index afcb811443..f0895cf189 100644 --- a/tests/web_api/case_with_reform/test_reforms.py +++ b/tests/web_api/case_with_reform/test_reforms.py @@ -6,7 +6,7 @@ from openfisca_web_api import app -@pytest.fixture() +@pytest.fixture def test_reforms_path(test_country_package_name): return [ f"{test_country_package_name}.reforms.add_dynamic_variable.add_dynamic_variable", @@ -29,37 +29,37 @@ def client(test_country_package_name, test_reforms_path): return app.create_app(tax_benefit_system).test_client() -def test_return_code_of_dynamic_variable(client): +def test_return_code_of_dynamic_variable(client) -> None: result = client.get("/variable/goes_to_school") assert result.status_code == http.client.OK -def test_return_code_of_has_car_variable(client): +def test_return_code_of_has_car_variable(client) -> None: result = client.get("/variable/has_car") assert result.status_code == http.client.OK -def test_return_code_of_new_tax_variable(client): +def test_return_code_of_new_tax_variable(client) -> None: result = client.get("/variable/new_tax") assert result.status_code == http.client.OK -def test_return_code_of_social_security_contribution_variable(client): +def test_return_code_of_social_security_contribution_variable(client) -> None: result = client.get("/variable/social_security_contribution") assert result.status_code == http.client.OK -def test_return_code_of_social_security_contribution_parameter(client): +def test_return_code_of_social_security_contribution_parameter(client) -> None: result = client.get("/parameter/taxes.social_security_contribution") assert result.status_code == http.client.OK -def test_return_code_of_basic_income_variable(client): +def test_return_code_of_basic_income_variable(client) -> None: result = client.get("/variable/basic_income") assert result.status_code == http.client.OK diff --git a/tests/web_api/loader/test_parameters.py b/tests/web_api/loader/test_parameters.py index 2b6be58916..f44632ce49 100644 --- a/tests/web_api/loader/test_parameters.py +++ b/tests/web_api/loader/test_parameters.py @@ -1,46 +1,44 @@ -# -*- coding: utf-8 -*- - from openfisca_core.parameters import Scale from openfisca_web_api.loader.parameters import build_api_parameter, build_api_scale -def test_build_rate_scale(): - """Extracts a 'rate' children from a bracket collection""" +def test_build_rate_scale() -> None: + """Extracts a 'rate' children from a bracket collection.""" data = { "brackets": [ { "rate": {"2014-01-01": {"value": 0.5}}, "threshold": {"2014-01-01": {"value": 1}}, - } - ] + }, + ], } rate = Scale("this rate", data, None) assert build_api_scale(rate, "rate") == {"2014-01-01": {1: 0.5}} -def test_build_amount_scale(): - """Extracts an 'amount' children from a bracket collection""" +def test_build_amount_scale() -> None: + """Extracts an 'amount' children from a bracket collection.""" data = { "brackets": [ { "amount": {"2014-01-01": {"value": 0}}, "threshold": {"2014-01-01": {"value": 1}}, - } - ] + }, + ], } rate = Scale("that amount", data, None) assert build_api_scale(rate, "amount") == {"2014-01-01": {1: 0}} -def test_full_rate_scale(): - """Serializes a 'rate' scale parameter""" +def test_full_rate_scale() -> None: + """Serializes a 'rate' scale parameter.""" data = { "brackets": [ { "rate": {"2014-01-01": {"value": 0.5}}, "threshold": {"2014-01-01": {"value": 1}}, - } - ] + }, + ], } scale = Scale("rate", data, None) api_scale = build_api_parameter(scale, {}) @@ -52,15 +50,15 @@ def test_full_rate_scale(): } -def test_walk_node_amount_scale(): - """Serializes an 'amount' scale parameter""" +def test_walk_node_amount_scale() -> None: + """Serializes an 'amount' scale parameter.""" data = { "brackets": [ { "amount": {"2014-01-01": {"value": 0}}, "threshold": {"2014-01-01": {"value": 1}}, - } - ] + }, + ], } scale = Scale("amount", data, None) api_scale = build_api_parameter(scale, {}) diff --git a/tests/web_api/test_calculate.py b/tests/web_api/test_calculate.py index d5d64c3c38..4d69dae9ab 100644 --- a/tests/web_api/test_calculate.py +++ b/tests/web_api/test_calculate.py @@ -12,14 +12,18 @@ def post_json(client, data=None, file=None): if file: file_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "assets", file + os.path.dirname(os.path.abspath(__file__)), + "assets", + file, ) - with open(file_path, "r") as file: + with open(file_path) as file: data = file.read() return client.post("/calculate", data=data, content_type="application/json") -def check_response(client, data, expected_error_code, path_to_check, content_to_check): +def check_response( + client, data, expected_error_code, path_to_check, content_to_check +) -> None: response = post_json(client, data) assert response.status_code == expected_error_code json_response = json.loads(response.data.decode("utf-8")) @@ -138,11 +142,11 @@ def check_response(client, data, expected_error_code, path_to_check, content_to_ ), ], ) -def test_responses(test_client, test): +def test_responses(test_client, test) -> None: check_response(test_client, *test) -def test_basic_calculation(test_client): +def test_basic_calculation(test_client) -> None: simulation_json = json.dumps( { "persons": { @@ -166,7 +170,7 @@ def test_basic_calculation(test_client): "accommodation_size": {"2017-01": 300}, }, }, - } + }, ) response = post_json(test_client, simulation_json) @@ -184,7 +188,8 @@ def test_basic_calculation(test_client): assert dpath.util.get(response_json, "persons/bob/basic_income/2017-12") == 600 assert ( dpath.util.get( - response_json, "persons/bob/social_security_contribution/2017-12" + response_json, + "persons/bob/social_security_contribution/2017-12", ) == 816 ) # From social_security_contribution.yaml test @@ -194,7 +199,7 @@ def test_basic_calculation(test_client): ) -def test_enums_sending_identifier(test_client): +def test_enums_sending_identifier(test_client) -> None: simulation_json = json.dumps( { "persons": {"bill": {}}, @@ -204,9 +209,9 @@ def test_enums_sending_identifier(test_client): "housing_tax": {"2017": None}, "accommodation_size": {"2017-01": 300}, "housing_occupancy_status": {"2017-01": "free_lodger"}, - } + }, }, - } + }, ) response = post_json(test_client, simulation_json) @@ -215,7 +220,7 @@ def test_enums_sending_identifier(test_client): assert dpath.util.get(response_json, "households/_/housing_tax/2017") == 0 -def test_enum_output(test_client): +def test_enum_output(test_client) -> None: simulation_json = json.dumps( { "persons": { @@ -227,7 +232,7 @@ def test_enum_output(test_client): "housing_occupancy_status": {"2017-01": None}, }, }, - } + }, ) response = post_json(test_client, simulation_json) @@ -239,7 +244,7 @@ def test_enum_output(test_client): ) -def test_enum_wrong_value(test_client): +def test_enum_wrong_value(test_client) -> None: simulation_json = json.dumps( { "persons": { @@ -251,7 +256,7 @@ def test_enum_wrong_value(test_client): "housing_occupancy_status": {"2017-01": "Unknown value lodger"}, }, }, - } + }, ) response = post_json(test_client, simulation_json) @@ -259,26 +264,27 @@ def test_enum_wrong_value(test_client): response_json = json.loads(response.data.decode("utf-8")) message = "Possible values are ['owner', 'tenant', 'free_lodger', 'homeless']" text = dpath.util.get( - response_json, "households/_/housing_occupancy_status/2017-01" + response_json, + "households/_/housing_occupancy_status/2017-01", ) assert message in text -def test_encoding_variable_value(test_client): +def test_encoding_variable_value(test_client) -> None: simulation_json = json.dumps( { "persons": {"toto": {}}, "households": { "_": { "housing_occupancy_status": { - "2017-07": "Locataire ou sous-locataire d‘un logement loué vide non-HLM" + "2017-07": "Locataire ou sous-locataire d‘un logement loué vide non-HLM", }, "parent": [ "toto", ], - } + }, }, - } + }, ) # No UnicodeDecodeError @@ -287,17 +293,18 @@ def test_encoding_variable_value(test_client): response_json = json.loads(response.data.decode("utf-8")) message = "'Locataire ou sous-locataire d‘un logement loué vide non-HLM' is not a known value for 'housing_occupancy_status'. Possible values are " text = dpath.util.get( - response_json, "households/_/housing_occupancy_status/2017-07" + response_json, + "households/_/housing_occupancy_status/2017-07", ) assert message in text -def test_encoding_entity_name(test_client): +def test_encoding_entity_name(test_client) -> None: simulation_json = json.dumps( { "persons": {"O‘Ryan": {}, "Renée": {}}, "households": {"_": {"parents": ["O‘Ryan", "Renée"]}}, - } + }, ) # No UnicodeDecodeError @@ -311,7 +318,7 @@ def test_encoding_entity_name(test_client): assert message in text -def test_encoding_period_id(test_client): +def test_encoding_period_id(test_client) -> None: simulation_json = json.dumps( { "persons": { @@ -324,9 +331,9 @@ def test_encoding_period_id(test_client): "housing_tax": {"à": 400}, "accommodation_size": {"2017-01": 300}, "housing_occupancy_status": {"2017-01": "tenant"}, - } + }, }, - } + }, ) # No UnicodeDecodeError @@ -341,19 +348,21 @@ def test_encoding_period_id(test_client): assert message in text -def test_str_variable(test_client): +def test_str_variable(test_client) -> None: new_couple = copy.deepcopy(couple) new_couple["households"]["_"]["postal_code"] = {"2017-01": None} simulation_json = json.dumps(new_couple) response = test_client.post( - "/calculate", data=simulation_json, content_type="application/json" + "/calculate", + data=simulation_json, + content_type="application/json", ) assert response.status_code == client.OK -def test_periods(test_client): +def test_periods(test_client) -> None: simulation_json = json.dumps( { "persons": {"bill": {}}, @@ -362,9 +371,9 @@ def test_periods(test_client): "parents": ["bill"], "housing_tax": {"2017": None}, "housing_occupancy_status": {"2017-01": None}, - } + }, }, - } + }, ) response = post_json(test_client, simulation_json) @@ -373,19 +382,20 @@ def test_periods(test_client): response_json = json.loads(response.data.decode("utf-8")) yearly_variable = dpath.util.get( - response_json, "households/_/housing_tax" + response_json, + "households/_/housing_tax", ) # web api year is an int assert yearly_variable == {"2017": 200.0} monthly_variable = dpath.util.get( - response_json, "households/_/housing_occupancy_status" + response_json, + "households/_/housing_occupancy_status", ) # web api month is a string assert monthly_variable == {"2017-01": "tenant"} -def test_two_periods(test_client): - """ - Test `calculate` on a request with mixed types periods: yearly periods following +def test_two_periods(test_client) -> None: + """Test `calculate` on a request with mixed types periods: yearly periods following monthly or daily periods to check dpath limitation on numeric keys (yearly periods). Made to test the case where we have more than one path with a numeric in it. See https://github.com/dpath-maintainers/dpath-python/issues/160 for more informations. @@ -398,9 +408,9 @@ def test_two_periods(test_client): "parents": ["bill"], "housing_tax": {"2017": None, "2018": None}, "housing_occupancy_status": {"2017-01": None, "2018-01": None}, - } + }, }, - } + }, ) response = post_json(test_client, simulation_json) @@ -409,17 +419,19 @@ def test_two_periods(test_client): response_json = json.loads(response.data.decode("utf-8")) yearly_variable = dpath.util.get( - response_json, "households/_/housing_tax" + response_json, + "households/_/housing_tax", ) # web api year is an int assert yearly_variable == {"2017": 200.0, "2018": 200.0} monthly_variable = dpath.util.get( - response_json, "households/_/housing_occupancy_status" + response_json, + "households/_/housing_occupancy_status", ) # web api month is a string assert monthly_variable == {"2017-01": "tenant", "2018-01": "tenant"} -def test_handle_period_mismatch_error(test_client): +def test_handle_period_mismatch_error(test_client) -> None: variable = "housing_tax" period = "2017-01" @@ -430,9 +442,9 @@ def test_handle_period_mismatch_error(test_client): "_": { "parents": ["bill"], variable: {period: 400}, - } + }, }, - } + }, ) response = post_json(test_client, simulation_json) @@ -445,9 +457,8 @@ def test_handle_period_mismatch_error(test_client): assert message in error -def test_gracefully_handle_unexpected_errors(test_client): - """ - Context +def test_gracefully_handle_unexpected_errors(test_client) -> None: + """Context. ======= Whenever an exception is raised by the calculation engine, the API will try @@ -466,7 +477,7 @@ def test_gracefully_handle_unexpected_errors(test_client): In the `country-template`, Housing Tax is only defined from 2010 onwards. The calculation engine should therefore raise an exception `ParameterNotFound`. The API is not expecting this, but she should handle the situation nonetheless. - """ # noqa RST399 + """ variable = "housing_tax" period = "1234-05-06" @@ -481,9 +492,9 @@ def test_gracefully_handle_unexpected_errors(test_client): variable: { period: None, }, - } + }, }, - } + }, ) response = post_json(test_client, simulation_json) diff --git a/tests/web_api/test_entities.py b/tests/web_api/test_entities.py index afb909ef57..e7d0ef5b9b 100644 --- a/tests/web_api/test_entities.py +++ b/tests/web_api/test_entities.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import json from http import client @@ -8,12 +6,12 @@ # /entities -def test_return_code(test_client): +def test_return_code(test_client) -> None: entities_response = test_client.get("/entities") assert entities_response.status_code == client.OK -def test_response_data(test_client): +def test_response_data(test_client) -> None: entities_response = test_client.get("/entities") entities_dict = json.loads(entities_response.data.decode("utf-8")) test_documentation = entities.Household.doc.strip() diff --git a/tests/web_api/test_headers.py b/tests/web_api/test_headers.py index c5464d91b1..dc95437a09 100644 --- a/tests/web_api/test_headers.py +++ b/tests/web_api/test_headers.py @@ -1,10 +1,10 @@ -def test_package_name_header(test_client, distribution): +def test_package_name_header(test_client, distribution) -> None: name = distribution.metadata.get("Name").lower() parameters_response = test_client.get("/parameters") assert parameters_response.headers.get("Country-Package") == name -def test_package_version_header(test_client, distribution): +def test_package_version_header(test_client, distribution) -> None: version = distribution.metadata.get("Version") parameters_response = test_client.get("/parameters") assert parameters_response.headers.get("Country-Package-Version") == version diff --git a/tests/web_api/test_helpers.py b/tests/web_api/test_helpers.py index 5b22a57b47..a1725cdfbf 100644 --- a/tests/web_api/test_helpers.py +++ b/tests/web_api/test_helpers.py @@ -6,7 +6,7 @@ dir_path = os.path.join(os.path.dirname(__file__), "assets") -def test_build_api_values_history(): +def test_build_api_values_history() -> None: file_path = os.path.join(dir_path, "test_helpers.yaml") parameter = load_parameter_file(name="dummy_name", file_path=file_path) @@ -18,7 +18,7 @@ def test_build_api_values_history(): assert parameters.build_api_values_history(parameter) == values -def test_build_api_values_history_with_stop_date(): +def test_build_api_values_history_with_stop_date() -> None: file_path = os.path.join(dir_path, "test_helpers_with_stop_date.yaml") parameter = load_parameter_file(name="dummy_name", file_path=file_path) @@ -32,7 +32,7 @@ def test_build_api_values_history_with_stop_date(): assert parameters.build_api_values_history(parameter) == values -def test_get_value(): +def test_get_value() -> None: values = {"2013-01-01": 0.03, "2017-01-01": 0.02, "2015-01-01": 0.04} assert parameters.get_value("2013-01-01", values) == 0.03 @@ -43,7 +43,7 @@ def test_get_value(): assert parameters.get_value("2018-01-01", values) == 0.02 -def test_get_value_with_none(): +def test_get_value_with_none() -> None: values = {"2015-01-01": 0.04, "2017-01-01": None} assert parameters.get_value("2016-12-31", values) == 0.04 diff --git a/tests/web_api/test_parameters.py b/tests/web_api/test_parameters.py index 762193fc2d..77fee8f7ea 100644 --- a/tests/web_api/test_parameters.py +++ b/tests/web_api/test_parameters.py @@ -10,12 +10,12 @@ GITHUB_URL_REGEX = r"^https://github\.com/openfisca/country-template/blob/\d+\.\d+\.\d+((.dev|rc)\d+)?/openfisca_country_template/parameters/(.)+\.yaml$" -def test_return_code(test_client): +def test_return_code(test_client) -> None: parameters_response = test_client.get("/parameters") assert parameters_response.status_code == client.OK -def test_response_data(test_client): +def test_response_data(test_client) -> None: parameters_response = test_client.get("/parameters") parameters = json.loads(parameters_response.data.decode("utf-8")) @@ -29,25 +29,25 @@ def test_response_data(test_client): # /parameter/ -def test_error_code_non_existing_parameter(test_client): +def test_error_code_non_existing_parameter(test_client) -> None: response = test_client.get("/parameter/non/existing.parameter") assert response.status_code == client.NOT_FOUND -def test_return_code_existing_parameter(test_client): +def test_return_code_existing_parameter(test_client) -> None: response = test_client.get("/parameter/taxes/income_tax_rate") assert response.status_code == client.OK -def test_legacy_parameter_route(test_client): +def test_legacy_parameter_route(test_client) -> None: response = test_client.get("/parameter/taxes.income_tax_rate") assert response.status_code == client.OK -def test_parameter_values(test_client): +def test_parameter_values(test_client) -> None: response = test_client.get("/parameter/taxes/income_tax_rate") parameter = json.loads(response.data) - assert sorted(list(parameter.keys())), [ + assert sorted(parameter.keys()), [ "description", "id", "metadata", @@ -69,7 +69,7 @@ def test_parameter_values(test_client): # 'documentation' attribute exists only when a value is defined response = test_client.get("/parameter/benefits/housing_allowance") parameter = json.loads(response.data) - assert sorted(list(parameter.keys())), [ + assert sorted(parameter.keys()), [ "description", "documentation", "id", @@ -82,11 +82,11 @@ def test_parameter_values(test_client): ) -def test_parameter_node(tax_benefit_system, test_client): +def test_parameter_node(tax_benefit_system, test_client) -> None: response = test_client.get("/parameter/benefits") assert response.status_code == client.OK parameter = json.loads(response.data) - assert sorted(list(parameter.keys())), [ + assert sorted(parameter.keys()), [ "description", "documentation", "id", @@ -107,20 +107,22 @@ def test_parameter_node(tax_benefit_system, test_client): assert "description" in parameter["subparams"]["basic_income"] assert parameter["subparams"]["basic_income"]["description"] == getattr( - model_benefits.basic_income, "description", None + model_benefits.basic_income, + "description", + None, ), parameter["subparams"]["basic_income"]["description"] -def test_stopped_parameter_values(test_client): +def test_stopped_parameter_values(test_client) -> None: response = test_client.get("/parameter/benefits/housing_allowance") parameter = json.loads(response.data) assert parameter["values"] == {"2016-12-01": None, "2010-01-01": 0.25} -def test_scale(test_client): +def test_scale(test_client) -> None: response = test_client.get("/parameter/taxes/social_security_contribution") parameter = json.loads(response.data) - assert sorted(list(parameter.keys())), [ + assert sorted(parameter.keys()), [ "brackets", "description", "id", @@ -135,7 +137,7 @@ def test_scale(test_client): } -def check_code(client, route, code): +def check_code(client, route, code) -> None: response = client.get(route) assert response.status_code == code @@ -153,10 +155,10 @@ def check_code(client, route, code): ("/parameter//taxes/income_tax_rate/", client.FOUND), ], ) -def test_routes_robustness(test_client, expected_code): +def test_routes_robustness(test_client, expected_code) -> None: check_code(test_client, *expected_code) -def test_parameter_encoding(test_client): +def test_parameter_encoding(test_client) -> None: parameter_response = test_client.get("/parameter/general/age_of_retirement") assert parameter_response.status_code == client.OK diff --git a/tests/web_api/test_spec.py b/tests/web_api/test_spec.py index 228cf27eb8..75a0f00e64 100644 --- a/tests/web_api/test_spec.py +++ b/tests/web_api/test_spec.py @@ -6,11 +6,11 @@ from openapi_spec_validator import OpenAPIV30SpecValidator -def assert_items_equal(x, y): +def assert_items_equal(x, y) -> None: assert sorted(x) == sorted(y) -def test_return_code(test_client): +def test_return_code(test_client) -> None: openAPI_response = test_client.get("/spec") assert openAPI_response.status_code == client.OK @@ -21,7 +21,7 @@ def body(test_client): return json.loads(openAPI_response.data.decode("utf-8")) -def test_paths(body): +def test_paths(body) -> None: assert_items_equal( body["paths"], [ @@ -37,29 +37,41 @@ def test_paths(body): ) -def test_entity_definition(body): +def test_entity_definition(body) -> None: assert "parents" in dpath.util.get(body, "components/schemas/Household/properties") assert "children" in dpath.util.get(body, "components/schemas/Household/properties") assert "salary" in dpath.util.get(body, "components/schemas/Person/properties") assert "rent" in dpath.util.get(body, "components/schemas/Household/properties") - assert "number" == dpath.util.get( - body, "components/schemas/Person/properties/salary/additionalProperties/type" + assert ( + dpath.util.get( + body, + "components/schemas/Person/properties/salary/additionalProperties/type", + ) + == "number" ) -def test_situation_definition(body): +def test_situation_definition(body) -> None: situation_input = body["components"]["schemas"]["SituationInput"] situation_output = body["components"]["schemas"]["SituationOutput"] for situation in situation_input, situation_output: assert "households" in dpath.util.get(situation, "/properties") assert "persons" in dpath.util.get(situation, "/properties") - assert "#/components/schemas/Household" == dpath.util.get( - situation, "/properties/households/additionalProperties/$ref" + assert ( + dpath.util.get( + situation, + "/properties/households/additionalProperties/$ref", + ) + == "#/components/schemas/Household" ) - assert "#/components/schemas/Person" == dpath.util.get( - situation, "/properties/persons/additionalProperties/$ref" + assert ( + dpath.util.get( + situation, + "/properties/persons/additionalProperties/$ref", + ) + == "#/components/schemas/Person" ) -def test_respects_spec(body): - assert not [error for error in OpenAPIV30SpecValidator(body).iter_errors()] +def test_respects_spec(body) -> None: + assert not list(OpenAPIV30SpecValidator(body).iter_errors()) diff --git a/tests/web_api/test_trace.py b/tests/web_api/test_trace.py index ee6c6ab21f..9463e69dfb 100644 --- a/tests/web_api/test_trace.py +++ b/tests/web_api/test_trace.py @@ -7,24 +7,28 @@ from openfisca_country_template.situation_examples import couple, single -def assert_items_equal(x, y): +def assert_items_equal(x, y) -> None: assert set(x) == set(y) -def test_trace_basic(test_client): +def test_trace_basic(test_client) -> None: simulation_json = json.dumps(single) response = test_client.post( - "/trace", data=simulation_json, content_type="application/json" + "/trace", + data=simulation_json, + content_type="application/json", ) assert response.status_code == client.OK response_json = json.loads(response.data.decode("utf-8")) disposable_income_value = dpath.util.get( - response_json, "trace/disposable_income<2017-01>/value" + response_json, + "trace/disposable_income<2017-01>/value", ) assert isinstance(disposable_income_value, list) assert isinstance(disposable_income_value[0], float) disposable_income_dep = dpath.util.get( - response_json, "trace/disposable_income<2017-01>/dependencies" + response_json, + "trace/disposable_income<2017-01>/dependencies", ) assert_items_equal( disposable_income_dep, @@ -36,29 +40,35 @@ def test_trace_basic(test_client): ], ) basic_income_dep = dpath.util.get( - response_json, "trace/basic_income<2017-01>/dependencies" + response_json, + "trace/basic_income<2017-01>/dependencies", ) assert_items_equal(basic_income_dep, ["age<2017-01>"]) -def test_trace_enums(test_client): +def test_trace_enums(test_client) -> None: new_single = copy.deepcopy(single) new_single["households"]["_"]["housing_occupancy_status"] = {"2017-01": None} simulation_json = json.dumps(new_single) response = test_client.post( - "/trace", data=simulation_json, content_type="application/json" + "/trace", + data=simulation_json, + content_type="application/json", ) response_json = json.loads(response.data) housing_status = dpath.util.get( - response_json, "trace/housing_occupancy_status<2017-01>/value" + response_json, + "trace/housing_occupancy_status<2017-01>/value", ) assert housing_status[0] == "tenant" # The default value -def test_entities_description(test_client): +def test_entities_description(test_client) -> None: simulation_json = json.dumps(couple) response = test_client.post( - "/trace", data=simulation_json, content_type="application/json" + "/trace", + data=simulation_json, + content_type="application/json", ) response_json = json.loads(response.data.decode("utf-8")) assert_items_equal( @@ -67,10 +77,12 @@ def test_entities_description(test_client): ) -def test_root_nodes(test_client): +def test_root_nodes(test_client) -> None: simulation_json = json.dumps(couple) response = test_client.post( - "/trace", data=simulation_json, content_type="application/json" + "/trace", + data=simulation_json, + content_type="application/json", ) response_json = json.loads(response.data.decode("utf-8")) assert_items_equal( @@ -83,25 +95,29 @@ def test_root_nodes(test_client): ) -def test_str_variable(test_client): +def test_str_variable(test_client) -> None: new_couple = copy.deepcopy(couple) new_couple["households"]["_"]["postal_code"] = {"2017-01": None} simulation_json = json.dumps(new_couple) response = test_client.post( - "/trace", data=simulation_json, content_type="application/json" + "/trace", + data=simulation_json, + content_type="application/json", ) assert response.status_code == client.OK -def test_trace_parameters(test_client): +def test_trace_parameters(test_client) -> None: new_couple = copy.deepcopy(couple) new_couple["households"]["_"]["housing_tax"] = {"2017": None} simulation_json = json.dumps(new_couple) response = test_client.post( - "/trace", data=simulation_json, content_type="application/json" + "/trace", + data=simulation_json, + content_type="application/json", ) response_json = json.loads(response.data.decode("utf-8")) diff --git a/tests/web_api/test_variables.py b/tests/web_api/test_variables.py index d53581618d..d3b46dfff9 100644 --- a/tests/web_api/test_variables.py +++ b/tests/web_api/test_variables.py @@ -5,7 +5,7 @@ import pytest -def assert_items_equal(x, y): +def assert_items_equal(x, y) -> None: assert set(x) == set(y) @@ -17,15 +17,14 @@ def assert_items_equal(x, y): @pytest.fixture(scope="module") def variables_response(test_client): - variables_response = test_client.get("/variables") - return variables_response + return test_client.get("/variables") -def test_return_code(variables_response): +def test_return_code(variables_response) -> None: assert variables_response.status_code == client.OK -def test_response_data(variables_response): +def test_response_data(variables_response) -> None: variables = json.loads(variables_response.data.decode("utf-8")) assert variables["birth"] == { "description": "Birth date", @@ -36,22 +35,21 @@ def test_response_data(variables_response): # /variable/ -def test_error_code_non_existing_variable(test_client): +def test_error_code_non_existing_variable(test_client) -> None: response = test_client.get("/variable/non_existing_variable") assert response.status_code == client.NOT_FOUND @pytest.fixture(scope="module") def input_variable_response(test_client): - input_variable_response = test_client.get("/variable/birth") - return input_variable_response + return test_client.get("/variable/birth") -def test_return_code_existing_input_variable(input_variable_response): +def test_return_code_existing_input_variable(input_variable_response) -> None: assert input_variable_response.status_code == client.OK -def check_input_variable_value(key, expected_value, input_variable=None): +def check_input_variable_value(key, expected_value, input_variable=None) -> None: assert input_variable[key] == expected_value @@ -66,25 +64,25 @@ def check_input_variable_value(key, expected_value, input_variable=None): ("references", ["https://en.wiktionary.org/wiki/birthdate"]), ], ) -def test_input_variable_value(expected_values, input_variable_response): +def test_input_variable_value(expected_values, input_variable_response) -> None: input_variable = json.loads(input_variable_response.data.decode("utf-8")) check_input_variable_value(*expected_values, input_variable=input_variable) -def test_input_variable_github_url(test_client): +def test_input_variable_github_url(test_client) -> None: input_variable_response = test_client.get("/variable/income_tax") input_variable = json.loads(input_variable_response.data.decode("utf-8")) assert re.match(GITHUB_URL_REGEX, input_variable["source"]) -def test_return_code_existing_variable(test_client): +def test_return_code_existing_variable(test_client) -> None: variable_response = test_client.get("/variable/income_tax") assert variable_response.status_code == client.OK -def check_variable_value(key, expected_value, variable=None): +def check_variable_value(key, expected_value, variable=None) -> None: assert variable[key] == expected_value @@ -98,19 +96,19 @@ def check_variable_value(key, expected_value, variable=None): ("entity", "person"), ], ) -def test_variable_value(expected_values, test_client): +def test_variable_value(expected_values, test_client) -> None: variable_response = test_client.get("/variable/income_tax") variable = json.loads(variable_response.data.decode("utf-8")) check_variable_value(*expected_values, variable=variable) -def test_variable_formula_github_link(test_client): +def test_variable_formula_github_link(test_client) -> None: variable_response = test_client.get("/variable/income_tax") variable = json.loads(variable_response.data.decode("utf-8")) assert re.match(GITHUB_URL_REGEX, variable["formulas"]["0001-01-01"]["source"]) -def test_variable_formula_content(test_client): +def test_variable_formula_content(test_client) -> None: variable_response = test_client.get("/variable/income_tax") variable = json.loads(variable_response.data.decode("utf-8")) content = variable["formulas"]["0001-01-01"]["content"] @@ -121,13 +119,13 @@ def test_variable_formula_content(test_client): ) -def test_null_values_are_dropped(test_client): +def test_null_values_are_dropped(test_client) -> None: variable_response = test_client.get("/variable/age") variable = json.loads(variable_response.data.decode("utf-8")) - assert "references" not in variable.keys() + assert "references" not in variable -def test_variable_with_start_and_stop_date(test_client): +def test_variable_with_start_and_stop_date(test_client) -> None: response = test_client.get("/variable/housing_allowance") variable = json.loads(response.data.decode("utf-8")) assert_items_equal(variable["formulas"], ["1980-01-01", "2016-12-01"]) @@ -135,12 +133,12 @@ def test_variable_with_start_and_stop_date(test_client): assert "formula" in variable["formulas"]["1980-01-01"]["content"] -def test_variable_with_enum(test_client): +def test_variable_with_enum(test_client) -> None: response = test_client.get("/variable/housing_occupancy_status") variable = json.loads(response.data.decode("utf-8")) assert variable["valueType"] == "String" assert variable["defaultValue"] == "tenant" - assert "possibleValues" in variable.keys() + assert "possibleValues" in variable assert variable["possibleValues"] == { "free_lodger": "Free lodger", "homeless": "Homeless", @@ -151,20 +149,19 @@ def test_variable_with_enum(test_client): @pytest.fixture(scope="module") def dated_variable_response(test_client): - dated_variable_response = test_client.get("/variable/basic_income") - return dated_variable_response + return test_client.get("/variable/basic_income") -def test_return_code_existing_dated_variable(dated_variable_response): +def test_return_code_existing_dated_variable(dated_variable_response) -> None: assert dated_variable_response.status_code == client.OK -def test_dated_variable_formulas_dates(dated_variable_response): +def test_dated_variable_formulas_dates(dated_variable_response) -> None: dated_variable = json.loads(dated_variable_response.data.decode("utf-8")) assert_items_equal(dated_variable["formulas"], ["2016-12-01", "2015-12-01"]) -def test_dated_variable_formulas_content(dated_variable_response): +def test_dated_variable_formulas_content(dated_variable_response) -> None: dated_variable = json.loads(dated_variable_response.data.decode("utf-8")) formula_code_2016 = dated_variable["formulas"]["2016-12-01"]["content"] formula_code_2015 = dated_variable["formulas"]["2015-12-01"]["content"] @@ -175,12 +172,12 @@ def test_dated_variable_formulas_content(dated_variable_response): assert "return" in formula_code_2015 -def test_variable_encoding(test_client): +def test_variable_encoding(test_client) -> None: variable_response = test_client.get("/variable/pension") assert variable_response.status_code == client.OK -def test_variable_documentation(test_client): +def test_variable_documentation(test_client) -> None: response = test_client.get("/variable/housing_allowance") variable = json.loads(response.data.decode("utf-8")) assert (