Skip to content

Commit

Permalink
Use definition elements location for generated classes (#832)
Browse files Browse the repository at this point in the history
Resolves #823
  • Loading branch information
tefra authored Jul 23, 2023
1 parent ed6927a commit 7372769
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 39 deletions.
28 changes: 16 additions & 12 deletions tests/codegen/mappers/test_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def test_map_binding(
def test_map_binding_operation(
self, mock_operation_namespace, mock_map_binding_operation_messages
):
definitions = Definitions(location="foo.wsdl", target_namespace="xsdata")
operation = BindingOperation(name="Add")
definitions = Definitions(target_namespace="xsdata")
operation = BindingOperation(name="Add", location="foo.wsdl")
operation.ns_map["foo"] = "bar"
port_operation = PortTypeOperation()
config = {"a": "one", "b": "two", "style": "rpc"}
Expand Down Expand Up @@ -333,14 +333,17 @@ def test_build_envelope_class(
name = "some_operation_bindings"
style = "document"
namespace = "xsdata"
definitions = Definitions(location="foo.wsdl", target_namespace="bar")
port_type_message = PortTypeMessage(message="some_operation")
definitions = Definitions(target_namespace="bar")
port_type_message = PortTypeMessage(
message="some_operation", location="foo.wsdl"
)
binding_message = BindingMessage(
extended=[
AnyElement(qname="body"),
AnyElement(qname="header"),
AnyElement(qname="header"),
]
],
location="foo.wsdl",
)
binding_message.ns_map["foo"] = "bar"

Expand Down Expand Up @@ -415,14 +418,17 @@ def test_build_envelope_class_with_style_rpc(
name = "some_operation_bindings"
style = "rpc"
namespace = "xsdata"
definitions = Definitions(location="foo.wsdl", target_namespace="bar")
port_type_message = PortTypeMessage(message="some_operation")
definitions = Definitions(target_namespace="bar")
port_type_message = PortTypeMessage(
message="some_operation", location="foo.wsdl"
)
binding_message = BindingMessage(
extended=[
AnyElement(qname="body", attributes={"namespace": "bodyns"}),
AnyElement(qname="header"),
AnyElement(qname="header"),
]
],
location="foo.wsdl",
)
binding_message.ns_map["foo"] = "bar"

Expand Down Expand Up @@ -705,10 +711,8 @@ def test_build_parts_attributes(self, mock_warning):
def test_build_message_class(self, mock_create_message_attributes):
message = Message(name="bar", parts=[Part()])
message.ns_map["foo"] = "bar"
definitions = Definitions(
messages=[message], target_namespace="xsdata", location="foo.wsdl"
)
port_type_message = PortTypeMessage(message="foo:bar")
definitions = Definitions(messages=[message], target_namespace="xsdata")
port_type_message = PortTypeMessage(message="foo:bar", location="foo.wsdl")

attrs = AttrFactory.list(2)
mock_create_message_attributes.return_value = attrs
Expand Down
25 changes: 10 additions & 15 deletions tests/codegen/parsers/test_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def setUp(self):

def test_complete(self):
path = fixtures_dir.joinpath("calculator/services.wsdl").resolve()
parser = DefinitionsParser()
parser = DefinitionsParser(location="here.wsdl")
definitions = parser.from_path(path, Definitions)

self.assertIsInstance(definitions, Definitions)
Expand All @@ -22,20 +22,15 @@ def test_complete(self):
self.assertEqual(1, len(definitions.port_types))
self.assertEqual(1, len(definitions.types.schemas))
self.assertEqual(8, len(definitions.messages))
self.assertEqual(parser.location, definitions.bindings[0].location)

def test_end_definitions(self):
parser = DefinitionsParser()
definitions = Definitions(
imports=[Import(location="../foo.xsd"), Import(location="bar.xsd")]
)
def test_end_import(self):
parser = DefinitionsParser(location="foo/bar.wsdl")
imp = Import(location="../hello/foo.wsdl")

parser.end_definitions(definitions)
self.assertEqual("bar.xsd", definitions.imports[1].location)
parser.end_import(imp)
self.assertEqual("hello/foo.wsdl", imp.location)

parser.location = "file://a/b/services/parent.wsdl"
parser.end_definitions(definitions)
self.assertEqual("file://a/b/foo.xsd", definitions.imports[0].location)
self.assertEqual("file://a/b/services/bar.xsd", definitions.imports[1].location)

# Update only Definitions instances
parser.end_definitions("foo")
parser.location = None
parser.end_import(imp)
self.assertEqual("hello/foo.wsdl", imp.location)
12 changes: 6 additions & 6 deletions xsdata/codegen/mappers/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,13 @@ def map_binding_operation(
message_type = message_class.name.split("_")[-1]
attrs.append(cls.build_attr(message_type, message_class.qname))

assert definitions.location is not None
assert binding_operation.location is not None

yield Class(
qname=namespaces.build_qname(definitions.target_namespace, name),
status=Status.FLATTENED,
tag=type(binding_operation).__name__,
location=definitions.location,
location=binding_operation.location,
ns_map=binding_operation.ns_map,
attrs=attrs,
)
Expand Down Expand Up @@ -210,13 +210,13 @@ def build_envelope_class(
"""Step 6.1: Build Envelope class for the given binding message with
attributes from the port type message."""

assert definitions.location is not None
assert binding_message.location is not None

target = Class(
qname=namespaces.build_qname(definitions.target_namespace, name),
meta_name="Envelope",
tag=Tag.BINDING_MESSAGE,
location=definitions.location,
location=binding_message.location,
ns_map=binding_message.ns_map,
namespace=namespace,
)
Expand Down Expand Up @@ -253,14 +253,14 @@ def build_message_class(
ns_map = definition_message.ns_map.copy()
source_namespace = ns_map.get(prefix)

assert definitions.location is not None
assert port_type_message.location is not None

return Class(
qname=namespaces.build_qname(source_namespace, name),
namespace=source_namespace,
status=Status.RAW,
tag=Tag.ELEMENT,
location=definitions.location,
location=port_type_message.location,
ns_map=ns_map,
attrs=list(cls.build_parts_attributes(definition_message.parts, ns_map)),
)
Expand Down
27 changes: 22 additions & 5 deletions xsdata/codegen/parsers/definitions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from dataclasses import dataclass
from typing import Any
from typing import List
from typing import Optional

from xsdata.codegen.parsers.schema import SchemaParser
from xsdata.formats.bindings import T
from xsdata.formats.dataclass.parsers.bases import Parsed
from xsdata.formats.dataclass.parsers.mixins import XmlNode
from xsdata.models import wsdl


Expand All @@ -10,9 +15,21 @@ class DefinitionsParser(SchemaParser):
"""A simple parser to convert a wsdl to an easy to handle data structure
based on dataclasses."""

def end_definitions(self, obj: T):
"""Normalize various properties for the schema and it's children."""
if isinstance(obj, wsdl.Definitions) and self.location:
def end(
self,
queue: List[XmlNode],
objects: List[Parsed],
qname: str,
text: Optional[str],
tail: Optional[str],
) -> Any:
"""Override parent method to set element location."""
obj = super().end(queue, objects, qname, text, tail)
if isinstance(obj, wsdl.WsdlElement):
obj.location = self.location
for imp in obj.imports:
imp.location = self.resolve_path(imp.location)

return obj

def end_import(self, obj: T):
if isinstance(obj, wsdl.Import) and self.location:
obj.location = self.resolve_path(obj.location)
4 changes: 3 additions & 1 deletion xsdata/models/wsdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@ class WsdlElement:
"""
:param name:
:param documentation:
:param location:
:param ns_map
"""

name: str = attribute()
documentation: Optional[Documentation] = element()
location: Optional[str] = field(default=None, metadata={"type": "Ignore"})
ns_map: Dict[str, str] = field(
default_factory=dict, init=False, metadata={"type": "Ignore"}
)
Expand Down Expand Up @@ -207,7 +210,6 @@ class Meta:
port_types: List[PortType] = array_element(name="portType")
bindings: List[Binding] = array_element(name="binding")
services: List[Service] = array_element(name="service")
location: Optional[str] = field(default=None)

@property
def schemas(self):
Expand Down

0 comments on commit 7372769

Please sign in to comment.