Skip to content

refactor: replace exception raising with error flag resolution #474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 88 additions & 32 deletions openfeature/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,12 +446,12 @@ def _establish_hooks_and_provider(

def _assert_provider_status(
self,
) -> None:
) -> typing.Optional[OpenFeatureError]:
status = self.get_provider_status()
if status == ProviderStatus.NOT_READY:
raise ProviderNotReadyError()
return ProviderNotReadyError()
if status == ProviderStatus.FATAL:
raise ProviderFatalError()
return ProviderFatalError()
return None

def _before_hooks_and_merge_context(
Expand Down Expand Up @@ -511,7 +511,22 @@ async def evaluate_flag_details_async(
)

try:
self._assert_provider_status()
if provider_err := self._assert_provider_status():
error_hooks(
flag_type,
hook_context,
provider_err,
reversed_merged_hooks,
hook_hints,
)
flag_evaluation = FlagEvaluationDetails(
flag_key=flag_key,
value=default_value,
reason=Reason.ERROR,
error_code=provider_err.error_code,
error_message=provider_err.error_message,
)
return flag_evaluation

merged_context = self._before_hooks_and_merge_context(
flag_type,
Expand All @@ -528,6 +543,11 @@ async def evaluate_flag_details_async(
default_value,
merged_context,
)
if err := flag_evaluation.get_exception():
error_hooks(
flag_type, hook_context, err, reversed_merged_hooks, hook_hints
)
return flag_evaluation

after_hooks(
flag_type,
Expand Down Expand Up @@ -607,7 +627,22 @@ def evaluate_flag_details(
)

try:
self._assert_provider_status()
if provider_err := self._assert_provider_status():
error_hooks(
flag_type,
hook_context,
provider_err,
reversed_merged_hooks,
hook_hints,
)
flag_evaluation = FlagEvaluationDetails(
flag_key=flag_key,
value=default_value,
reason=Reason.ERROR,
error_code=provider_err.error_code,
error_message=provider_err.error_message,
)
return flag_evaluation

merged_context = self._before_hooks_and_merge_context(
flag_type,
Expand All @@ -624,6 +659,12 @@ def evaluate_flag_details(
default_value,
merged_context,
)
if err := flag_evaluation.get_exception():
error_hooks(
flag_type, hook_context, err, reversed_merged_hooks, hook_hints
)
flag_evaluation.value = default_value
return flag_evaluation

after_hooks(
flag_type,
Expand Down Expand Up @@ -693,27 +734,33 @@ async def _create_provider_evaluation_async(
}
get_details_callable = get_details_callables_async.get(flag_type)
if not get_details_callable:
raise GeneralError(error_message="Unknown flag type")
return FlagEvaluationDetails(
flag_key=flag_key,
value=default_value,
reason=Reason.ERROR,
error_code=ErrorCode.GENERAL,
error_message="Unknown flag type",
)

resolution = await get_details_callable(
flag_key=flag_key,
default_value=default_value,
evaluation_context=evaluation_context,
)
resolution.raise_for_error()
if resolution.error_code:
return resolution.to_flag_evaluation_details(flag_key)

# we need to check the get_args to be compatible with union types.
_typecheck_flag_value(resolution.value, flag_type)
if err := _typecheck_flag_value(value=resolution.value, flag_type=flag_type):
return FlagEvaluationDetails(
flag_key=flag_key,
value=resolution.value,
reason=Reason.ERROR,
error_code=err.error_code,
error_message=err.error_message,
)

return FlagEvaluationDetails(
flag_key=flag_key,
value=resolution.value,
variant=resolution.variant,
flag_metadata=resolution.flag_metadata or {},
reason=resolution.reason,
error_code=resolution.error_code,
error_message=resolution.error_message,
)
return resolution.to_flag_evaluation_details(flag_key)

def _create_provider_evaluation(
self,
Expand Down Expand Up @@ -743,27 +790,33 @@ def _create_provider_evaluation(

get_details_callable = get_details_callables.get(flag_type)
if not get_details_callable:
raise GeneralError(error_message="Unknown flag type")
return FlagEvaluationDetails(
flag_key=flag_key,
value=default_value,
reason=Reason.ERROR,
error_code=ErrorCode.GENERAL,
error_message="Unknown flag type",
)

resolution = get_details_callable(
flag_key=flag_key,
default_value=default_value,
evaluation_context=evaluation_context,
)
resolution.raise_for_error()
if resolution.error_code:
return resolution.to_flag_evaluation_details(flag_key)

# we need to check the get_args to be compatible with union types.
_typecheck_flag_value(resolution.value, flag_type)
if err := _typecheck_flag_value(value=resolution.value, flag_type=flag_type):
return FlagEvaluationDetails(
flag_key=flag_key,
value=resolution.value,
reason=Reason.ERROR,
error_code=err.error_code,
error_message=err.error_message,
)

return FlagEvaluationDetails(
flag_key=flag_key,
value=resolution.value,
variant=resolution.variant,
flag_metadata=resolution.flag_metadata or {},
reason=resolution.reason,
error_code=resolution.error_code,
error_message=resolution.error_message,
)
return resolution.to_flag_evaluation_details(flag_key)

def add_handler(self, event: ProviderEvent, handler: EventHandler) -> None:
_event_support.add_client_handler(self, event, handler)
Expand All @@ -772,7 +825,9 @@ def remove_handler(self, event: ProviderEvent, handler: EventHandler) -> None:
_event_support.remove_client_handler(self, event, handler)


def _typecheck_flag_value(value: typing.Any, flag_type: FlagType) -> None:
def _typecheck_flag_value(
value: typing.Any, flag_type: FlagType
) -> typing.Optional[OpenFeatureError]:
type_map: TypeMap = {
FlagType.BOOLEAN: bool,
FlagType.STRING: str,
Expand All @@ -782,6 +837,7 @@ def _typecheck_flag_value(value: typing.Any, flag_type: FlagType) -> None:
}
_type = type_map.get(flag_type)
if not _type:
raise GeneralError(error_message="Unknown flag type")
return GeneralError(error_message="Unknown flag type")
if not isinstance(value, _type):
raise TypeMismatchError(f"Expected type {_type} but got {type(value)}")
return TypeMismatchError(f"Expected type {_type} but got {type(value)}")
return None
18 changes: 17 additions & 1 deletion openfeature/flag_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass, field

from openfeature._backports.strenum import StrEnum
from openfeature.exception import ErrorCode
from openfeature.exception import ErrorCode, OpenFeatureError

if typing.TYPE_CHECKING: # pragma: no cover
# resolves a circular dependency in type annotations
Expand Down Expand Up @@ -56,6 +56,11 @@ class FlagEvaluationDetails(typing.Generic[T_co]):
error_code: typing.Optional[ErrorCode] = None
error_message: typing.Optional[str] = None

def get_exception(self) -> typing.Optional[OpenFeatureError]:
if self.error_code:
return ErrorCode.to_exception(self.error_code, self.error_message or "")
return None


@dataclass
class FlagEvaluationOptions:
Expand All @@ -79,3 +84,14 @@ def raise_for_error(self) -> None:
if self.error_code:
raise ErrorCode.to_exception(self.error_code, self.error_message or "")
return None

def to_flag_evaluation_details(self, flag_key: str) -> FlagEvaluationDetails[U_co]:
return FlagEvaluationDetails(
flag_key=flag_key,
value=self.value,
variant=self.variant,
flag_metadata=self.flag_metadata,
reason=self.reason,
error_code=self.error_code,
error_message=self.error_message,
)
33 changes: 20 additions & 13 deletions openfeature/provider/in_memory_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from openfeature._backports.strenum import StrEnum
from openfeature.evaluation_context import EvaluationContext
from openfeature.exception import FlagNotFoundError
from openfeature.exception import ErrorCode
from openfeature.flag_evaluation import FlagMetadata, FlagResolutionDetails, Reason
from openfeature.hook import Hook
from openfeature.provider import AbstractProvider, Metadata
Expand Down Expand Up @@ -74,93 +74,100 @@ def resolve_boolean_details(
default_value: bool,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
return self._resolve(flag_key, evaluation_context)
return self._resolve(flag_key, default_value, evaluation_context)

async def resolve_boolean_details_async(
self,
flag_key: str,
default_value: bool,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
return await self._resolve_async(flag_key, evaluation_context)
return await self._resolve_async(flag_key, default_value, evaluation_context)

def resolve_string_details(
self,
flag_key: str,
default_value: str,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[str]:
return self._resolve(flag_key, evaluation_context)
return self._resolve(flag_key, default_value, evaluation_context)

async def resolve_string_details_async(
self,
flag_key: str,
default_value: str,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[str]:
return await self._resolve_async(flag_key, evaluation_context)
return await self._resolve_async(flag_key, default_value, evaluation_context)

def resolve_integer_details(
self,
flag_key: str,
default_value: int,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[int]:
return self._resolve(flag_key, evaluation_context)
return self._resolve(flag_key, default_value, evaluation_context)

async def resolve_integer_details_async(
self,
flag_key: str,
default_value: int,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[int]:
return await self._resolve_async(flag_key, evaluation_context)
return await self._resolve_async(flag_key, default_value, evaluation_context)

def resolve_float_details(
self,
flag_key: str,
default_value: float,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[float]:
return self._resolve(flag_key, evaluation_context)
return self._resolve(flag_key, default_value, evaluation_context)

async def resolve_float_details_async(
self,
flag_key: str,
default_value: float,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[float]:
return await self._resolve_async(flag_key, evaluation_context)
return await self._resolve_async(flag_key, default_value, evaluation_context)

def resolve_object_details(
self,
flag_key: str,
default_value: typing.Union[dict, list],
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[typing.Union[dict, list]]:
return self._resolve(flag_key, evaluation_context)
return self._resolve(flag_key, default_value, evaluation_context)

async def resolve_object_details_async(
self,
flag_key: str,
default_value: typing.Union[dict, list],
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[typing.Union[dict, list]]:
return await self._resolve_async(flag_key, evaluation_context)
return await self._resolve_async(flag_key, default_value, evaluation_context)

def _resolve(
self,
flag_key: str,
default_value: V,
evaluation_context: typing.Optional[EvaluationContext],
) -> FlagResolutionDetails[V]:
flag = self._flags.get(flag_key)
if flag is None:
raise FlagNotFoundError(f"Flag '{flag_key}' not found")
return FlagResolutionDetails(
value=default_value,
reason=Reason.ERROR,
error_code=ErrorCode.FLAG_NOT_FOUND,
error_message=f"Flag '{flag_key}' not found",
)
return flag.resolve(evaluation_context)

async def _resolve_async(
self,
flag_key: str,
default_value: V,
evaluation_context: typing.Optional[EvaluationContext],
) -> FlagResolutionDetails[V]:
return self._resolve(flag_key, evaluation_context)
return self._resolve(flag_key, default_value, evaluation_context)
17 changes: 12 additions & 5 deletions tests/provider/test_in_memory_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from openfeature.exception import FlagNotFoundError
from openfeature.exception import ErrorCode
from openfeature.flag_evaluation import FlagResolutionDetails, Reason
from openfeature.provider.in_memory_provider import InMemoryFlag, InMemoryProvider

Expand All @@ -22,11 +22,18 @@ async def test_should_handle_unknown_flags_correctly():
# Given
provider = InMemoryProvider({})
# When
with pytest.raises(FlagNotFoundError):
provider.resolve_boolean_details(flag_key="Key", default_value=True)
with pytest.raises(FlagNotFoundError):
await provider.resolve_integer_details_async(flag_key="Key", default_value=1)
flag_sync = provider.resolve_boolean_details(flag_key="Key", default_value=True)
flag_async = await provider.resolve_boolean_details_async(
flag_key="Key", default_value=True
)
# Then
assert flag_sync == flag_async
for flag in [flag_sync, flag_async]:
assert flag is not None
assert flag.value is True
assert flag.reason == Reason.ERROR
assert flag.error_code == ErrorCode.FLAG_NOT_FOUND
assert flag.error_message == "Flag 'Key' not found"


@pytest.mark.asyncio
Expand Down
Loading