Skip to content

Commit

Permalink
misc: Add initial type hints (#470)
Browse files Browse the repository at this point in the history
Introducing partial type hints to the project, in preparation for
publishing typing stubs for any applications using `django-waffle`.
  • Loading branch information
adamantike authored Sep 17, 2022
1 parent 25dbbe2 commit c8c444e
Show file tree
Hide file tree
Showing 10 changed files with 73 additions and 48 deletions.
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ ignore = E731

[mypy]
python_version = 3.7
exclude = waffle/tests
disallow_incomplete_defs = True
disallow_untyped_calls = True
disallow_untyped_decorators = True
strict_equality = True
[mypy-django.*]
Expand Down
22 changes: 15 additions & 7 deletions waffle/__init__.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,49 @@
from typing import TYPE_CHECKING, Optional, Type, Union

import django
from django.core.exceptions import ImproperlyConfigured
from django.http import HttpRequest

from waffle.utils import get_setting
from django.apps import apps as django_apps

if TYPE_CHECKING:
from waffle.models import AbstractBaseFlag, AbstractBaseSample, AbstractBaseSwitch

VERSION = (3, 0, 0)
__version__ = '.'.join(map(str, VERSION))


def flag_is_active(request, flag_name, read_only=False):
def flag_is_active(request: HttpRequest, flag_name: str, read_only: bool = False) -> Optional[bool]:
flag = get_waffle_flag_model().get(flag_name)
return flag.is_active(request, read_only=read_only)


def switch_is_active(switch_name):
def switch_is_active(switch_name: str) -> bool:
switch = get_waffle_switch_model().get(switch_name)
return switch.is_active()


def sample_is_active(sample_name):
def sample_is_active(sample_name: str) -> bool:
sample = get_waffle_sample_model().get(sample_name)
return sample.is_active()


def get_waffle_flag_model():
def get_waffle_flag_model() -> Type['AbstractBaseFlag']:
return get_waffle_model('FLAG_MODEL')


def get_waffle_switch_model():
def get_waffle_switch_model() -> Type['AbstractBaseSwitch']:
return get_waffle_model('SWITCH_MODEL')


def get_waffle_sample_model():
def get_waffle_sample_model() -> Type['AbstractBaseSample']:
return get_waffle_model('SAMPLE_MODEL')


def get_waffle_model(setting_name):
def get_waffle_model(setting_name: str) -> Union[
Type['AbstractBaseFlag'], Type['AbstractBaseSwitch'], Type['AbstractBaseSample']
]:
"""
Returns the waffle Flag model that is active in this project.
"""
Expand Down
2 changes: 1 addition & 1 deletion waffle/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ class WaffleConfig(AppConfig):
verbose_name = 'django-waffle'
default_auto_field = 'django.db.models.AutoField'

def ready(self):
def ready(self) -> None:
import waffle.signals # noqa: F401
8 changes: 5 additions & 3 deletions waffle/management/commands/waffle_delete.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from django.core.management.base import BaseCommand
from typing import Any

from django.core.management.base import BaseCommand, CommandParser

from waffle import (
get_waffle_flag_model,
Expand All @@ -8,7 +10,7 @@


class Command(BaseCommand):
def add_arguments(self, parser):
def add_arguments(self, parser: CommandParser) -> None:
parser.add_argument(
'--flags',
action='store',
Expand All @@ -30,7 +32,7 @@ def add_arguments(self, parser):

help = 'Delete flags, samples, and switches from database'

def handle(self, *args, **options):
def handle(self, *args: Any, **options: Any) -> None:
flags = options['flag_names']
if flags:
flag_queryset = get_waffle_flag_model().objects.filter(name__in=flags)
Expand Down
8 changes: 5 additions & 3 deletions waffle/management/commands/waffle_flag.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any

from django.contrib.auth import get_user_model
from django.contrib.auth.models import Group
from django.core.management.base import BaseCommand, CommandError
from django.core.management.base import BaseCommand, CommandError, CommandParser
from django.db.models import Q

from waffle import get_waffle_flag_model
Expand All @@ -9,7 +11,7 @@


class Command(BaseCommand):
def add_arguments(self, parser):
def add_arguments(self, parser: CommandParser) -> None:
parser.add_argument(
'name',
nargs='?',
Expand Down Expand Up @@ -91,7 +93,7 @@ def add_arguments(self, parser):

help = 'Modify a flag.'

def handle(self, *args, **options):
def handle(self, *args: Any, **options: Any) -> None:
if options['list_flags']:
self.stdout.write('Flags:')
for flag in get_waffle_flag_model().objects.iterator():
Expand Down
8 changes: 5 additions & 3 deletions waffle/management/commands/waffle_sample.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from django.core.management.base import BaseCommand, CommandError
from typing import Any

from django.core.management.base import BaseCommand, CommandError, CommandParser

from waffle import get_waffle_sample_model


class Command(BaseCommand):
def add_arguments(self, parser):
def add_arguments(self, parser: CommandParser) -> None:
parser.add_argument(
'name',
nargs='?',
Expand All @@ -27,7 +29,7 @@ def add_arguments(self, parser):

help = 'Change percentage of a sample.'

def handle(self, *args, **options):
def handle(self, *args: Any, **options: Any) -> None:
if options['list_samples']:
self.stdout.write('Samples:')
for sample in get_waffle_sample_model().objects.iterator():
Expand Down
8 changes: 5 additions & 3 deletions waffle/management/commands/waffle_switch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any

from argparse import ArgumentTypeError
from django.core.management.base import BaseCommand, CommandError
from django.core.management.base import BaseCommand, CommandError, CommandParser

from waffle import get_waffle_switch_model

Expand All @@ -12,7 +14,7 @@ def on_off_bool(string):


class Command(BaseCommand):
def add_arguments(self, parser):
def add_arguments(self, parser: CommandParser) -> None:
parser.add_argument(
'name',
nargs='?',
Expand All @@ -37,7 +39,7 @@ def add_arguments(self, parser):

help = 'Activate or deactivate a switch.'

def handle(self, *args, **options):
def handle(self, *args: Any, **options: Any) -> None:
if options['list_switches']:
self.stdout.write('Switches:')
for switch in get_waffle_switch_model().objects.iterator():
Expand Down
3 changes: 2 additions & 1 deletion waffle/middleware.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from django.http import HttpRequest, HttpResponse
from django.utils.deprecation import MiddlewareMixin
from django.utils.encoding import smart_str

from waffle.utils import get_setting


class WaffleMiddleware(MiddlewareMixin):
def process_response(self, request, response):
def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
secure = get_setting('SECURE')
max_age = get_setting('MAX_AGE')

Expand Down
51 changes: 28 additions & 23 deletions waffle/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
import random
from decimal import Decimal
import logging
from typing import Any, List, Optional, Set, Tuple, Type, TypeVar

from django.conf import settings
from django.contrib.auth.models import Group
from django.contrib.auth.models import AbstractBaseUser, Group
from django.db import models, router, transaction
from django.http import HttpRequest
from django.utils import timezone
from django.utils.translation import gettext_lazy as _

Expand All @@ -21,25 +23,28 @@
CACHE_EMPTY = '-'


_BaseModelType = TypeVar("_BaseModelType", bound="BaseModel")


class BaseModel(models.Model):
SINGLE_CACHE_KEY = ''
ALL_CACHE_KEY = ''

class Meta:
abstract = True

def __str__(self):
def __str__(self) -> str:
return self.name

def natural_key(self):
def natural_key(self) -> Tuple[str]:
return (self.name,)

@classmethod
def _cache_key(cls, name):
def _cache_key(cls, name: str) -> str:
return keyfmt(get_setting(cls.SINGLE_CACHE_KEY), name)

@classmethod
def get(cls, name):
def get(cls: Type[_BaseModelType], name: str) -> _BaseModelType:
cache = get_cache()
cache_key = cls._cache_key(name)
cached = cache.get(cache_key)
Expand All @@ -58,14 +63,14 @@ def get(cls, name):
return obj

@classmethod
def get_from_db(cls, name):
def get_from_db(cls: Type[_BaseModelType], name: str) -> _BaseModelType:
objects = cls.objects
if get_setting('READ_FROM_WRITE_DB'):
objects = objects.using(router.db_for_write(cls))
return objects.get(name=name)

@classmethod
def get_all(cls):
def get_all(cls: Type[_BaseModelType]) -> List[_BaseModelType]:
cache = get_cache()
cache_key = get_setting(cls.ALL_CACHE_KEY)
cached = cache.get(cache_key)
Expand All @@ -83,13 +88,13 @@ def get_all(cls):
return objs

@classmethod
def get_all_from_db(cls):
def get_all_from_db(cls: Type[_BaseModelType]) -> List[_BaseModelType]:
objects = cls.objects
if get_setting('READ_FROM_WRITE_DB'):
objects = objects.using(router.db_for_write(cls))
return list(objects.all())

def flush(self):
def flush(self) -> None:
cache = get_cache()
keys = [
self._cache_key(self.name),
Expand All @@ -115,7 +120,7 @@ def delete(self, *args, **kwargs):
return ret


def set_flag(request, flag_name, active=True, session_only=False):
def set_flag(request: HttpRequest, flag_name: str, active: Optional[bool] = True, session_only: bool = False) -> None:
"""Set a flag value on a request object."""
if not hasattr(request, 'waffles'):
request.waffles = {}
Expand Down Expand Up @@ -209,20 +214,20 @@ class Meta:
verbose_name = _('Flag')
verbose_name_plural = _('Flags')

def flush(self):
def flush(self) -> None:
cache = get_cache()
keys = self.get_flush_keys()
cache.delete_many(keys)

def get_flush_keys(self, flush_keys=None):
def get_flush_keys(self, flush_keys: Optional[List[str]] = None) -> List[str]:
flush_keys = flush_keys or []
flush_keys.extend([
self._cache_key(self.name),
get_setting('ALL_FLAGS_CACHE_KEY'),
])
return flush_keys

def is_active_for_user(self, user):
def is_active_for_user(self, user: AbstractBaseUser) -> Optional[bool]:
if self.authenticated and user.is_authenticated:
return True

Expand All @@ -234,21 +239,21 @@ def is_active_for_user(self, user):

return None

def _is_active_for_user(self, request):
def _is_active_for_user(self, request: HttpRequest) -> Optional[bool]:
user = getattr(request, "user", None)
if user:
return self.is_active_for_user(user)
return False

def _is_active_for_language(self, request):
def _is_active_for_language(self, request: HttpRequest) -> Optional[bool]:
if self.languages:
languages = [ln.strip() for ln in self.languages.split(',')]
if (hasattr(request, 'LANGUAGE_CODE') and
request.LANGUAGE_CODE in languages):
return True
return None

def is_active(self, request, read_only=False):
def is_active(self, request: HttpRequest, read_only: bool = False) -> Optional[bool]:
if not self.pk:
log_level = get_setting('LOG_MISSING_FLAGS')
if log_level:
Expand Down Expand Up @@ -342,15 +347,15 @@ class Meta(AbstractBaseFlag.Meta):
verbose_name = _('Flag')
verbose_name_plural = _('Flags')

def get_flush_keys(self, flush_keys=None):
def get_flush_keys(self, flush_keys: Optional[List[str]] = None) -> List[str]:
flush_keys = super().get_flush_keys(flush_keys)
flush_keys.extend([
keyfmt(get_setting('FLAG_USERS_CACHE_KEY'), self.name),
keyfmt(get_setting('FLAG_GROUPS_CACHE_KEY'), self.name),
])
return flush_keys

def _get_user_ids(self):
def _get_user_ids(self) -> Set[Any]:
cache = get_cache()
cache_key = keyfmt(get_setting('FLAG_USERS_CACHE_KEY'), self.name)
cached = cache.get(cache_key)
Expand All @@ -367,7 +372,7 @@ def _get_user_ids(self):
cache.add(cache_key, user_ids)
return user_ids

def _get_group_ids(self):
def _get_group_ids(self) -> Set[Any]:
cache = get_cache()
cache_key = keyfmt(get_setting('FLAG_GROUPS_CACHE_KEY'), self.name)
cached = cache.get(cache_key)
Expand All @@ -384,7 +389,7 @@ def _get_group_ids(self):
cache.add(cache_key, group_ids)
return group_ids

def is_active_for_user(self, user):
def is_active_for_user(self, user: AbstractBaseUser) -> Optional[bool]:
is_active = super().is_active_for_user(user)
if is_active:
return is_active
Expand Down Expand Up @@ -460,7 +465,7 @@ class Meta:
verbose_name = _('Switch')
verbose_name_plural = _('Switches')

def is_active(self):
def is_active(self) -> bool:
if not self.pk:
log_level = get_setting('LOG_MISSING_SWITCHES')
if log_level:
Expand Down Expand Up @@ -525,7 +530,7 @@ class Meta:
verbose_name = _('Sample')
verbose_name_plural = _('Samples')

def is_active(self):
def is_active(self) -> bool:
if not self.pk:
log_level = get_setting('LOG_MISSING_SAMPLES')
if log_level:
Expand Down
Loading

0 comments on commit c8c444e

Please sign in to comment.