Skip to content

Commit

Permalink
Extract run formulas
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko committed Oct 2, 2023
1 parent b0e5fdb commit d2625f4
Show file tree
Hide file tree
Showing 10 changed files with 176 additions and 30 deletions.
2 changes: 1 addition & 1 deletion openfisca_core/periods/period_.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

import typing
from collections.abc import Sequence

import calendar
import datetime
from collections.abc import Sequence

import pendulum

Expand Down
7 changes: 5 additions & 2 deletions openfisca_core/simulations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,14 @@
SpiralError,
)

from .actions import RunFormula
from .helpers import ( # noqa: F401
calculate_output_add,
calculate_output_divide,
check_type,
transform_to_strict_syntax,
)
from .simulation import Simulation # noqa: F401
from .simulation_builder import SimulationBuilder # noqa: F401
from .simulation import Simulation
from .simulation_builder import SimulationBuilder

__all__ = ["RunFormula", "Simulation", "SimulationBuilder"]
3 changes: 3 additions & 0 deletions openfisca_core/simulations/actions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._run_formula import RunFormula

__all__ = ["RunFormula"]
32 changes: 32 additions & 0 deletions openfisca_core/simulations/actions/_run_formula.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations

from numpy.typing import NDArray
from typing import Any, cast

import dataclasses
import inspect

from ..typing import Formula2, Formula3, Instant, Params, Population


@dataclasses.dataclass(frozen=True)
class RunFormula:
"""Run a Variable's given Formula."""

#: The formula we want to run.
formula: Formula3 | Formula2 | None = None

def __call__(
self, population: Population, instant: Instant, params: Params
) -> NDArray[Any] | None:
if self.formula is None:
return None

if self.__arity() == 3:
return cast(Formula3, self.formula)(population, instant, params)

else:
return cast(Formula2, self.formula)(population, instant)

def __arity(self) -> int:
return len(inspect.getfullargspec(self.formula).args)
36 changes: 13 additions & 23 deletions openfisca_core/simulations/simulation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from openfisca_core.types import Population, TaxBenefitSystem, Variable
from typing import Dict, NamedTuple, Optional, Set

import tempfile
Expand All @@ -16,9 +17,11 @@
SimpleTracer,
TracingParameterNodeAtInstant,
)
from openfisca_core.types import Population, TaxBenefitSystem, Variable
from openfisca_core.warnings import TempfileWarning

from .actions import RunFormula
from .typing import Params


class Simulation:
"""
Expand All @@ -29,6 +32,13 @@ class Simulation:
populations: Dict[str, Population]
invalidated_caches: Set[Cache]

@property
def params(self) -> Params:
if self.trace:
return self.trace_parameters_at_instant

return self.tax_benefit_system.get_parameters_at_instant

def __init__(
self,
tax_benefit_system: TaxBenefitSystem,
Expand Down Expand Up @@ -144,7 +154,8 @@ def _calculate(self, variable_name: str, period: Period):
# First, try to run a formula
try:
self._check_for_cycle(variable.name, period)
array = self._run_formula(variable, population, period)
run_formula = RunFormula(variable.get_formula(period))
array = run_formula(population, period, self.params)

# If no result, use the default value and cache it
if array is None:
Expand Down Expand Up @@ -306,27 +317,6 @@ def trace_parameters_at_instant(self, formula_period):
self.tracer,
)

def _run_formula(self, variable, population, period):
"""
Find the ``variable`` formula for the given ``period`` if it exists, and apply it to ``population``.
"""

formula = variable.get_formula(period)
if formula is None:
return None

if self.trace:
parameters_at = self.trace_parameters_at_instant
else:
parameters_at = self.tax_benefit_system.get_parameters_at_instant

if formula.__code__.co_argcount == 2:
array = formula(population, period)
else:
array = formula(population, period, parameters_at)

return array

def _check_period_consistency(self, period, variable):
"""
Check that a period matches the variable definition_period
Expand Down
6 changes: 3 additions & 3 deletions openfisca_core/simulations/simulation_builder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Dict, List, Iterable
from typing import Dict, Iterable, List

import copy
import dpath.util

import dpath.util
import numpy

from openfisca_core import periods
Expand All @@ -13,7 +13,7 @@
VariableNotFoundError,
)
from openfisca_core.populations import Population
from openfisca_core.simulations import helpers, Simulation
from openfisca_core.simulations import Simulation, helpers
from openfisca_core.variables import Variable


Expand Down
Empty file.
81 changes: 81 additions & 0 deletions openfisca_core/simulations/tests/test_run_formula.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import annotations

from numpy.typing import NDArray
from openfisca_core.simulations.typing import (
Formula2,
Formula3,
Instant,
ParameterNodeAtInstant,
Params,
Population,
)
from typing import Any

import numpy
import pytest

from openfisca_core import simulations


class TestPopulation:
...


class TestInstant:
...


class TestParams:
def __call__(self, instant: Instant) -> ParameterNodeAtInstant:
...


@pytest.fixture
def population() -> Population:
return TestPopulation()


@pytest.fixture
def instant() -> Instant:
return TestInstant()


@pytest.fixture
def params() -> Params:
return TestParams()


def test_run_formula_without_formula(
population: Population, instant: Instant, params: Params
) -> None:
"""Test that RunFormula runs without a formula."""

run_formula = simulations.RunFormula(None)

assert not run_formula(population, instant, params)


def test_run_formula_with_two_arguments(
population: Population, instant: Instant, params: Params
) -> None:
"""Test that RunFormula runs a formula with two arguments."""

def formula(a: Population, b: Instant) -> NDArray[Any]:
return numpy.array([1, 2, 3])

run_formula = simulations.RunFormula(formula)

assert run_formula(population, instant, params)


def test_run_formula_with_three_arguments(
population: Population, instant: Instant, params: Params
) -> None:
"""Test that RunFormula runs a formula with three arguments."""

def formula(a: Population, b: Instant, c: Params) -> NDArray[Any]:
return numpy.array([1, 2, 3])

run_formula = simulations.RunFormula(formula)

assert run_formula(population, instant, params)
37 changes: 37 additions & 0 deletions openfisca_core/simulations/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import annotations

import abc
from numpy.typing import NDArray
from typing import Any, Protocol


class Formula3(Protocol):
@abc.abstractmethod
def __call__(
self, __population: Population, __instant: Instant, __params: Params
) -> NDArray[Any]:
...


class Formula2(Protocol):
@abc.abstractmethod
def __call__(self, __population: Population, __instant: Instant) -> NDArray[Any]:
...


class Instant(Protocol):
...


class ParameterNodeAtInstant(Protocol):
...


class Params(Protocol):
@abc.abstractmethod
def __call__(self, __instant: Instant) -> ParameterNodeAtInstant:
...


class Population(Protocol):
...
2 changes: 1 addition & 1 deletion openfisca_tasks/lint.mk
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,6 @@ lint-typing-strict-%:
## Run code formatters to correct style errors.
format-style: $(shell git ls-files "*.py")
@$(call print_help,$@:)
@isort openfisca_core/periods
@isort openfisca_core/periods openfisca_core/simulations
@black $?
@$(call print_pass,$@:)

0 comments on commit d2625f4

Please sign in to comment.