Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

from __future__ import annotations #477

Merged
merged 4 commits into from
Oct 1, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions waffle/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import TYPE_CHECKING, Optional, Type, Union
from __future__ import annotations

from typing import TYPE_CHECKING

import django
from django.core.exceptions import ImproperlyConfigured
Expand All @@ -14,7 +16,7 @@
__version__ = '.'.join(map(str, VERSION))


def flag_is_active(request: HttpRequest, flag_name: str, read_only: bool = False) -> Optional[bool]:
def flag_is_active(request: HttpRequest, flag_name: str, read_only: bool = False) -> bool | None:
flag = get_waffle_flag_model().get(flag_name)
return flag.is_active(request, read_only=read_only)

Expand All @@ -29,21 +31,21 @@ def sample_is_active(sample_name: str) -> bool:
return sample.is_active()


def get_waffle_flag_model() -> Type['AbstractBaseFlag']:
def get_waffle_flag_model() -> type[AbstractBaseFlag]:
return get_waffle_model('FLAG_MODEL')


def get_waffle_switch_model() -> Type['AbstractBaseSwitch']:
def get_waffle_switch_model() -> type[AbstractBaseSwitch]:
return get_waffle_model('SWITCH_MODEL')


def get_waffle_sample_model() -> Type['AbstractBaseSample']:
def get_waffle_sample_model() -> type[AbstractBaseSample]:
return get_waffle_model('SAMPLE_MODEL')


def get_waffle_model(setting_name: str) -> Union[
Type['AbstractBaseFlag'], Type['AbstractBaseSwitch'], Type['AbstractBaseSample']
]:
def get_waffle_model(setting_name: str) -> (
type[AbstractBaseFlag] | type[AbstractBaseSwitch] | type[AbstractBaseSample]
):
"""
Returns the waffle Flag model that is active in this project.
"""
Expand Down
8 changes: 5 additions & 3 deletions waffle/admin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, Dict, Tuple
from __future__ import annotations

from typing import Any

from django.contrib import admin
from django.contrib.admin.models import LogEntry, CHANGE, DELETION
Expand All @@ -15,7 +17,7 @@
class BaseAdmin(admin.ModelAdmin):
search_fields = ('name', 'note')

def get_actions(self, request: HttpRequest) -> Dict[str, Any]:
def get_actions(self, request: HttpRequest) -> dict[str, Any]:
actions = super().get_actions(request)
if 'delete_selected' in actions:
del actions['delete_selected']
Expand Down Expand Up @@ -73,7 +75,7 @@ class InformativeManyToManyRawIdWidget(ManyToManyRawIdWidget):
Will display the names of the users in a parenthesised list after the
input field. This widget works with all models that have a "name" field.
"""
def label_and_url_for_value(self, values: Any) -> Tuple[str, str]:
def label_and_url_for_value(self, values: Any) -> tuple[str, str]:
names = []
key = self.rel.get_related_field().name
for value in values:
Expand Down
12 changes: 7 additions & 5 deletions waffle/decorators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from functools import wraps, WRAPPER_ASSIGNMENTS
from typing import Any, Callable, Optional, Union
from typing import Any, Callable

from django.http import Http404, HttpRequest, HttpResponse, HttpResponsePermanentRedirect, HttpResponseRedirect
from django.shortcuts import redirect
Expand All @@ -9,7 +11,7 @@


def waffle_flag(
flag_name: str, redirect_to: Optional[Union[Callable, str]] = None,
flag_name: str, redirect_to: Callable | str | None = None,
) -> Callable[[Callable[[HttpRequest], HttpResponse]], Callable[[HttpRequest], HttpResponse]]:
def decorator(view: Callable[[HttpRequest], HttpResponse]) -> Callable[[HttpRequest], HttpResponse]:
@wraps(view, assigned=WRAPPER_ASSIGNMENTS)
Expand All @@ -32,7 +34,7 @@ def _wrapped_view(request, *args, **kwargs):


def waffle_switch(
switch_name: str, redirect_to: Optional[Union[Callable, str]] = None,
switch_name: str, redirect_to: Callable | str | None = None,
) -> Callable[[Callable[[HttpRequest], HttpResponse]], Callable[[HttpRequest], HttpResponse]]:
def decorator(view: Callable[[HttpRequest], HttpResponse]) -> Callable[[HttpRequest], HttpResponse]:
@wraps(view, assigned=WRAPPER_ASSIGNMENTS)
Expand All @@ -55,8 +57,8 @@ def _wrapped_view(request, *args, **kwargs):


def get_response_to_redirect(
view: Optional[Union[Callable, str]], *args: Any, **kwargs: Any,
) -> Optional[Union[HttpResponseRedirect, HttpResponsePermanentRedirect]]:
view: Callable | str | None, *args: Any, **kwargs: Any,
) -> HttpResponseRedirect | HttpResponsePermanentRedirect | None:
try:
return redirect(reverse(view, args=args, kwargs=kwargs)) if view else None
except NoReverseMatch:
Expand Down
9 changes: 5 additions & 4 deletions waffle/mixins.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import partial
from typing import Optional

from django.http import Http404

Expand All @@ -25,7 +26,7 @@ class WaffleFlagMixin(BaseWaffleMixin):
waffle_flag
"""

