Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make feature flags easier to use and fix database errors #140

Merged
merged 12 commits into from
Oct 31, 2024
1 change: 1 addition & 0 deletions src/identity/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ classifiers = [
]

dependencies = [
"injector",
"pysaml2",
"requests"
]
Expand Down
2 changes: 2 additions & 0 deletions src/platform/Ligare/platform/feature_flag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .caching_feature_flag_router import FeatureFlag as CacheFeatureFlag
from .db_feature_flag_router import DBFeatureFlagRouter
from .db_feature_flag_router import FeatureFlag as DBFeatureFlag
from .decorators import feature_flag
from .feature_flag_router import FeatureFlag, FeatureFlagChange, FeatureFlagRouter

__all__ = (
Expand All @@ -12,4 +13,5 @@
"CacheFeatureFlag",
"DBFeatureFlag",
"FeatureFlagChange",
"feature_flag",
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sqlalchemy import Boolean, Column, String, Unicode
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.declarative import DeclarativeMeta
from sqlalchemy.orm.session import Session
from sqlalchemy.orm.scoping import ScopedSession
from typing_extensions import override

from .caching_feature_flag_router import CachingFeatureFlagRouter
Expand All @@ -24,6 +24,8 @@ class FeatureFlag(FeatureFlagBaseData):


class FeatureFlagTableBase(ABC):
__tablename__: str

def __init__( # pyright: ignore[reportMissingSuperCall]
self,
/,
Expand All @@ -36,9 +38,9 @@ def __init__( # pyright: ignore[reportMissingSuperCall]
)

__tablename__: str
name: Column[Unicode] | str
description: Column[Unicode] | str
enabled: Column[Boolean] | bool
name: str
description: str
enabled: bool


class FeatureFlagTable:
Expand Down Expand Up @@ -70,17 +72,15 @@ def __repr__(self) -> str:


class DBFeatureFlagRouter(CachingFeatureFlagRouter[TFeatureFlag]):
# The SQLAlchemy table type used for querying from the type[FeatureFlag] database table
_feature_flag: type[FeatureFlagTableBase]
# The SQLAlchemy session used for connecting to and querying the database
_session: Session

@inject
def __init__(
self, feature_flag: type[FeatureFlagTableBase], session: Session, logger: Logger
self,
feature_flag: type[FeatureFlagTableBase],
scoped_session: ScopedSession,
logger: Logger,
) -> None:
self._feature_flag = feature_flag
self._session = session
self._scoped_session = scoped_session
super().__init__(logger)

@override
Expand All @@ -103,20 +103,21 @@ def set_feature_is_enabled(self, name: str, is_enabled: bool) -> FeatureFlagChan
raise ValueError("`name` parameter is required and cannot be empty.")

feature_flag: FeatureFlagTableBase
try:
feature_flag = (
self._session.query(self._feature_flag)
.filter(self._feature_flag.name == name)
.one()
)
except NoResultFound as e:
raise LookupError(
f"The feature flag `{name}` does not exist. It must be created before being accessed."
) from e

old_enabled_value = cast(bool | None, feature_flag.enabled)
feature_flag.enabled = is_enabled
self._session.commit()
with self._scoped_session() as session:
try:
feature_flag = (
session.query(self._feature_flag)
.filter(self._feature_flag.name == name)
.one()
)
except NoResultFound as e:
raise LookupError(
f"The feature flag `{name}` does not exist. It must be created before being accessed."
) from e

old_enabled_value = cast(bool | None, feature_flag.enabled)
feature_flag.enabled = is_enabled
session.commit()
_ = super().set_feature_is_enabled(name, is_enabled)

return FeatureFlagChange(
Expand Down Expand Up @@ -149,11 +150,12 @@ def feature_is_enabled(
if check_cache and super().feature_is_cached(name):
return super().feature_is_enabled(name, default)

feature_flag = (
self._session.query(self._feature_flag)
.filter(self._feature_flag.name == name)
.one_or_none()
)
with self._scoped_session() as session:
feature_flag = (
session.query(self._feature_flag)
.filter(self._feature_flag.name == name)
.one_or_none()
)

if feature_flag is None:
self._logger.warning(
Expand Down Expand Up @@ -192,20 +194,21 @@ def get_feature_flags(
If `names` is `None` this sequence contains _all_ feature flags in the database. Otherwise, the list is filtered.
"""
db_feature_flags: list[FeatureFlagTableBase]
if names is None:
db_feature_flags = self._session.query(self._feature_flag).all()
else:
db_feature_flags = (
self._session.query(self._feature_flag)
.filter(cast(Column[String], self._feature_flag.name).in_(names))
.all()
)
with self._scoped_session() as session:
if names is None:
db_feature_flags = session.query(self._feature_flag).all()
else:
db_feature_flags = (
session.query(self._feature_flag)
.filter(cast(Column[String], self._feature_flag.name).in_(names))
.all()
)

feature_flags = tuple(
self._create_feature_flag(
name=cast(str, feature_flag.name),
enabled=cast(bool, feature_flag.enabled),
description=cast(str, feature_flag.description),
name=feature_flag.name,
enabled=feature_flag.enabled,
description=feature_flag.description,
)
for feature_flag in db_feature_flags
)
Expand Down
40 changes: 40 additions & 0 deletions src/platform/Ligare/platform/feature_flag/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Any, Callable

from injector import Injector, inject
from typing_extensions import overload

from .feature_flag_router import FeatureFlag, FeatureFlagRouter


@overload
def feature_flag(
feature_flag_name: str, *, enabled_callback: Callable[..., Any]
) -> Callable[..., Callable[..., Any]]: ...
@overload
def feature_flag(
feature_flag_name: str, *, disabled_callback: Callable[..., Any]
) -> Callable[..., Callable[..., Any]]: ...


def feature_flag(
feature_flag_name: str,
*,
enabled_callback: Callable[..., None] = lambda: None,
disabled_callback: Callable[..., None] = lambda: None,
) -> Callable[..., Callable[..., Any]]:
def decorator(fn: Callable[..., Any]):
@inject
def wrapper(
feature_flag_router: FeatureFlagRouter[FeatureFlag],
injector: Injector,
):
if feature_flag_router.feature_is_enabled(feature_flag_name):
enabled_callback()
else:
disabled_callback()

return injector.call_with_injection(fn)

return wrapper

return decorator
4 changes: 3 additions & 1 deletion src/platform/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ classifiers = [
]

dependencies = [
"Ligare.database"
"Ligare.database",

"injector"
]

dynamic = ["version", "readme"]
Expand Down
Loading
Loading