Skip to content

Commit

Permalink
fix(simulations): use group pop instead of single
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko committed Sep 16, 2024
1 parent bb00b12 commit 8ecc899
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 11 deletions.
14 changes: 6 additions & 8 deletions openfisca_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from openfisca_core import commons, errors, indexed_enums, periods, tracers
from openfisca_core import warnings as core_warnings

from .types import SinglePopulation, TaxBenefitSystem, Variable
from .types import GroupPopulation, TaxBenefitSystem, Variable


class Simulation:
Expand All @@ -19,13 +19,13 @@ class Simulation:
"""

tax_benefit_system: TaxBenefitSystem
populations: dict[str, SinglePopulation]
populations: dict[str, GroupPopulation]
invalidated_caches: Set[Cache]

def __init__(
self,
tax_benefit_system: TaxBenefitSystem,
populations: dict[str, SinglePopulation],
populations: dict[str, GroupPopulation],
):
"""
This constructor is reserved for internal use; see :any:`SimulationBuilder`,
Expand Down Expand Up @@ -531,7 +531,7 @@ def set_input(self, variable_name: str, period, value):
return
self.get_holder(variable_name).set_input(period, value)

def get_variable_population(self, variable_name: str) -> SinglePopulation:
def get_variable_population(self, variable_name: str) -> GroupPopulation:
variable: Optional[Variable]

variable = self.tax_benefit_system.get_variable(
Expand All @@ -543,9 +543,7 @@ def get_variable_population(self, variable_name: str) -> SinglePopulation:

return self.populations[variable.entity.key]

def get_population(
self, plural: Optional[str] = None
) -> Optional[SinglePopulation]:
def get_population(self, plural: Optional[str] = None) -> Optional[GroupPopulation]:
return next(
(
population
Expand All @@ -558,7 +556,7 @@ def get_population(
def get_entity(
self,
plural: Optional[str] = None,
) -> Optional[SinglePopulation]:
) -> Optional[GroupPopulation]:
population = self.get_population(plural)
return population and population.entity

Expand Down
11 changes: 8 additions & 3 deletions openfisca_core/simulations/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from collections.abc import Iterable, Sequence
from collections.abc import Callable, Iterable, Sequence
from typing import Protocol, TypeVar, TypedDict, Union
from typing_extensions import NotRequired, Required, TypeAlias

Expand Down Expand Up @@ -160,6 +160,10 @@ class Axis(TypedDict, total=False):
period: NotRequired[str | int]


class Simulation(t.Simulation, Protocol):
...


# Tax-Benefit systems


Expand Down Expand Up @@ -188,8 +192,8 @@ def entities_plural(self) -> Iterable[str]:
def get_variable(
self,
__variable_name: str,
__check_existence: bool = ...,
) -> V | None:
check_existence: bool = ...,
) -> Variable[T] | None:
...

def instantiate_entities(
Expand All @@ -202,6 +206,7 @@ def instantiate_entities(


class Variable(t.Variable, Protocol[T]):
calculate_output: Callable[[Simulation, str, str], t.Array[T]] | None
definition_period: str
end: str
name: str
Expand Down
1 change: 1 addition & 0 deletions openfisca_tasks/lint.mk
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ check-types:
openfisca_core/simulations/_build_from_variables.py \
openfisca_core/simulations/_guards.py \
openfisca_core/simulations/helpers.py \
openfisca_core/simulations/simulation.py \
openfisca_core/simulations/types.py \
openfisca_core/types.py
@$(call print_pass,$@:)
Expand Down

0 comments on commit 8ecc899

Please sign in to comment.