Skip to content

Commit

Permalink
Add GroupMembership.roles column to model
Browse files Browse the repository at this point in the history
  • Loading branch information
seanh committed Oct 18, 2024
1 parent 8f2d01f commit 795eafc
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 1 deletion.
2 changes: 1 addition & 1 deletion h/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from h.models.feature import Feature
from h.models.feature_cohort import FeatureCohort, FeatureCohortUser
from h.models.flag import Flag
from h.models.group import Group, GroupMembership
from h.models.group import Group, GroupMembership, GroupMembershipRoles
from h.models.group_scope import GroupScope
from h.models.job import Job
from h.models.organization import Organization
Expand Down
21 changes: 21 additions & 0 deletions h/models/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import slugify
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import JSONB

from h import pubid
from h.db import Base, mixins
Expand Down Expand Up @@ -33,6 +34,15 @@ class WriteableBy(enum.Enum):
members = "members"


class GroupMembershipRoles(enum.StrEnum):
"""The valid role strings that're allowed in the GroupMembership.roles column."""

MEMBER = "member"
MODERATOR = "moderator"
ADMIN = "admin"
OWNER = "owner"


class GroupMembership(Base):
__tablename__ = "user_group"

Expand All @@ -47,6 +57,17 @@ class GroupMembership(Base):
nullable=False,
index=True,
)
roles = sa.Column(
JSONB,
sa.CheckConstraint(
" OR ".join(
f"""(roles = '["{role}"]'::jsonb)""" for role in GroupMembershipRoles
),
name="validate_role_strings",
),
server_default=sa.text("""'["member"]'::jsonb"""),
nullable=False,
)


class Group(Base, mixins.Timestamps):
Expand Down
59 changes: 59 additions & 0 deletions tests/unit/h/models/group_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from sqlalchemy.exc import IntegrityError

from h import models
from h.models.group import (
Expand Down Expand Up @@ -218,6 +219,64 @@ def test_non_public_group():
assert not group.is_public


class TestGroupMembership:
def test_defaults(self, db_session, user, group):
membership = models.GroupMembership(user_id=user.id, group_id=group.id)
db_session.add(membership)

db_session.flush()
assert membership.id
assert membership.user_id == user.id
assert membership.group_id == group.id
assert membership.roles == ["member"]

@pytest.mark.parametrize(
"roles",
(
["member"],
["moderator"],
["admin"],
["owner"],
),
)
def test_custom_roles(self, db_session, user, group, roles):
membership = models.GroupMembership(
user_id=user.id, group_id=group.id, roles=roles
)
db_session.add(membership)

db_session.flush()
assert membership.roles == roles

@pytest.mark.parametrize(
"roles",
(
["unknown_role"],
["moderator", "admin"], # Two valid roles, only one role is allowed.
[], # Every membership must have at least one role.
),
)
def test_invalid_roles(self, db_session, user, group, roles):
membership = models.GroupMembership(
user_id=user.id, group_id=group.id, roles=roles
)
db_session.add(membership)

with pytest.raises(
IntegrityError,
match='new row for relation "user_group" violates check constraint "ck__user_group__validate_role_strings"',
):
db_session.flush()

@pytest.fixture
def user(self, factories):
return factories.User()

@pytest.fixture
def group(self, factories):
return factories.Group()


@pytest.fixture()
def organization(factories):
return factories.Organization()

0 comments on commit 795eafc

Please sign in to comment.