diff --git a/digid_eherkenning/models/eherkenning.py b/digid_eherkenning/models/eherkenning.py index 78f1ab1..a0b2101 100644 --- a/digid_eherkenning/models/eherkenning.py +++ b/digid_eherkenning/models/eherkenning.py @@ -203,6 +203,10 @@ def as_dict(self) -> EHerkenningConfig: "service_uuid": str(self.eh_service_uuid), "service_name": self.service_name, "attribute_consuming_service_index": self.eh_attribute_consuming_service_index, + # always mark EH as default and EIDAS as not the default. If we ever support + # more assertion consumer services than these two, then we need to expand on + # this logic/configuration. + "mark_default": True, "service_instance_uuid": str(self.eh_service_instance_uuid), "service_description": self.service_description, "service_description_url": self.service_description_url, diff --git a/digid_eherkenning/saml2/base.py b/digid_eherkenning/saml2/base.py index 15baa93..d9d5a41 100644 --- a/digid_eherkenning/saml2/base.py +++ b/digid_eherkenning/saml2/base.py @@ -84,6 +84,8 @@ class BaseSaml2Client: "custom_base_path": None, } + settings_cls = OneLogin_Saml2_Settings + def __init__(self, conf=None): self.authn_storage = AuthnRequestStorage( self.cache_key_prefix, self.cache_timeout @@ -203,7 +205,8 @@ def create_config(self, config_dict): """ Convert to the format expected by the OneLogin SAML2 library. """ - return OneLogin_Saml2_Settings(config_dict, **self.saml2_setting_kwargs) + cls = self.settings_cls + return cls(config_dict, **self.saml2_setting_kwargs) def create_config_dict(self, conf): """ diff --git a/digid_eherkenning/saml2/eherkenning.py b/digid_eherkenning/saml2/eherkenning.py index a64d0b5..ee8e30c 100644 --- a/digid_eherkenning/saml2/eherkenning.py +++ b/digid_eherkenning/saml2/eherkenning.py @@ -11,6 +11,7 @@ from cryptography.x509 import load_pem_x509_certificate from lxml.builder import ElementMaker from lxml.etree import Element, tostring +from onelogin.saml2.metadata import OneLogin_Saml2_Metadata from onelogin.saml2.settings import OneLogin_Saml2_Settings from ..models import EherkenningConfiguration @@ -406,7 +407,7 @@ def get_metadata_eherkenning_requested_attributes( conf: ServiceConfig, service_id: str ) -> list[dict]: # There needs to be a RequestedAttribute element where the name is the ServiceID - # https://afsprakenstelsel.etoegang.nl/display/as/DV+metadata+for+HM + # https://afsprakenstelsel.etoegang.nl/Startpagina/v3/dv-metadata-for-hm requested_attributes = [{"name": service_id, "isRequired": False}] for requested_attribute in conf.get("requested_attributes", []): if isinstance(requested_attribute, dict): @@ -447,15 +448,58 @@ def create_attribute_consuming_services(conf: EHerkenningConfig) -> list[dict]: "serviceDescription": service_description, "requestedAttributes": requested_attributes, "language": service.get("language", "nl"), + "mark_default": service.get("mark_default", False), } ) return attribute_consuming_services +class CustomOneLogin_Saml2_Metadata(OneLogin_Saml2_Metadata): + """ + Modify the generated metadata to comply with AfsprakenStelsel 1.24a + """ + + @staticmethod + def make_attribute_consuming_services(service_provider: dict): + """ + Add an attribute to the default AttributeConsumingService element. + + .. note:: the upstream master branch has refactored this interface, so once we + rebase on master (quite a task I think), we will have to deal with this too. + """ + result = super( + CustomOneLogin_Saml2_Metadata, CustomOneLogin_Saml2_Metadata + ).make_attribute_consuming_services(service_provider) + + attribute_consuming_services = service_provider["attributeConsumingServices"] + if len(attribute_consuming_services) > 1: + # find the ACS that's marked as default - there *must* be one otherwise we + # don't comply with AfsprakenStelsel 1.24a requirements + default_service_index = next( + acs["index"] + for acs in attribute_consuming_services + if acs["mark_default"] + ) + + # do string replacement, because we can't pass any options to the metadata + # generation to modify this behaviour :/ + needle = f'' + replacement = f'' + result = result.replace(needle, replacement, 1) + + return result + + +class CustomOneLogin_Saml2_Settings(OneLogin_Saml2_Settings): + metadata_class = CustomOneLogin_Saml2_Metadata + + class eHerkenningClient(BaseSaml2Client): cache_key_prefix = "eherkenning" cache_timeout = 60 * 60 # 1 hour + settings_cls = CustomOneLogin_Saml2_Settings + @property def conf(self) -> EHerkenningConfig: if not hasattr(self, "_conf"): diff --git a/digid_eherkenning/types.py b/digid_eherkenning/types.py index f55e1f4..7e9730b 100644 --- a/digid_eherkenning/types.py +++ b/digid_eherkenning/types.py @@ -20,6 +20,7 @@ class ServiceConfig(TypedDict): entity_concerned_types_allowed: list[dict] language: str classifiers: Optional[list[str]] + mark_default: bool class EHerkenningConfig(TypedDict): diff --git a/tests/test_eherkenning_metadata.py b/tests/test_eherkenning_metadata.py index 44742c2..f9c0242 100644 --- a/tests/test_eherkenning_metadata.py +++ b/tests/test_eherkenning_metadata.py @@ -235,42 +235,49 @@ def test_generate_metadata_all_options_specified(self): eh_attribute_consuming_service_node = attribute_consuming_service_nodes[0] eidas_attribute_consuming_service_node = attribute_consuming_service_nodes[1] - self.assertEqual( - "urn:etoegang:DV:00000000000000000011:services:9050", - eh_attribute_consuming_service_node.find( - ".//md:RequestedAttribute", namespaces=NAME_SPACES - ).attrib["Name"], - ) - self.assertEqual( - "Test Service Name", - eh_attribute_consuming_service_node.find( - ".//md:ServiceName", namespaces=NAME_SPACES - ).text, - ) - self.assertEqual( - "Test Service Description", - eh_attribute_consuming_service_node.find( - ".//md:ServiceDescription", namespaces=NAME_SPACES - ).text, - ) - self.assertEqual( - "urn:etoegang:DV:00000000000000000011:services:9051", - eidas_attribute_consuming_service_node.find( - ".//md:RequestedAttribute", namespaces=NAME_SPACES - ).attrib["Name"], - ) - self.assertEqual( - "Test Service Name (eIDAS)", - eidas_attribute_consuming_service_node.find( - ".//md:ServiceName", namespaces=NAME_SPACES - ).text, - ) - self.assertEqual( - "Test Service Description", - eidas_attribute_consuming_service_node.find( - ".//md:ServiceDescription", namespaces=NAME_SPACES - ).text, - ) + with self.subTest("eh attribute consuming service"): + self.assertEqual( + eh_attribute_consuming_service_node.attrib["isDefault"], + "true", + ) + self.assertEqual( + "urn:etoegang:DV:00000000000000000011:services:9050", + eh_attribute_consuming_service_node.find( + ".//md:RequestedAttribute", namespaces=NAME_SPACES + ).attrib["Name"], + ) + self.assertEqual( + "Test Service Name", + eh_attribute_consuming_service_node.find( + ".//md:ServiceName", namespaces=NAME_SPACES + ).text, + ) + self.assertEqual( + "Test Service Description", + eh_attribute_consuming_service_node.find( + ".//md:ServiceDescription", namespaces=NAME_SPACES + ).text, + ) + + with self.subTest("eidas attribute consuming service"): + self.assertEqual( + "urn:etoegang:DV:00000000000000000011:services:9051", + eidas_attribute_consuming_service_node.find( + ".//md:RequestedAttribute", namespaces=NAME_SPACES + ).attrib["Name"], + ) + self.assertEqual( + "Test Service Name (eIDAS)", + eidas_attribute_consuming_service_node.find( + ".//md:ServiceName", namespaces=NAME_SPACES + ).text, + ) + self.assertEqual( + "Test Service Description", + eidas_attribute_consuming_service_node.find( + ".//md:ServiceDescription", namespaces=NAME_SPACES + ).text, + ) organisation_name_node = entity_descriptor_node.find( ".//md:OrganizationName",