From c8c444e4785a48e8ff0fff5912fda3b67c03d709 Mon Sep 17 00:00:00 2001 From: Michael Manganiello Date: Sat, 17 Sep 2022 17:07:54 -0300 Subject: [PATCH] misc: Add initial type hints (#470) Introducing partial type hints to the project, in preparation for publishing typing stubs for any applications using `django-waffle`. --- setup.cfg | 2 + waffle/__init__.py | 22 ++++++--- waffle/apps.py | 2 +- waffle/management/commands/waffle_delete.py | 8 ++-- waffle/management/commands/waffle_flag.py | 8 ++-- waffle/management/commands/waffle_sample.py | 8 ++-- waffle/management/commands/waffle_switch.py | 8 ++-- waffle/middleware.py | 3 +- waffle/models.py | 51 +++++++++++---------- waffle/utils.py | 9 ++-- 10 files changed, 73 insertions(+), 48 deletions(-) diff --git a/setup.cfg b/setup.cfg index dc57f6af..61d986f3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,7 +9,9 @@ ignore = E731 [mypy] python_version = 3.7 +exclude = waffle/tests disallow_incomplete_defs = True +disallow_untyped_calls = True disallow_untyped_decorators = True strict_equality = True [mypy-django.*] diff --git a/waffle/__init__.py b/waffle/__init__.py index e4b8535d..16ca877d 100755 --- a/waffle/__init__.py +++ b/waffle/__init__.py @@ -1,41 +1,49 @@ +from typing import TYPE_CHECKING, Optional, Type, Union + import django from django.core.exceptions import ImproperlyConfigured +from django.http import HttpRequest from waffle.utils import get_setting from django.apps import apps as django_apps +if TYPE_CHECKING: + from waffle.models import AbstractBaseFlag, AbstractBaseSample, AbstractBaseSwitch + VERSION = (3, 0, 0) __version__ = '.'.join(map(str, VERSION)) -def flag_is_active(request, flag_name, read_only=False): +def flag_is_active(request: HttpRequest, flag_name: str, read_only: bool = False) -> Optional[bool]: flag = get_waffle_flag_model().get(flag_name) return flag.is_active(request, read_only=read_only) -def switch_is_active(switch_name): +def switch_is_active(switch_name: str) -> bool: switch = get_waffle_switch_model().get(switch_name) return switch.is_active() -def sample_is_active(sample_name): +def sample_is_active(sample_name: str) -> bool: sample = get_waffle_sample_model().get(sample_name) return sample.is_active() -def get_waffle_flag_model(): +def get_waffle_flag_model() -> Type['AbstractBaseFlag']: return get_waffle_model('FLAG_MODEL') -def get_waffle_switch_model(): +def get_waffle_switch_model() -> Type['AbstractBaseSwitch']: return get_waffle_model('SWITCH_MODEL') -def get_waffle_sample_model(): +def get_waffle_sample_model() -> Type['AbstractBaseSample']: return get_waffle_model('SAMPLE_MODEL') -def get_waffle_model(setting_name): +def get_waffle_model(setting_name: str) -> Union[ + Type['AbstractBaseFlag'], Type['AbstractBaseSwitch'], Type['AbstractBaseSample'] +]: """ Returns the waffle Flag model that is active in this project. """ diff --git a/waffle/apps.py b/waffle/apps.py index 5af3f2a5..2dfbae6e 100644 --- a/waffle/apps.py +++ b/waffle/apps.py @@ -6,5 +6,5 @@ class WaffleConfig(AppConfig): verbose_name = 'django-waffle' default_auto_field = 'django.db.models.AutoField' - def ready(self): + def ready(self) -> None: import waffle.signals # noqa: F401 diff --git a/waffle/management/commands/waffle_delete.py b/waffle/management/commands/waffle_delete.py index a1f161da..dc9e77e5 100644 --- a/waffle/management/commands/waffle_delete.py +++ b/waffle/management/commands/waffle_delete.py @@ -1,4 +1,6 @@ -from django.core.management.base import BaseCommand +from typing import Any + +from django.core.management.base import BaseCommand, CommandParser from waffle import ( get_waffle_flag_model, @@ -8,7 +10,7 @@ class Command(BaseCommand): - def add_arguments(self, parser): + def add_arguments(self, parser: CommandParser) -> None: parser.add_argument( '--flags', action='store', @@ -30,7 +32,7 @@ def add_arguments(self, parser): help = 'Delete flags, samples, and switches from database' - def handle(self, *args, **options): + def handle(self, *args: Any, **options: Any) -> None: flags = options['flag_names'] if flags: flag_queryset = get_waffle_flag_model().objects.filter(name__in=flags) diff --git a/waffle/management/commands/waffle_flag.py b/waffle/management/commands/waffle_flag.py index 7025288b..cdf9c374 100644 --- a/waffle/management/commands/waffle_flag.py +++ b/waffle/management/commands/waffle_flag.py @@ -1,6 +1,8 @@ +from typing import Any + from django.contrib.auth import get_user_model from django.contrib.auth.models import Group -from django.core.management.base import BaseCommand, CommandError +from django.core.management.base import BaseCommand, CommandError, CommandParser from django.db.models import Q from waffle import get_waffle_flag_model @@ -9,7 +11,7 @@ class Command(BaseCommand): - def add_arguments(self, parser): + def add_arguments(self, parser: CommandParser) -> None: parser.add_argument( 'name', nargs='?', @@ -91,7 +93,7 @@ def add_arguments(self, parser): help = 'Modify a flag.' - def handle(self, *args, **options): + def handle(self, *args: Any, **options: Any) -> None: if options['list_flags']: self.stdout.write('Flags:') for flag in get_waffle_flag_model().objects.iterator(): diff --git a/waffle/management/commands/waffle_sample.py b/waffle/management/commands/waffle_sample.py index a0c223d8..311887b9 100644 --- a/waffle/management/commands/waffle_sample.py +++ b/waffle/management/commands/waffle_sample.py @@ -1,10 +1,12 @@ -from django.core.management.base import BaseCommand, CommandError +from typing import Any + +from django.core.management.base import BaseCommand, CommandError, CommandParser from waffle import get_waffle_sample_model class Command(BaseCommand): - def add_arguments(self, parser): + def add_arguments(self, parser: CommandParser) -> None: parser.add_argument( 'name', nargs='?', @@ -27,7 +29,7 @@ def add_arguments(self, parser): help = 'Change percentage of a sample.' - def handle(self, *args, **options): + def handle(self, *args: Any, **options: Any) -> None: if options['list_samples']: self.stdout.write('Samples:') for sample in get_waffle_sample_model().objects.iterator(): diff --git a/waffle/management/commands/waffle_switch.py b/waffle/management/commands/waffle_switch.py index 68b40db3..62ef80ab 100644 --- a/waffle/management/commands/waffle_switch.py +++ b/waffle/management/commands/waffle_switch.py @@ -1,5 +1,7 @@ +from typing import Any + from argparse import ArgumentTypeError -from django.core.management.base import BaseCommand, CommandError +from django.core.management.base import BaseCommand, CommandError, CommandParser from waffle import get_waffle_switch_model @@ -12,7 +14,7 @@ def on_off_bool(string): class Command(BaseCommand): - def add_arguments(self, parser): + def add_arguments(self, parser: CommandParser) -> None: parser.add_argument( 'name', nargs='?', @@ -37,7 +39,7 @@ def add_arguments(self, parser): help = 'Activate or deactivate a switch.' - def handle(self, *args, **options): + def handle(self, *args: Any, **options: Any) -> None: if options['list_switches']: self.stdout.write('Switches:') for switch in get_waffle_switch_model().objects.iterator(): diff --git a/waffle/middleware.py b/waffle/middleware.py index 6c7b5f26..ec31b1a7 100644 --- a/waffle/middleware.py +++ b/waffle/middleware.py @@ -1,3 +1,4 @@ +from django.http import HttpRequest, HttpResponse from django.utils.deprecation import MiddlewareMixin from django.utils.encoding import smart_str @@ -5,7 +6,7 @@ class WaffleMiddleware(MiddlewareMixin): - def process_response(self, request, response): + def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse: secure = get_setting('SECURE') max_age = get_setting('MAX_AGE') diff --git a/waffle/models.py b/waffle/models.py index e18d1797..53e47d4a 100644 --- a/waffle/models.py +++ b/waffle/models.py @@ -1,10 +1,12 @@ +import logging import random from decimal import Decimal -import logging +from typing import Any, List, Optional, Set, Tuple, Type, TypeVar from django.conf import settings -from django.contrib.auth.models import Group +from django.contrib.auth.models import AbstractBaseUser, Group from django.db import models, router, transaction +from django.http import HttpRequest from django.utils import timezone from django.utils.translation import gettext_lazy as _ @@ -21,6 +23,9 @@ CACHE_EMPTY = '-' +_BaseModelType = TypeVar("_BaseModelType", bound="BaseModel") + + class BaseModel(models.Model): SINGLE_CACHE_KEY = '' ALL_CACHE_KEY = '' @@ -28,18 +33,18 @@ class BaseModel(models.Model): class Meta: abstract = True - def __str__(self): + def __str__(self) -> str: return self.name - def natural_key(self): + def natural_key(self) -> Tuple[str]: return (self.name,) @classmethod - def _cache_key(cls, name): + def _cache_key(cls, name: str) -> str: return keyfmt(get_setting(cls.SINGLE_CACHE_KEY), name) @classmethod - def get(cls, name): + def get(cls: Type[_BaseModelType], name: str) -> _BaseModelType: cache = get_cache() cache_key = cls._cache_key(name) cached = cache.get(cache_key) @@ -58,14 +63,14 @@ def get(cls, name): return obj @classmethod - def get_from_db(cls, name): + def get_from_db(cls: Type[_BaseModelType], name: str) -> _BaseModelType: objects = cls.objects if get_setting('READ_FROM_WRITE_DB'): objects = objects.using(router.db_for_write(cls)) return objects.get(name=name) @classmethod - def get_all(cls): + def get_all(cls: Type[_BaseModelType]) -> List[_BaseModelType]: cache = get_cache() cache_key = get_setting(cls.ALL_CACHE_KEY) cached = cache.get(cache_key) @@ -83,13 +88,13 @@ def get_all(cls): return objs @classmethod - def get_all_from_db(cls): + def get_all_from_db(cls: Type[_BaseModelType]) -> List[_BaseModelType]: objects = cls.objects if get_setting('READ_FROM_WRITE_DB'): objects = objects.using(router.db_for_write(cls)) return list(objects.all()) - def flush(self): + def flush(self) -> None: cache = get_cache() keys = [ self._cache_key(self.name), @@ -115,7 +120,7 @@ def delete(self, *args, **kwargs): return ret -def set_flag(request, flag_name, active=True, session_only=False): +def set_flag(request: HttpRequest, flag_name: str, active: Optional[bool] = True, session_only: bool = False) -> None: """Set a flag value on a request object.""" if not hasattr(request, 'waffles'): request.waffles = {} @@ -209,12 +214,12 @@ class Meta: verbose_name = _('Flag') verbose_name_plural = _('Flags') - def flush(self): + def flush(self) -> None: cache = get_cache() keys = self.get_flush_keys() cache.delete_many(keys) - def get_flush_keys(self, flush_keys=None): + def get_flush_keys(self, flush_keys: Optional[List[str]] = None) -> List[str]: flush_keys = flush_keys or [] flush_keys.extend([ self._cache_key(self.name), @@ -222,7 +227,7 @@ def get_flush_keys(self, flush_keys=None): ]) return flush_keys - def is_active_for_user(self, user): + def is_active_for_user(self, user: AbstractBaseUser) -> Optional[bool]: if self.authenticated and user.is_authenticated: return True @@ -234,13 +239,13 @@ def is_active_for_user(self, user): return None - def _is_active_for_user(self, request): + def _is_active_for_user(self, request: HttpRequest) -> Optional[bool]: user = getattr(request, "user", None) if user: return self.is_active_for_user(user) return False - def _is_active_for_language(self, request): + def _is_active_for_language(self, request: HttpRequest) -> Optional[bool]: if self.languages: languages = [ln.strip() for ln in self.languages.split(',')] if (hasattr(request, 'LANGUAGE_CODE') and @@ -248,7 +253,7 @@ def _is_active_for_language(self, request): return True return None - def is_active(self, request, read_only=False): + def is_active(self, request: HttpRequest, read_only: bool = False) -> Optional[bool]: if not self.pk: log_level = get_setting('LOG_MISSING_FLAGS') if log_level: @@ -342,7 +347,7 @@ class Meta(AbstractBaseFlag.Meta): verbose_name = _('Flag') verbose_name_plural = _('Flags') - def get_flush_keys(self, flush_keys=None): + def get_flush_keys(self, flush_keys: Optional[List[str]] = None) -> List[str]: flush_keys = super().get_flush_keys(flush_keys) flush_keys.extend([ keyfmt(get_setting('FLAG_USERS_CACHE_KEY'), self.name), @@ -350,7 +355,7 @@ def get_flush_keys(self, flush_keys=None): ]) return flush_keys - def _get_user_ids(self): + def _get_user_ids(self) -> Set[Any]: cache = get_cache() cache_key = keyfmt(get_setting('FLAG_USERS_CACHE_KEY'), self.name) cached = cache.get(cache_key) @@ -367,7 +372,7 @@ def _get_user_ids(self): cache.add(cache_key, user_ids) return user_ids - def _get_group_ids(self): + def _get_group_ids(self) -> Set[Any]: cache = get_cache() cache_key = keyfmt(get_setting('FLAG_GROUPS_CACHE_KEY'), self.name) cached = cache.get(cache_key) @@ -384,7 +389,7 @@ def _get_group_ids(self): cache.add(cache_key, group_ids) return group_ids - def is_active_for_user(self, user): + def is_active_for_user(self, user: AbstractBaseUser) -> Optional[bool]: is_active = super().is_active_for_user(user) if is_active: return is_active @@ -460,7 +465,7 @@ class Meta: verbose_name = _('Switch') verbose_name_plural = _('Switches') - def is_active(self): + def is_active(self) -> bool: if not self.pk: log_level = get_setting('LOG_MISSING_SWITCHES') if log_level: @@ -525,7 +530,7 @@ class Meta: verbose_name = _('Sample') verbose_name_plural = _('Samples') - def is_active(self): + def is_active(self) -> bool: if not self.pk: log_level = get_setting('LOG_MISSING_SAMPLES') if log_level: diff --git a/waffle/utils.py b/waffle/utils.py index db122074..6ec976f5 100644 --- a/waffle/utils.py +++ b/waffle/utils.py @@ -1,20 +1,21 @@ import hashlib +from typing import Any, Optional from django.conf import settings -from django.core.cache import caches +from django.core.cache import BaseCache, caches import waffle from waffle import defaults -def get_setting(name, default=None): +def get_setting(name: str, default: Any = None) -> Any: try: return getattr(settings, 'WAFFLE_' + name) except AttributeError: return getattr(defaults, name, default) -def keyfmt(k, v=None): +def keyfmt(k: str, v: Optional[str] = None) -> str: prefix = get_setting('CACHE_PREFIX') + waffle.__version__ if v is None: key = prefix + k @@ -23,6 +24,6 @@ def keyfmt(k, v=None): return key -def get_cache(): +def get_cache() -> BaseCache: CACHE_NAME = get_setting('CACHE_NAME') return caches[CACHE_NAME]