From 3a8d57455ad5fdc22c52f3427da791f7b8354bd7 Mon Sep 17 00:00:00 2001 From: Christodoulos Tsoulloftas Date: Sun, 20 Oct 2024 17:57:11 +0300 Subject: [PATCH] Add cli config to use generic collections --- .pre-commit-config.yaml | 8 ++++---- tests/formats/dataclass/cases/attribute.py | 2 ++ tests/formats/dataclass/cases/attributes.py | 2 ++ tests/formats/dataclass/cases/element.py | 2 ++ tests/formats/dataclass/cases/wildcard.py | 2 ++ tests/formats/dataclass/test_filters.py | 14 ++++++++++++++ tests/models/test_config.py | 16 ++++++++++++++-- xsdata/formats/dataclass/client.py | 2 +- xsdata/formats/dataclass/filters.py | 16 ++++++++++++++-- xsdata/formats/dataclass/typing.py | 16 +++++++++++++--- xsdata/models/config.py | 9 +++++++++ 11 files changed, 77 insertions(+), 12 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7a8d0a35..1117e8f6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,24 +2,24 @@ exclude: tests/fixtures repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer - id: debug-statements - repo: https://github.com/crate-ci/typos - rev: v1.24.6 + rev: v1.26.0 hooks: - id: typos exclude: ^tests/|.xsd|xsdata/models/datatype.py$ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.6 + rev: v0.7.0 hooks: - id: ruff args: [ --fix, --show-fixes] - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.2 + rev: v1.12.1 hooks: - id: mypy files: ^(xsdata/) diff --git a/tests/formats/dataclass/cases/attribute.py b/tests/formats/dataclass/cases/attribute.py index e4d75a58..beae5aa2 100644 --- a/tests/formats/dataclass/cases/attribute.py +++ b/tests/formats/dataclass/cases/attribute.py @@ -1,3 +1,4 @@ +from collections.abc import Iterable from typing import Dict, List, Literal, Optional, Set, Tuple, Union from tests.formats.dataclass.cases import PY39, PY310 @@ -12,6 +13,7 @@ (List[Union[List[int], int]], False), (List[List[int]], False), (Tuple[int, ...], ((int,), None, tuple)), + (Iterable[int], ((int,), None, list)), (List[int], ((int,), None, list)), (List[Union[str, int]], ((str, int), None, list)), (Optional[List[Union[str, int]]], ((str, int), None, list)), diff --git a/tests/formats/dataclass/cases/attributes.py b/tests/formats/dataclass/cases/attributes.py index f6d1a155..5f0f8fac 100644 --- a/tests/formats/dataclass/cases/attributes.py +++ b/tests/formats/dataclass/cases/attributes.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from typing import Dict, List, Set, Tuple from tests.formats.dataclass.cases import PY39 @@ -10,6 +11,7 @@ (Dict[str, int], False), (Dict, ((str,), dict, None)), (Dict[str, str], ((str,), dict, None)), + (Mapping[str, str], ((str,), dict, None)), ] if PY39: diff --git a/tests/formats/dataclass/cases/element.py b/tests/formats/dataclass/cases/element.py index 24310ab5..a2640b26 100644 --- a/tests/formats/dataclass/cases/element.py +++ b/tests/formats/dataclass/cases/element.py @@ -1,3 +1,4 @@ +from collections.abc import Iterable from typing import Dict, List, Optional, Set, Tuple, Union from tests.formats.dataclass.cases import PY39 @@ -12,6 +13,7 @@ (List[List[str]], ((str,), list, list)), (Optional[List[List[Union[str, int]]]], ((str, int), list, list)), (List[Tuple[str, ...]], ((str,), list, tuple)), + (Iterable[Iterable[str, ...]], ((str,), list, list)), (Tuple[List[str], ...], ((str,), tuple, list)), (Optional[Tuple[List[str], ...]], ((str,), tuple, list)), ] diff --git a/tests/formats/dataclass/cases/wildcard.py b/tests/formats/dataclass/cases/wildcard.py index d84e8c4e..47940d05 100644 --- a/tests/formats/dataclass/cases/wildcard.py +++ b/tests/formats/dataclass/cases/wildcard.py @@ -1,3 +1,4 @@ +from collections.abc import Iterable from typing import Dict, List, Literal, Optional, Set, Tuple from tests.formats.dataclass.cases import PY310 @@ -11,6 +12,7 @@ (object, ((object,), None, None)), (List[object], ((object,), list, None)), (Tuple[object, ...], ((object,), tuple, None)), + (Iterable[object, ...], ((object,), list, None)), (Optional[object], ((object,), None, None)), ] diff --git a/tests/formats/dataclass/test_filters.py b/tests/formats/dataclass/test_filters.py index ebd10c60..eeb4ea31 100644 --- a/tests/formats/dataclass/test_filters.py +++ b/tests/formats/dataclass/test_filters.py @@ -643,6 +643,9 @@ def test_field_type_with_array_type(self): self.filters.format.frozen = False self.assertEqual('list["A.B.C"]', self.filters.field_type(self.obj, attr)) + self.filters.generic_collections = True + self.assertEqual('Iterable["A.B.C"]', self.filters.field_type(self.obj, attr)) + def test_field_type_with_token_attr(self): attr = AttrFactory.create( types=AttrTypeFactory.list(1, qname="foo_bar"), @@ -710,6 +713,9 @@ def test_field_type_with_any_attribute(self): self.filters.subscriptable_types = True self.assertEqual("dict[str, str]", self.filters.field_type(self.obj, attr)) + self.filters.generic_collections = True + self.assertEqual("Mapping[str, str]", self.filters.field_type(self.obj, attr)) + def test_field_type_with_native_type(self): attr = AttrFactory.create( types=[ @@ -903,6 +909,14 @@ def test_default_imports_with_typing(self): expected = "from typing import Any" self.assertIn(expected, self.filters.default_imports(output)) + output = ": Iterable[str] = " + expected = "from collections.abc import Iterable" + self.assertIn(expected, self.filters.default_imports(output)) + + output = ": Mapping[str, str] = " + expected = "from collections.abc import Mapping" + self.assertIn(expected, self.filters.default_imports(output)) + def test_default_imports_combo(self): output = ( "@dataclass\n" diff --git a/tests/models/test_config.py b/tests/models/test_config.py index 656ae7b7..c4c89596 100644 --- a/tests/models/test_config.py +++ b/tests/models/test_config.py @@ -29,7 +29,7 @@ def test_create(self): expected = ( '\n' f'\n' - ' \n' + ' \n' " generated\n" ' dataclasses\n' " filenames\n" @@ -89,7 +89,7 @@ def test_read(self): expected = ( '\n' f'\n' - ' \n' + ' \n' " foo.bar\n" ' dataclasses\n' @@ -172,6 +172,18 @@ def test_use_union_type_requires_310_and_postponed_annotations(self): str(w[-1].message), ) + def test_generic_collections_requires_frozen_false(self): + with warnings.catch_warnings(record=True) as w: + output = GeneratorOutput( + generic_collections=True, format=OutputFormat(frozen=True) + ) + self.assertFalse(output.generic_collections) + + self.assertEqual( + "Generic Collections, requires frozen=False, reverting...", + str(w[-1].message), + ) + def test_format_slots_requires_310(self): if sys.version_info < (3, 10): self.assertTrue(OutputFormat(slots=True, value="attrs").slots) diff --git a/xsdata/formats/dataclass/client.py b/xsdata/formats/dataclass/client.py index 64a14f6f..98a90ff6 100644 --- a/xsdata/formats/dataclass/client.py +++ b/xsdata/formats/dataclass/client.py @@ -51,7 +51,7 @@ def from_service(cls, obj: Any, **kwargs: Any) -> "Config": for f in fields(cls) } - return cls(**params) + return cls(**params) # type: ignore class TransportTypes: diff --git a/xsdata/formats/dataclass/filters.py b/xsdata/formats/dataclass/filters.py index 32715c88..574a47b7 100644 --- a/xsdata/formats/dataclass/filters.py +++ b/xsdata/formats/dataclass/filters.py @@ -58,6 +58,7 @@ class Filters: "max_line_length", "union_type", "subscriptable_types", + "generic_collections", "relative_imports", "postponed_annotations", "format", @@ -106,6 +107,7 @@ def __init__(self, config: GeneratorConfig): self.max_line_length: int = config.output.max_line_length self.union_type: bool = config.output.union_type self.subscriptable_types: bool = config.output.subscriptable_types + self.generic_collections: bool = config.output.generic_collections self.relative_imports: bool = config.output.relative_imports self.postponed_annotations: bool = config.output.postponed_annotations self.format = config.output.format @@ -758,6 +760,9 @@ def field_type(self, obj: Class, attr: Attr) -> str: return result if attr.is_dict: + if self.generic_collections: + return "Mapping[str, str]" + return "dict[str, str]" if self.subscriptable_types else "Dict[str, str]" if attr.is_nillable or ( @@ -889,7 +894,10 @@ def default_imports(self, output: str) -> str: return "\n".join(collections.unique_sequence(imports)) - def _get_iterable_format(self): + def _get_iterable_format(self) -> str: + if self.generic_collections: + return "Iterable[{}]" + fmt = "Tuple[{}, ...]" if self.format.frozen else "List[{}]" return fmt.lower() if self.subscriptable_types else fmt @@ -902,7 +910,7 @@ def build_import_patterns(cls) -> Dict[str, Dict]: "decimal": {"Decimal": type_patterns("Decimal")}, "enum": {"Enum": ["(Enum)"]}, "typing": { - "Dict": [": Dict"], + "Dict": [": Dict["], "List": [": List["], "Optional": ["Optional["], "Tuple": ["Tuple["], @@ -910,6 +918,10 @@ def build_import_patterns(cls) -> Dict[str, Dict]: "ForwardRef": [": ForwardRef("], "Any": type_patterns("Any"), }, + "collections.abc": { + "Iterable": [": Iterable["], + "Mapping": [": Mapping["], + }, "xml.etree.ElementTree": {"QName": type_patterns("QName")}, "xsdata.models.datatype": { "XmlDate": type_patterns("XmlDate"), diff --git a/xsdata/formats/dataclass/typing.py b/xsdata/formats/dataclass/typing.py index 8cb39c6e..cb1202fd 100644 --- a/xsdata/formats/dataclass/typing.py +++ b/xsdata/formats/dataclass/typing.py @@ -1,4 +1,5 @@ import sys +from collections.abc import Iterable, Iterator, Mapping from typing import ( Any, Callable, @@ -45,7 +46,7 @@ def _eval_type(tp: Any, globalns: Any, localns: Any) -> Any: NONE_TYPE = type(None) UNION_TYPES = (Union, UnionType) -ITERABLE_TYPES = (list, tuple) +ITERABLE_TYPES = (list, tuple, Iterable) def evaluate(tp: Any, globalns: Any, localns: Any = None) -> Any: @@ -143,6 +144,9 @@ def evaluate_attribute(annotation: Any, tokens: bool = False) -> Result: args = analyze_token_args(origin, args) tokens_factory = origin + if tokens_factory is Iterable: + tokens_factory = list + origin = get_origin(args[0]) if origin in UNION_TYPES: @@ -171,7 +175,7 @@ def evaluate_attributes(annotation: Any, **_: Any) -> Result: origin = get_origin(annotation) args = get_args(annotation) - if origin is not dict and annotation is not dict: + if origin is not dict and origin is not Mapping: raise TypeError if args and not all(arg is str for arg in args): @@ -222,6 +226,12 @@ def evaluate_element(annotation: Any, tokens: bool = False) -> Result: elif origin: raise TypeError + if factory is Iterable: + factory = list + + if tokens_factory is Iterable: + tokens_factory = list + return Result(types=types, factory=factory, tokens_factory=tokens_factory) @@ -247,7 +257,7 @@ def evaluate_wildcard(annotation: Any, **_: Any) -> Result: if origin in UNION_TYPES: types = filter_none_type(get_args(annotation)) elif origin in ITERABLE_TYPES: - factory = origin + factory = list if origin is Iterable else origin types = filter_ellipsis(get_args(annotation)) elif origin is None: types = (annotation,) diff --git a/xsdata/models/config.py b/xsdata/models/config.py index 95d51b9e..1f8d1b5b 100644 --- a/xsdata/models/config.py +++ b/xsdata/models/config.py @@ -225,6 +225,7 @@ class GeneratorOutput: max_line_length: Adjust the maximum line length subscriptable_types: Use PEP-585 generics for standard collections, python>=3.9 Only + generic_collections: Use generic collections (Iterable, Mapping) union_type: Use PEP-604 union type, python>=3.10 Only postponed_annotations: Use 563 postponed evaluation of annotations unnest_classes: Move inner classes to upper level @@ -243,6 +244,7 @@ class GeneratorOutput: wrapper_fields: bool = element(default=False) max_line_length: int = attribute(default=79) subscriptable_types: bool = attribute(default=False) + generic_collections: bool = attribute(default=False) union_type: bool = attribute(default=False) postponed_annotations: bool = element(default=False) unnest_classes: bool = element(default=False) @@ -276,6 +278,13 @@ def validate(self): CodegenWarning, ) + if self.generic_collections and self.format.frozen: + self.generic_collections = False + warnings.warn( + "Generic Collections, requires frozen=False, reverting...", + CodegenWarning, + ) + def update(self, **kwargs: Any): """Update instance attributes recursively.""" objects.update(self, **kwargs)