diff --git a/tests/formats/dataclass/models/cases/__init__.py b/tests/formats/dataclass/models/cases/__init__.py new file mode 100644 index 00000000..dacb7683 --- /dev/null +++ b/tests/formats/dataclass/models/cases/__init__.py @@ -0,0 +1,4 @@ +import sys + +PY39 = sys.version_info[:2] >= (3, 9) +PY310 = sys.version_info[:2] >= (3, 10) diff --git a/tests/formats/dataclass/models/cases/attribute.py b/tests/formats/dataclass/models/cases/attribute.py new file mode 100644 index 00000000..19ddf65a --- /dev/null +++ b/tests/formats/dataclass/models/cases/attribute.py @@ -0,0 +1,51 @@ +from typing import Dict, List, Literal, Set, Tuple, Union + +from tests.formats.dataclass.models.cases import PY39, PY310 +from xsdata.models.enums import Mode + +tokens = [ + (int, False), + (Dict[int, int], False), + (Dict, False), + (Literal["foo"], False), + (Set[str], False), + (List[Union[List[int], int]], False), + (List[List[int]], False), + (Tuple[int, ...], ((int,), None, tuple)), + (List[int], ((int,), None, list)), + (List[Union[str, int]], ((str, int), None, list)), +] + +not_tokens = [ + (List[int], False), + (Dict[int, str], False), + (int, ((int,), None, None)), + (str, ((str,), None, None)), + (Union[str, Mode], ((str, Mode), None, None)), +] + +if PY39: + tokens.extend( + [ + (list[int, int], False), + (dict[str, str], False), + (dict, False), + (set[str], False), + (tuple[int, ...], ((int,), None, tuple)), + (list[int], ((int,), None, list)), + (list[Union[str, int]], ((str, int), None, list)), + ] + ) + +if PY310: + tokens.extend( + [ + (tuple[int | str], ((int, str), None, tuple)), + ] + ) + + not_tokens.extend( + [ + (int | str, ((int, str), None, None)), + ] + ) diff --git a/tests/formats/dataclass/models/cases/attributes.py b/tests/formats/dataclass/models/cases/attributes.py new file mode 100644 index 00000000..1113ec7c --- /dev/null +++ b/tests/formats/dataclass/models/cases/attributes.py @@ -0,0 +1,20 @@ +from typing import Dict, List, Set, Tuple + +from tests.formats.dataclass.models.cases import PY39 + +cases = [ + (int, False), + (Set, False), + (List, False), + (Tuple, False), + (Dict[str, int], False), + (Dict, ((str,), dict, None)), + (Dict[str, str], ((str,), dict, None)), +] + +if PY39: + cases.extend( + [ + (dict[str, str], ((str,), dict, None)), + ] + ) diff --git a/tests/formats/dataclass/models/cases/element.py b/tests/formats/dataclass/models/cases/element.py new file mode 100644 index 00000000..7fb05977 --- /dev/null +++ b/tests/formats/dataclass/models/cases/element.py @@ -0,0 +1,45 @@ +from typing import Dict, List, Set, Tuple, Union + +from tests.formats.dataclass.models.cases import PY39 + +tokens = [ + (Set, False), + (Dict[str, int], False), + (Tuple[str, str], False), + (List[str], ((str,), None, list)), + (Tuple[str, ...], ((str,), None, tuple)), + (List[List[str]], ((str,), list, list)), + (List[Tuple[str, ...]], ((str,), list, tuple)), + (Tuple[List[str], ...], ((str,), tuple, list)), +] + +not_tokens = [ + (Set, False), + (Dict[str, int], False), + (Tuple[str, int], False), + (List[List[str]], False), + (List[Tuple[str, ...]], False), + (Tuple[List[str], ...], False), + (str, ((str,), None, None)), + (List[str], ((str,), list, None)), + (List[Union[str, int]], ((str, int), list, None)), + (Tuple[str, ...], ((str,), tuple, None)), +] + +if PY39: + tokens.extend( + [ + (list[str], ((str,), None, list)), + (tuple[str, ...], ((str,), None, tuple)), + (list[list[str]], ((str,), list, list)), + (list[tuple[str, ...]], ((str,), list, tuple)), + (tuple[list[str], ...], ((str,), tuple, list)), + ] + ) + + not_tokens.extend( + [ + (list[str], ((str,), list, None)), + (tuple[str, ...], ((str,), tuple, None)), + ] + ) diff --git a/tests/formats/dataclass/models/cases/elements.py b/tests/formats/dataclass/models/cases/elements.py new file mode 100644 index 00000000..2e446de3 --- /dev/null +++ b/tests/formats/dataclass/models/cases/elements.py @@ -0,0 +1,32 @@ +from typing import Dict, List, Optional, Tuple, Union + +from tests.formats.dataclass.models.cases import PY39, PY310 + +cases = [ + (Dict, False), + (str, ((object,), None, None)), + (List[str], ((object,), list, None)), + (Tuple[str, ...], ((object,), tuple, None)), + (Optional[Union[str, int]], ((object,), None, None)), + (Union[str, int, None], ((object,), None, None)), + (List[Union[List[str], Tuple[str, ...]]], ((object,), list, None)), +] + + +if PY39: + cases.extend( + [ + (list[str], ((object,), list, None)), + (tuple[str, ...], ((object,), tuple, None)), + (list[Union[list[str], tuple[str, ...]]], ((object,), list, None)), + ] + ) + + +if PY310: + cases.extend( + [ + (str | int | None, ((object,), None, None)), + (list[list[str] | tuple[str, ...]], ((object,), list, None)), + ] + ) diff --git a/tests/formats/dataclass/models/cases/wildcard.py b/tests/formats/dataclass/models/cases/wildcard.py new file mode 100644 index 00000000..4e95bac4 --- /dev/null +++ b/tests/formats/dataclass/models/cases/wildcard.py @@ -0,0 +1,24 @@ +from typing import Dict, List, Literal, Optional, Set, Tuple + +from tests.formats.dataclass.models.cases import PY310 + +cases = [ + (int, False), + (Dict[int, int], False), + (Dict, False), + (Set, False), + (Literal["foo"], False), + (object, ((object,), None, None)), + (List[object], ((object,), list, None)), + (Tuple[object, ...], ((object,), tuple, None)), + (Optional[object], ((object,), None, None)), +] + +if PY310: + cases.extend( + [ + (list[object], ((object,), list, None)), + (tuple[object, ...], ((object,), tuple, None)), + (object | None, ((object,), None, None)), + ] + ) diff --git a/tests/formats/dataclass/models/test_builders.py b/tests/formats/dataclass/models/test_builders.py index df539dc0..5523f018 100644 --- a/tests/formats/dataclass/models/test_builders.py +++ b/tests/formats/dataclass/models/test_builders.py @@ -1,9 +1,7 @@ -import functools import sys -import uuid from dataclasses import dataclass, field, fields, make_dataclass from decimal import Decimal -from typing import Dict, Iterator, List, Union, get_type_hints +from typing import Iterator, List, get_type_hints from unittest import TestCase, mock from xml.etree.ElementTree import QName @@ -448,74 +446,3 @@ def test_resolve_namespaces(self): actual = func(XmlType.WILDCARD, "##targetNamespace foo", "p") self.assertEqual(("foo", "p"), tuple(sorted(actual))) - - def test_analyze_types(self): - func = functools.partial(self.builder.analyze_types, BookForm, "foo") - - actual = func(List[List[Union[str, int]]], None) - self.assertEqual((list, list, (int, str)), actual) - - actual = func(Union[str, int], None) - self.assertEqual((None, None, (int, str)), actual) - - actual = func(Dict[str, int], None) - self.assertEqual((dict, None, (int, str)), actual) - - with self.assertRaises(XmlContextError) as cm: - func(List[List[List[int]]], None) - - self.assertEqual( - "Error on BookForm::foo: Unsupported field typing `typing.List[typing.List[typing.List[int]]]`", - str(cm.exception), - ) - - def test_is_valid(self): - # Attributes need origin dict - self.assertFalse( - self.builder.is_valid(XmlType.ATTRIBUTES, None, None, (), False, True) - ) - - # Attributes don't support any origin - self.assertFalse( - self.builder.is_valid(XmlType.ATTRIBUTES, dict, list, (), False, True) - ) - - # Attributes don't support xs:NMTOKENS - self.assertFalse( - self.builder.is_valid(XmlType.ATTRIBUTES, dict, None, (), True, True) - ) - - self.assertTrue( - self.builder.is_valid( - XmlType.ATTRIBUTES, dict, None, (str, str), False, True - ) - ) - - # xs:NMTOKENS need origin list - self.assertFalse( - self.builder.is_valid(XmlType.TEXT, dict, None, (), True, True) - ) - - # xs:NMTOKENS need origin list - self.assertFalse(self.builder.is_valid(XmlType.TEXT, set, None, (), True, True)) - - # Any type object is a superset, it's only supported alone - self.assertFalse( - self.builder.is_valid( - XmlType.ELEMENT, None, None, (object, int), False, True - ) - ) - - # Type is not registered in converter. - self.assertFalse( - self.builder.is_valid( - XmlType.TEXT, None, None, (int, uuid.UUID), False, True - ) - ) - - # init false vars are ignored! - self.assertTrue( - self.builder.is_valid( - XmlType.TEXT, None, None, (int, uuid.UUID), False, False - ) - ) diff --git a/tests/formats/dataclass/models/test_typing.py b/tests/formats/dataclass/models/test_typing.py new file mode 100644 index 00000000..22f300dc --- /dev/null +++ b/tests/formats/dataclass/models/test_typing.py @@ -0,0 +1,90 @@ +from typing import Type + +import pytest + +from tests.formats.dataclass.models.cases import ( + attribute, + attributes, + element, + elements, + wildcard, +) +from xsdata.formats.dataclass.models.typing import ( + evaluate, + evaluate_attribute, + evaluate_attributes, + evaluate_element, + evaluate_elements, + evaluate_wildcard, +) + + +def test_evaluate_with_typevar(): + result = evaluate(Type["str"], None) + assert str == result + + with pytest.raises(TypeError): + evaluate_attribute(Type, tokens=True) + + +@pytest.mark.parametrize("case,expected", attribute.tokens) +def test_evaluate_attribute_with_tokens(case, expected): + if expected: + assert expected == evaluate_attribute(case, tokens=True) + else: + with pytest.raises(TypeError): + evaluate_attribute(case, tokens=True) + + +@pytest.mark.parametrize("case,expected", attribute.not_tokens) +def test_evaluate_attribute_without_tokens(case, expected): + if expected: + assert expected == evaluate_attribute(case, tokens=False) + else: + with pytest.raises(TypeError): + evaluate_attribute(case, tokens=False) + + +@pytest.mark.parametrize("case,expected", attributes.cases) +def test_evaluate_attributes(case, expected): + if expected: + assert expected == evaluate_attributes(case) + else: + with pytest.raises(TypeError): + evaluate_attributes(case) + + +@pytest.mark.parametrize("case,expected", element.tokens) +def test_evaluate_element_with_tokens(case, expected): + if expected: + assert expected == evaluate_element(case, tokens=True) + else: + with pytest.raises(TypeError): + evaluate_element(case, tokens=True) + + +@pytest.mark.parametrize("case,expected", element.not_tokens) +def test_evaluate_element_without_tokens(case, expected): + if expected: + assert expected == evaluate_element(case, tokens=False) + else: + with pytest.raises(TypeError): + evaluate_element(case, tokens=False) + + +@pytest.mark.parametrize("case,expected", elements.cases) +def test_evaluate_elements(case, expected): + if expected: + assert expected == evaluate_elements(case) + else: + with pytest.raises(TypeError): + evaluate_elements(case) + + +@pytest.mark.parametrize("case,expected", wildcard.cases) +def test_evaluate_wildcard(case, expected): + if expected: + assert expected == evaluate_wildcard(case) + else: + with pytest.raises(TypeError): + evaluate_wildcard(case) diff --git a/tests/formats/dataclass/parsers/test_dict.py b/tests/formats/dataclass/parsers/test_dict.py index dcb363bf..d4454cdb 100644 --- a/tests/formats/dataclass/parsers/test_dict.py +++ b/tests/formats/dataclass/parsers/test_dict.py @@ -1,5 +1,5 @@ import json -from dataclasses import asdict, make_dataclass +from dataclasses import asdict, dataclass, field from decimal import Decimal from typing import List, Optional, Union from xml.etree.ElementTree import QName @@ -381,7 +381,10 @@ def test_bind_any_type_with_derived_dataclass(self): self.assertEqual("Unable to locate xsi:type `notexists`", str(cm.exception)) def test_bind_text_with_unions(self): - Fixture = make_dataclass("Fixture", [("x", List[Union[int, float, str, bool]])]) + @dataclass + class Fixture: + x: List[Union[int, float, str, bool]] = field(metadata={"tokens": True}) + values = ["foo", 12.2, "12.2", 12, "12", True, "false"] result = self.decoder.decode({"x": values}, Fixture) diff --git a/tests/formats/dataclass/test_typing.py b/tests/formats/dataclass/test_typing.py deleted file mode 100644 index ffb10e53..00000000 --- a/tests/formats/dataclass/test_typing.py +++ /dev/null @@ -1,195 +0,0 @@ -import datetime -import sys -from decimal import Decimal -from enum import Enum -from typing import Any, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union -from unittest import TestCase -from xml.etree.ElementTree import QName - -from xsdata.formats.dataclass.typing import evaluate -from xsdata.formats.types import T -from xsdata.models.datatype import XmlDate, XmlDateTime, XmlDuration, XmlPeriod, XmlTime -from xsdata.models.enums import Namespace - - -class TypingTests(TestCase): - def assertCases(self, cases): - for tp, result in cases.items(): - if result is False: - with self.assertRaises(TypeError): - evaluate(tp) - else: - self.assertEqual(result, evaluate(tp), msg=tp) - - def test_evaluate_simple(self): - types = ( - int, - str, - int, - bool, - float, - bytes, - object, - datetime.time, - datetime.date, - datetime.datetime, - XmlTime, - XmlDate, - XmlDateTime, - XmlDuration, - XmlPeriod, - QName, - Decimal, - Enum, - Namespace, - ) - cases = {tp: (tp,) for tp in types} - self.assertCases(cases) - - def test_evaluate_unsupported_typing(self): - cases = [Any, Set[str]] - - for case in cases: - with self.assertRaises(TypeError): - evaluate(case) - - def test_evaluate_dict(self): - cases = { - Dict: (dict, str, str), - Dict[str, int]: (dict, str, int), - Dict[Any, Any]: False, - Dict[Union[str, int], int]: False, - Dict[int, Union[str, int]]: False, - Dict[TypeVar("A", bound=int), str]: False, - Dict[TypeVar("A"), str]: (dict, str, str), - } - - if sys.version_info[:2] >= (3, 9): - cases.update( - { - dict: (dict, str, str), - dict[str, int]: (dict, str, int), - dict[Any, Any]: False, - dict[Union[str, int], int]: False, - dict[int, Union[str, int]]: False, - dict[TypeVar("A", bound=int), str]: False, - dict[TypeVar("A"), str]: (dict, str, str), - } - ) - - if sys.version_info[:2] >= (3, 10): - cases.update({dict[str | int, int]: False}) - - self.assertCases(cases) - - def test_evaluate_list(self): - A = TypeVar("A", int, str) - - cases = { - List[A]: (list, int, str), - List[int]: (list, int), - List[Union[float, str]]: (list, float, str), - List[Optional[int]]: (list, int), - List[Tuple[int]]: (list, tuple, int), - List[List[Union[bool, str]]]: (list, list, bool, str), - List: (list, str), - List[Dict[str, str]]: False, - List[Any]: False, - } - - if sys.version_info[:2] >= (3, 9): - cases.update( - { - list[A]: (list, int, str), - list[int]: (list, int), - list[Union[float, str]]: (list, float, str), - list[Optional[int]]: (list, int), - list[Tuple[int]]: (list, tuple, int), - list[list[Union[bool, str]]]: (list, list, bool, str), - list: (list, str), - list["str"]: (list, str), - list[dict[str, str]]: False, - list[Any]: False, - } - ) - - self.assertCases(cases) - - def test_evaluate_tuple(self): - A = TypeVar("A", int, str) - - cases = { - Tuple[A]: (tuple, int, str), - Tuple[int]: (tuple, int), - Tuple[int, ...]: (tuple, int), - Tuple[List[int], ...]: (tuple, list, int), - Tuple[Union[float, str]]: (tuple, float, str), - Tuple[Optional[int]]: (tuple, int), - Tuple[Tuple[int]]: (tuple, tuple, int), - Tuple[Tuple[Union[bool, str]]]: (tuple, tuple, bool, str), - Tuple: (tuple, str), - Tuple[Dict[str, str]]: False, - Tuple[Any, ...]: False, - } - - if sys.version_info[:2] >= (3, 9): - cases.update( - { - tuple[A]: (tuple, int, str), - tuple[int]: (tuple, int), - tuple[int, ...]: (tuple, int), - tuple[List[int], ...]: (tuple, list, int), - tuple[Union[float, str]]: (tuple, float, str), - tuple[Optional[int]]: (tuple, int), - tuple[tuple[int]]: (tuple, tuple, int), - tuple[tuple[Union[bool, str]]]: (tuple, tuple, bool, str), - tuple: (tuple, str), - tuple[dict[str, str]]: False, - tuple[Any, ...]: False, - } - ) - - self.assertCases(cases) - - def test_evaluate_union(self): - A = TypeVar("A", int, str) - - cases = { - Optional[Union[bool, str]]: (bool, str), - Optional[List[Union[int, float]]]: (list, int, float), - Optional[A]: (int, str), - Union[List[int], None]: (list, int), - Union[Tuple[int, ...], None]: (tuple, int), - Union[List[int], List[str]]: (list, int, list, str), - Union[List[Dict]]: False, - } - - if sys.version_info[:2] >= (3, 10): - cases.update( - { - None | bool | str: (bool, str), - None | List[int | float]: (list, int, float), - None | A: (int, str), - List[int] | None: (list, int), - Tuple[int, ...] | None: (tuple, int), - List[int] | List[str]: (list, int, list, str), - } - ) - - self.assertCases(cases) - - def test_evaluate_type(self): - self.assertEqual((str,), evaluate(Type["str"])) - - with self.assertRaises(TypeError): - evaluate(Type) - - def test_evaluate_typevar(self): - A = TypeVar("A", int, str) - B = TypeVar("B", bound=object) - - self.assertEqual((int, str), evaluate(A)) - self.assertEqual((object,), evaluate(B)) - - with self.assertRaises(TypeError): - evaluate(T) diff --git a/xsdata/formats/dataclass/models/builders.py b/xsdata/formats/dataclass/models/builders.py index f2b1d566..7a35e1c3 100644 --- a/xsdata/formats/dataclass/models/builders.py +++ b/xsdata/formats/dataclass/models/builders.py @@ -20,12 +20,29 @@ from xsdata.formats.converter import converter from xsdata.formats.dataclass.compat import ClassType from xsdata.formats.dataclass.models.elements import XmlMeta, XmlType, XmlVar -from xsdata.formats.dataclass.typing import evaluate +from xsdata.formats.dataclass.models.typing import ( + evaluate, + evaluate_attribute, + evaluate_attributes, + evaluate_element, + evaluate_elements, + evaluate_text, + evaluate_wildcard, +) from xsdata.models.enums import NamespaceType from xsdata.utils.collections import first from xsdata.utils.constants import EMPTY_SEQUENCE, return_input from xsdata.utils.namespaces import build_qname +evaluations: Dict[str, Callable] = { + XmlType.TEXT: evaluate_text, + XmlType.ELEMENT: evaluate_element, + XmlType.ELEMENTS: evaluate_elements, + XmlType.WILDCARD: evaluate_wildcard, + XmlType.ATTRIBUTE: evaluate_attribute, + XmlType.ATTRIBUTES: evaluate_attributes, +} + class ClassMeta: """The binding model combined metadata. @@ -349,7 +366,7 @@ def build( parent_namespace: Optional[str], default_value: Any, globalns: Any, - factory: Optional[Callable] = None, + parent_factory: Optional[Callable] = None, ) -> Optional[XmlVar]: """Build the binding metadata for a class field. @@ -362,7 +379,7 @@ def build( parent_namespace: The class namespace default_value: The field default value or factory globalns: Python's global namespace - factory: The value factory + parent_factory: The value factory Returns: The field binding metadata instance. @@ -383,27 +400,23 @@ def build( sequence = metadata.get("sequence", None) wrapper = metadata.get("wrapper", None) - origin, sub_origin, types = self.analyze_types(model, name, type_hint, globalns) + annotation = evaluate(type_hint, globalns) + + try: + analyze = evaluations[xml_type] + types, factory, tokens_factory = analyze(annotation, tokens=tokens) + types = tuple(converter.sort_types(types)) + if not self.is_typing_supported(types): + raise TypeError - if not self.is_valid(xml_type, origin, sub_origin, types, tokens, init): + except TypeError: raise XmlContextError( f"Error on {model.__qualname__}::{name}: " f"Xml {xml_type} does not support typing `{type_hint}`" ) - if xml_type == XmlType.ELEMENTS: - sub_origin = None - types = (object,) - + factory = factory or parent_factory local_name = local_name or self.build_local_name(xml_type, name) - - if tokens and sub_origin is None: - sub_origin = origin - origin = None - - if origin is None: - origin = factory - any_type = self.is_any_type(types, xml_type) clazz = first(tp for tp in types if self.class_type.is_model(tp)) namespaces = self.resolve_namespaces(xml_type, namespace, parent_namespace) @@ -413,7 +426,7 @@ def build( self.index += 1 cur_index = self.index for choice in self.build_choices( - model, name, choices, origin, globalns, parent_namespace + model, name, choices, factory, globalns, parent_namespace ): if choice.is_element: elements[choice.qname] = choice @@ -434,8 +447,8 @@ def build( required=required, nillable=nillable, sequence=sequence, - factory=origin, - tokens_factory=sub_origin, + factory=factory, + tokens_factory=tokens_factory, default=default_value, types=types, elements=elements, @@ -449,7 +462,7 @@ def build_choices( model: Type, name: str, choices: List[Dict], - factory: Callable, + factory: Optional[Callable], globalns: Any, parent_namespace: Optional[str], ) -> Iterator[XmlVar]: @@ -570,78 +583,6 @@ def is_any_type(cls, types: Sequence[Type], xml_type: str) -> bool: return False - @classmethod - def analyze_types( - cls, - model: Type, - name: str, - type_hint: Any, - globalns: Any, - ) -> Tuple[Any, Any, Tuple[Type, ...]]: - """Analyze a type hint and return the origin, sub origin and the type args. - - The only case we support a sub origin is for fields derived from - xs:NMTOKENS! - - # Todo please rewrite this in a way that makes sense :( - - Raises: - XmlContextError: if the typing is not supported for binding - """ - try: - types = evaluate(type_hint, globalns) - origin = None - sub_origin = None - - while types[0] in (tuple, list, dict): - if origin is None: - origin = types[0] - elif sub_origin is None: - sub_origin = types[0] - else: - raise TypeError() - - types = types[1:] - - return origin, sub_origin, tuple(converter.sort_types(types)) - except Exception: - raise XmlContextError( - f"Error on {model.__qualname__}::{name}: " - f"Unsupported field typing `{type_hint}`" - ) - - def is_valid( - self, - xml_type: str, - origin: Any, - sub_origin: Any, - types: Sequence[Type], - tokens: bool, - init: bool, - ) -> bool: - """Validate the given xml type against common unsupported cases.""" - if not init: - # Ignore init==false vars - return True - - if xml_type == XmlType.ATTRIBUTES: - # Attributes need origin dict, no sub origin and tokens - if origin is not dict or sub_origin or tokens: - return False - elif origin is dict or tokens and origin not in (list, tuple): - # Origin dict is only supported by Attributes - # xs:NMTOKENS need origin list - return False - - if object in types and xml_type != XmlType.ELEMENTS: - # Any type, secondary types are not allowed except for 'Elements' XML type - return len(types) == 1 - - if xml_type == XmlType.ELEMENTS: - return True - - return self.is_typing_supported(types) - def is_typing_supported(self, types: Sequence[Type]) -> bool: """Validate all types are registered in the converter.""" for tp in types: diff --git a/xsdata/formats/dataclass/models/typing.py b/xsdata/formats/dataclass/models/typing.py new file mode 100644 index 00000000..542aaf13 --- /dev/null +++ b/xsdata/formats/dataclass/models/typing.py @@ -0,0 +1,222 @@ +import sys +from typing import ( + Any, + Callable, + NamedTuple, + Optional, + Tuple, + Type, + Union, +) + +from typing_extensions import get_args, get_origin + +try: + from types import UnionType # type: ignore +except ImportError: + UnionType = () # type: ignore +from typing_extensions import ForwardRef + +if (3, 9) <= sys.version_info[:2] <= (3, 10): + # Backport this fix for python 3.9 and 3.10 + # https://github.com/python/cpython/pull/30900 + + from types import GenericAlias + from typing import ForwardRef + from typing import _eval_type as __eval_type # type: ignore + + def _eval_type(tp: Any, globalns: Any, localns: Any) -> Any: + if isinstance(tp, GenericAlias): + args = tuple( + ForwardRef(arg) if isinstance(arg, str) else arg for arg in tp.__args__ + ) + tp = tp.__origin__[args] # type: ignore + + return __eval_type(tp, globalns, localns) + +else: + from typing import _eval_type # type: ignore + + +NONE_TYPE = type(None) +UNION_TYPES = (Union, UnionType) +ITERABLE_TYPES = (list, tuple) + + +def evaluate(tp: Any, globalns: Any, localns: Any = None) -> Any: + """Analyze/Validate the typing annotation.""" + result = _eval_type(tp, globalns, localns) + + # Ugly hack for the Type["str"] + # Let's switch to ForwardRef("str") + if get_origin(result) is type: + args = get_args(result) + if len(args) != 1: + raise TypeError + + return args[0] + + return result + + +class Result(NamedTuple): + types: Tuple[Type[Any], ...] + factory: Optional[Callable] = None + tokens_factory: Optional[Callable] = None + + +def analyze_token_args(origin: Any, args: Tuple[Any, ...]) -> Tuple[Any]: + """Analyze token arguments. + + Ensure it only has one argument, filter out ellipsis. + + Args + origin: The annotation origin + args: The annotation arguments + + Returns: + A tuple that contains only one arg + + Raises: + TypeError: If the origin is not list or tuple, + and it has more than one argument + + """ + if origin in ITERABLE_TYPES: + args = filter_ellipsis(args) + if len(args) == 1: + return args + + raise TypeError + + +def filter_none_type(args: Tuple[Any, ...]) -> Tuple[Any, ...]: + return tuple(arg for arg in args if arg is not NONE_TYPE) + + +def filter_ellipsis(args: Tuple[Any, ...]) -> Tuple[Any]: + return tuple(arg for arg in args if arg is not Ellipsis) + + +def evaluate_text(annotation: Any, tokens: bool = False) -> Result: + """Run exactly the same validations with attribute.""" + return evaluate_attribute(annotation, tokens) + + +def evaluate_attribute(annotation: Any, tokens: bool = False) -> Result: + """Validate annotations for a xml attribute.""" + origin = get_origin(annotation) + args = get_args(annotation) + tokens_factory = None + + if tokens: + args = analyze_token_args(origin, args) + tokens_factory = origin + origin = get_origin(args[0]) + + if origin in UNION_TYPES: + args = get_args(args[0]) + elif origin: + raise TypeError + + if origin in UNION_TYPES: + types = filter_none_type(args) + elif origin is None: + types = args or (annotation,) + else: + raise TypeError + + if any(get_origin(tp) for tp in types): + raise TypeError + + return Result(types=types, tokens_factory=tokens_factory) + + +def evaluate_attributes(annotation: Any, **_: Any) -> Result: + """Validate annotations for xml wildcard attributes.""" + if annotation is dict: + args = () + else: + origin = get_origin(annotation) + args = get_args(annotation) + + if origin is not dict and annotation is not dict: + raise TypeError + + if args and not all(arg is str for arg in args): + raise TypeError + + return Result(types=(str,), factory=dict) + + +def evaluate_element(annotation: Any, tokens: bool = False) -> Result: + """Validate annotations for a xml element.""" + types = (annotation,) + origin = get_origin(annotation) + args = get_args(annotation) + tokens_factory = factory = None + + if tokens: + args = analyze_token_args(origin, args) + + tokens_factory = origin + origin = get_origin(args[0]) + types = args + args = get_args(args[0]) + + if origin in ITERABLE_TYPES: + args = tuple(arg for arg in args if arg is not Ellipsis) + if len(args) != 1: + raise TypeError + + if tokens_factory: + factory = tokens_factory + tokens_factory = origin + else: + factory = origin + + types = args + origin = get_origin(args[0]) + args = get_args(args[0]) + + if origin in UNION_TYPES: + types = filter_none_type(args) + elif origin: + raise TypeError + + return Result(types=types, factory=factory, tokens_factory=tokens_factory) + + +def evaluate_elements(annotation: Any, **_: Any) -> Result: + """Validate annotations for a xml compound field.""" + ( + types, + factory, + __, + ) = evaluate_element(annotation, tokens=False) + + for tp in types: + evaluate_element(tp, tokens=False) + + return Result(types=(object,), factory=factory) + + +def evaluate_wildcard(annotation: Any, **_: Any) -> Result: + """Validate annotations for a xml wildcard.""" + origin = get_origin(annotation) + factory = None + + if origin in UNION_TYPES: + types = filter_none_type(get_args(annotation)) + elif origin in ITERABLE_TYPES: + factory = origin + types = filter_ellipsis(get_args(annotation)) + elif origin is None: + types = (annotation,) + else: + raise TypeError + + if len(types) != 1 or object not in types: + raise TypeError + + return Result(types=types, factory=factory) diff --git a/xsdata/formats/dataclass/parsers/dict.py b/xsdata/formats/dataclass/parsers/dict.py index 7f6f6d37..4bb7ad9b 100644 --- a/xsdata/formats/dataclass/parsers/dict.py +++ b/xsdata/formats/dataclass/parsers/dict.py @@ -2,13 +2,14 @@ from dataclasses import dataclass, field from typing import Any, Dict, Iterable, List, Optional, Type, Union +from typing_extensions import get_args, get_origin + from xsdata.exceptions import ConverterWarning, ParserError from xsdata.formats.converter import converter from xsdata.formats.dataclass.context import XmlContext from xsdata.formats.dataclass.models.elements import XmlMeta, XmlVar from xsdata.formats.dataclass.parsers.config import ParserConfig from xsdata.formats.dataclass.parsers.utils import ParserUtils -from xsdata.formats.dataclass.typing import get_args, get_origin from xsdata.formats.types import T from xsdata.utils import collections from xsdata.utils.constants import EMPTY_MAP diff --git a/xsdata/formats/dataclass/typing.py b/xsdata/formats/dataclass/typing.py deleted file mode 100644 index 3da7edeb..00000000 --- a/xsdata/formats/dataclass/typing.py +++ /dev/null @@ -1,177 +0,0 @@ -import sys -from typing import Any, Iterator, Tuple, Type, TypeVar, Union - -from typing_extensions import get_args, get_origin - -NONE_TYPE = type(None) - - -try: - from types import UnionType # type: ignore -except ImportError: - UnionType = () # type: ignore - - -if (3, 9) <= sys.version_info[:2] <= (3, 10): - # Backport this fix for python 3.9 and 3.10 - # https://github.com/python/cpython/pull/30900 - - from types import GenericAlias - from typing import ForwardRef - from typing import _eval_type as __eval_type # type: ignore - - def _eval_type(tp: Any, globalns: Any, localns: Any) -> Any: - if isinstance(tp, GenericAlias): - args = tuple( - ForwardRef(arg) if isinstance(arg, str) else arg for arg in tp.__args__ - ) - tp = tp.__origin__[args] # type: ignore - - return __eval_type(tp, globalns, localns) - -else: - from typing import _eval_type # type: ignore - - -intern_typing = sys.intern("typing.") - - -def is_from_typing(tp: Any) -> bool: - """Return whether the type is from the typing module.""" - return str(tp).startswith(intern_typing) - - -def evaluate( - tp: Any, - globalns: Any = None, - localns: Any = None, -) -> Tuple[Type, ...]: - """Analyze/Validate the typing annotation.""" - return tuple(_evaluate(_eval_type(tp, globalns, localns))) - - -def _evaluate(tp: Any) -> Iterator[Type]: - if tp in (dict, list, tuple): - origin = tp - elif isinstance(tp, TypeVar): - origin = TypeVar - else: - origin = get_origin(tp) - - if origin: - try: - yield from __evaluations__[origin](tp) - except KeyError: - raise TypeError() - elif is_from_typing(tp): - raise TypeError() - else: - yield tp - - -def _evaluate_type(tp: Any) -> Iterator[Type]: - args = get_args(tp) - if not args or isinstance(args[0], TypeVar): - raise TypeError() - yield from _evaluate(args[0]) - - -def _evaluate_mapping(tp: Any) -> Iterator[Type]: - yield dict - args = get_args(tp) - - if not args: - yield str - yield str - - for arg in args: - if isinstance(arg, TypeVar): - try: - next(_evaluate_typevar(arg)) - except TypeError: - yield str - else: - raise TypeError() - elif is_from_typing(arg) or get_origin(arg) is not None: - raise TypeError() - else: - yield arg - - -def _evaluate_list(tp: Any) -> Iterator[Type]: - yield list - - args = get_args(tp) - if not args: - yield str - - for arg in args: - yield from _evaluate_array_arg(arg) - - -def _evaluate_array_arg(arg: Any) -> Iterator[Type]: - if isinstance(arg, TypeVar): - yield from _evaluate_typevar(arg) - else: - origin = get_origin(arg) - - if origin is None and not is_from_typing(arg): - yield arg - elif origin in (Union, UnionType, list, tuple): - yield from __evaluations__[origin](arg) - else: - raise TypeError() - - -def _evaluate_tuple(tp: Any) -> Iterator[Type]: - yield tuple - - args = get_args(tp) - if not args: - yield str - - for arg in args: - if arg is Ellipsis: - continue - - yield from _evaluate_array_arg(arg) - - -def _evaluate_union(tp: Any) -> Iterator[Type]: - for arg in get_args(tp): - if arg is NONE_TYPE: - continue - - if isinstance(arg, TypeVar): - yield from _evaluate_typevar(arg) - else: - origin = get_origin(arg) - if origin is list: - yield from _evaluate_list(arg) - elif origin is tuple: - yield from _evaluate_tuple(arg) - elif origin is None and not is_from_typing(arg): - yield arg - else: - raise TypeError() - - -def _evaluate_typevar(tp: TypeVar): - if tp.__bound__: - yield from _evaluate(tp.__bound__) - elif tp.__constraints__: - for arg in tp.__constraints__: - yield from _evaluate(arg) - else: - raise TypeError() - - -__evaluations__ = { - tuple: _evaluate_tuple, - list: _evaluate_list, - dict: _evaluate_mapping, - Union: _evaluate_union, - UnionType: _evaluate_union, - type: _evaluate_type, - TypeVar: _evaluate_typevar, -} diff --git a/xsdata/models/xsd.py b/xsdata/models/xsd.py index 0effca9d..134dd1b2 100644 --- a/xsdata/models/xsd.py +++ b/xsdata/models/xsd.py @@ -962,7 +962,7 @@ class Import(AnnotationBase): namespace: Optional[str] = attribute() schema_location: Optional[str] = attribute(name="schemaLocation") - location: Optional[str] = field(default=None, metadata={"type": "ignore"}) + location: Optional[str] = field(default=None, metadata={"type": "Ignore"}) @dataclass @@ -970,7 +970,7 @@ class Include(AnnotationBase): """XSD Include model representation.""" schema_location: Optional[str] = attribute(name="schemaLocation") - location: Optional[str] = field(default=None, metadata={"type": "ignore"}) + location: Optional[str] = field(default=None, metadata={"type": "Ignore"}) @dataclass @@ -982,7 +982,7 @@ class Redefine(AnnotationBase): complex_types: Array[ComplexType] = array_element(name="complexType") groups: Array[Group] = array_element(name="group") attribute_groups: Array[AttributeGroup] = array_element(name="attributeGroup") - location: Optional[str] = field(default=None, metadata={"type": "ignore"}) + location: Optional[str] = field(default=None, metadata={"type": "Ignore"}) @dataclass @@ -997,7 +997,7 @@ class Override(AnnotationBase): elements: Array[Element] = array_element(name="element") attributes: Array[Attribute] = array_element(name="attribute") notations: Array[Notation] = array_element(name="notation") - location: Optional[str] = field(default=None, metadata={"type": "ignore"}) + location: Optional[str] = field(default=None, metadata={"type": "Ignore"}) @dataclass @@ -1040,7 +1040,7 @@ class Meta: elements: Array[Element] = array_element(name="element") attributes: Array[Attribute] = array_element(name="attribute") notations: Array[Notation] = array_element(name="notation") - location: Optional[str] = field(default=None, metadata={"type": "ignore"}) + location: Optional[str] = field(default=None, metadata={"type": "Ignore"}) def included(self) -> Iterator[UnionType[Import, Include, Redefine, Override]]: """Yields an iterator of included resources."""