waffle_flag: Optional[str] = None
waffle_flag: str | None = None

def dispatch(self, request, *args, **kwargs):
func = partial(flag_is_active, request)
Expand All @@ -43,7 +44,7 @@ class WaffleSampleMixin(BaseWaffleMixin):
waffle_sample.
"""

waffle_sample: Optional[str] = None
waffle_sample: str | None = None

def dispatch(self, request, *args, **kwargs):
active = self.validate_waffle(self.waffle_sample, sample_is_active)
Expand All @@ -60,7 +61,7 @@ class WaffleSwitchMixin(BaseWaffleMixin):
waffle_switch.
"""

waffle_switch: Optional[str] = None
waffle_switch: str | None = None

def dispatch(self, request, *args, **kwargs):
active = self.validate_waffle(self.waffle_switch, switch_is_active)
Expand Down
36 changes: 19 additions & 17 deletions waffle/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import logging
import random
from decimal import Decimal
from typing import Any, Dict, List, Optional, Set, Tuple, Type, TypeVar
from typing import Any, TypeVar

from django.conf import settings
from django.contrib.auth.models import AbstractBaseUser, Group
Expand Down Expand Up @@ -36,15 +38,15 @@ class Meta:
def __str__(self) -> str:
return self.name

def natural_key(self) -> Tuple[str]:
def natural_key(self) -> tuple[str]:
return (self.name,)

@classmethod
def _cache_key(cls, name: str) -> str:
return keyfmt(get_setting(cls.SINGLE_CACHE_KEY), name)

@classmethod
def get(cls: Type[_BaseModelType], name: str) -> _BaseModelType:
def get(cls: type[_BaseModelType], name: str) -> _BaseModelType:
cache = get_cache()
cache_key = cls._cache_key(name)
cached = cache.get(cache_key)
Expand All @@ -63,14 +65,14 @@ def get(cls: Type[_BaseModelType], name: str) -> _BaseModelType:
return obj

@classmethod
def get_from_db(cls: Type[_BaseModelType], name: str) -> _BaseModelType:
def get_from_db(cls: type[_BaseModelType], name: str) -> _BaseModelType:
objects = cls.objects
if get_setting('READ_FROM_WRITE_DB'):
objects = objects.using(router.db_for_write(cls))
return objects.get(name=name)

@classmethod
def get_all(cls: Type[_BaseModelType]) -> List[_BaseModelType]:
def get_all(cls: type[_BaseModelType]) -> list[_BaseModelType]:
cache = get_cache()
cache_key = get_setting(cls.ALL_CACHE_KEY)
cached = cache.get(cache_key)
Expand All @@ -88,7 +90,7 @@ def get_all(cls: Type[_BaseModelType]) -> List[_BaseModelType]:
return objs

@classmethod
def get_all_from_db(cls: Type[_BaseModelType]) -> List[_BaseModelType]:
def get_all_from_db(cls: type[_BaseModelType]) -> list[_BaseModelType]:
objects = cls.objects
if get_setting('READ_FROM_WRITE_DB'):
objects = objects.using(router.db_for_write(cls))
Expand All @@ -111,7 +113,7 @@ def save(self, *args: Any, **kwargs: Any) -> None:
self.flush()
return ret

def delete(self, *args: Any, **kwargs: Any) -> Tuple[int, Dict[str, int]]:
def delete(self, *args: Any, **kwargs: Any) -> tuple[int, dict[str, int]]:
ret = super().delete(*args, **kwargs)
if hasattr(transaction, 'on_commit'):
transaction.on_commit(self.flush)
Expand All @@ -120,7 +122,7 @@ def delete(self, *args: Any, **kwargs: Any) -> Tuple[int, Dict[str, int]]:
return ret


def set_flag(request: HttpRequest, flag_name: str, active: Optional[bool] = True, session_only: bool = False) -> None:
def set_flag(request: HttpRequest, flag_name: str, active: bool | None = True, session_only: bool = False) -> None:
"""Set a flag value on a request object."""
if not hasattr(request, 'waffles'):
request.waffles = {}
Expand Down Expand Up @@ -219,15 +221,15 @@ def flush(self) -> None:
keys = self.get_flush_keys()
cache.delete_many(keys)

