Skip to content

Commit

Permalink
Fix optional/required override validation (#820)
Browse files Browse the repository at this point in the history
  • Loading branch information
tefra authored Jul 16, 2023
2 parents fcb89c2 + 8f7e789 commit 68d5071
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 4 deletions.
8 changes: 8 additions & 0 deletions tests/codegen/handlers/test_calculate_attribute_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ def test_process(self):
min_occurs=1,
max_occurs=1,
path=[("s", 1, 1, 1), ("c", 5, 2, 2)],
),
),
AttrFactory.element(
restrictions=Restrictions(
min_occurs=1,
max_occurs=1,
path=[("s", 1, 1, 1), ("c", 6, 1, 21)],
)
),
]
Expand Down Expand Up @@ -99,5 +106,6 @@ def test_process(self):
(1, None, 0, 1),
(1, None, 2, 2),
(1, None, 2, 2),
(1, None, 1, 21),
]
self.assertEqual(expected, actual)
47 changes: 46 additions & 1 deletion tests/codegen/handlers/test_create_compound_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,31 @@ def test_process(self, mock_group_fields):
]
)

def test_process_with_config_enabled_false_calculate_min_occurs(self):
self.processor.config.enabled = False
target = ClassFactory.elements(5)
target.attrs[0].restrictions.choice = 1
target.attrs[1].restrictions.choice = 1
target.attrs[2].restrictions.choice = 2
target.attrs[3].restrictions.choice = 2

for attr in target.attrs:
attr.restrictions.min_occurs = 2
attr.restrictions.max_occurs = 3

target.attrs[0].restrictions.path = [("g", 0, 1, 1), ("c", 1, 2, 1)]
target.attrs[1].restrictions.path = [("g", 0, 1, 1), ("c", 1, 2, 1)]
target.attrs[2].restrictions.path = [("g", 0, 1, 1), ("c", 2, 1, 1)]
target.attrs[3].restrictions.path = [("g", 0, 1, 1), ("c", 2, 1, 1)]
self.processor.process(target)

actual = [
(attr.restrictions.min_occurs, attr.restrictions.max_occurs)
for attr in target.attrs
]
expected = [(2, 3), (2, 3), (0, 3), (0, 3), (2, 3)]
self.assertEqual(expected, actual)

def test_group_fields(self):
target = ClassFactory.create(attrs=AttrFactory.list(4))
target.attrs[0].restrictions.choice = 1
Expand Down Expand Up @@ -91,7 +116,7 @@ def test_group_fields(self):
),
],
)
expected_res = Restrictions(min_occurs=1, max_occurs=20)
expected_res = Restrictions(min_occurs=0, max_occurs=20)

self.processor.group_fields(target, list(target.attrs))
self.assertEqual(1, len(target.attrs))
Expand Down Expand Up @@ -241,3 +266,23 @@ def test_sum_counters(self):

result = self.processor.sum_counters(counters)
self.assertEqual((0, 3), (sum(result[0]), sum(result[1])))

def test_update_counters(self):
attr = AttrFactory.create()
attr.restrictions.min_occurs = 2
attr.restrictions.max_occurs = 3
attr.restrictions.path = [("c", 0, 1, 1)]

counters = {}
self.processor.update_counters(attr, counters)

expected = {("c", 0, 1, 1): {"max": [3], "min": [0]}}
self.assertEqual(expected, counters)

attr.restrictions.min_occurs = 2
attr.restrictions.path = [("c", 0, 2, 1)]

counters = {}
self.processor.update_counters(attr, counters)
expected = {("c", 0, 2, 1): {"max": [3], "min": [2]}}
self.assertEqual(expected, counters)
7 changes: 7 additions & 0 deletions tests/codegen/handlers/test_validate_attributes_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,19 @@ def test_validate_override(self):
attr_b.fixed = attr_a.fixed
attr_a.restrictions.tokens = not attr_b.restrictions.tokens
attr_a.restrictions.nillable = not attr_b.restrictions.nillable
attr_a.restrictions.min_occurs = 0
attr_b.restrictions.min_occurs = 1
attr_a.restrictions.max_occurs = 0
attr_b.restrictions.max_occurs = 1

self.processor.validate_override(target, attr_a, attr_b)
self.assertEqual(1, len(target.attrs))

# Restrictions are compatible again
attr_a.restrictions.tokens = attr_b.restrictions.tokens
attr_a.restrictions.nillable = attr_b.restrictions.nillable
attr_a.restrictions.min_occurs = attr_b.restrictions.min_occurs = 1
attr_a.restrictions.max_occurs = attr_b.restrictions.max_occurs = 1
self.processor.validate_override(target, attr_a, attr_b)
self.assertEqual(0, len(target.attrs))

Expand Down
2 changes: 0 additions & 2 deletions xsdata/codegen/handlers/calculate_attribute_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ def process_attr_path(cls, attr: Attr):
if not attr.restrictions.sequence:
attr.restrictions.sequence = index
elif name == CHOICE:
if mi <= 1:
mi = 0
if not attr.restrictions.choice:
attr.restrictions.choice = index
elif name == GROUP:
Expand Down
22 changes: 21 additions & 1 deletion xsdata/codegen/handlers/create_compound_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
from xsdata.models.enums import Tag
from xsdata.utils.collections import group_by

ALL = "a"
GROUP = "g"
SEQUENCE = "s"
CHOICE = "c"


class CreateCompoundFields(RelativeHandlerInterface):
"""Group attributes that belong in the same choice and replace them by
Expand All @@ -35,20 +40,35 @@ def process(self, target: Class):
for choice, attrs in groups.items():
if choice and len(attrs) > 1:
self.group_fields(target, attrs)
else:
for attr in target.attrs:
if attr.restrictions.choice:
self.calculate_choice_min_occurs(attr)

@classmethod
def calculate_choice_min_occurs(cls, attr: Attr):
for path in attr.restrictions.path:
name, index, mi, ma = path
if name == CHOICE and mi <= 1:
attr.restrictions.min_occurs = 0

@classmethod
def update_counters(cls, attr: Attr, counters: Dict):
started = False
choice = attr.restrictions.choice
for path in attr.restrictions.path:
if not started and path[0] != "c" and path[1] != choice:
name, index, mi, ma = path
if not started and name != CHOICE and index != choice:
continue

started = True
if path not in counters:
counters[path] = {"min": [], "max": []}
counters = counters[path]

if mi <= 1:
attr.restrictions.min_occurs = 0

counters["min"].append(attr.restrictions.min_occurs)
counters["max"].append(attr.restrictions.max_occurs)

Expand Down
2 changes: 2 additions & 0 deletions xsdata/codegen/handlers/validate_attributes_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def validate_override(cls, target: Class, attr: Attr, source_attr: Attr):
and bool_eq(attr.mixed, source_attr.mixed)
and bool_eq(attr.restrictions.tokens, source_attr.restrictions.tokens)
and bool_eq(attr.restrictions.nillable, source_attr.restrictions.nillable)
and bool_eq(attr.is_prohibited, source_attr.is_prohibited)
and bool_eq(attr.is_optional, source_attr.is_optional)
):
cls.remove_attribute(target, attr)

Expand Down

0 comments on commit 68d5071

Please sign in to comment.