Skip to content

Commit

Permalink
refactor: move filtered searching into user/group/queue repositories
Browse files Browse the repository at this point in the history
  • Loading branch information
chisholm committed Oct 2, 2024
1 parent ec7de19 commit 82637c5
Show file tree
Hide file tree
Showing 12 changed files with 369 additions and 162 deletions.
14 changes: 14 additions & 0 deletions src/dioptra/restapi/db/repository/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,17 @@ class UserEmailNotAvailableError(Exception):

class QueueAlreadyExistsError(Exception):
"""The queue name already exists."""


class QueueSortError(Exception):
"""The requested sortBy column is not a sortable field."""


class UnsupportedFilterField(Exception):
"""A filter field is not supported for a particular repository method"""

def __init__(self, field_name: str) -> None:
self.field_name = field_name

message = f"{self.field_name!r} is not a valid field"
super().__init__(message)
42 changes: 41 additions & 1 deletion src/dioptra/restapi/db/repository/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
"""
The group repository: data operations related to groups
"""
from typing import Final
from collections.abc import Sequence
from typing import Any, Final

import sqlalchemy as sa

Expand All @@ -33,6 +34,7 @@
assert_group_exists,
assert_user_exists,
check_user_collision,
construct_sql_query_filters,
group_exists,
user_exists,
)
Expand All @@ -52,6 +54,10 @@

class GroupRepository:

SEARCHABLE_FIELDS: Final[dict[str, Any]] = {
"name": lambda x: Group.name.like(x, escape="/"),
}

def __init__(self, session: CompatibleSession[S]):
self.session = session

