Skip to content

Commit

Permalink
Merge pull request #88 from canonical/update-tls-lib
Browse files Browse the repository at this point in the history
fix issues with tls on immediate relation
  • Loading branch information
MiaAltieri authored Jan 16, 2025
2 parents c3e9b71 + 5676c37 commit 9f81c03
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 22 deletions.
118 changes: 96 additions & 22 deletions lib/charms/mongodb/v1/mongodb_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
external relation.
"""
import base64
import json
import logging
import re
import socket
from typing import List, Optional, Tuple
from typing import Optional, Tuple

from charms.tls_certificates_interface.v3.tls_certificates import (
CertificateAvailableEvent,
Expand All @@ -20,6 +21,8 @@
generate_csr,
generate_private_key,
)
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from ops.charm import ActionEvent, RelationBrokenEvent, RelationJoinedEvent
from ops.framework import Object
from ops.model import ActiveStatus, MaintenanceStatus, WaitingStatus
Expand All @@ -28,7 +31,8 @@

UNIT_SCOPE = Config.Relations.UNIT_SCOPE
Scopes = Config.Relations.Scopes

SANS_DNS_KEY = "sans_dns"
SANS_IPS_KEY = "sans_ips"

# The unique Charmhub library identifier, never change it
LIBID = "e02a50f0795e4dd292f58e93b4f493dd"
Expand All @@ -38,7 +42,9 @@

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 2
LIBPATCH = 5

WAIT_CERT_UPDATE = "wait-cert-updated"

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -104,12 +110,13 @@ def request_certificate(
else:
key = self._parse_tls_file(param)

sans = self.get_new_sans()
csr = generate_csr(
private_key=key,
subject=self._get_subject_name(),
organization=self._get_subject_name(),
sans=self._get_sans(),
sans_ip=[str(self.charm.model.get_binding(self.peer_relation).network.bind_address)],
sans=sans[SANS_DNS_KEY],
sans_ip=sans[SANS_IPS_KEY],
)
self.set_tls_secret(internal, Config.TLS.SECRET_KEY_LABEL, key.decode("utf-8"))
self.set_tls_secret(internal, Config.TLS.SECRET_CSR_LABEL, csr.decode("utf-8"))
Expand All @@ -118,9 +125,8 @@ def request_certificate(
label = "int" if internal else "ext"
self.charm.unit_peer_data[f"{label}_certs_subject"] = self._get_subject_name()
self.charm.unit_peer_data[f"{label}_certs_subject"] = self._get_subject_name()

if self.charm.model.get_relation(Config.TLS.TLS_PEER_RELATION):
self.certs.request_certificate_creation(certificate_signing_request=csr)
self.certs.request_certificate_creation(certificate_signing_request=csr)
self.set_waiting_for_cert_to_update(internal=internal, waiting=True)

@staticmethod
def _parse_tls_file(raw_content: str) -> bytes:
Expand Down Expand Up @@ -158,12 +164,18 @@ def _on_tls_relation_joined(self, event: RelationJoinedEvent) -> None:

def _on_tls_relation_broken(self, event: RelationBrokenEvent) -> None:
"""Disable TLS when TLS relation broken."""
logger.debug("Disabling external and internal TLS for unit: %s", self.charm.unit.name)
if not self.charm.db_initialised:
logger.info("Deferring %s. db is not initialised.", str(type(event)))
event.defer()
return

if self.charm.upgrade_in_progress:
logger.warning(
"Disabling TLS is not supported during an upgrade. The charm may be in a broken, unrecoverable state."
)

logger.debug("Disabling external and internal TLS for unit: %s", self.charm.unit.name)

for internal in [True, False]:
self.set_tls_secret(internal, Config.TLS.SECRET_CA_LABEL, None)
self.set_tls_secret(internal, Config.TLS.SECRET_CERT_LABEL, None)
Expand All @@ -188,6 +200,11 @@ def _on_certificate_available(self, event: CertificateAvailableEvent) -> None:
event.defer()
return

if not self.charm.db_initialised:
logger.info("Deferring %s. db is not initialised.", str(type(event)))
event.defer()
return

int_csr = self.get_tls_secret(internal=True, label_name=Config.TLS.SECRET_CSR_LABEL)
ext_csr = self.get_tls_secret(internal=False, label_name=Config.TLS.SECRET_CSR_LABEL)

Expand All @@ -208,12 +225,13 @@ def _on_certificate_available(self, event: CertificateAvailableEvent) -> None:
)
self.set_tls_secret(internal, Config.TLS.SECRET_CERT_LABEL, event.certificate)
self.set_tls_secret(internal, Config.TLS.SECRET_CA_LABEL, event.ca)
self.set_waiting_for_cert_to_update(internal=internal, waiting=False)

if self.charm.is_role(Config.Role.CONFIG_SERVER) and internal:
self.charm.cluster.update_ca_secret(new_ca=event.ca)
self.charm.config_server.update_ca_secret(new_ca=event.ca)

if self.waiting_for_certs():
if self.is_waiting_for_both_certs():
logger.debug(
"Defer till both internal and external TLS certificates available to avoid second restart."
)
Expand All @@ -235,7 +253,7 @@ def _on_certificate_available(self, event: CertificateAvailableEvent) -> None:
# clear waiting status if db service is ready
self.charm.status.set_and_share_status(ActiveStatus())

def waiting_for_certs(self):
def is_waiting_for_both_certs(self) -> bool:
"""Returns a boolean indicating whether additional certs are needed."""
if not self.get_tls_secret(internal=True, label_name=Config.TLS.SECRET_CERT_LABEL):
logger.debug("Waiting for internal certificate.")
Expand Down Expand Up @@ -268,21 +286,25 @@ def _on_certificate_expiring(self, event: CertificateExpiringEvent) -> None:
== self.get_tls_secret(internal=True, label_name=Config.TLS.SECRET_CERT_LABEL).rstrip()
):
logger.debug("The internal TLS certificate expiring.")

internal = True
else:
logger.error("An unknown certificate expiring.")
return

logger.debug("Generating a new Certificate Signing Request.")
self.request_new_certificates(internal)

def request_new_certificates(self, internal: bool) -> None:
"""Requests the renewel of a new certificate."""
key = self.get_tls_secret(internal, Config.TLS.SECRET_KEY_LABEL).encode("utf-8")
old_csr = self.get_tls_secret(internal, Config.TLS.SECRET_CSR_LABEL).encode("utf-8")
sans = self.get_new_sans()
new_csr = generate_csr(
private_key=key,
subject=self._get_subject_name(),
organization=self._get_subject_name(),
sans=self._get_sans(),
sans_ip=[str(self.charm.model.get_binding(self.peer_relation).network.bind_address)],
sans=sans[SANS_DNS_KEY],
sans_ip=sans[SANS_IPS_KEY],
)
logger.debug("Requesting a certificate renewal.")

Expand All @@ -292,21 +314,53 @@ def _on_certificate_expiring(self, event: CertificateExpiringEvent) -> None:
)

self.set_tls_secret(internal, Config.TLS.SECRET_CSR_LABEL, new_csr.decode("utf-8"))
self.set_waiting_for_cert_to_update(waiting=True, internal=internal)

def _get_sans(self) -> List[str]:
def get_new_sans(self) -> dict[str, list[str]]:
"""Create a list of DNS names for a MongoDB unit.
Returns:
A list representing the hostnames of the MongoDB unit.
"""
unit_id = self.charm.unit.name.split("/")[1]
return [
f"{self.charm.app.name}-{unit_id}",
socket.getfqdn(),
f"{self.charm.app.name}-{unit_id}.{self.charm.app.name}-endpoints",
str(self.charm.model.get_binding(self.peer_relation).network.bind_address),
"localhost",
]

sans = {
SANS_DNS_KEY: [
f"{self.charm.app.name}-{unit_id}",
socket.getfqdn(),
"localhost",
f"{self.charm.app.name}-{unit_id}.{self.charm.app.name}-endpoints",
],
SANS_IPS_KEY: [
str(self.charm.model.get_binding(self.peer_relation).network.bind_address)
],
}

if self.charm.is_role(Config.Role.MONGOS) and self.charm.is_external_client:
sans[SANS_IPS_KEY].append(
self.charm.get_ext_mongos_host(self.charm.unit, incl_port=False)
)

return sans

def get_current_sans(self, internal: bool) -> dict[str, list[str]] | None:
"""Gets the current SANs for the unit cert."""
# if unit has no certificates do not proceed.
if not self.is_tls_enabled(internal=internal):
return

pem_file = self.get_tls_secret(internal, Config.TLS.SECRET_CERT_LABEL)

try:
cert = x509.load_pem_x509_certificate(pem_file.encode(), default_backend())
sans = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value
sans_ip = [str(san) for san in sans.get_values_for_type(x509.IPAddress)]
sans_dns = [str(san) for san in sans.get_values_for_type(x509.DNSName)]
except x509.ExtensionNotFound:
sans_ip = []
sans_dns = []

return {SANS_IPS_KEY: sorted(sans_ip), SANS_DNS_KEY: sorted(sans_dns)}

def get_tls_files(self, internal: bool) -> Tuple[Optional[str], Optional[str]]:
"""Prepare TLS files in special MongoDB way.
Expand Down Expand Up @@ -356,3 +410,23 @@ def _get_subject_name(self) -> str:
return self.charm.get_config_server_name() or self.charm.app.name

return self.charm.app.name

def is_set_waiting_for_cert_to_update(
self,
internal: bool = False,
) -> bool:
"""Returns True if we are waiting for a cert to update."""
scope = "int" if internal else "ext"
label_name = f"{scope}-{WAIT_CERT_UPDATE}"

return json.loads(self.charm.unit_peer_data.get(label_name, "false"))

def set_waiting_for_cert_to_update(
self,
waiting: bool,
internal: bool,
) -> None:
"""Sets a boolean indicator, for whether or not we are waiting for a cert to update."""
scope = "int" if internal else "ext"
label_name = f"{scope}-{WAIT_CERT_UPDATE}"
self.charm.unit_peer_data[label_name] = json.dumps(waiting)
5 changes: 5 additions & 0 deletions src/charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,11 @@ def unit_host(self, unit: Unit) -> str:
else:
raise ApplicationHostNotFoundError

@property
def db_initialised(self) -> bool:
"""Proxy for mongos_initialised, since some libs rely on db_initialised."""
return self.mongos_initialised

@property
def mongos_initialised(self) -> bool:
"""Check if mongos is initialised."""
Expand Down

0 comments on commit 9f81c03

Please sign in to comment.