def get_flush_keys(self, flush_keys: Optional[List[str]] = None) -> List[str]:
def get_flush_keys(self, flush_keys: list[str] | None = None) -> list[str]:
flush_keys = flush_keys or []
flush_keys.extend([
self._cache_key(self.name),
get_setting('ALL_FLAGS_CACHE_KEY'),
])
return flush_keys

def is_active_for_user(self, user: AbstractBaseUser) -> Optional[bool]:
def is_active_for_user(self, user: AbstractBaseUser) -> bool | None:
if self.authenticated and user.is_authenticated:
return True

Expand All @@ -239,21 +241,21 @@ def is_active_for_user(self, user: AbstractBaseUser) -> Optional[bool]:

return None

def _is_active_for_user(self, request: HttpRequest) -> Optional[bool]:
def _is_active_for_user(self, request: HttpRequest) -> bool | None:
user = getattr(request, "user", None)
if user:
return self.is_active_for_user(user)
return False

def _is_active_for_language(self, request: HttpRequest) -> Optional[bool]:
def _is_active_for_language(self, request: HttpRequest) -> bool | None:
if self.languages:
languages = [ln.strip() for ln in self.languages.split(',')]
if (hasattr(request, 'LANGUAGE_CODE') and
request.LANGUAGE_CODE in languages):
return True
return None

def is_active(self, request: HttpRequest, read_only: bool = False) -> Optional[bool]:
def is_active(self, request: HttpRequest, read_only: bool = False) -> bool | None:
if not self.pk:
log_level = get_setting('LOG_MISSING_FLAGS')
if log_level:
Expand Down Expand Up @@ -347,15 +349,15 @@ class Meta(AbstractBaseFlag.Meta):
verbose_name = _('Flag')
verbose_name_plural = _('Flags')

def get_flush_keys(self, flush_keys: Optional[List[str]] = None) -> List[str]:
def get_flush_keys(self, flush_keys: list[str] | None = None) -> list[str]:
flush_keys = super().get_flush_keys(flush_keys)
flush_keys.extend([
keyfmt(get_setting('FLAG_USERS_CACHE_KEY'), self.name),
keyfmt(get_setting('FLAG_GROUPS_CACHE_KEY'), self.name),
])
return flush_keys

def _get_user_ids(self) -> Set[Any]:
def _get_user_ids(self) -> set[Any]:
cache = get_cache()
cache_key = keyfmt(get_setting('FLAG_USERS_CACHE_KEY'), self.name)
cached = cache.get(cache_key)
Expand All @@ -372,7 +374,7 @@ def _get_user_ids(self) -> Set[Any]:
cache.add(cache_key, user_ids)
return user_ids

def _get_group_ids(self) -> Set[Any]:
def _get_group_ids(self) -> set[Any]:
cache = get_cache()
cache_key = keyfmt(get_setting('FLAG_GROUPS_CACHE_KEY'), self.name)
cached = cache.get(cache_key)
Expand All @@ -389,7 +391,7 @@ def _get_group_ids(self) -> Set[Any]:
cache.add(cache_key, group_ids)
return group_ids

def is_active_for_user(self, user: AbstractBaseUser) -> Optional[bool]:
def is_active_for_user(self, user: AbstractBaseUser) -> bool | None:
is_active = super().is_active_for_user(user)
if is_active:
return is_active
Expand Down
6 changes: 4 additions & 2 deletions waffle/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import hashlib
from typing import Any, Optional
from typing import Any

from django.conf import settings
from django.core.cache import BaseCache, caches
Expand All @@ -15,7 +17,7 @@ def get_setting(name: str, default: Any = None) -> Any:
return getattr(defaults, name, default)


def keyfmt(k: str, v: Optional[str] = None) -> str:
def keyfmt(k: str, v: str | None = None) -> str:
prefix = get_setting('CACHE_PREFIX') + waffle.__version__
if v is None:
key = prefix + k
Expand Down
6 changes: 4 additions & 2 deletions waffle/views.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, Dict
from __future__ import annotations

from typing import Any

from django.http import HttpRequest, HttpResponse, JsonResponse
from django.template import loader
Expand Down Expand Up @@ -39,7 +41,7 @@ def waffle_json(request):
return JsonResponse(_generate_waffle_json(request))


def _generate_waffle_json(request: HttpRequest) -> Dict[str, Dict[str, Any]]:
def _generate_waffle_json(request: HttpRequest) -> dict[str, dict[str, Any]]:
flags = get_waffle_flag_model().get_all()
flag_values = {
f.name: {
Expand Down