diff --git a/openfisca_core/periods/period_.py b/openfisca_core/periods/period_.py index f7b901c58..11a7b671b 100644 --- a/openfisca_core/periods/period_.py +++ b/openfisca_core/periods/period_.py @@ -1,10 +1,10 @@ from __future__ import annotations import typing +from collections.abc import Sequence import calendar import datetime -from collections.abc import Sequence import pendulum diff --git a/openfisca_core/simulations/__init__.py b/openfisca_core/simulations/__init__.py index 913e90d1e..c5db39e11 100644 --- a/openfisca_core/simulations/__init__.py +++ b/openfisca_core/simulations/__init__.py @@ -27,11 +27,14 @@ SpiralError, ) +from .actions import RunFormula from .helpers import ( # noqa: F401 calculate_output_add, calculate_output_divide, check_type, transform_to_strict_syntax, ) -from .simulation import Simulation # noqa: F401 -from .simulation_builder import SimulationBuilder # noqa: F401 +from .simulation import Simulation +from .simulation_builder import SimulationBuilder + +__all__ = ["RunFormula", "Simulation", "SimulationBuilder"] diff --git a/openfisca_core/simulations/actions/__init__.py b/openfisca_core/simulations/actions/__init__.py new file mode 100644 index 000000000..df278c8cd --- /dev/null +++ b/openfisca_core/simulations/actions/__init__.py @@ -0,0 +1,3 @@ +from ._run_formula import RunFormula + +__all__ = ["RunFormula"] diff --git a/openfisca_core/simulations/actions/_run_formula.py b/openfisca_core/simulations/actions/_run_formula.py new file mode 100644 index 000000000..eab3c1a26 --- /dev/null +++ b/openfisca_core/simulations/actions/_run_formula.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from numpy.typing import NDArray +from typing import Any, cast + +import dataclasses +import inspect + +from ..typing import Formula2, Formula3, Instant, Params, Population + + +@dataclasses.dataclass(frozen=True) +class RunFormula: + """Run a Variable's given Formula.""" + + #: The formula we want to run. + formula: Formula3 | Formula2 | None = None + + def __call__( + self, population: Population, instant: Instant, params: Params + ) -> NDArray[Any] | None: + if self.formula is None: + return None + + if self.__arity() == 3: + return cast(Formula3, self.formula)(population, instant, params) + + else: + return cast(Formula2, self.formula)(population, instant) + + def __arity(self) -> int: + return len(inspect.getfullargspec(self.formula).args) diff --git a/openfisca_core/simulations/simulation.py b/openfisca_core/simulations/simulation.py index 95c90a9ee..2bd4e6fb7 100644 --- a/openfisca_core/simulations/simulation.py +++ b/openfisca_core/simulations/simulation.py @@ -1,5 +1,6 @@ from __future__ import annotations +from openfisca_core.types import Population, TaxBenefitSystem, Variable from typing import Dict, NamedTuple, Optional, Set import tempfile @@ -16,9 +17,11 @@ SimpleTracer, TracingParameterNodeAtInstant, ) -from openfisca_core.types import Population, TaxBenefitSystem, Variable from openfisca_core.warnings import TempfileWarning +from .actions import RunFormula +from .typing import Params + class Simulation: """ @@ -29,6 +32,13 @@ class Simulation: populations: Dict[str, Population] invalidated_caches: Set[Cache] + @property + def params(self) -> Params: + if self.trace: + return self.trace_parameters_at_instant + + return self.tax_benefit_system.get_parameters_at_instant + def __init__( self, tax_benefit_system: TaxBenefitSystem, @@ -144,7 +154,8 @@ def _calculate(self, variable_name: str, period: Period): # First, try to run a formula try: self._check_for_cycle(variable.name, period) - array = self._run_formula(variable, population, period) + run_formula = RunFormula(variable.get_formula(period)) + array = run_formula(population, period, self.params) # If no result, use the default value and cache it if array is None: @@ -306,27 +317,6 @@ def trace_parameters_at_instant(self, formula_period): self.tracer, ) - def _run_formula(self, variable, population, period): - """ - Find the ``variable`` formula for the given ``period`` if it exists, and apply it to ``population``. - """ - - formula = variable.get_formula(period) - if formula is None: - return None - - if self.trace: - parameters_at = self.trace_parameters_at_instant - else: - parameters_at = self.tax_benefit_system.get_parameters_at_instant - - if formula.__code__.co_argcount == 2: - array = formula(population, period) - else: - array = formula(population, period, parameters_at) - - return array - def _check_period_consistency(self, period, variable): """ Check that a period matches the variable definition_period diff --git a/openfisca_core/simulations/simulation_builder.py b/openfisca_core/simulations/simulation_builder.py index 41ca1e22e..71c37a352 100644 --- a/openfisca_core/simulations/simulation_builder.py +++ b/openfisca_core/simulations/simulation_builder.py @@ -1,8 +1,8 @@ -from typing import Dict, List, Iterable +from typing import Dict, Iterable, List import copy -import dpath.util +import dpath.util import numpy from openfisca_core import periods @@ -13,7 +13,7 @@ VariableNotFoundError, ) from openfisca_core.populations import Population -from openfisca_core.simulations import helpers, Simulation +from openfisca_core.simulations import Simulation, helpers from openfisca_core.variables import Variable diff --git a/openfisca_core/simulations/tests/__init__.py b/openfisca_core/simulations/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/openfisca_core/simulations/tests/test_run_formula.py b/openfisca_core/simulations/tests/test_run_formula.py new file mode 100644 index 000000000..9b9486086 --- /dev/null +++ b/openfisca_core/simulations/tests/test_run_formula.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from numpy.typing import NDArray +from openfisca_core.simulations.typing import ( + Formula2, + Formula3, + Instant, + ParameterNodeAtInstant, + Params, + Population, +) +from typing import Any + +import numpy +import pytest + +from openfisca_core import simulations + + +class TestPopulation: + ... + + +class TestInstant: + ... + + +class TestParams: + def __call__(self, instant: Instant) -> ParameterNodeAtInstant: + ... + + +@pytest.fixture +def population() -> Population: + return TestPopulation() + + +@pytest.fixture +def instant() -> Instant: + return TestInstant() + + +@pytest.fixture +def params() -> Params: + return TestParams() + + +def test_run_formula_without_formula( + population: Population, instant: Instant, params: Params +) -> None: + """Test that RunFormula runs without a formula.""" + + run_formula = simulations.RunFormula(None) + + assert not run_formula(population, instant, params) + + +def test_run_formula_with_two_arguments( + population: Population, instant: Instant, params: Params +) -> None: + """Test that RunFormula runs a formula with two arguments.""" + + def formula(a: Population, b: Instant) -> NDArray[Any]: + return numpy.array([1, 2, 3]) + + run_formula = simulations.RunFormula(formula) + + assert run_formula(population, instant, params) + + +def test_run_formula_with_three_arguments( + population: Population, instant: Instant, params: Params +) -> None: + """Test that RunFormula runs a formula with three arguments.""" + + def formula(a: Population, b: Instant, c: Params) -> NDArray[Any]: + return numpy.array([1, 2, 3]) + + run_formula = simulations.RunFormula(formula) + + assert run_formula(population, instant, params) diff --git a/openfisca_core/simulations/typing.py b/openfisca_core/simulations/typing.py new file mode 100644 index 000000000..cdca5a31d --- /dev/null +++ b/openfisca_core/simulations/typing.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import abc +from numpy.typing import NDArray +from typing import Any, Protocol + + +class Formula3(Protocol): + @abc.abstractmethod + def __call__( + self, __population: Population, __instant: Instant, __params: Params + ) -> NDArray[Any]: + ... + + +class Formula2(Protocol): + @abc.abstractmethod + def __call__(self, __population: Population, __instant: Instant) -> NDArray[Any]: + ... + + +class Instant(Protocol): + ... + + +class ParameterNodeAtInstant(Protocol): + ... + + +class Params(Protocol): + @abc.abstractmethod + def __call__(self, __instant: Instant) -> ParameterNodeAtInstant: + ... + + +class Population(Protocol): + ... diff --git a/openfisca_tasks/lint.mk b/openfisca_tasks/lint.mk index a15b70457..2bf126682 100644 --- a/openfisca_tasks/lint.mk +++ b/openfisca_tasks/lint.mk @@ -60,6 +60,6 @@ lint-typing-strict-%: ## Run code formatters to correct style errors. format-style: $(shell git ls-files "*.py") @$(call print_help,$@:) - @isort openfisca_core/periods + @isort openfisca_core/periods openfisca_core/simulations @black $? @$(call print_pass,$@:)