Skip to content

Commit

Permalink
Add cli config to use generic collections
Browse files Browse the repository at this point in the history
  • Loading branch information
tefra committed Oct 20, 2024
1 parent 534857b commit 3a8d574
Show file tree
Hide file tree
Showing 11 changed files with 77 additions and 12 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,24 @@ exclude: tests/fixtures

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: debug-statements
- repo: https://github.com/crate-ci/typos
rev: v1.24.6
rev: v1.26.0
hooks:
- id: typos
exclude: ^tests/|.xsd|xsdata/models/datatype.py$
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.6
rev: v0.7.0
hooks:
- id: ruff
args: [ --fix, --show-fixes]
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.2
rev: v1.12.1
hooks:
- id: mypy
files: ^(xsdata/)
Expand Down
2 changes: 2 additions & 0 deletions tests/formats/dataclass/cases/attribute.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Iterable
from typing import Dict, List, Literal, Optional, Set, Tuple, Union

from tests.formats.dataclass.cases import PY39, PY310
Expand All @@ -12,6 +13,7 @@
(List[Union[List[int], int]], False),
(List[List[int]], False),
(Tuple[int, ...], ((int,), None, tuple)),
(Iterable[int], ((int,), None, list)),
(List[int], ((int,), None, list)),
(List[Union[str, int]], ((str, int), None, list)),
(Optional[List[Union[str, int]]], ((str, int), None, list)),
Expand Down
2 changes: 2 additions & 0 deletions tests/formats/dataclass/cases/attributes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Mapping
from typing import Dict, List, Set, Tuple

from tests.formats.dataclass.cases import PY39
Expand All @@ -10,6 +11,7 @@
(Dict[str, int], False),
(Dict, ((str,), dict, None)),
(Dict[str, str], ((str,), dict, None)),
(Mapping[str, str], ((str,), dict, None)),
]

if PY39:
Expand Down
2 changes: 2 additions & 0 deletions tests/formats/dataclass/cases/element.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Iterable
from typing import Dict, List, Optional, Set, Tuple, Union

from tests.formats.dataclass.cases import PY39
Expand All @@ -12,6 +13,7 @@
(List[List[str]], ((str,), list, list)),
(Optional[List[List[Union[str, int]]]], ((str, int), list, list)),
(List[Tuple[str, ...]], ((str,), list, tuple)),
(Iterable[Iterable[str, ...]], ((str,), list, list)),
(Tuple[List[str], ...], ((str,), tuple, list)),
(Optional[Tuple[List[str], ...]], ((str,), tuple, list)),
]
Expand Down
2 changes: 2 additions & 0 deletions tests/formats/dataclass/cases/wildcard.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Iterable
from typing import Dict, List, Literal, Optional, Set, Tuple

from tests.formats.dataclass.cases import PY310
Expand All @@ -11,6 +12,7 @@
(object, ((object,), None, None)),
(List[object], ((object,), list, None)),
(Tuple[object, ...], ((object,), tuple, None)),
(Iterable[object, ...], ((object,), list, None)),
(Optional[object], ((object,), None, None)),
]

Expand Down
14 changes: 14 additions & 0 deletions tests/formats/dataclass/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,9 @@ def test_field_type_with_array_type(self):
self.filters.format.frozen = False
self.assertEqual('list["A.B.C"]', self.filters.field_type(self.obj, attr))

self.filters.generic_collections = True
self.assertEqual('Iterable["A.B.C"]', self.filters.field_type(self.obj, attr))

