Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better typing for util functions #95

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions hdx_hapi/db/dao/util/util.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks fine to me. I should probably apply this machinery to the population view I am working on now

Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Type
from sqlalchemy import Column, Select
from typing import Protocol, Type
from sqlalchemy import DateTime, Select
from sqlalchemy.orm import Mapped

from hdx_hapi.db.models.views.all_views import Admin1View, Admin2View, LocationView
from hdx_hapi.endpoints.util.util import PaginationParams, ReferencePeriodParameters


Expand All @@ -16,10 +16,15 @@ def apply_pagination(query: Select, pagination_parameters: PaginationParams) ->
return query.limit(limit).offset(offset)


class EntityWithReferencePeriod(Protocol):
reference_period_start: Mapped[DateTime]
reference_period_end: Mapped[DateTime]


def apply_reference_period_filter(
query: Select,
ref_period_parameters: ReferencePeriodParameters,
db_class: Type[LocationView] | Type[Admin1View] | Type[Admin2View],
db_class: Type[EntityWithReferencePeriod],
) -> Select:
if ref_period_parameters.reference_period_start_min:
query = query.where(db_class.reference_period_start >= ref_period_parameters.reference_period_start_min)
Expand All @@ -32,6 +37,6 @@ def apply_reference_period_filter(
return query


def case_insensitive_filter(query: Select, column: Column, value: str) -> Select:
def case_insensitive_filter(query: Select, column: Mapped[str], value: str) -> Select:
query = query.where(column.ilike(value))
return query
Loading