Skip to content

Commit

Permalink
Define a specific badge for each model using a badge mixin
Browse files Browse the repository at this point in the history
And use inheritance instead of a `get_badge_mixin` helper in the
inheritance declaration of the model.
  • Loading branch information
magopian committed Oct 17, 2024
1 parent 4cef617 commit 362acf2
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 96 deletions.
4 changes: 1 addition & 3 deletions udata/core/badges/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@

from udata.factories import ModelFactory

from .models import get_badge


def badge_factory(model_):
class BadgeFactory(ModelFactory):
class Meta:
model = get_badge(model_.__badges__)
model = model_._fields["badges"].field.document_type

kind = FuzzyChoice(model_.__badges__)

Expand Down
4 changes: 0 additions & 4 deletions udata/core/badges/forms.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
from udata.forms import ModelForm, fields, validators
from udata.i18n import lazy_gettext as _
from udata.models import get_badge

__all__ = ("badge_form",)


def badge_form(model):
"""A form factory for a given model badges"""
badge = get_badge()

class BadgeForm(ModelForm):
model_class = badge

kind = fields.RadioField(
_("Kind"),
[validators.DataRequired()],
Expand Down
149 changes: 72 additions & 77 deletions udata/core/badges/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,80 +13,75 @@
log = logging.getLogger(__name__)


__all__ = ["get_badge_mixin", "get_badge"]


def get_badge(choices=None):
@generate_fields(default_filterable_field="kind")
class Badge(db.EmbeddedDocument):
kind = db.StringField(required=True, choices=list(choices.keys()) if choices else None)
created = db.DateTimeField(default=datetime.utcnow, required=True)
created_by = db.ReferenceField("User")

def __str__(self):
return self.kind

return Badge


def get_badge_mixin(badges_):
class BadgesList(db.EmbeddedDocumentListField):
def __init__(self, *args, **kwargs):
return super(BadgesList, self).__init__(get_badge(badges_), *args, **kwargs)

class BadgeMixin(object):
__badges__ = badges_
badge = get_badge(badges_)

badges = field(
BadgesList(),
readonly=True,
inner_field_info={"nested_fields": badge_fields},
)

def get_badge(self, kind):
"""Get a badge given its kind if present"""
candidates = [b for b in self.badges if b.kind == kind]
return candidates[0] if candidates else None

def add_badge(self, kind):
"""Perform an atomic prepend for a new badge"""
badge = self.get_badge(kind)
if badge:
return badge
if kind not in getattr(self, "__badges__", {}):
msg = "Unknown badge type for {model}: {kind}"
raise db.ValidationError(msg.format(model=self.__class__.__name__, kind=kind))
badge = self.badge(kind=kind)
if current_user.is_authenticated:
badge.created_by = current_user.id

self.update(
__raw__={"$push": {"badges": {"$each": [badge.to_mongo()], "$position": 0}}}
)
self.reload()
post_save.send(self.__class__, document=self)
on_badge_added.send(self, kind=kind)
return self.get_badge(kind)

def remove_badge(self, kind):
"""Perform an atomic removal for a given badge"""
self.update(__raw__={"$pull": {"badges": {"kind": kind}}})
self.reload()
on_badge_removed.send(self, kind=kind)
post_save.send(self.__class__, document=self)

def toggle_badge(self, kind):
"""Toggle a bdage given its kind"""
badge = self.get_badge(kind)
if badge:
return self.remove_badge(kind)
else:
return self.add_badge(kind)

def badge_label(self, badge):
"""Display the badge label for a given kind"""
kind = badge.kind if isinstance(badge, self.badge) else badge
return self.__badges__[kind]

return BadgeMixin
__all__ = ["Badge", "BadgeMixin", "BadgesList"]

DEFAULT_BADGES_LIST_PARAMS = {
"readonly": True,
"inner_field_info": {"nested_fields": badge_fields},
}


@generate_fields(default_filterable_field="kind")
class Badge(db.EmbeddedDocument):
meta = {"allow_inheritance": True}
# The following field should be overloaded in descendants.
kind = db.StringField(required=True)
created = db.DateTimeField(default=datetime.utcnow, required=True)
created_by = db.ReferenceField("User")

def __str__(self):
return self.kind


class BadgesList(db.EmbeddedDocumentListField):
def __init__(self, badge_model, *args, **kwargs):
return super(BadgesList, self).__init__(badge_model, *args, **kwargs)


class BadgeMixin:
default_badges_list_params = DEFAULT_BADGES_LIST_PARAMS
# The following field should be overloaded in descendants.
badges = field(BadgesList(Badge), **DEFAULT_BADGES_LIST_PARAMS)

def get_badge(self, kind):
"""Get a badge given its kind if present"""
candidates = [b for b in self.badges if b.kind == kind]
return candidates[0] if candidates else None

def add_badge(self, kind):
"""Perform an atomic prepend for a new badge"""
badge = self.get_badge(kind)
if badge:
return badge
if kind not in getattr(self, "__badges__", {}):
msg = "Unknown badge type for {model}: {kind}"
raise db.ValidationError(msg.format(model=self.__class__.__name__, kind=kind))
badge = self._fields["badges"].field.document_type(kind=kind)
if current_user.is_authenticated:
badge.created_by = current_user.id

self.update(__raw__={"$push": {"badges": {"$each": [badge.to_mongo()], "$position": 0}}})
self.reload()
post_save.send(self.__class__, document=self)
on_badge_added.send(self, kind=kind)
return self.get_badge(kind)

def remove_badge(self, kind):
"""Perform an atomic removal for a given badge"""
self.update(__raw__={"$pull": {"badges": {"kind": kind}}})
self.reload()
on_badge_removed.send(self, kind=kind)
post_save.send(self.__class__, document=self)

def toggle_badge(self, kind):
"""Toggle a bdage given its kind"""
badge = self.get_badge(kind)
if badge:
return self.remove_badge(kind)
else:
return self.add_badge(kind)

def badge_label(self, badge):
"""Display the badge label for a given kind"""
kind = badge.kind if isinstance(badge, self.badge) else badge
return self.__badges__[kind]
14 changes: 12 additions & 2 deletions udata/core/badges/tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from udata.api_fields import field
from udata.auth import login_user
from udata.core.user.factories import UserFactory
from udata.mongo import db
from udata.tests import DBTestMixin, TestCase

from ..models import get_badge_mixin
from ..models import Badge, BadgeMixin, BadgesList

TEST = "test"
OTHER = "other"
Expand All @@ -14,7 +15,16 @@
}


