From 737276998c92169f2be8dc6dfd75ebebbb1a35d7 Mon Sep 17 00:00:00 2001 From: Chris Tsou Date: Sun, 23 Jul 2023 09:42:56 +0300 Subject: [PATCH] Use definition elements location for generated classes (#832) Resolves #823 --- tests/codegen/mappers/test_definitions.py | 28 +++++++++++++---------- tests/codegen/parsers/test_definitions.py | 25 ++++++++------------ xsdata/codegen/mappers/definitions.py | 12 +++++----- xsdata/codegen/parsers/definitions.py | 27 ++++++++++++++++++---- xsdata/models/wsdl.py | 4 +++- 5 files changed, 57 insertions(+), 39 deletions(-) diff --git a/tests/codegen/mappers/test_definitions.py b/tests/codegen/mappers/test_definitions.py index 16488a32b..f6c49ae98 100644 --- a/tests/codegen/mappers/test_definitions.py +++ b/tests/codegen/mappers/test_definitions.py @@ -158,8 +158,8 @@ def test_map_binding( def test_map_binding_operation( self, mock_operation_namespace, mock_map_binding_operation_messages ): - definitions = Definitions(location="foo.wsdl", target_namespace="xsdata") - operation = BindingOperation(name="Add") + definitions = Definitions(target_namespace="xsdata") + operation = BindingOperation(name="Add", location="foo.wsdl") operation.ns_map["foo"] = "bar" port_operation = PortTypeOperation() config = {"a": "one", "b": "two", "style": "rpc"} @@ -333,14 +333,17 @@ def test_build_envelope_class( name = "some_operation_bindings" style = "document" namespace = "xsdata" - definitions = Definitions(location="foo.wsdl", target_namespace="bar") - port_type_message = PortTypeMessage(message="some_operation") + definitions = Definitions(target_namespace="bar") + port_type_message = PortTypeMessage( + message="some_operation", location="foo.wsdl" + ) binding_message = BindingMessage( extended=[ AnyElement(qname="body"), AnyElement(qname="header"), AnyElement(qname="header"), - ] + ], + location="foo.wsdl", ) binding_message.ns_map["foo"] = "bar" @@ -415,14 +418,17 @@ def test_build_envelope_class_with_style_rpc( name = "some_operation_bindings" style = "rpc" namespace = "xsdata" - definitions = Definitions(location="foo.wsdl", target_namespace="bar") - port_type_message = PortTypeMessage(message="some_operation") + definitions = Definitions(target_namespace="bar") + port_type_message = PortTypeMessage( + message="some_operation", location="foo.wsdl" + ) binding_message = BindingMessage( extended=[ AnyElement(qname="body", attributes={"namespace": "bodyns"}), AnyElement(qname="header"), AnyElement(qname="header"), - ] + ], + location="foo.wsdl", ) binding_message.ns_map["foo"] = "bar" @@ -705,10 +711,8 @@ def test_build_parts_attributes(self, mock_warning): def test_build_message_class(self, mock_create_message_attributes): message = Message(name="bar", parts=[Part()]) message.ns_map["foo"] = "bar" - definitions = Definitions( - messages=[message], target_namespace="xsdata", location="foo.wsdl" - ) - port_type_message = PortTypeMessage(message="foo:bar") + definitions = Definitions(messages=[message], target_namespace="xsdata") + port_type_message = PortTypeMessage(message="foo:bar", location="foo.wsdl") attrs = AttrFactory.list(2) mock_create_message_attributes.return_value = attrs diff --git a/tests/codegen/parsers/test_definitions.py b/tests/codegen/parsers/test_definitions.py index 292ad3883..6b071043c 100644 --- a/tests/codegen/parsers/test_definitions.py +++ b/tests/codegen/parsers/test_definitions.py @@ -13,7 +13,7 @@ def setUp(self): def test_complete(self): path = fixtures_dir.joinpath("calculator/services.wsdl").resolve() - parser = DefinitionsParser() + parser = DefinitionsParser(location="here.wsdl") definitions = parser.from_path(path, Definitions) self.assertIsInstance(definitions, Definitions) @@ -22,20 +22,15 @@ def test_complete(self): self.assertEqual(1, len(definitions.port_types)) self.assertEqual(1, len(definitions.types.schemas)) self.assertEqual(8, len(definitions.messages)) + self.assertEqual(parser.location, definitions.bindings[0].location) - def test_end_definitions(self): - parser = DefinitionsParser() - definitions = Definitions( - imports=[Import(location="../foo.xsd"), Import(location="bar.xsd")] - ) + def test_end_import(self): + parser = DefinitionsParser(location="foo/bar.wsdl") + imp = Import(location="../hello/foo.wsdl") - parser.end_definitions(definitions) - self.assertEqual("bar.xsd", definitions.imports[1].location) + parser.end_import(imp) + self.assertEqual("hello/foo.wsdl", imp.location) - parser.location = "file://a/b/services/parent.wsdl" - parser.end_definitions(definitions) - self.assertEqual("file://a/b/foo.xsd", definitions.imports[0].location) - self.assertEqual("file://a/b/services/bar.xsd", definitions.imports[1].location) - - # Update only Definitions instances - parser.end_definitions("foo") + parser.location = None + parser.end_import(imp) + self.assertEqual("hello/foo.wsdl", imp.location) diff --git a/xsdata/codegen/mappers/definitions.py b/xsdata/codegen/mappers/definitions.py index 4eec66308..be823d443 100644 --- a/xsdata/codegen/mappers/definitions.py +++ b/xsdata/codegen/mappers/definitions.py @@ -110,13 +110,13 @@ def map_binding_operation( message_type = message_class.name.split("_")[-1] attrs.append(cls.build_attr(message_type, message_class.qname)) - assert definitions.location is not None + assert binding_operation.location is not None yield Class( qname=namespaces.build_qname(definitions.target_namespace, name), status=Status.FLATTENED, tag=type(binding_operation).__name__, - location=definitions.location, + location=binding_operation.location, ns_map=binding_operation.ns_map, attrs=attrs, ) @@ -210,13 +210,13 @@ def build_envelope_class( """Step 6.1: Build Envelope class for the given binding message with attributes from the port type message.""" - assert definitions.location is not None + assert binding_message.location is not None target = Class( qname=namespaces.build_qname(definitions.target_namespace, name), meta_name="Envelope", tag=Tag.BINDING_MESSAGE, - location=definitions.location, + location=binding_message.location, ns_map=binding_message.ns_map, namespace=namespace, ) @@ -253,14 +253,14 @@ def build_message_class( ns_map = definition_message.ns_map.copy() source_namespace = ns_map.get(prefix) - assert definitions.location is not None + assert port_type_message.location is not None return Class( qname=namespaces.build_qname(source_namespace, name), namespace=source_namespace, status=Status.RAW, tag=Tag.ELEMENT, - location=definitions.location, + location=port_type_message.location, ns_map=ns_map, attrs=list(cls.build_parts_attributes(definition_message.parts, ns_map)), ) diff --git a/xsdata/codegen/parsers/definitions.py b/xsdata/codegen/parsers/definitions.py index 5c286f623..2e66a2205 100644 --- a/xsdata/codegen/parsers/definitions.py +++ b/xsdata/codegen/parsers/definitions.py @@ -1,7 +1,12 @@ from dataclasses import dataclass +from typing import Any +from typing import List +from typing import Optional from xsdata.codegen.parsers.schema import SchemaParser from xsdata.formats.bindings import T +from xsdata.formats.dataclass.parsers.bases import Parsed +from xsdata.formats.dataclass.parsers.mixins import XmlNode from xsdata.models import wsdl @@ -10,9 +15,21 @@ class DefinitionsParser(SchemaParser): """A simple parser to convert a wsdl to an easy to handle data structure based on dataclasses.""" - def end_definitions(self, obj: T): - """Normalize various properties for the schema and it's children.""" - if isinstance(obj, wsdl.Definitions) and self.location: + def end( + self, + queue: List[XmlNode], + objects: List[Parsed], + qname: str, + text: Optional[str], + tail: Optional[str], + ) -> Any: + """Override parent method to set element location.""" + obj = super().end(queue, objects, qname, text, tail) + if isinstance(obj, wsdl.WsdlElement): obj.location = self.location - for imp in obj.imports: - imp.location = self.resolve_path(imp.location) + + return obj + + def end_import(self, obj: T): + if isinstance(obj, wsdl.Import) and self.location: + obj.location = self.resolve_path(obj.location) diff --git a/xsdata/models/wsdl.py b/xsdata/models/wsdl.py index 023c583ed..3df87a6e0 100644 --- a/xsdata/models/wsdl.py +++ b/xsdata/models/wsdl.py @@ -32,10 +32,13 @@ class WsdlElement: """ :param name: :param documentation: + :param location: + :param ns_map """ name: str = attribute() documentation: Optional[Documentation] = element() + location: Optional[str] = field(default=None, metadata={"type": "Ignore"}) ns_map: Dict[str, str] = field( default_factory=dict, init=False, metadata={"type": "Ignore"} ) @@ -207,7 +210,6 @@ class Meta: port_types: List[PortType] = array_element(name="portType") bindings: List[Binding] = array_element(name="binding") services: List[Service] = array_element(name="service") - location: Optional[str] = field(default=None) @property def schemas(self):