Expand Down Expand Up @@ -180,6 +186,40 @@ def get_by_name(

return group

def get_by_filters_paged(
self,
filters: list[dict],
page_start: int,
page_length: int,
deletion_policy: DeletionPolicy = DeletionPolicy.NOT_DELETED,
) -> tuple[Sequence[User], int]:
sql_filter = construct_sql_query_filters(filters, self.SEARCHABLE_FIELDS)

count_stmt = sa.select(sa.func.count()).select_from(Group)
if sql_filter is not None:
count_stmt = count_stmt.where(sql_filter)
count_stmt = _apply_deletion_policy(count_stmt, deletion_policy)
current_count = self.session.scalar(count_stmt)

# For mypy: a "SELECT count(*)..." query should never return NULL.
assert current_count is not None

groups: Sequence[Group]
if current_count == 0:
groups = []
else:
page_stmt = sa.select(Group)
if sql_filter is not None:
page_stmt = page_stmt.where(sql_filter)
page_stmt = _apply_deletion_policy(page_stmt, deletion_policy)
# *must* enforce a sort order for consistent paging
page_stmt = page_stmt.order_by(Group.group_id)
page_stmt = page_stmt.offset(page_start).limit(page_length)

groups = self.session.scalars(page_stmt).all()

return groups, current_count

def num_groups(
self, deletion_policy: DeletionPolicy = DeletionPolicy.NOT_DELETED
) -> int:
Expand Down
112 changes: 110 additions & 2 deletions src/dioptra/restapi/db/repository/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
"""

from collections.abc import Iterable, Sequence
from typing import Any, Final

import sqlalchemy as sa

from dioptra.restapi.db.models import Group, Queue, Resource
from dioptra.restapi.db.repository.errors import QueueAlreadyExistsError
from dioptra.restapi.db.models import Group, Queue, Resource, Tag
from dioptra.restapi.db.repository.errors import QueueAlreadyExistsError, QueueSortError
from dioptra.restapi.db.repository.utils import (
CompatibleSession,
DeletionPolicy,
Expand All @@ -35,6 +36,7 @@
assert_snapshot_does_not_exist,
assert_user_exists,
assert_user_in_group,
construct_sql_query_filters,
delete_resource,
get_group_id,
get_resource_id,
Expand All @@ -43,6 +45,21 @@


class QueueRepository:

SEARCHABLE_FIELDS: Final[dict[str, Any]] = {
"name": lambda x: Queue.name.like(x, escape="/"),
"description": lambda x: Queue.description.like(x, escape="/"),
"tag": lambda x: Queue.tags.any(Tag.name.like(x, escape="/")),
}

# Maps a general sort criterion name to a Queue attribute name
SORTABLE_FIELDS: Final[dict[str, str]] = {
"name": "name",
"createdOn": "created_on",
"lastModifiedOn": "last_modified_on",
"description": "description",
}

def __init__(self, session: CompatibleSession[S]):
self.session = session

Expand Down Expand Up @@ -261,3 +278,94 @@ def get_by_name(
queue = self.session.scalar(stmt)

return queue

def get_by_filters_paged(
self,
group: Group | int | None,
filters: list[dict],
page_start: int,
page_length: int,
sort_by: str | None,
descending: bool,
deletion_policy: DeletionPolicy = DeletionPolicy.NOT_DELETED,
) -> tuple[Sequence[Queue], int]:
"""
Get a page of queues according to more complex criteria.
Args:
group: Limit queues to those owned by this group; None to not limit
the search
filters: Search criteria, see parse_search_text()
page_start: Zero-based row index where the page should start
page_length: Maximum number of rows in the page
sort_by: Sort criterion; must be a key of SORTABLE_FIELDS. None
to sort in an implementation-dependent way.
descending: Whether to sort in descending order; only applicable
if sort_by is given
deletion_policy: Whether to look at deleted queues, non-deleted
queue, or all queues
Returns:
A 2-tuple including the page of queues and total count of matching
queues which exist
"""
sql_filter = construct_sql_query_filters(filters, self.SEARCHABLE_FIELDS)
if sort_by:
sort_by = self.SORTABLE_FIELDS.get(sort_by)
if not sort_by:
raise QueueSortError
group_id = None if group is None else get_group_id(group)

if group_id is not None:
assert_group_exists(self.session, group_id, DeletionPolicy.NOT_DELETED)

count_stmt = (
sa.select(sa.func.count())
.select_from(Queue, Resource)
.where(Queue.resource_snapshot_id == Resource.latest_snapshot_id)
)

if group_id is not None:
count_stmt = count_stmt.where(Resource.group_id == group_id)

if sql_filter is not None:
count_stmt = count_stmt.where(sql_filter)

count_stmt = apply_resource_deletion_policy(count_stmt, deletion_policy)
current_count = self.session.scalar(count_stmt)

# For mypy: a "SELECT count(*)..." query should never return NULL.
assert current_count is not None

queues: Sequence[Queue]
if current_count == 0:
queues = []
else:
page_stmt = (
sa.select(Queue)
.join(Resource)
.where(Queue.resource_snapshot_id == Resource.latest_snapshot_id)
)

if group_id is not None:
page_stmt = page_stmt.where(Resource.group_id == group_id)

if sql_filter is not None:
page_stmt = page_stmt.where(sql_filter)

page_stmt = apply_resource_deletion_policy(page_stmt, deletion_policy)

if sort_by:
sort_criteria = getattr(Queue, sort_by)
if descending:
sort_criteria = sort_criteria.desc()
else:
# *must* enforce a sort order for consistent paging
sort_criteria = Queue.resource_snapshot_id
page_stmt = page_stmt.order_by(sort_criteria)

page_stmt = page_stmt.offset(page_start).limit(page_length)

queues = self.session.scalars(page_stmt).all()

return queues, current_count
56 changes: 56 additions & 0 deletions src/dioptra/restapi/db/repository/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""
import uuid
from collections.abc import Sequence
from typing import Any, Final

import sqlalchemy as sa

Expand All @@ -33,12 +34,18 @@
assert_user_does_not_exist,
assert_user_exists,
check_user_collision,
construct_sql_query_filters,
user_exists,
)


class UserRepository:

SEARCHABLE_FIELDS: Final[dict[str, Any]] = {
"username": lambda x: User.username.like(x, escape="/"),
"email": lambda x: User.email_address.like(x, escape="/"),
}

def __init__(self, session: CompatibleSession[S]):
self.session = session

Expand Down Expand Up @@ -208,6 +215,55 @@ def get_by_email(

return user

def get_by_filters_paged(
self,
filters: list[dict],
page_start: int,
page_length: int,
deletion_policy: DeletionPolicy = DeletionPolicy.NOT_DELETED,
) -> tuple[Sequence[User], int]:
"""
Get some users according to search criteria.
Args:
filters: A structure representing search criteria. See
parse_search_text().
page_start: A row index where the returned page should start
page_length: A row count representing the page length
deletion_policy: Whether to look at deleted users, non-deleted
users, or all users
Returns:
A 2-tuple including a page of User objects, and a count of the
total number of users matching the criteria
"""
sql_filter = construct_sql_query_filters(filters, self.SEARCHABLE_FIELDS)

count_stmt = sa.select(sa.func.count()).select_from(User)
if sql_filter is not None:
count_stmt = count_stmt.where(sql_filter)
count_stmt = _apply_deletion_policy(count_stmt, deletion_policy)
current_count = self.session.scalar(count_stmt)

# For mypy: a "SELECT count(*)..." query should never return NULL.
assert current_count is not None

users: Sequence[User]
if current_count == 0:
users = []
else:
page_stmt = sa.select(User)
if sql_filter is not None:
page_stmt = page_stmt.where(sql_filter)
page_stmt = _apply_deletion_policy(page_stmt, deletion_policy)
# *must* enforce a sort order for consistent paging
page_stmt = page_stmt.order_by(User.user_id)
page_stmt = page_stmt.offset(page_start).limit(page_length)

users = self.session.scalars(page_stmt).all()

return users, current_count

def num_users(
self, deletion_policy: DeletionPolicy = DeletionPolicy.NOT_DELETED
) -> int:
Expand Down
Loading

0 comments on commit 82637c5

Please sign in to comment.