class Fake(db.Document, get_badge_mixin(BADGES)):
class FakeBadge(Badge):
kind = db.StringField(required=True, choices=list(BADGES.keys()))


class FakeBadgeMixin(BadgeMixin):
badges = field(BadgesList(FakeBadge), **BadgeMixin.default_badges_list_params)
__badges__ = BADGES


class Fake(db.Document, FakeBadgeMixin):
pass


Expand Down
14 changes: 12 additions & 2 deletions udata/core/dataset/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
from stringdist import rdlevenshtein
from werkzeug.utils import cached_property

from udata.api_fields import field
from udata.app import cache
from udata.core import storages
from udata.core.owned import Owned, OwnedQuerySet
from udata.frontend.markdown import mdstrip
from udata.i18n import lazy_gettext as _
from udata.models import SpatialCoverage, WithMetrics, db, get_badge_mixin
from udata.models import Badge, BadgeMixin, BadgesList, SpatialCoverage, WithMetrics, db
from udata.mongo.errors import FieldValidationError
from udata.uris import ValidationError, endpoint_for
from udata.uris import validate as validate_url
Expand Down Expand Up @@ -502,7 +503,16 @@ def save(self, *args, **kwargs):
self.dataset.save(*args, **kwargs)


class Dataset(WithMetrics, get_badge_mixin(BADGES), Owned, db.Document):
class DatasetBadge(Badge):
kind = db.StringField(required=True, choices=list(BADGES.keys()))


class DatasetBadgeMixin(BadgeMixin):
badges = field(BadgesList(DatasetBadge), **BadgeMixin.default_badges_list_params)
__badges__ = BADGES


class Dataset(WithMetrics, DatasetBadgeMixin, Owned, db.Document):
title = db.StringField(required=True)
acronym = db.StringField(max_length=128)
# /!\ do not set directly the slug when creating or updating a dataset
Expand Down
14 changes: 12 additions & 2 deletions udata/core/organization/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from mongoengine.signals import post_save, pre_save
from werkzeug.utils import cached_property

from udata.core.badges.models import get_badge_mixin
from udata.api_fields import field
from udata.core.badges.models import Badge, BadgeMixin, BadgesList
from udata.core.metrics.models import WithMetrics
from udata.core.storages import avatars, default_image_basename
from udata.frontend.markdown import mdstrip
Expand Down Expand Up @@ -94,7 +95,16 @@ def with_badge(self, kind):
return self(badges__kind=kind)


class Organization(WithMetrics, get_badge_mixin(BADGES), db.Datetimed, db.Document):
class OrganizationBadge(Badge):
kind = db.StringField(required=True, choices=list(BADGES.keys()))


class OrganizationBadgeMixin(BadgeMixin):
badges = field(BadgesList(OrganizationBadge), **BadgeMixin.default_badges_list_params)
__badges__ = BADGES


class Organization(WithMetrics, OrganizationBadgeMixin, db.Datetimed, db.Document):
name = db.StringField(required=True)
acronym = db.StringField(max_length=128)
slug = db.SlugField(
Expand Down
13 changes: 11 additions & 2 deletions udata/core/reuse/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from udata.core.storages import default_image_basename, images
from udata.frontend.markdown import mdstrip
from udata.i18n import lazy_gettext as _
from udata.models import WithMetrics, db, get_badge_mixin
from udata.models import Badge, BadgeMixin, BadgesList, WithMetrics, db
from udata.mongo.errors import FieldValidationError
from udata.uris import endpoint_for
from udata.utils import hash_url
Expand All @@ -35,6 +35,15 @@ def check_url_does_not_exists(url):
raise FieldValidationError(_("This URL is already registered"), field="url")


class ReuseBadge(Badge):
kind = db.StringField(required=True, choices=list(BADGES.keys()))


class ReuseBadgeMixin(BadgeMixin):
badges = field(BadgesList(ReuseBadge), **BadgeMixin.default_badges_list_params)
__badges__ = BADGES


@generate_fields(
searchable=True,
additional_sorts=[
Expand All @@ -46,7 +55,7 @@ def check_url_does_not_exists(url):
"organization.badges",
],
)
class Reuse(db.Datetimed, WithMetrics, get_badge_mixin(BADGES), Owned, db.Document):
class Reuse(db.Datetimed, WithMetrics, ReuseBadgeMixin, Owned, db.Document):
title = field(
db.StringField(required=True),
sortable=True,
Expand Down
4 changes: 0 additions & 4 deletions udata/tests/api/test_organizations_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,10 +833,6 @@ class OrganizationBadgeAPITest:

@pytest.fixture(autouse=True)
def setUp(self, api, clean_db):
# Register at least two badges
Organization.__badges__["test-1"] = "Test 1"
Organization.__badges__["test-2"] = "Test 2"

self.factory = badge_factory(Organization)
self.user = api.login(AdminFactory())
self.organization = OrganizationFactory()
Expand Down

0 comments on commit 362acf2

Please sign in to comment.