Skip to content

Commit

Permalink
Fix imports ordering (#841)
Browse files Browse the repository at this point in the history
  • Loading branch information
tefra authored Jul 25, 2023
1 parent 8ed7487 commit 548b736
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 12 deletions.
17 changes: 13 additions & 4 deletions tests/formats/dataclass/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,12 +877,21 @@ def test_default_imports_with_module(self):
self.assertEqual(expected, self.filters.default_imports(output))

def test_default_imports_with_annotations(self):
config = GeneratorConfig()
config.output.postponed_annotations = True
filters = Filters(config)
self.filters.postponed_annotations = True

expected = "from __future__ import annotations"
self.assertEqual(expected, filters.default_imports(""))
self.assertEqual(expected, self.filters.default_imports(""))

def test_default_imports_ordering(self):
self.filters.postponed_annotations = True
self.filters.import_patterns["attrs"] = {"__module__": ["@attrs.s"]}

expected = (
"from __future__ import annotations\n"
"import attrs\n"
"from dataclasses import dataclass"
)
self.assertEqual(expected, self.filters.default_imports("@dataclass @attrs.s"))

def test_format_metadata(self):
data = dict(
Expand Down
17 changes: 9 additions & 8 deletions xsdata/formats/dataclass/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,11 +764,8 @@ def constant_value(self, attr: Attr) -> str:

def default_imports(self, output: str) -> str:
"""Generate the default imports for the given package output."""
result = []

if self.postponed_annotations:
result.append("from __future__ import annotations")

module_imports = set()
func_imports = set()
for library, types in self.import_patterns.items():
names = [
name
Expand All @@ -777,11 +774,15 @@ def default_imports(self, output: str) -> str:
]

if len(names) == 1 and names[0] == "__module__":
result.append(f"import {library}")
module_imports.add(f"import {library}")
elif names:
result.append(f"from {library} import {', '.join(names)}")
func_imports.add(f"from {library} import {', '.join(names)}")

imports = sorted(module_imports) + sorted(func_imports)
if self.postponed_annotations:
imports.insert(0, "from __future__ import annotations")

return "\n".join(sorted(result))
return "\n".join(imports)

@classmethod
def build_import_patterns(cls) -> Dict[str, Dict]:
Expand Down

0 comments on commit 548b736

Please sign in to comment.