def test_field_type_with_token_attr(self):
attr = AttrFactory.create(
types=AttrTypeFactory.list(1, qname="foo_bar"),
Expand Down Expand Up @@ -710,6 +713,9 @@ def test_field_type_with_any_attribute(self):
self.filters.subscriptable_types = True
self.assertEqual("dict[str, str]", self.filters.field_type(self.obj, attr))

self.filters.generic_collections = True
self.assertEqual("Mapping[str, str]", self.filters.field_type(self.obj, attr))

def test_field_type_with_native_type(self):
attr = AttrFactory.create(
types=[
Expand Down Expand Up @@ -903,6 +909,14 @@ def test_default_imports_with_typing(self):
expected = "from typing import Any"
self.assertIn(expected, self.filters.default_imports(output))

output = ": Iterable[str] = "
expected = "from collections.abc import Iterable"
self.assertIn(expected, self.filters.default_imports(output))

output = ": Mapping[str, str] = "
expected = "from collections.abc import Mapping"
self.assertIn(expected, self.filters.default_imports(output))

def test_default_imports_combo(self):
output = (
"@dataclass\n"
Expand Down
16 changes: 14 additions & 2 deletions tests/models/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_create(self):
expected = (
'<?xml version="1.0" encoding="UTF-8"?>\n'
f'<Config xmlns="http://pypi.org/project/xsdata" version="{__version__}">\n'
' <Output maxLineLength="79" subscriptableTypes="false" unionType="false">\n'
' <Output maxLineLength="79" subscriptableTypes="false" genericCollections="false" unionType="false">\n'
" <Package>generated</Package>\n"
' <Format repr="true" eq="true" order="false" unsafeHash="false" frozen="false" slots="false" kwOnly="false">dataclasses</Format>\n'
" <Structure>filenames</Structure>\n"
Expand Down Expand Up @@ -89,7 +89,7 @@ def test_read(self):
expected = (
'<?xml version="1.0" encoding="UTF-8"?>\n'
f'<Config xmlns="http://pypi.org/project/xsdata" version="{__version__}">\n'
' <Output maxLineLength="79" subscriptableTypes="false" unionType="false">\n'
' <Output maxLineLength="79" subscriptableTypes="false" genericCollections="false" unionType="false">\n'
" <Package>foo.bar</Package>\n"
' <Format repr="true" eq="true" order="false" unsafeHash="false"'
' frozen="false" slots="false" kwOnly="false">dataclasses</Format>\n'
Expand Down Expand Up @@ -172,6 +172,18 @@ def test_use_union_type_requires_310_and_postponed_annotations(self):
str(w[-1].message),
)

def test_generic_collections_requires_frozen_false(self):
with warnings.catch_warnings(record=True) as w:
output = GeneratorOutput(
generic_collections=True, format=OutputFormat(frozen=True)
)
self.assertFalse(output.generic_collections)

self.assertEqual(
"Generic Collections, requires frozen=False, reverting...",
str(w[-1].message),
)

def test_format_slots_requires_310(self):
if sys.version_info < (3, 10):
self.assertTrue(OutputFormat(slots=True, value="attrs").slots)
Expand Down
2 changes: 1 addition & 1 deletion xsdata/formats/dataclass/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def from_service(cls, obj: Any, **kwargs: Any) -> "Config":
for f in fields(cls)
}

return cls(**params)
return cls(**params) # type: ignore


class TransportTypes:
Expand Down
16 changes: 14 additions & 2 deletions xsdata/formats/dataclass/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class Filters:
"max_line_length",
"union_type",
"subscriptable_types",
"generic_collections",
"relative_imports",
"postponed_annotations",
"format",
Expand Down Expand Up @@ -106,6 +107,7 @@ def __init__(self, config: GeneratorConfig):
self.max_line_length: int = config.output.max_line_length
self.union_type: bool = config.output.union_type
self.subscriptable_types: bool = config.output.subscriptable_types
self.generic_collections: bool = config.output.generic_collections
self.relative_imports: bool = config.output.relative_imports
self.postponed_annotations: bool = config.output.postponed_annotations
self.format = config.output.format
Expand Down Expand Up @@ -758,6 +760,9 @@ def field_type(self, obj: Class, attr: Attr) -> str:
return result

if attr.is_dict:
if self.generic_collections:
return "Mapping[str, str]"

return "dict[str, str]" if self.subscriptable_types else "Dict[str, str]"

if attr.is_nillable or (
Expand Down Expand Up @@ -889,7 +894,10 @@ def default_imports(self, output: str) -> str:

return "\n".join(collections.unique_sequence(imports))

def _get_iterable_format(self):
def _get_iterable_format(self) -> str:
if self.generic_collections:
return "Iterable[{}]"

fmt = "Tuple[{}, ...]" if self.format.frozen else "List[{}]"
return fmt.lower() if self.subscriptable_types else fmt

Expand All @@ -902,14 +910,18 @@ def build_import_patterns(cls) -> Dict[str, Dict]:
"decimal": {"Decimal": type_patterns("Decimal")},
"enum": {"Enum": ["(Enum)"]},
"typing": {
"Dict": [": Dict"],
"Dict": [": Dict["],
"List": [": List["],
"Optional": ["Optional["],
"Tuple": ["Tuple["],
"Union": ["Union["],
"ForwardRef": [": ForwardRef("],
"Any": type_patterns("Any"),
},
"collections.abc": {
"Iterable": [": Iterable["],
"Mapping": [": Mapping["],
},
"xml.etree.ElementTree": {"QName": type_patterns("QName")},
"xsdata.models.datatype": {
"XmlDate": type_patterns("XmlDate"),
Expand Down
16 changes: 13 additions & 3 deletions xsdata/formats/dataclass/typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
from collections.abc import Iterable, Iterator, Mapping
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -45,7 +46,7 @@ def _eval_type(tp: Any, globalns: Any, localns: Any) -> Any:

NONE_TYPE = type(None)
UNION_TYPES = (Union, UnionType)
ITERABLE_TYPES = (list, tuple)
ITERABLE_TYPES = (list, tuple, Iterable)


def evaluate(tp: Any, globalns: Any, localns: Any = None) -> Any:
Expand Down Expand Up @@ -143,6 +144,9 @@ def evaluate_attribute(annotation: Any, tokens: bool = False) -> Result:

args = analyze_token_args(origin, args)
tokens_factory = origin
if tokens_factory is Iterable:
tokens_factory = list

origin = get_origin(args[0])

if origin in UNION_TYPES:
Expand Down Expand Up @@ -171,7 +175,7 @@ def evaluate_attributes(annotation: Any, **_: Any) -> Result:
origin = get_origin(annotation)
args = get_args(annotation)

if origin is not dict and annotation is not dict:
if origin is not dict and origin is not Mapping:
raise TypeError

if args and not all(arg is str for arg in args):
Expand Down Expand Up @@ -222,6 +226,12 @@ def evaluate_element(annotation: Any, tokens: bool = False) -> Result:
elif origin:
raise TypeError

if factory is Iterable:
factory = list

if tokens_factory is Iterable:
tokens_factory = list

return Result(types=types, factory=factory, tokens_factory=tokens_factory)


Expand All @@ -247,7 +257,7 @@ def evaluate_wildcard(annotation: Any, **_: Any) -> Result:
if origin in UNION_TYPES:
types = filter_none_type(get_args(annotation))
elif origin in ITERABLE_TYPES:
factory = origin
factory = list if origin is Iterable else origin
types = filter_ellipsis(get_args(annotation))
elif origin is None:
types = (annotation,)
Expand Down
9 changes: 9 additions & 0 deletions xsdata/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ class GeneratorOutput:
max_line_length: Adjust the maximum line length
subscriptable_types: Use PEP-585 generics for standard
collections, python>=3.9 Only
generic_collections: Use generic collections (Iterable, Mapping)
union_type: Use PEP-604 union type, python>=3.10 Only
postponed_annotations: Use 563 postponed evaluation of annotations
unnest_classes: Move inner classes to upper level
Expand All @@ -243,6 +244,7 @@ class GeneratorOutput:
wrapper_fields: bool = element(default=False)
max_line_length: int = attribute(default=79)
subscriptable_types: bool = attribute(default=False)
generic_collections: bool = attribute(default=False)
union_type: bool = attribute(default=False)
postponed_annotations: bool = element(default=False)
unnest_classes: bool = element(default=False)
Expand Down Expand Up @@ -276,6 +278,13 @@ def validate(self):
CodegenWarning,
)

if self.generic_collections and self.format.frozen:
self.generic_collections = False
warnings.warn(
"Generic Collections, requires frozen=False, reverting...",
CodegenWarning,
)

def update(self, **kwargs: Any):
"""Update instance attributes recursively."""
objects.update(self, **kwargs)
Expand Down

0 comments on commit 3a8d574

Please sign in to comment.