diff --git a/docs/conf.py b/docs/conf.py index e26f6be6d..3054a02d6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -97,3 +97,5 @@ apidoc_output_dir = "reference" apidoc_excluded_paths = ["tests"] apidoc_separate_modules = True +set_type_checking_flag = True +always_document_param_types = True diff --git a/tests/test_analyzer.py b/tests/test_analyzer.py index 1fce8bc9e..e79b70b81 100644 --- a/tests/test_analyzer.py +++ b/tests/test_analyzer.py @@ -11,7 +11,6 @@ from xsdata.analyzer import ClassAnalyzer from xsdata.exceptions import AnalyzerError from xsdata.models.codegen import Class -from xsdata.models.codegen import Restrictions from xsdata.models.elements import ComplexType from xsdata.models.elements import Element from xsdata.models.elements import SimpleType @@ -30,16 +29,14 @@ def setUp(self): @mock.patch.object(ClassAnalyzer, "fetch_classes_for_generation") @mock.patch.object(ClassAnalyzer, "flatten_classes") - @mock.patch.object(ClassAnalyzer, "update_abstract_classes") @mock.patch.object(ClassAnalyzer, "create_substitutions_index") + @mock.patch.object(ClassAnalyzer, "handle_duplicate_classes") @mock.patch.object(ClassAnalyzer, "create_class_index") - @mock.patch.object(ClassAnalyzer, "merge_redefined_classes") def test_process( self, - mock_merge_redefined_classes, mock_create_class_index, + mock_handle_duplicates, mock_create_substitutions_index, - mock_update_abstract_classes, mock_flatten_classes, mock_fetch_classes_for_generation, ): @@ -49,13 +46,50 @@ def test_process( self.assertEqual(gen_classes, self.analyzer.process(classes)) - mock_merge_redefined_classes.assert_called_once_with(classes) - mock_create_substitutions_index.assert_called_once_with(classes) mock_create_class_index.assert_called_once_with(classes) - mock_update_abstract_classes.assert_called_once() + mock_handle_duplicates.assert_called_once() + mock_create_substitutions_index.assert_called_once_with() mock_flatten_classes.assert_called_once_with() mock_fetch_classes_for_generation.assert_called_once_with() + @mock.patch.object(ClassAnalyzer, "update_abstract_classes") + @mock.patch.object(ClassAnalyzer, "merge_redefined_classes") + @mock.patch.object(ClassAnalyzer, "remove_invalid_classes") + def test_handle_duplicate_classes( + self, + mock_remove_invalid_classes, + mock_merge_redefined_classes, + mock_update_abstract_classes, + ): + first = ClassFactory.create() + second = first.clone() + third = ClassFactory.create() + + self.analyzer.create_class_index([first, second, third]) + self.analyzer.handle_duplicate_classes() + + mock_remove_invalid_classes.assert_called_once_with([first, second]) + mock_merge_redefined_classes.assert_called_once_with([first, second]) + mock_update_abstract_classes.assert_called_once_with([first, second]) + + def test_remove_invalid_classes(self): + first = ClassFactory.create( + extensions=[ + ExtensionFactory.create(type=AttrTypeFactory.xs_bool()), + ExtensionFactory.create(type=AttrTypeFactory.create(name="foo")), + ] + ) + second = ClassFactory.create( + extensions=[ExtensionFactory.create(type=AttrTypeFactory.xs_bool()),] + ) + third = ClassFactory.create() + + self.analyzer.create_class_index([first, second, third]) + + classes = [first, second, third] + self.analyzer.remove_invalid_classes(classes) + self.assertEqual([second, third], classes) + def test_create_class_index(self): classes = [ ClassFactory.create(type=Element, name="foo"), @@ -83,7 +117,8 @@ def test_create_substitutions_index(self, mock_create_reference_attribute): reference_attrs = AttrFactory.list(3) mock_create_reference_attribute.side_effect = reference_attrs - self.analyzer.create_substitutions_index(classes) + self.analyzer.create_class_index(classes) + self.analyzer.create_substitutions_index() expected = { QName(namespace, "foo"): [reference_attrs[0], reference_attrs[2]], @@ -101,27 +136,6 @@ def test_create_substitutions_index(self, mock_create_reference_attribute): ] ) - def test_mark_abstract_duplicate_classes(self): - one = ClassFactory.create(name="foo", abstract=True, type=Element) - two = ClassFactory.create(name="foo", type=Element) - three = ClassFactory.create(name="foo", type=ComplexType) - four = ClassFactory.create(name="foo", type=SimpleType) - - five = ClassFactory.create(name="bar", abstract=True, type=Element) - six = ClassFactory.create(name="bar", type=ComplexType) - seven = ClassFactory.create(name="opa", type=ComplexType) - - self.analyzer.create_class_index([one, two, three, four, five, six, seven]) - self.analyzer.update_abstract_classes() - - self.assertTrue(one.abstract) # Was abstract already - self.assertFalse(two.abstract) # Is an element - self.assertTrue(three.abstract) # Marked as abstract - self.assertFalse(four.abstract) # Is common - self.assertTrue(five.abstract) # Was abstract already - self.assertFalse(six.abstract) # No element in group - self.assertFalse(seven.abstract) # Alone - @mock.patch.object(ClassAnalyzer, "sanitize_attributes") def test_fetch_classes_for_generation(self, mock_sanitize_attributes): classes = [ @@ -690,53 +704,6 @@ def test_add_substitution_attrs(self, mock_find_attribute): self.analyzer.add_substitution_attrs(target, AttrFactory.enumeration()) self.assertEqual(4, len(target.attrs)) - def test_merge_redefined_classes_with_unique_classes(self): - classes = ClassFactory.list(2) - self.analyzer.merge_redefined_classes(classes) - self.assertEqual(2, len(classes)) - - @mock.patch.object(ClassAnalyzer, "copy_attributes") - def test_merge_redefined_classes_copies_attributes(self, mock_copy_attributes): - class_a = ClassFactory.create() - class_b = ClassFactory.create() - class_c = class_a.clone() - - ext_a = ExtensionFactory.create(type=AttrTypeFactory.create(name=class_a.name)) - ext_str = ExtensionFactory.create(type=AttrTypeFactory.create(name="foo")) - class_c.extensions.append(ext_a) - class_c.extensions.append(ext_str) - classes = [class_a, class_b, class_c] - - self.analyzer.merge_redefined_classes(classes) - self.assertEqual(2, len(classes)) - - mock_copy_attributes.assert_called_once_with(class_a, class_c, ext_a) - - def test_merge_redefined_classes_copies_extensions(self): - class_a = ClassFactory.create() - class_c = class_a.clone() - - type_int = AttrTypeFactory.xs_int() - - ext_a = ExtensionFactory.create( - type=type_int, - restrictions=Restrictions(max_inclusive=10, min_inclusive=1, required=True), - ) - ext_c = ExtensionFactory.create( - type=AttrTypeFactory.create(name=class_a.name), - restrictions=Restrictions(max_inclusive=0, min_inclusive=-10), - ) - - class_a.extensions.append(ext_a) - class_c.extensions.append(ext_c) - classes = [class_a, class_c] - expected = {"max_inclusive": 0, "min_inclusive": -10, "required": True} - - self.analyzer.merge_redefined_classes(classes) - self.assertEqual(1, len(classes)) - self.assertEqual(1, len(classes[0].extensions)) - self.assertEqual(expected, classes[0].extensions[0].restrictions.asdict()) - @mock.patch.object(ClassAnalyzer, "find_class") @mock.patch.object(Class, "dependencies") def test_class_depends_on(self, mock_dependencies, mock_find_class): diff --git a/tests/utils/test_classes.py b/tests/utils/test_classes.py index 7b667a846..c27040311 100644 --- a/tests/utils/test_classes.py +++ b/tests/utils/test_classes.py @@ -11,6 +11,9 @@ from tests.factories import RestrictionsFactory from xsdata.models.codegen import AttrType from xsdata.models.codegen import Restrictions +from xsdata.models.elements import ComplexType +from xsdata.models.elements import Element +from xsdata.models.elements import SimpleType from xsdata.models.enums import DataType from xsdata.models.enums import Tag from xsdata.utils.classes import ClassUtils @@ -386,3 +389,63 @@ def test_merge_attribute_type_when_source_attrs_is_not_one( source.attrs = AttrFactory.list(2) ClassUtils.merge_attribute_type(source, target, attr, attr.types[0]) self.assertEqual("string", attr.types[0].name) + + def test_merge_redefined_classes_with_unique_classes(self): + classes = ClassFactory.list(2) + ClassUtils.merge_redefined_classes(classes) + self.assertEqual(2, len(classes)) + + @mock.patch.object(ClassUtils, "copy_attributes") + def test_merge_redefined_classes_copies_attributes(self, mock_copy_attributes): + class_a = ClassFactory.create() + class_b = ClassFactory.create() + class_c = class_a.clone() + + ext_a = ExtensionFactory.create(type=AttrTypeFactory.create(name=class_a.name)) + ext_str = ExtensionFactory.create(type=AttrTypeFactory.create(name="foo")) + class_c.extensions.append(ext_a) + class_c.extensions.append(ext_str) + classes = [class_a, class_b, class_c] + + ClassUtils.merge_redefined_classes(classes) + self.assertEqual(2, len(classes)) + + mock_copy_attributes.assert_called_once_with(class_a, class_c, ext_a) + + def test_merge_redefined_classes_copies_extensions(self): + class_a = ClassFactory.create() + class_c = class_a.clone() + + type_int = AttrTypeFactory.xs_int() + + ext_a = ExtensionFactory.create( + type=type_int, + restrictions=Restrictions(max_inclusive=10, min_inclusive=1, required=True), + ) + ext_c = ExtensionFactory.create( + type=AttrTypeFactory.create(name=class_a.name), + restrictions=Restrictions(max_inclusive=0, min_inclusive=-10), + ) + + class_a.extensions.append(ext_a) + class_c.extensions.append(ext_c) + classes = [class_a, class_c] + expected = {"max_inclusive": 0, "min_inclusive": -10, "required": True} + + ClassUtils.merge_redefined_classes(classes) + self.assertEqual(1, len(classes)) + self.assertEqual(1, len(classes[0].extensions)) + self.assertEqual(expected, classes[0].extensions[0].restrictions.asdict()) + + def test_mark_abstract_duplicate_classes(self): + one = ClassFactory.create(name="foo", abstract=True, type=Element) + two = ClassFactory.create(name="foo", type=Element) + three = ClassFactory.create(name="foo", type=ComplexType) + four = ClassFactory.create(name="foo", type=SimpleType) + + ClassUtils.update_abstract_classes([one, two, three, four]) + + self.assertTrue(one.abstract) # Was abstract already + self.assertFalse(two.abstract) # Is an element + self.assertTrue(three.abstract) # Marked as abstract + self.assertFalse(four.abstract) # Is common diff --git a/xsdata/analyzer.py b/xsdata/analyzer.py index 404ab8c88..964e7224f 100644 --- a/xsdata/analyzer.py +++ b/xsdata/analyzer.py @@ -5,7 +5,8 @@ from typing import Dict from typing import List from typing import Optional -from xml.etree.ElementTree import QName + +from lxml.etree import QName from xsdata.exceptions import AnalyzerError from xsdata.logger import logger @@ -41,27 +42,52 @@ def process(self, classes: List[Class]) -> List[Class]: Process class list in steps. Steps: - * Merge redefined classes - * Create a class index - * Create a substitution index - * Mark as abstract classes with the same qname - * Flatten classes + * Create a class index. + * Handle duplicate types. + * Create a substitution index. + * Flatten classes. * Return a final class list for code generators. """ - self.merge_redefined_classes(classes) - self.create_class_index(classes) - self.create_substitutions_index(classes) + self.handle_duplicate_classes() - self.update_abstract_classes() + self.create_substitutions_index() self.flatten_classes() - gen_classes = self.fetch_classes_for_generation() + return self.fetch_classes_for_generation() + + def handle_duplicate_classes(self): + """ + Remove if possible classes with the same qualified name. + + Steps: + 1. Remove classes with missing extension type. + 2. Merge redefined classes. + 3. Fix implied abstract flags. + """ + for classes in self.class_index.values(): + + if len(classes) > 1: + self.remove_invalid_classes(classes) - return gen_classes + if len(classes) > 1: + self.merge_redefined_classes(classes) + + if len(classes) > 1: + self.update_abstract_classes(classes) + + def remove_invalid_classes(self, classes: List[Class]): + """Remove from the given class list any class with missing extension + type.""" + for target in list(classes): + if any( + self.attr_type_is_missing(target, extension.type) + for extension in target.extensions + ): + classes.remove(target) def fetch_classes_for_generation(self) -> List[Class]: """ @@ -90,20 +116,30 @@ def create_class_index(self, classes: List[Class]): for item in classes: self.class_index[item.source_qname()].append(item) - def create_substitutions_index(self, classes: List[Class]): + def create_substitutions_index(self): """Create reference attributes for all the classes substitutions and group them by their fully qualified name.""" - for item in classes: - for substitution in item.substitutions: - item.abstract = False - qname = item.source_qname(substitution) - attr = self.create_reference_attribute(item, qname) - self.substitutions_index[qname].append(attr) + + for classes in self.class_index.values(): + for item in classes: + for substitution in item.substitutions: + item.abstract = False + qname = item.source_qname(substitution) + attr = self.create_reference_attribute(item, qname) + self.substitutions_index[qname].append(attr) def find_attr_type(self, source: Class, attr_type: AttrType) -> Optional[Class]: qname = source.source_qname(attr_type.name) return self.find_class(qname) + def attr_type_is_missing(self, source: Class, attr_type: AttrType) -> bool: + """Check if given type declaration is not native and is missing.""" + if attr_type.native: + return False + + qname = source.source_qname(attr_type.name) + return qname not in self.class_index + def find_attr_simple_type( self, source: Class, attr_type: AttrType ) -> Optional[Class]: @@ -135,53 +171,6 @@ def find_class( return None - def merge_redefined_classes(self, classes: List[Class]): - """Merge original and redefined classes.""" - grouped: Dict[str, List[Class]] = defaultdict(list) - for item in classes: - grouped[f"{item.type.__name__}{item.source_qname()}"].append(item) - - for items in grouped.values(): - if len(items) == 1: - continue - - winner: Class = items.pop() - for item in items: - classes.remove(item) - - self_extension = next( - ( - ext - for ext in winner.extensions - if text.suffix(ext.type.name) == winner.name - ), - None, - ) - - if not self_extension: - continue - - self.copy_attributes(item, winner, self_extension) - for looser_ext in item.extensions: - new_ext = looser_ext.clone() - new_ext.restrictions.merge(self_extension.restrictions) - winner.extensions.append(new_ext) - - def update_abstract_classes(self): - """Explicitly update classes to mark them as abstract when it's - implied.""" - for classes in self.class_index.values(): - if len(classes) == 1: - continue - - element = next((obj for obj in classes if obj.is_element), None) - if not element: - continue - - for obj in classes: - if obj is not element and obj.is_complex: - obj.abstract = True - def flatten_classes(self): for classes in self.class_index.values(): for obj in classes: diff --git a/xsdata/utils/classes.py b/xsdata/utils/classes.py index 0d51095e3..4c25217a8 100644 --- a/xsdata/utils/classes.py +++ b/xsdata/utils/classes.py @@ -1,4 +1,6 @@ import sys +from collections import defaultdict +from typing import Dict from typing import List from typing import Optional @@ -211,6 +213,55 @@ def copy_inner_classes(cls, source: Class, target: Class): if not any(existing.name == inner.name for existing in target.inner): target.inner.append(inner) + @classmethod + def merge_redefined_classes(cls, classes: List[Class]): + """Merge original and redefined classes.""" + grouped: Dict[str, List[Class]] = defaultdict(list) + for item in classes: + grouped[f"{item.type.__name__}{item.source_qname()}"].append(item) + + for items in grouped.values(): + if len(items) == 1: + continue + + winner: Class = items.pop() + for item in items: + classes.remove(item) + + self_extension = next( + ( + ext + for ext in winner.extensions + if text.suffix(ext.type.name) == winner.name + ), + None, + ) + + if not self_extension: + continue + + cls.copy_attributes(item, winner, self_extension) + for looser_ext in item.extensions: + new_ext = looser_ext.clone() + new_ext.restrictions.merge(self_extension.restrictions) + winner.extensions.append(new_ext) + + @classmethod + def update_abstract_classes(cls, classes: List[Class]): + """ + Update classes with the same qualified name to set implied abstract + flags. + + If a non abstract xs:element exists in the list mark the rest + xs:complexType(s) as abstract. + """ + + element = next((obj for obj in classes if obj.is_element), None) + if element: + for obj in classes: + if obj is not element and obj.is_complex: + obj.abstract = True + @classmethod def create_mixed_attribute(cls, target: Class): if not target.mixed or target.has_wild_attr: