diff --git a/lib/charms/certificate_transfer_interface/v0/certificate_transfer.py b/lib/charms/certificate_transfer_interface/v0/certificate_transfer.py index edad1d90..44ddfdae 100644 --- a/lib/charms/certificate_transfer_interface/v0/certificate_transfer.py +++ b/lib/charms/certificate_transfer_interface/v0/certificate_transfer.py @@ -97,7 +97,7 @@ def _on_certificate_removed(self, event: CertificateRemovedEvent): import logging from typing import List -from jsonschema import exceptions, validate # type: ignore[import] +from jsonschema import exceptions, validate # type: ignore[import-untyped] from ops.charm import CharmBase, CharmEvents, RelationBrokenEvent, RelationChangedEvent from ops.framework import EventBase, EventSource, Handle, Object @@ -109,7 +109,7 @@ def _on_certificate_removed(self, event: CertificateRemovedEvent): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 4 +LIBPATCH = 5 PYDEPS = ["jsonschema"] diff --git a/lib/charms/loki_k8s/v0/loki_push_api.py b/lib/charms/loki_k8s/v0/loki_push_api.py index 9f9372d2..01d7dc16 100644 --- a/lib/charms/loki_k8s/v0/loki_push_api.py +++ b/lib/charms/loki_k8s/v0/loki_push_api.py @@ -12,9 +12,9 @@ implement the provider side of the `loki_push_api` relation interface. For instance, a Loki charm. The provider side of the relation represents the server side, to which logs are being pushed. -- `LokiPushApiConsumer`: This object is meant to be used by any Charmed Operator that needs to -send log to Loki by implementing the consumer side of the `loki_push_api` relation interface. -For instance, a Promtail or Grafana agent charm which needs to send logs to Loki. +- `LokiPushApiConsumer`: Used to obtain the loki api endpoint. This is useful for configuring + applications such as pebble, or charmed operators of workloads such as grafana-agent or promtail, + that can communicate with loki directly. - `LogProxyConsumer`: This object can be used by any Charmed Operator which needs to send telemetry, such as logs, to Loki through a Log Proxy by implementing the consumer side of the @@ -456,7 +456,7 @@ def _alert_rules_error(self, event): from urllib.error import HTTPError import yaml -from charms.observability_libs.v0.juju_topology import JujuTopology +from cosl import JujuTopology from ops.charm import ( CharmBase, HookEvent, @@ -480,7 +480,7 @@ def _alert_rules_error(self, event): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 22 +LIBPATCH = 25 logger = logging.getLogger(__name__) @@ -604,7 +604,9 @@ def _validate_relation_by_interface_and_direction( actual_relation_interface = relation.interface_name if actual_relation_interface != expected_relation_interface: raise RelationInterfaceMismatchError( - relation_name, expected_relation_interface, actual_relation_interface + relation_name, + expected_relation_interface, + actual_relation_interface, # pyright: ignore ) if expected_relation_role == RelationRole.provides: @@ -866,20 +868,20 @@ def _from_dir(self, dir_path: Path, recursive: bool) -> List[dict]: return alert_groups - def add_path(self, path: str, *, recursive: bool = False): + def add_path(self, path_str: str, *, recursive: bool = False): """Add rules from a dir path. All rules from files are aggregated into a data structure representing a single rule file. All group names are augmented with juju topology. Args: - path: either a rules file or a dir of rules files. + path_str: either a rules file or a dir of rules files. recursive: whether to read files recursively or not (no impact if `path` is a file). Raises: InvalidAlertRulePathError: if the provided path is invalid. """ - path = Path(path) # type: Path + path = Path(path_str) # type: Path if path.is_dir(): self.alert_groups.extend(self._from_dir(path, recursive)) elif path.is_file(): @@ -992,6 +994,8 @@ def __init__(self, handle, relation, relation_id, app=None, unit=None): def snapshot(self) -> Dict: """Save event information.""" + if not self.relation: + return {} snapshot = {"relation_name": self.relation.name, "relation_id": self.relation.id} if self.app: snapshot["app_name"] = self.app.name @@ -1052,7 +1056,7 @@ class LokiPushApiEvents(ObjectEvents): class LokiPushApiProvider(Object): """A LokiPushApiProvider class.""" - on = LokiPushApiEvents() + on = LokiPushApiEvents() # pyright: ignore def __init__( self, @@ -1146,11 +1150,11 @@ def _on_logging_relation_changed(self, event: HookEvent): event: a `CharmEvent` in response to which the consumer charm must update its relation data. """ - should_update = self._process_logging_relation_changed(event.relation) + should_update = self._process_logging_relation_changed(event.relation) # pyright: ignore if should_update: self.on.loki_push_api_alert_rules_changed.emit( - relation=event.relation, - relation_id=event.relation.id, + relation=event.relation, # pyright: ignore + relation_id=event.relation.id, # pyright: ignore app=self._charm.app, unit=self._charm.unit, ) @@ -1517,7 +1521,7 @@ def loki_endpoints(self) -> List[dict]: class LokiPushApiConsumer(ConsumerBase): """Loki Consumer class.""" - on = LokiPushApiEvents() + on = LokiPushApiEvents() # pyright: ignore def __init__( self, @@ -1760,7 +1764,7 @@ class LogProxyConsumer(ConsumerBase): role. """ - on = LogProxyEvents() + on = LogProxyEvents() # pyright: ignore def __init__( self, @@ -1885,7 +1889,7 @@ def _on_relation_departed(self, _: RelationEvent) -> None: self._container.stop(WORKLOAD_SERVICE_NAME) self.on.log_proxy_endpoint_departed.emit() - def _get_container(self, container_name: str = "") -> Container: + def _get_container(self, container_name: str = "") -> Container: # pyright: ignore """Gets a single container by name or using the only container running in the Pod. If there is more than one container in the Pod a `PromtailDigestError` is emitted. @@ -1959,7 +1963,9 @@ def _add_pebble_layer(self, workload_binary_path: str) -> None: } }, } - self._container.add_layer(self._container_name, pebble_layer, combine=True) + self._container.add_layer( + self._container_name, pebble_layer, combine=True # pyright: ignore + ) def _create_directories(self) -> None: """Creates the directories for Promtail binary and config file.""" @@ -1996,7 +2002,11 @@ def _push_binary_to_workload(self, binary_path: str, workload_binary_path: str) """ with open(binary_path, "rb") as f: self._container.push( - workload_binary_path, f, permissions=0o755, encoding=None, make_dirs=True + workload_binary_path, + f, + permissions=0o755, + encoding=None, # pyright: ignore + make_dirs=True, ) logger.debug("The promtail binary file has been pushed to the workload container.") diff --git a/lib/charms/oathkeeper/v0/forward_auth.py b/lib/charms/oathkeeper/v0/forward_auth.py index e0954a64..068fba67 100644 --- a/lib/charms/oathkeeper/v0/forward_auth.py +++ b/lib/charms/oathkeeper/v0/forward_auth.py @@ -54,13 +54,7 @@ def some_event_function(self, event: AuthConfigChangedEvent): from typing import Dict, List, Mapping, Optional import jsonschema -from ops.charm import ( - CharmBase, - RelationBrokenEvent, - RelationChangedEvent, - RelationCreatedEvent, - RelationDepartedEvent, -) +from ops.charm import CharmBase, RelationBrokenEvent, RelationChangedEvent, RelationCreatedEvent from ops.framework import EventBase, EventSource, Handle, Object, ObjectEvents from ops.model import Relation, TooManyRelatedAppsError @@ -72,7 +66,7 @@ def some_event_function(self, event: AuthConfigChangedEvent): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 2 +LIBPATCH = 3 RELATION_NAME = "forward-auth" INTERFACE_NAME = "forward_auth" @@ -505,7 +499,7 @@ def __init__( events = self.charm.on[relation_name] self.framework.observe(events.relation_created, self._on_relation_created_event) self.framework.observe(events.relation_changed, self._on_relation_changed_event) - self.framework.observe(events.relation_departed, self._on_relation_departed_event) + self.framework.observe(events.relation_broken, self._on_relation_broken_event) def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: """Update the relation with provider data when a relation is created.""" @@ -525,8 +519,8 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: # Compare ingress-related apps with apps that requested the proxy self._compare_apps() - def _on_relation_departed_event(self, event: RelationDepartedEvent) -> None: - """Wipe the relation databag and notify the charm that the relation has departed.""" + def _on_relation_broken_event(self, event: RelationBrokenEvent) -> None: + """Wipe the relation databag and notify the charm that the relation is broken.""" # Workaround for https://github.com/canonical/operator/issues/888 self._pop_relation_data(event.relation.id) diff --git a/lib/charms/observability_libs/v0/cert_handler.py b/lib/charms/observability_libs/v0/cert_handler.py index 88a8374e..db14e00f 100644 --- a/lib/charms/observability_libs/v0/cert_handler.py +++ b/lib/charms/observability_libs/v0/cert_handler.py @@ -64,7 +64,7 @@ LIBID = "b5cd5cd580f3428fa5f59a8876dcbe6a" LIBAPI = 0 -LIBPATCH = 8 +LIBPATCH = 9 def is_ip_address(value: str) -> bool: @@ -181,33 +181,40 @@ def _peer_relation(self) -> Optional[Relation]: return self.charm.model.get_relation(self.peer_relation_name, None) def _on_peer_relation_created(self, _): - """Generate the private key and store it in a peer relation.""" - # We're in "relation-created", so the relation should be there + """Generate the CSR if the certificates relation is ready.""" + self._generate_privkey() - # Just in case we already have a private key, do not overwrite it. - # Not sure how this could happen. - # TODO figure out how to go about key rotation. - if not self._private_key: - private_key = generate_private_key() - self._private_key = private_key.decode() - - # Generate CSR here, in case peer events fired after tls-certificate relation events + # check cert relation is ready if not (self.charm.model.get_relation(self.certificates_relation_name)): # peer relation event happened to fire before tls-certificates events. # Abort, and let the "certificates joined" observer create the CSR. + logger.info("certhandler waiting on certificates relation") return + logger.debug("certhandler has peer and certs relation: proceeding to generate csr") self._generate_csr() def _on_certificates_relation_joined(self, _) -> None: - """Generate the CSR and request the certificate creation.""" + """Generate the CSR if the peer relation is ready.""" + self._generate_privkey() + + # check peer relation is there if not self._peer_relation: # tls-certificates relation event happened to fire before peer events. # Abort, and let the "peer joined" relation create the CSR. + logger.info("certhandler waiting on peer relation") return + logger.debug("certhandler has peer and certs relation: proceeding to generate csr") self._generate_csr() + def _generate_privkey(self): + # Generate priv key unless done already + # TODO figure out how to go about key rotation. + if not self._private_key: + private_key = generate_private_key() + self._private_key = private_key.decode() + def _on_config_changed(self, _): # FIXME on config changed, the web_external_url may or may not change. But because every # call to `generate_csr` appends a uuid, CSRs cannot be easily compared to one another. @@ -237,7 +244,12 @@ def _generate_csr( # In case we already have a csr, do not overwrite it by default. if overwrite or renew or not self._csr: private_key = self._private_key - assert private_key is not None # for type checker + if private_key is None: + # FIXME: raise this in a less nested scope by + # generating privkey and csr in the same method. + raise RuntimeError( + "private key unset. call _generate_privkey() before you call this method." + ) csr = generate_csr( private_key=private_key.encode(), subject=self.cert_subject, diff --git a/lib/charms/observability_libs/v1/kubernetes_service_patch.py b/lib/charms/observability_libs/v1/kubernetes_service_patch.py index 64dd13ce..2cce729e 100644 --- a/lib/charms/observability_libs/v1/kubernetes_service_patch.py +++ b/lib/charms/observability_libs/v1/kubernetes_service_patch.py @@ -127,7 +127,7 @@ def setUp(self, *unused): from types import MethodType from typing import List, Literal, Optional, Union -from lightkube import ApiError, Client +from lightkube import ApiError, Client # pyright: ignore from lightkube.core import exceptions from lightkube.models.core_v1 import ServicePort, ServiceSpec from lightkube.models.meta_v1 import ObjectMeta @@ -146,7 +146,7 @@ def setUp(self, *unused): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 7 +LIBPATCH = 9 ServiceType = Literal["ClusterIP", "LoadBalancer"] @@ -268,7 +268,7 @@ def _patch(self, _) -> None: PatchFailed: if patching fails due to lack of permissions, or otherwise. """ try: - client = Client() + client = Client() # pyright: ignore except exceptions.ConfigError as e: logger.warning("Error creating k8s client: %s", e) return @@ -300,7 +300,7 @@ def is_patched(self) -> bool: Returns: bool: A boolean indicating if the service patch has been applied. """ - client = Client() + client = Client() # pyright: ignore return self._is_patched(client) def _is_patched(self, client: Client) -> bool: @@ -314,7 +314,7 @@ def _is_patched(self, client: Client) -> bool: raise # Construct a list of expected ports, should the patch be applied - expected_ports = [(p.port, p.targetPort) for p in self.service.spec.ports] + expected_ports = [(p.port, p.targetPort) for p in self.service.spec.ports] # type: ignore[attr-defined] # Construct a list in the same manner, using the fetched service fetched_ports = [ (p.port, p.targetPort) for p in service.spec.ports # type: ignore[attr-defined] diff --git a/lib/charms/prometheus_k8s/v0/prometheus_scrape.py b/lib/charms/prometheus_k8s/v0/prometheus_scrape.py index e4297aa1..665af886 100644 --- a/lib/charms/prometheus_k8s/v0/prometheus_scrape.py +++ b/lib/charms/prometheus_k8s/v0/prometheus_scrape.py @@ -362,7 +362,7 @@ def _on_scrape_targets_changed(self, event): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 42 +LIBPATCH = 44 PYDEPS = ["cosl"] @@ -386,6 +386,7 @@ def _on_scrape_targets_changed(self, event): "basic_auth", "tls_config", "authorization", + "params", } DEFAULT_JOB = { "metrics_path": "/metrics", @@ -764,7 +765,7 @@ def _validate_relation_by_interface_and_direction( actual_relation_interface = relation.interface_name if actual_relation_interface != expected_relation_interface: raise RelationInterfaceMismatchError( - relation_name, expected_relation_interface, actual_relation_interface + relation_name, expected_relation_interface, actual_relation_interface or "None" ) if expected_relation_role == RelationRole.provides: @@ -857,7 +858,7 @@ class MonitoringEvents(ObjectEvents): class MetricsEndpointConsumer(Object): """A Prometheus based Monitoring service.""" - on = MonitoringEvents() + on = MonitoringEvents() # pyright: ignore def __init__(self, charm: CharmBase, relation_name: str = DEFAULT_RELATION_NAME): """A Prometheus based Monitoring service. @@ -1014,7 +1015,6 @@ def alerts(self) -> dict: try: scrape_metadata = json.loads(relation.data[relation.app]["scrape_metadata"]) identifier = JujuTopology.from_dict(scrape_metadata).identifier - alerts[identifier] = self._tool.apply_label_matchers(alert_rules) # type: ignore except KeyError as e: logger.debug( @@ -1029,6 +1029,10 @@ def alerts(self) -> dict: ) continue + # We need to append the relation info to the identifier. This is to allow for cases for there are two + # relations which eventually scrape the same application. Issue #551. + identifier = f"{identifier}_{relation.name}_{relation.id}" + alerts[identifier] = alert_rules _, errmsg = self._tool.validate_alert_rules(alert_rules) @@ -1294,7 +1298,7 @@ def _resolve_dir_against_charm_path(charm: CharmBase, *path_elements: str) -> st class MetricsEndpointProvider(Object): """A metrics endpoint for Prometheus.""" - on = MetricsEndpointProviderEvents() + on = MetricsEndpointProviderEvents() # pyright: ignore def __init__( self, @@ -1836,14 +1840,16 @@ def _set_prometheus_data(self, event): return jobs = [] + _type_convert_stored( - self._stored.jobs + self._stored.jobs # pyright: ignore ) # list of scrape jobs, one per relation for relation in self.model.relations[self._target_relation]: targets = self._get_targets(relation) if targets and relation.app: jobs.append(self._static_scrape_job(targets, relation.app.name)) - groups = [] + _type_convert_stored(self._stored.alert_rules) # list of alert rule groups + groups = [] + _type_convert_stored( + self._stored.alert_rules # pyright: ignore + ) # list of alert rule groups for relation in self.model.relations[self._alert_rules_relation]: unit_rules = self._get_alert_rules(relation) if unit_rules and relation.app: @@ -1895,7 +1901,7 @@ def set_target_job_data(self, targets: dict, app_name: str, **kwargs) -> None: jobs.append(updated_job) relation.data[self._charm.app]["scrape_jobs"] = json.dumps(jobs) - if not _type_convert_stored(self._stored.jobs) == jobs: + if not _type_convert_stored(self._stored.jobs) == jobs: # pyright: ignore self._stored.jobs = jobs def _on_prometheus_targets_departed(self, event): @@ -1947,7 +1953,7 @@ def remove_prometheus_jobs(self, job_name: str, unit_name: Optional[str] = ""): relation.data[self._charm.app]["scrape_jobs"] = json.dumps(jobs) - if not _type_convert_stored(self._stored.jobs) == jobs: + if not _type_convert_stored(self._stored.jobs) == jobs: # pyright: ignore self._stored.jobs = jobs def _job_name(self, appname) -> str: @@ -2126,7 +2132,7 @@ def set_alert_rule_data(self, name: str, unit_rules: dict, label_rules: bool = T groups.append(updated_group) relation.data[self._charm.app]["alert_rules"] = json.dumps({"groups": groups}) - if not _type_convert_stored(self._stored.alert_rules) == groups: + if not _type_convert_stored(self._stored.alert_rules) == groups: # pyright: ignore self._stored.alert_rules = groups def _on_alert_rules_departed(self, event): @@ -2176,7 +2182,7 @@ def remove_alert_rules(self, group_name: str, unit_name: str) -> None: json.dumps({"groups": groups}) if groups else "{}" ) - if not _type_convert_stored(self._stored.alert_rules) == groups: + if not _type_convert_stored(self._stored.alert_rules) == groups: # pyright: ignore self._stored.alert_rules = groups def _get_alert_rules(self, relation) -> dict: diff --git a/lib/charms/tempo_k8s/v0/charm_tracing.py b/lib/charms/tempo_k8s/v0/charm_tracing.py index 5939fff9..efad1605 100644 --- a/lib/charms/tempo_k8s/v0/charm_tracing.py +++ b/lib/charms/tempo_k8s/v0/charm_tracing.py @@ -86,7 +86,7 @@ def tracer(self) -> opentelemetry.trace.Tracer: # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 3 +LIBPATCH = 4 PYDEPS = ["opentelemetry-exporter-otlp-proto-grpc==1.17.0"] @@ -314,7 +314,7 @@ def trace_charm( Usage: >>> from charms.tempo_k8s.v0.charm_tracing import trace_charm - >>> from charms.tempo_k8s.v0.tracing import TracingEndpointProvider + >>> from charms.tempo_k8s.v1.tracing import TracingEndpointProvider >>> from ops import CharmBase >>> >>> @trace_charm( diff --git a/lib/charms/tempo_k8s/v0/tracing.py b/lib/charms/tempo_k8s/v0/tracing.py index cb5dac87..aa07afaa 100644 --- a/lib/charms/tempo_k8s/v0/tracing.py +++ b/lib/charms/tempo_k8s/v0/tracing.py @@ -93,7 +93,7 @@ def __init__(self, *args): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 7 +LIBPATCH = 8 PYDEPS = ["pydantic<2.0"] @@ -494,7 +494,7 @@ def is_ready(self, relation: Optional[Relation] = None): return False try: TracingProviderAppData.load(relation.data[relation.app]) - except (json.JSONDecodeError, pydantic.ValidationError): + except (json.JSONDecodeError, pydantic.ValidationError, DataValidationError): logger.info(f"failed validating relation data for {relation}") return False return True diff --git a/lib/charms/tls_certificates_interface/v2/tls_certificates.py b/lib/charms/tls_certificates_interface/v2/tls_certificates.py index f4a08366..08c5cb50 100644 --- a/lib/charms/tls_certificates_interface/v2/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v2/tls_certificates.py @@ -287,7 +287,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.serialization import pkcs12 from cryptography.x509.extensions import Extension, ExtensionNotFound -from jsonschema import exceptions, validate # type: ignore[import] +from jsonschema import exceptions, validate # type: ignore[import-untyped] from ops.charm import ( CharmBase, CharmEvents, @@ -298,7 +298,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven ) from ops.framework import EventBase, EventSource, Handle, Object from ops.jujuversion import JujuVersion -from ops.model import Relation, SecretNotFoundError +from ops.model import ModelError, Relation, RelationDataContent, SecretNotFoundError # The unique Charmhub library identifier, never change it LIBID = "afd8c2bccf834997afce12c2706d2ede" @@ -308,13 +308,13 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 16 +LIBPATCH = 22 PYDEPS = ["cryptography", "jsonschema"] REQUIRER_JSON_SCHEMA = { "$schema": "http://json-schema.org/draft-04/schema#", - "$id": "https://canonical.github.io/charm-relation-interfaces/tls_certificates/v2/schemas/requirer.json", # noqa: E501 + "$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/tls_certificates/v1/schemas/requirer.json", "type": "object", "title": "`tls_certificates` requirer root schema", "description": "The `tls_certificates` root schema comprises the entire requirer databag for this interface.", # noqa: E501 @@ -335,7 +335,10 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven "type": "array", "items": { "type": "object", - "properties": {"certificate_signing_request": {"type": "string"}}, + "properties": { + "certificate_signing_request": {"type": "string"}, + "ca": {"type": "boolean"}, + }, "required": ["certificate_signing_request"], }, } @@ -346,7 +349,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven PROVIDER_JSON_SCHEMA = { "$schema": "http://json-schema.org/draft-04/schema#", - "$id": "https://canonical.github.io/charm-relation-interfaces/tls_certificates/v2/schemas/provider.json", # noqa: E501 + "$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/tls_certificates/v1/schemas/provider.json", "type": "object", "title": "`tls_certificates` provider root schema", "description": "The `tls_certificates` root schema comprises the entire provider databag for this interface.", # noqa: E501 @@ -536,22 +539,31 @@ def restore(self, snapshot: dict): class CertificateCreationRequestEvent(EventBase): """Charm Event triggered when a TLS certificate is required.""" - def __init__(self, handle: Handle, certificate_signing_request: str, relation_id: int): + def __init__( + self, + handle: Handle, + certificate_signing_request: str, + relation_id: int, + is_ca: bool = False, + ): super().__init__(handle) self.certificate_signing_request = certificate_signing_request self.relation_id = relation_id + self.is_ca = is_ca def snapshot(self) -> dict: """Returns snapshot.""" return { "certificate_signing_request": self.certificate_signing_request, "relation_id": self.relation_id, + "is_ca": self.is_ca, } def restore(self, snapshot: dict): """Restores snapshot.""" self.certificate_signing_request = snapshot["certificate_signing_request"] self.relation_id = snapshot["relation_id"] + self.is_ca = snapshot["is_ca"] class CertificateRevocationRequestEvent(EventBase): @@ -588,26 +600,63 @@ def restore(self, snapshot: dict): self.chain = snapshot["chain"] -def _load_relation_data(raw_relation_data: dict) -> dict: +def _load_relation_data(relation_data_content: RelationDataContent) -> dict: """Loads relation data from the relation data bag. Json loads all data. Args: - raw_relation_data: Relation data from the databag + relation_data_content: Relation data from the databag Returns: dict: Relation data in dict format. """ certificate_data = dict() - for key in raw_relation_data: - try: - certificate_data[key] = json.loads(raw_relation_data[key]) - except (json.decoder.JSONDecodeError, TypeError): - certificate_data[key] = raw_relation_data[key] + try: + for key in relation_data_content: + try: + certificate_data[key] = json.loads(relation_data_content[key]) + except (json.decoder.JSONDecodeError, TypeError): + certificate_data[key] = relation_data_content[key] + except ModelError: + pass return certificate_data +def _get_closest_future_time( + expiry_notification_time: datetime, expiry_time: datetime +) -> datetime: + """Return expiry_notification_time if not in the past, otherwise return expiry_time. + + Args: + expiry_notification_time (datetime): Notification time of impending expiration + expiry_time (datetime): Expiration time + + Returns: + datetime: expiry_notification_time if not in the past, expiry_time otherwise + """ + return ( + expiry_notification_time if datetime.utcnow() < expiry_notification_time else expiry_time + ) + + +def _get_certificate_expiry_time(certificate: str) -> Optional[datetime]: + """Extract expiry time from a certificate string. + + Args: + certificate (str): x509 certificate as a string + + Returns: + Optional[datetime]: Expiry datetime or None + """ + try: + certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) + return certificate_object.not_valid_after + except ValueError: + logger.warning("Could not load certificate.") + return None + + def generate_ca( private_key: bytes, subject: str, @@ -678,6 +727,105 @@ def generate_ca( return cert.public_bytes(serialization.Encoding.PEM) +def get_certificate_extensions( + authority_key_identifier: bytes, + csr: x509.CertificateSigningRequest, + alt_names: Optional[List[str]], + is_ca: bool, +) -> List[x509.Extension]: + """Generates a list of certificate extensions from a CSR and other known information. + + Args: + authority_key_identifier (bytes): Authority key identifier + csr (x509.CertificateSigningRequest): CSR + alt_names (list): List of alt names to put on cert - prefer putting SANs in CSR + is_ca (bool): Whether the certificate is a CA certificate + + Returns: + List[x509.Extension]: List of extensions + """ + cert_extensions_list: List[x509.Extension] = [ + x509.Extension( + oid=ExtensionOID.AUTHORITY_KEY_IDENTIFIER, + value=x509.AuthorityKeyIdentifier( + key_identifier=authority_key_identifier, + authority_cert_issuer=None, + authority_cert_serial_number=None, + ), + critical=False, + ), + x509.Extension( + oid=ExtensionOID.SUBJECT_KEY_IDENTIFIER, + value=x509.SubjectKeyIdentifier.from_public_key(csr.public_key()), + critical=False, + ), + x509.Extension( + oid=ExtensionOID.BASIC_CONSTRAINTS, + critical=True, + value=x509.BasicConstraints(ca=is_ca, path_length=None), + ), + ] + + sans: List[x509.GeneralName] = [] + san_alt_names = [x509.DNSName(name) for name in alt_names] if alt_names else [] + sans.extend(san_alt_names) + try: + loaded_san_ext = csr.extensions.get_extension_for_class(x509.SubjectAlternativeName) + sans.extend( + [x509.DNSName(name) for name in loaded_san_ext.value.get_values_for_type(x509.DNSName)] + ) + sans.extend( + [x509.IPAddress(ip) for ip in loaded_san_ext.value.get_values_for_type(x509.IPAddress)] + ) + sans.extend( + [ + x509.RegisteredID(oid) + for oid in loaded_san_ext.value.get_values_for_type(x509.RegisteredID) + ] + ) + except x509.ExtensionNotFound: + pass + + if sans: + cert_extensions_list.append( + x509.Extension( + oid=ExtensionOID.SUBJECT_ALTERNATIVE_NAME, + critical=False, + value=x509.SubjectAlternativeName(sans), + ) + ) + + if is_ca: + cert_extensions_list.append( + x509.Extension( + ExtensionOID.KEY_USAGE, + critical=True, + value=x509.KeyUsage( + digital_signature=False, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=True, + crl_sign=True, + encipher_only=False, + decipher_only=False, + ), + ) + ) + + existing_oids = {ext.oid for ext in cert_extensions_list} + for extension in csr.extensions: + if extension.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME: + continue + if extension.oid in existing_oids: + logger.warning("Extension %s is managed by the TLS provider, ignoring.", extension.oid) + continue + cert_extensions_list.append(extension) + + return cert_extensions_list + + def generate_certificate( csr: bytes, ca: bytes, @@ -685,6 +833,7 @@ def generate_certificate( ca_key_password: Optional[bytes] = None, validity: int = 365, alt_names: Optional[List[str]] = None, + is_ca: bool = False, ) -> bytes: """Generates a TLS certificate based on a CSR. @@ -695,6 +844,7 @@ def generate_certificate( ca_key_password: CA private key password validity (int): Certificate validity (in days) alt_names (list): List of alt names to put on cert - prefer putting SANs in CSR + is_ca (bool): Whether the certificate is a CA certificate Returns: bytes: Certificate @@ -713,52 +863,24 @@ def generate_certificate( .serial_number(x509.random_serial_number()) .not_valid_before(datetime.utcnow()) .not_valid_after(datetime.utcnow() + timedelta(days=validity)) - .add_extension( - x509.AuthorityKeyIdentifier( - key_identifier=ca_pem.extensions.get_extension_for_class( - x509.SubjectKeyIdentifier - ).value.key_identifier, - authority_cert_issuer=None, - authority_cert_serial_number=None, - ), - critical=False, - ) - .add_extension( - x509.SubjectKeyIdentifier.from_public_key(csr_object.public_key()), critical=False - ) - .add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=False) ) - - extensions_list = csr_object.extensions - san_ext: Optional[x509.Extension] = None - if alt_names: - full_sans_dns = alt_names.copy() + extensions = get_certificate_extensions( + authority_key_identifier=ca_pem.extensions.get_extension_for_class( + x509.SubjectKeyIdentifier + ).value.key_identifier, + csr=csr_object, + alt_names=alt_names, + is_ca=is_ca, + ) + for extension in extensions: try: - loaded_san_ext = csr_object.extensions.get_extension_for_class( - x509.SubjectAlternativeName + certificate_builder = certificate_builder.add_extension( + extval=extension.value, + critical=extension.critical, ) - full_sans_dns.extend(loaded_san_ext.value.get_values_for_type(x509.DNSName)) - except ExtensionNotFound: - pass - finally: - san_ext = Extension( - ExtensionOID.SUBJECT_ALTERNATIVE_NAME, - False, - x509.SubjectAlternativeName([x509.DNSName(name) for name in full_sans_dns]), - ) - if not extensions_list: - extensions_list = x509.Extensions([san_ext]) - - for extension in extensions_list: - if extension.value.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME and san_ext: - extension = san_ext - - certificate_builder = certificate_builder.add_extension( - extension.value, - critical=extension.critical, - ) + except ValueError as e: + logger.warning("Failed to add extension %s: %s", extension.oid, e) - certificate_builder._version = x509.Version.v3 cert = certificate_builder.sign(private_key, hashes.SHA256()) # type: ignore[arg-type] return cert.public_bytes(serialization.Encoding.PEM) @@ -896,6 +1018,38 @@ def generate_csr( return signed_certificate.public_bytes(serialization.Encoding.PEM) +def csr_matches_certificate(csr: str, cert: str) -> bool: + """Check if a CSR matches a certificate. + + Args: + csr (str): Certificate Signing Request as a string + cert (str): Certificate as a string + Returns: + bool: True/False depending on whether the CSR matches the certificate. + """ + try: + csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) + cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) + + if csr_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) != cert_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ): + return False + if ( + csr_object.public_key().public_numbers().n # type: ignore[union-attr] + != cert_object.public_key().public_numbers().n # type: ignore[union-attr] + ): + return False + except ValueError: + logger.warning("Could not load certificate or CSR.") + return False + return True + + class CertificatesProviderCharmEvents(CharmEvents): """List of events that the TLS Certificates provider charm can leverage.""" @@ -1171,15 +1325,19 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: certificate_creation_request["certificate_signing_request"] for certificate_creation_request in provider_certificates ] - requirer_unit_csrs = [ - certificate_creation_request["certificate_signing_request"] + requirer_unit_certificate_requests = [ + { + "csr": certificate_creation_request["certificate_signing_request"], + "is_ca": certificate_creation_request.get("ca", False), + } for certificate_creation_request in requirer_csrs ] - for certificate_signing_request in requirer_unit_csrs: - if certificate_signing_request not in provider_csrs: + for certificate_request in requirer_unit_certificate_requests: + if certificate_request["csr"] not in provider_csrs: self.on.certificate_creation_request.emit( - certificate_signing_request=certificate_signing_request, + certificate_signing_request=certificate_request["csr"], relation_id=event.relation.id, + is_ca=certificate_request["is_ca"], ) self._revoke_certificates_for_which_no_csr_exists(relation_id=event.relation.id) @@ -1217,12 +1375,24 @@ def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None ) self.remove_certificate(certificate=certificate["certificate"]) - def get_requirer_csrs_with_no_certs( + def get_outstanding_certificate_requests( self, relation_id: Optional[int] = None ) -> List[Dict[str, Union[int, str, List[Dict[str, str]]]]]: - """Filters the requirer's units csrs. + """Returns CSR's for which no certificate has been issued. - Keeps the ones for which no certificate was provided. + Example return: [ + { + "relation_id": 0, + "application_name": "tls-certificates-requirer", + "unit_name": "tls-certificates-requirer/0", + "unit_csrs": [ + { + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + "is_ca": false + } + ] + } + ] Args: relation_id (int): Relation id @@ -1239,6 +1409,7 @@ def get_requirer_csrs_with_no_certs( if not self.certificate_issued_for_csr( app_name=unit_csr_mapping["application_name"], # type: ignore[arg-type] csr=csr["certificate_signing_request"], # type: ignore[index] + relation_id=relation_id, ): csrs_without_certs.append(csr) if csrs_without_certs: @@ -1285,17 +1456,21 @@ def get_requirer_csrs( ) return unit_csr_mappings - def certificate_issued_for_csr(self, app_name: str, csr: str) -> bool: + def certificate_issued_for_csr( + self, app_name: str, csr: str, relation_id: Optional[int] + ) -> bool: """Checks whether a certificate has been issued for a given CSR. Args: app_name (str): Application name that the CSR belongs to. csr (str): Certificate Signing Request. - + relation_id (Optional[int]): Relation ID Returns: bool: True/False depending on whether a certificate has been issued for the given CSR. """ - issued_certificates_per_csr = self.get_issued_certificates()[app_name] + issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id)[ + app_name + ] for issued_pair in issued_certificates_per_csr: if "csr" in issued_pair and issued_pair["csr"] == csr: return csr_matches_certificate(csr, issued_pair["certificate"]) @@ -1337,8 +1512,17 @@ def __init__( self.framework.observe(charm.on.update_status, self._on_update_status) @property - def _requirer_csrs(self) -> List[Dict[str, str]]: - """Returns list of requirer's CSRs from relation data.""" + def _requirer_csrs(self) -> List[Dict[str, Union[bool, str]]]: + """Returns list of requirer's CSRs from relation unit data. + + Example: + [ + { + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + "ca": false + } + ] + """ relation = self.model.get_relation(self.relationship_name) if not relation: raise RuntimeError(f"Relation {self.relationship_name} does not exist") @@ -1361,11 +1545,12 @@ def _provider_certificates(self) -> List[Dict[str, str]]: return [] return provider_relation_data.get("certificates", []) - def _add_requirer_csr(self, csr: str) -> None: + def _add_requirer_csr(self, csr: str, is_ca: bool) -> None: """Adds CSR to relation data. Args: csr (str): Certificate Signing Request + is_ca (bool): Whether the certificate is a CA certificate Returns: None @@ -1376,7 +1561,10 @@ def _add_requirer_csr(self, csr: str) -> None: f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - new_csr_dict = {"certificate_signing_request": csr} + new_csr_dict: Dict[str, Union[bool, str]] = { + "certificate_signing_request": csr, + "ca": is_ca, + } if new_csr_dict in self._requirer_csrs: logger.info("CSR already in relation data - Doing nothing") return @@ -1400,18 +1588,22 @@ def _remove_requirer_csr(self, csr: str) -> None: f"The certificate request can't be completed" ) requirer_csrs = copy.deepcopy(self._requirer_csrs) - csr_dict = {"certificate_signing_request": csr} - if csr_dict not in requirer_csrs: - logger.info("CSR not in relation data - Doing nothing") + if not requirer_csrs: + logger.info("No CSRs in relation data - Doing nothing") return - requirer_csrs.remove(csr_dict) + for requirer_csr in requirer_csrs: + if requirer_csr["certificate_signing_request"] == csr: + requirer_csrs.remove(requirer_csr) relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs) - def request_certificate_creation(self, certificate_signing_request: bytes) -> None: + def request_certificate_creation( + self, certificate_signing_request: bytes, is_ca: bool = False + ) -> None: """Request TLS certificate to provider charm. Args: certificate_signing_request (bytes): Certificate Signing Request + is_ca (bool): Whether the certificate is a CA certificate Returns: None @@ -1422,7 +1614,7 @@ def request_certificate_creation(self, certificate_signing_request: bytes) -> No f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - self._add_requirer_csr(certificate_signing_request.decode().strip()) + self._add_requirer_csr(certificate_signing_request.decode().strip(), is_ca=is_ca) logger.info("Certificate request sent to provider") def request_certificate_revocation(self, certificate_signing_request: bytes) -> None: @@ -1466,6 +1658,92 @@ def request_certificate_renewal( ) logger.info("Certificate renewal request completed.") + def get_assigned_certificates(self) -> List[Dict[str, str]]: + """Get a list of certificates that were assigned to this unit. + + Returns: + List of certificates. For example: + [ + { + "ca": "-----BEGIN CERTIFICATE-----...", + "chain": [ + "-----BEGIN CERTIFICATE-----..." + ], + "certificate": "-----BEGIN CERTIFICATE-----...", + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + } + ] + """ + final_list = [] + for csr in self.get_certificate_signing_requests(fulfilled_only=True): + assert type(csr["certificate_signing_request"]) == str + if cert := self._find_certificate_in_relation_data(csr["certificate_signing_request"]): + final_list.append(cert) + return final_list + + def get_expiring_certificates(self) -> List[Dict[str, str]]: + """Get a list of certificates that were assigned to this unit that are expiring or expired. + + Returns: + List of certificates. For example: + [ + { + "ca": "-----BEGIN CERTIFICATE-----...", + "chain": [ + "-----BEGIN CERTIFICATE-----..." + ], + "certificate": "-----BEGIN CERTIFICATE-----...", + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + } + ] + """ + final_list = [] + for csr in self.get_certificate_signing_requests(fulfilled_only=True): + assert type(csr["certificate_signing_request"]) == str + if cert := self._find_certificate_in_relation_data(csr["certificate_signing_request"]): + expiry_time = _get_certificate_expiry_time(cert["certificate"]) + if not expiry_time: + continue + expiry_notification_time = expiry_time - timedelta( + hours=self.expiry_notification_time + ) + if datetime.utcnow() > expiry_notification_time: + final_list.append(cert) + return final_list + + def get_certificate_signing_requests( + self, + fulfilled_only: bool = False, + unfulfilled_only: bool = False, + ) -> List[Dict[str, Union[bool, str]]]: + """Gets the list of CSR's that were sent to the provider. + + You can choose to get only the CSR's that have a certificate assigned or only the CSR's + that don't. + + Args: + fulfilled_only (bool): This option will discard CSRs that don't have certificates yet. + unfulfilled_only (bool): This option will discard CSRs that have certificates signed. + Returns: + List of CSR dictionaries. For example: + [ + { + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + "ca": false + } + ] + """ + + final_list = [] + for csr in self._requirer_csrs: + assert type(csr["certificate_signing_request"]) == str + cert = self._find_certificate_in_relation_data(csr["certificate_signing_request"]) + if (unfulfilled_only and cert) or (fulfilled_only and not cert): + continue + final_list.append(csr) + + return final_list + @staticmethod def _relation_data_is_valid(certificates_data: dict) -> bool: """Checks whether relation data is valid based on json schema. @@ -1676,68 +1954,3 @@ def _on_update_status(self, event: UpdateStatusEvent) -> None: certificate=certificate_dict["certificate"], expiry=expiry_time.isoformat(), ) - - -def csr_matches_certificate(csr: str, cert: str) -> bool: - """Check if a CSR matches a certificate. - - expects to get the original string representations. - - Args: - csr (str): Certificate Signing Request - cert (str): Certificate - Returns: - bool: True/False depending on whether the CSR matches the certificate. - """ - try: - csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) - cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) - - if csr_object.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ) != cert_object.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ): - return False - if csr_object.subject != cert_object.subject: - return False - except ValueError: - logger.warning("Could not load certificate or CSR.") - return False - return True - - -def _get_closest_future_time( - expiry_notification_time: datetime, expiry_time: datetime -) -> datetime: - """Return expiry_notification_time if not in the past, otherwise return expiry_time. - - Args: - expiry_notification_time (datetime): Notification time of impending expiration - expiry_time (datetime): Expiration time - - Returns: - datetime: expiry_notification_time if not in the past, expiry_time otherwise - """ - return ( - expiry_notification_time if datetime.utcnow() < expiry_notification_time else expiry_time - ) - - -def _get_certificate_expiry_time(certificate: str) -> Optional[datetime]: - """Extract expiry time from a certificate string. - - Args: - certificate (str): x509 certificate as a string - - Returns: - Optional[datetime]: Expiry datetime or None - """ - try: - certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) - return certificate_object.not_valid_after - except ValueError: - logger.warning("Could not load certificate.") - return None diff --git a/lib/charms/traefik_route_k8s/v0/traefik_route.py b/lib/charms/traefik_route_k8s/v0/traefik_route.py index 53da3cfe..48bedf38 100644 --- a/lib/charms/traefik_route_k8s/v0/traefik_route.py +++ b/lib/charms/traefik_route_k8s/v0/traefik_route.py @@ -88,7 +88,7 @@ def __init__(self, *args): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 8 +LIBPATCH = 9 log = logging.getLogger(__name__) @@ -137,7 +137,7 @@ class TraefikRouteProvider(Object): The TraefikRouteProvider provides api to do this easily. """ - on = TraefikRouteProviderEvents() + on = TraefikRouteProviderEvents() # pyright: ignore _stored = StoredState() def __init__( @@ -163,7 +163,10 @@ def __init__( self._charm = charm self._relation_name = relation_name - if self._stored.external_host != external_host or self._stored.scheme != scheme: + if ( + self._stored.external_host != external_host # pyright: ignore + or self._stored.scheme != scheme # pyright: ignore + ): # If traefik endpoint details changed, update self.update_traefik_address(external_host=external_host, scheme=scheme) @@ -197,7 +200,7 @@ def _update_stored(self) -> None: This is split out into a separate method since, in the case of multi-unit deployments, removal of a `TraefikRouteRequirer` will not cause a `RelationEvent`, but the guard on app data ensures that only the previous leader will know what it is. Separating it - allows for re-use both when the property is called and if the relation changes, so a + allows for reuse both when the property is called and if the relation changes, so a leader change where the new leader checks the property will do the right thing. """ if not self._charm.unit.is_leader(): @@ -209,9 +212,11 @@ def _update_stored(self) -> None: self._stored.scheme = "" return external_host = relation.data[relation.app].get("external_host", "") - self._stored.external_host = external_host or self._stored.external_host + self._stored.external_host = ( + external_host or self._stored.external_host # pyright: ignore + ) scheme = relation.data[relation.app].get("scheme", "") - self._stored.scheme = scheme or self._stored.scheme + self._stored.scheme = scheme or self._stored.scheme # pyright: ignore def _on_relation_changed(self, event: RelationEvent): if self.is_ready(event.relation): @@ -269,7 +274,7 @@ class TraefikRouteRequirer(Object): application databag. """ - on = TraefikRouteRequirerEvents() + on = TraefikRouteRequirerEvents() # pyright: ignore _stored = StoredState() def __init__(self, charm: CharmBase, relation: Relation, relation_name: str = "traefik-route"): @@ -304,7 +309,7 @@ def _update_stored(self) -> None: This is split out into a separate method since, in the case of multi-unit deployments, removal of a `TraefikRouteRequirer` will not cause a `RelationEvent`, but the guard on app data ensures that only the previous leader will know what it is. Separating it - allows for re-use both when the property is called and if the relation changes, so a + allows for reuse both when the property is called and if the relation changes, so a leader change where the new leader checks the property will do the right thing. """ if not self._charm.unit.is_leader(): @@ -317,9 +322,11 @@ def _update_stored(self) -> None: self._stored.scheme = "" return external_host = relation.data[relation.app].get("external_host", "") - self._stored.external_host = external_host or self._stored.external_host + self._stored.external_host = ( + external_host or self._stored.external_host # pyright: ignore + ) scheme = relation.data[relation.app].get("scheme", "") - self._stored.scheme = scheme or self._stored.scheme + self._stored.scheme = scheme or self._stored.scheme # pyright: ignore def _on_relation_changed(self, event: RelationEvent) -> None: """Update StoredState with external_host and other information from Traefik."""