Skip to content

Commit

Permalink
fixup! fix: use Annotated from typing_extensions for compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
art049 committed Dec 11, 2023
1 parent 2bbbd8a commit a4e3b85
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 21 deletions.
2 changes: 2 additions & 0 deletions odmantic/bson.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import decimal
import re
from dataclasses import dataclass
Expand Down
3 changes: 2 additions & 1 deletion odmantic/field.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import abc
from copy import deepcopy
from typing import (
TYPE_CHECKING,
Any,
Expand Down
4 changes: 3 additions & 1 deletion odmantic/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import datetime
import decimal
import enum
Expand Down Expand Up @@ -217,7 +219,7 @@ def __validate_cls_namespace__( # noqa C901
config = validate_config(namespace.get("model_config", ODMConfigDict()), name)
odm_fields: Dict[str, ODMBaseField] = {}
references: List[str] = []
bson_serializers = dict[str, Callable[[Any], Any]]()
bson_serializers: Dict[str, Callable[[Any], Any]] = {}
mutable_fields: Set[str] = set()

# Make sure all fields are defined with type annotation
Expand Down
8 changes: 4 additions & 4 deletions odmantic/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
TypeVar,
Union,
_eval_type,
get_args,
get_origin,
)

if sys.version_info < (3, 11):
Expand All @@ -31,10 +29,12 @@
if sys.version_info < (3, 9):
from typing import _GenericAlias as GenericAlias # type: ignore # noqa: F401

from typing_extensions import Annotated
# Even if get_args and get_origin are available in typing, it's important to
# import them from typing_extensions to have proper origins with Annotated fields
from typing_extensions import Annotated, get_args, get_origin
else:
from typing import Annotated # noqa: F401
from typing import GenericAlias # type: ignore
from typing import Annotated, get_args, get_origin # noqa: F401


if TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class TypeTestCase(Generic[T]):


def id_from_test_case(case: TypeTestCase):
return f"{case.python_type.__name__}|{case.bson_type}"
return f"{case.bson_type}"


@pytest.mark.parametrize("case", type_test_data, ids=id_from_test_case)
Expand Down
37 changes: 23 additions & 14 deletions tests/unit/test_model_type_validation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from typing import (
Callable,
Dict,
Expand Down Expand Up @@ -91,26 +92,34 @@ def test_mutable_types_immutables(t: Type):
assert not is_type_mutable(t)


@pytest.mark.parametrize(
"t",
(
List,
Set,
List[int],
Tuple[List[int]],
FrozenSet[Set[int]],
Dict[Tuple[int, ...], str],
DummyEmbedded,
Tuple[DummyEmbedded, ...],
Dict[str, DummyEmbedded],
FrozenSet[DummyEmbedded],
TEST_TYPES = [
List,
Set,
List[int],
Tuple[List[int]],
FrozenSet[Set[int]],
Dict[Tuple[int, ...], str],
DummyEmbedded,
Tuple[DummyEmbedded, ...],
Dict[str, DummyEmbedded],
FrozenSet[DummyEmbedded],
]

# if generic with builtin types are supported add them to the list
if sys.version_info >= (3, 9):
TEST_TYPES += [
list,
set,
list[int],
tuple[list[int]],
frozenset[set[int]],
dict[tuple[int, ...], str],
),
]


@pytest.mark.parametrize(
"t",
TEST_TYPES,
)
def test_mutable_types_mutables(t: Type):
assert is_type_mutable(t)
Expand Down

0 comments on commit a4e3b85

Please sign in to comment.