diff --git a/hdx_hapi/db/dao/util/util.py b/hdx_hapi/db/dao/util/util.py index 5b767744..21dc658d 100644 --- a/hdx_hapi/db/dao/util/util.py +++ b/hdx_hapi/db/dao/util/util.py @@ -1,7 +1,7 @@ -from typing import Type -from sqlalchemy import Column, Select +from typing import Protocol, Type +from sqlalchemy import DateTime, Select +from sqlalchemy.orm import Mapped -from hdx_hapi.db.models.views.all_views import Admin1View, Admin2View, LocationView from hdx_hapi.endpoints.util.util import PaginationParams, ReferencePeriodParameters @@ -16,10 +16,15 @@ def apply_pagination(query: Select, pagination_parameters: PaginationParams) -> return query.limit(limit).offset(offset) +class EntityWithReferencePeriod(Protocol): + reference_period_start: Mapped[DateTime] + reference_period_end: Mapped[DateTime] + + def apply_reference_period_filter( query: Select, ref_period_parameters: ReferencePeriodParameters, - db_class: Type[LocationView] | Type[Admin1View] | Type[Admin2View], + db_class: Type[EntityWithReferencePeriod], ) -> Select: if ref_period_parameters.reference_period_start_min: query = query.where(db_class.reference_period_start >= ref_period_parameters.reference_period_start_min) @@ -32,6 +37,6 @@ def apply_reference_period_filter( return query -def case_insensitive_filter(query: Select, column: Column, value: str) -> Select: +def case_insensitive_filter(query: Select, column: Mapped[str], value: str) -> Select: query = query.where(column.ilike(value)) return query