From 25ee95d90c0255ada13e3b2e20d748b54a351983 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Thu, 1 Jun 2023 16:34:41 +0000 Subject: [PATCH 01/57] Update libs & clear src --- .github/workflows/ci.yaml | 32 +- .../data_platform_libs/v0/data_interfaces.py | 1395 +++++++++++++++++ .../v0/database_provides.py | 316 ---- .../v0/database_requires.py | 496 ------ lib/charms/operator_libs_linux/v0/apt.py | 1329 ---------------- lib/charms/operator_libs_linux/v0/passwd.py | 255 --- lib/charms/operator_libs_linux/v1/systemd.py | 219 --- lib/charms/operator_libs_linux/v2/snap.py | 1065 +++++++++++++ src/charm.py | 212 --- src/constants.py | 23 - src/mysql_router_helpers.py | 222 --- src/relations/database_provides.py | 156 -- src/relations/database_requires.py | 150 -- src/relations/shared_db.py | 181 --- src/utils.py | 20 - tests/unit/__init__.py | 8 - tests/unit/test_charm.py | 136 -- tests/unit/test_mysql_router_helpers.py | 189 --- 18 files changed, 2477 insertions(+), 3927 deletions(-) create mode 100644 lib/charms/data_platform_libs/v0/data_interfaces.py delete mode 100644 lib/charms/data_platform_libs/v0/database_provides.py delete mode 100644 lib/charms/data_platform_libs/v0/database_requires.py delete mode 100644 lib/charms/operator_libs_linux/v0/apt.py delete mode 100644 lib/charms/operator_libs_linux/v0/passwd.py delete mode 100644 lib/charms/operator_libs_linux/v1/systemd.py create mode 100644 lib/charms/operator_libs_linux/v2/snap.py delete mode 100755 src/charm.py delete mode 100644 src/constants.py delete mode 100644 src/mysql_router_helpers.py delete mode 100644 src/relations/database_provides.py delete mode 100644 src/relations/database_requires.py delete mode 100644 src/relations/shared_db.py delete mode 100644 src/utils.py delete mode 100644 tests/unit/__init__.py delete mode 100644 tests/unit/test_charm.py delete mode 100644 tests/unit/test_mysql_router_helpers.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 3cb58d11..1f34c20c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -27,20 +27,21 @@ jobs: - name: Run linters run: tox run -e lint - unit-test: - name: Unit tests - runs-on: ubuntu-latest - timeout-minutes: 5 - steps: - - name: Checkout - uses: actions/checkout@v3 - - name: Install tox - # TODO: Consider replacing with custom image on self-hosted runner OR pinning version - run: python3 -m pip install tox - - name: Run tests - run: tox run -e unit - - name: Upload Coverage to Codecov - uses: codecov/codecov-action@v3 +# TODO: re-enable after adding unit tests +# unit-test: +# name: Unit tests +# runs-on: ubuntu-latest +# timeout-minutes: 5 +# steps: +# - name: Checkout +# uses: actions/checkout@v3 +# - name: Install tox +# # TODO: Consider replacing with custom image on self-hosted runner OR pinning version +# run: python3 -m pip install tox +# - name: Run tests +# run: tox run -e unit +# - name: Upload Coverage to Codecov +# uses: codecov/codecov-action@v3 lib-check: name: Check libraries @@ -81,7 +82,8 @@ jobs: name: ${{ matrix.tox-environments }} | ${{ matrix.ubuntu-versions.series }} needs: - lint - - unit-test + # TODO: re-enable after adding unit tests + # - unit-test - build runs-on: ubuntu-latest timeout-minutes: 120 diff --git a/lib/charms/data_platform_libs/v0/data_interfaces.py b/lib/charms/data_platform_libs/v0/data_interfaces.py new file mode 100644 index 00000000..86d7521a --- /dev/null +++ b/lib/charms/data_platform_libs/v0/data_interfaces.py @@ -0,0 +1,1395 @@ +# Copyright 2023 Canonical Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Library to manage the relation for the data-platform products. + +This library contains the Requires and Provides classes for handling the relation +between an application and multiple managed application supported by the data-team: +MySQL, Postgresql, MongoDB, Redis, and Kafka. + +### Database (MySQL, Postgresql, MongoDB, and Redis) + +#### Requires Charm +This library is a uniform interface to a selection of common database +metadata, with added custom events that add convenience to database management, +and methods to consume the application related data. + + +Following an example of using the DatabaseCreatedEvent, in the context of the +application charm code: + +```python + +from charms.data_platform_libs.v0.data_interfaces import ( + DatabaseCreatedEvent, + DatabaseRequires, +) + +class ApplicationCharm(CharmBase): + # Application charm that connects to database charms. + + def __init__(self, *args): + super().__init__(*args) + + # Charm events defined in the database requires charm library. + self.database = DatabaseRequires(self, relation_name="database", database_name="database") + self.framework.observe(self.database.on.database_created, self._on_database_created) + + def _on_database_created(self, event: DatabaseCreatedEvent) -> None: + # Handle the created database + + # Create configuration file for app + config_file = self._render_app_config_file( + event.username, + event.password, + event.endpoints, + ) + + # Start application with rendered configuration + self._start_application(config_file) + + # Set active status + self.unit.status = ActiveStatus("received database credentials") +``` + +As shown above, the library provides some custom events to handle specific situations, +which are listed below: + +- database_created: event emitted when the requested database is created. +- endpoints_changed: event emitted when the read/write endpoints of the database have changed. +- read_only_endpoints_changed: event emitted when the read-only endpoints of the database + have changed. Event is not triggered if read/write endpoints changed too. + +If it is needed to connect multiple database clusters to the same relation endpoint +the application charm can implement the same code as if it would connect to only +one database cluster (like the above code example). + +To differentiate multiple clusters connected to the same relation endpoint +the application charm can use the name of the remote application: + +```python + +def _on_database_created(self, event: DatabaseCreatedEvent) -> None: + # Get the remote app name of the cluster that triggered this event + cluster = event.relation.app.name +``` + +It is also possible to provide an alias for each different database cluster/relation. + +So, it is possible to differentiate the clusters in two ways. +The first is to use the remote application name, i.e., `event.relation.app.name`, as above. + +The second way is to use different event handlers to handle each cluster events. +The implementation would be something like the following code: + +```python + +from charms.data_platform_libs.v0.data_interfaces import ( + DatabaseCreatedEvent, + DatabaseRequires, +) + +class ApplicationCharm(CharmBase): + # Application charm that connects to database charms. + + def __init__(self, *args): + super().__init__(*args) + + # Define the cluster aliases and one handler for each cluster database created event. + self.database = DatabaseRequires( + self, + relation_name="database", + database_name="database", + relations_aliases = ["cluster1", "cluster2"], + ) + self.framework.observe( + self.database.on.cluster1_database_created, self._on_cluster1_database_created + ) + self.framework.observe( + self.database.on.cluster2_database_created, self._on_cluster2_database_created + ) + + def _on_cluster1_database_created(self, event: DatabaseCreatedEvent) -> None: + # Handle the created database on the cluster named cluster1 + + # Create configuration file for app + config_file = self._render_app_config_file( + event.username, + event.password, + event.endpoints, + ) + ... + + def _on_cluster2_database_created(self, event: DatabaseCreatedEvent) -> None: + # Handle the created database on the cluster named cluster2 + + # Create configuration file for app + config_file = self._render_app_config_file( + event.username, + event.password, + event.endpoints, + ) + ... + +``` + +When it's needed to check whether a plugin (extension) is enabled on the PostgreSQL +charm, you can use the is_postgresql_plugin_enabled method. To use that, you need to +add the following dependency to your charmcraft.yaml file: + +```yaml + +parts: + charm: + charm-binary-python-packages: + - psycopg[binary] + +``` + +### Provider Charm + +Following an example of using the DatabaseRequestedEvent, in the context of the +database charm code: + +```python +from charms.data_platform_libs.v0.data_interfaces import DatabaseProvides + +class SampleCharm(CharmBase): + + def __init__(self, *args): + super().__init__(*args) + # Charm events defined in the database provides charm library. + self.provided_database = DatabaseProvides(self, relation_name="database") + self.framework.observe(self.provided_database.on.database_requested, + self._on_database_requested) + # Database generic helper + self.database = DatabaseHelper() + + def _on_database_requested(self, event: DatabaseRequestedEvent) -> None: + # Handle the event triggered by a new database requested in the relation + # Retrieve the database name using the charm library. + db_name = event.database + # generate a new user credential + username = self.database.generate_user() + password = self.database.generate_password() + # set the credentials for the relation + self.provided_database.set_credentials(event.relation.id, username, password) + # set other variables for the relation event.set_tls("False") +``` +As shown above, the library provides a custom event (database_requested) to handle +the situation when an application charm requests a new database to be created. +It's preferred to subscribe to this event instead of relation changed event to avoid +creating a new database when other information other than a database name is +exchanged in the relation databag. + +### Kafka + +This library is the interface to use and interact with the Kafka charm. This library contains +custom events that add convenience to manage Kafka, and provides methods to consume the +application related data. + +#### Requirer Charm + +```python + +from charms.data_platform_libs.v0.data_interfaces import ( + BootstrapServerChangedEvent, + KafkaRequires, + TopicCreatedEvent, +) + +class ApplicationCharm(CharmBase): + + def __init__(self, *args): + super().__init__(*args) + self.kafka = KafkaRequires(self, "kafka_client", "test-topic") + self.framework.observe( + self.kafka.on.bootstrap_server_changed, self._on_kafka_bootstrap_server_changed + ) + self.framework.observe( + self.kafka.on.topic_created, self._on_kafka_topic_created + ) + + def _on_kafka_bootstrap_server_changed(self, event: BootstrapServerChangedEvent): + # Event triggered when a bootstrap server was changed for this application + + new_bootstrap_server = event.bootstrap_server + ... + + def _on_kafka_topic_created(self, event: TopicCreatedEvent): + # Event triggered when a topic was created for this application + username = event.username + password = event.password + tls = event.tls + tls_ca= event.tls_ca + bootstrap_server event.bootstrap_server + consumer_group_prefic = event.consumer_group_prefix + zookeeper_uris = event.zookeeper_uris + ... + +``` + +As shown above, the library provides some custom events to handle specific situations, +which are listed below: + +- topic_created: event emitted when the requested topic is created. +- bootstrap_server_changed: event emitted when the bootstrap server have changed. +- credential_changed: event emitted when the credentials of Kafka changed. + +### Provider Charm + +Following the previous example, this is an example of the provider charm. + +```python +class SampleCharm(CharmBase): + +from charms.data_platform_libs.v0.data_interfaces import ( + KafkaProvides, + TopicRequestedEvent, +) + + def __init__(self, *args): + super().__init__(*args) + + # Default charm events. + self.framework.observe(self.on.start, self._on_start) + + # Charm events defined in the Kafka Provides charm library. + self.kafka_provider = KafkaProvides(self, relation_name="kafka_client") + self.framework.observe(self.kafka_provider.on.topic_requested, self._on_topic_requested) + # Kafka generic helper + self.kafka = KafkaHelper() + + def _on_topic_requested(self, event: TopicRequestedEvent): + # Handle the on_topic_requested event. + + topic = event.topic + relation_id = event.relation.id + # set connection info in the databag relation + self.kafka_provider.set_bootstrap_server(relation_id, self.kafka.get_bootstrap_server()) + self.kafka_provider.set_credentials(relation_id, username=username, password=password) + self.kafka_provider.set_consumer_group_prefix(relation_id, ...) + self.kafka_provider.set_tls(relation_id, "False") + self.kafka_provider.set_zookeeper_uris(relation_id, ...) + +``` +As shown above, the library provides a custom event (topic_requested) to handle +the situation when an application charm requests a new topic to be created. +It is preferred to subscribe to this event instead of relation changed event to avoid +creating a new topic when other information other than a topic name is +exchanged in the relation databag. +""" + +import json +import logging +from abc import ABC, abstractmethod +from collections import namedtuple +from datetime import datetime +from typing import List, Optional + +from ops.charm import ( + CharmBase, + CharmEvents, + RelationChangedEvent, + RelationEvent, + RelationJoinedEvent, +) +from ops.framework import EventSource, Object +from ops.model import Relation + +# The unique Charmhub library identifier, never change it +LIBID = "6c3e6b6680d64e9c89e611d1a15f65be" + +# Increment this major API version when introducing breaking changes +LIBAPI = 0 + +# Increment this PATCH version before using `charmcraft publish-lib` or reset +# to 0 if you are raising the major API version +LIBPATCH = 12 + +PYDEPS = ["ops>=2.0.0"] + +logger = logging.getLogger(__name__) + +Diff = namedtuple("Diff", "added changed deleted") +Diff.__doc__ = """ +A tuple for storing the diff between two data mappings. + +added - keys that were added +changed - keys that still exist but have new values +deleted - key that were deleted""" + + +def diff(event: RelationChangedEvent, bucket: str) -> Diff: + """Retrieves the diff of the data in the relation changed databag. + + Args: + event: relation changed event. + bucket: bucket of the databag (app or unit) + + Returns: + a Diff instance containing the added, deleted and changed + keys from the event relation databag. + """ + # Retrieve the old data from the data key in the application relation databag. + old_data = json.loads(event.relation.data[bucket].get("data", "{}")) + # Retrieve the new data from the event relation databag. + new_data = { + key: value for key, value in event.relation.data[event.app].items() if key != "data" + } + + # These are the keys that were added to the databag and triggered this event. + added = new_data.keys() - old_data.keys() + # These are the keys that were removed from the databag and triggered this event. + deleted = old_data.keys() - new_data.keys() + # These are the keys that already existed in the databag, + # but had their values changed. + changed = {key for key in old_data.keys() & new_data.keys() if old_data[key] != new_data[key]} + # Convert the new_data to a serializable format and save it for a next diff check. + event.relation.data[bucket].update({"data": json.dumps(new_data)}) + + # Return the diff with all possible changes. + return Diff(added, changed, deleted) + + +# Base DataProvides and DataRequires + + +class DataProvides(Object, ABC): + """Base provides-side of the data products relation.""" + + def __init__(self, charm: CharmBase, relation_name: str) -> None: + super().__init__(charm, relation_name) + self.charm = charm + self.local_app = self.charm.model.app + self.local_unit = self.charm.unit + self.relation_name = relation_name + self.framework.observe( + charm.on[relation_name].relation_changed, + self._on_relation_changed, + ) + + def _diff(self, event: RelationChangedEvent) -> Diff: + """Retrieves the diff of the data in the relation changed databag. + + Args: + event: relation changed event. + + Returns: + a Diff instance containing the added, deleted and changed + keys from the event relation databag. + """ + return diff(event, self.local_app) + + @abstractmethod + def _on_relation_changed(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation data has changed.""" + raise NotImplementedError + + def fetch_relation_data(self) -> dict: + """Retrieves data from relation. + + This function can be used to retrieve data from a relation + in the charm code when outside an event callback. + + Returns: + a dict of the values stored in the relation data bag + for all relation instances (indexed by the relation id). + """ + data = {} + for relation in self.relations: + data[relation.id] = { + key: value for key, value in relation.data[relation.app].items() if key != "data" + } + return data + + def _update_relation_data(self, relation_id: int, data: dict) -> None: + """Updates a set of key-value pairs in the relation. + + This function writes in the application data bag, therefore, + only the leader unit can call it. + + Args: + relation_id: the identifier for a particular relation. + data: dict containing the key-value pairs + that should be updated in the relation. + """ + if self.local_unit.is_leader(): + relation = self.charm.model.get_relation(self.relation_name, relation_id) + relation.data[self.local_app].update(data) + + @property + def relations(self) -> List[Relation]: + """The list of Relation instances associated with this relation_name.""" + return list(self.charm.model.relations[self.relation_name]) + + def set_credentials(self, relation_id: int, username: str, password: str) -> None: + """Set credentials. + + This function writes in the application data bag, therefore, + only the leader unit can call it. + + Args: + relation_id: the identifier for a particular relation. + username: user that was created. + password: password of the created user. + """ + self._update_relation_data( + relation_id, + { + "username": username, + "password": password, + }, + ) + + def set_tls(self, relation_id: int, tls: str) -> None: + """Set whether TLS is enabled. + + Args: + relation_id: the identifier for a particular relation. + tls: whether tls is enabled (True or False). + """ + self._update_relation_data(relation_id, {"tls": tls}) + + def set_tls_ca(self, relation_id: int, tls_ca: str) -> None: + """Set the TLS CA in the application relation databag. + + Args: + relation_id: the identifier for a particular relation. + tls_ca: TLS certification authority. + """ + self._update_relation_data(relation_id, {"tls-ca": tls_ca}) + + +class DataRequires(Object, ABC): + """Requires-side of the relation.""" + + def __init__( + self, + charm, + relation_name: str, + extra_user_roles: str = None, + ): + """Manager of base client relations.""" + super().__init__(charm, relation_name) + self.charm = charm + self.extra_user_roles = extra_user_roles + self.local_app = self.charm.model.app + self.local_unit = self.charm.unit + self.relation_name = relation_name + self.framework.observe( + self.charm.on[relation_name].relation_joined, self._on_relation_joined_event + ) + self.framework.observe( + self.charm.on[relation_name].relation_changed, self._on_relation_changed_event + ) + + @abstractmethod + def _on_relation_joined_event(self, event: RelationJoinedEvent) -> None: + """Event emitted when the application joins the relation.""" + raise NotImplementedError + + @abstractmethod + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + raise NotImplementedError + + def fetch_relation_data(self) -> dict: + """Retrieves data from relation. + + This function can be used to retrieve data from a relation + in the charm code when outside an event callback. + Function cannot be used in `*-relation-broken` events and will raise an exception. + + Returns: + a dict of the values stored in the relation data bag + for all relation instances (indexed by the relation ID). + """ + data = {} + for relation in self.relations: + data[relation.id] = { + key: value for key, value in relation.data[relation.app].items() if key != "data" + } + return data + + def _update_relation_data(self, relation_id: int, data: dict) -> None: + """Updates a set of key-value pairs in the relation. + + This function writes in the application data bag, therefore, + only the leader unit can call it. + + Args: + relation_id: the identifier for a particular relation. + data: dict containing the key-value pairs + that should be updated in the relation. + """ + if self.local_unit.is_leader(): + relation = self.charm.model.get_relation(self.relation_name, relation_id) + relation.data[self.local_app].update(data) + + def _diff(self, event: RelationChangedEvent) -> Diff: + """Retrieves the diff of the data in the relation changed databag. + + Args: + event: relation changed event. + + Returns: + a Diff instance containing the added, deleted and changed + keys from the event relation databag. + """ + return diff(event, self.local_unit) + + @property + def relations(self) -> List[Relation]: + """The list of Relation instances associated with this relation_name.""" + return [ + relation + for relation in self.charm.model.relations[self.relation_name] + if self._is_relation_active(relation) + ] + + @staticmethod + def _is_relation_active(relation: Relation): + try: + _ = repr(relation.data) + return True + except RuntimeError: + return False + + @staticmethod + def _is_resource_created_for_relation(relation: Relation): + return ( + "username" in relation.data[relation.app] and "password" in relation.data[relation.app] + ) + + def is_resource_created(self, relation_id: Optional[int] = None) -> bool: + """Check if the resource has been created. + + This function can be used to check if the Provider answered with data in the charm code + when outside an event callback. + + Args: + relation_id (int, optional): When provided the check is done only for the relation id + provided, otherwise the check is done for all relations + + Returns: + True or False + + Raises: + IndexError: If relation_id is provided but that relation does not exist + """ + if relation_id is not None: + try: + relation = [relation for relation in self.relations if relation.id == relation_id][ + 0 + ] + return self._is_resource_created_for_relation(relation) + except IndexError: + raise IndexError(f"relation id {relation_id} cannot be accessed") + else: + return ( + all( + [ + self._is_resource_created_for_relation(relation) + for relation in self.relations + ] + ) + if self.relations + else False + ) + + +# General events + + +class ExtraRoleEvent(RelationEvent): + """Base class for data events.""" + + @property + def extra_user_roles(self) -> Optional[str]: + """Returns the extra user roles that were requested.""" + return self.relation.data[self.relation.app].get("extra-user-roles") + + +class AuthenticationEvent(RelationEvent): + """Base class for authentication fields for events.""" + + @property + def username(self) -> Optional[str]: + """Returns the created username.""" + return self.relation.data[self.relation.app].get("username") + + @property + def password(self) -> Optional[str]: + """Returns the password for the created user.""" + return self.relation.data[self.relation.app].get("password") + + @property + def tls(self) -> Optional[str]: + """Returns whether TLS is configured.""" + return self.relation.data[self.relation.app].get("tls") + + @property + def tls_ca(self) -> Optional[str]: + """Returns TLS CA.""" + return self.relation.data[self.relation.app].get("tls-ca") + + +# Database related events and fields + + +class DatabaseProvidesEvent(RelationEvent): + """Base class for database events.""" + + @property + def database(self) -> Optional[str]: + """Returns the database that was requested.""" + return self.relation.data[self.relation.app].get("database") + + +class DatabaseRequestedEvent(DatabaseProvidesEvent, ExtraRoleEvent): + """Event emitted when a new database is requested for use on this relation.""" + + +class DatabaseProvidesEvents(CharmEvents): + """Database events. + + This class defines the events that the database can emit. + """ + + database_requested = EventSource(DatabaseRequestedEvent) + + +class DatabaseRequiresEvent(RelationEvent): + """Base class for database events.""" + + @property + def database(self) -> Optional[str]: + """Returns the database name.""" + return self.relation.data[self.relation.app].get("database") + + @property + def endpoints(self) -> Optional[str]: + """Returns a comma separated list of read/write endpoints. + + In VM charms, this is the primary's address. + In kubernetes charms, this is the service to the primary pod. + """ + return self.relation.data[self.relation.app].get("endpoints") + + @property + def read_only_endpoints(self) -> Optional[str]: + """Returns a comma separated list of read only endpoints. + + In VM charms, this is the address of all the secondary instances. + In kubernetes charms, this is the service to all replica pod instances. + """ + return self.relation.data[self.relation.app].get("read-only-endpoints") + + @property + def replset(self) -> Optional[str]: + """Returns the replicaset name. + + MongoDB only. + """ + return self.relation.data[self.relation.app].get("replset") + + @property + def uris(self) -> Optional[str]: + """Returns the connection URIs. + + MongoDB, Redis, OpenSearch. + """ + return self.relation.data[self.relation.app].get("uris") + + @property + def version(self) -> Optional[str]: + """Returns the version of the database. + + Version as informed by the database daemon. + """ + return self.relation.data[self.relation.app].get("version") + + +class DatabaseCreatedEvent(AuthenticationEvent, DatabaseRequiresEvent): + """Event emitted when a new database is created for use on this relation.""" + + +class DatabaseEndpointsChangedEvent(AuthenticationEvent, DatabaseRequiresEvent): + """Event emitted when the read/write endpoints are changed.""" + + +class DatabaseReadOnlyEndpointsChangedEvent(AuthenticationEvent, DatabaseRequiresEvent): + """Event emitted when the read only endpoints are changed.""" + + +class DatabaseRequiresEvents(CharmEvents): + """Database events. + + This class defines the events that the database can emit. + """ + + database_created = EventSource(DatabaseCreatedEvent) + endpoints_changed = EventSource(DatabaseEndpointsChangedEvent) + read_only_endpoints_changed = EventSource(DatabaseReadOnlyEndpointsChangedEvent) + + +# Database Provider and Requires + + +class DatabaseProvides(DataProvides): + """Provider-side of the database relations.""" + + on = DatabaseProvidesEvents() + + def __init__(self, charm: CharmBase, relation_name: str) -> None: + super().__init__(charm, relation_name) + + def _on_relation_changed(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation has changed.""" + # Only the leader should handle this event. + if not self.local_unit.is_leader(): + return + + # Check which data has changed to emit customs events. + diff = self._diff(event) + + # Emit a database requested event if the setup key (database name and optional + # extra user roles) was added to the relation databag by the application. + if "database" in diff.added: + self.on.database_requested.emit(event.relation, app=event.app, unit=event.unit) + + def set_database(self, relation_id: int, database_name: str) -> None: + """Set database name. + + This function writes in the application data bag, therefore, + only the leader unit can call it. + + Args: + relation_id: the identifier for a particular relation. + database_name: database name. + """ + self._update_relation_data(relation_id, {"database": database_name}) + + def set_endpoints(self, relation_id: int, connection_strings: str) -> None: + """Set database primary connections. + + This function writes in the application data bag, therefore, + only the leader unit can call it. + + In VM charms, only the primary's address should be passed as an endpoint. + In kubernetes charms, the service endpoint to the primary pod should be + passed as an endpoint. + + Args: + relation_id: the identifier for a particular relation. + connection_strings: database hosts and ports comma separated list. + """ + self._update_relation_data(relation_id, {"endpoints": connection_strings}) + + def set_read_only_endpoints(self, relation_id: int, connection_strings: str) -> None: + """Set database replicas connection strings. + + This function writes in the application data bag, therefore, + only the leader unit can call it. + + Args: + relation_id: the identifier for a particular relation. + connection_strings: database hosts and ports comma separated list. + """ + self._update_relation_data(relation_id, {"read-only-endpoints": connection_strings}) + + def set_replset(self, relation_id: int, replset: str) -> None: + """Set replica set name in the application relation databag. + + MongoDB only. + + Args: + relation_id: the identifier for a particular relation. + replset: replica set name. + """ + self._update_relation_data(relation_id, {"replset": replset}) + + def set_uris(self, relation_id: int, uris: str) -> None: + """Set the database connection URIs in the application relation databag. + + MongoDB, Redis, and OpenSearch only. + + Args: + relation_id: the identifier for a particular relation. + uris: connection URIs. + """ + self._update_relation_data(relation_id, {"uris": uris}) + + def set_version(self, relation_id: int, version: str) -> None: + """Set the database version in the application relation databag. + + Args: + relation_id: the identifier for a particular relation. + version: database version. + """ + self._update_relation_data(relation_id, {"version": version}) + + +class DatabaseRequires(DataRequires): + """Requires-side of the database relation.""" + + on = DatabaseRequiresEvents() + + def __init__( + self, + charm, + relation_name: str, + database_name: str, + extra_user_roles: str = None, + relations_aliases: List[str] = None, + ): + """Manager of database client relations.""" + super().__init__(charm, relation_name, extra_user_roles) + self.database = database_name + self.relations_aliases = relations_aliases + + # Define custom event names for each alias. + if relations_aliases: + # Ensure the number of aliases does not exceed the maximum + # of connections allowed in the specific relation. + relation_connection_limit = self.charm.meta.requires[relation_name].limit + if len(relations_aliases) != relation_connection_limit: + raise ValueError( + f"The number of aliases must match the maximum number of connections allowed in the relation. " + f"Expected {relation_connection_limit}, got {len(relations_aliases)}" + ) + + for relation_alias in relations_aliases: + self.on.define_event(f"{relation_alias}_database_created", DatabaseCreatedEvent) + self.on.define_event( + f"{relation_alias}_endpoints_changed", DatabaseEndpointsChangedEvent + ) + self.on.define_event( + f"{relation_alias}_read_only_endpoints_changed", + DatabaseReadOnlyEndpointsChangedEvent, + ) + + def _assign_relation_alias(self, relation_id: int) -> None: + """Assigns an alias to a relation. + + This function writes in the unit data bag. + + Args: + relation_id: the identifier for a particular relation. + """ + # If no aliases were provided, return immediately. + if not self.relations_aliases: + return + + # Return if an alias was already assigned to this relation + # (like when there are more than one unit joining the relation). + if ( + self.charm.model.get_relation(self.relation_name, relation_id) + .data[self.local_unit] + .get("alias") + ): + return + + # Retrieve the available aliases (the ones that weren't assigned to any relation). + available_aliases = self.relations_aliases[:] + for relation in self.charm.model.relations[self.relation_name]: + alias = relation.data[self.local_unit].get("alias") + if alias: + logger.debug("Alias %s was already assigned to relation %d", alias, relation.id) + available_aliases.remove(alias) + + # Set the alias in the unit relation databag of the specific relation. + relation = self.charm.model.get_relation(self.relation_name, relation_id) + relation.data[self.local_unit].update({"alias": available_aliases[0]}) + + def _emit_aliased_event(self, event: RelationChangedEvent, event_name: str) -> None: + """Emit an aliased event to a particular relation if it has an alias. + + Args: + event: the relation changed event that was received. + event_name: the name of the event to emit. + """ + alias = self._get_relation_alias(event.relation.id) + if alias: + getattr(self.on, f"{alias}_{event_name}").emit( + event.relation, app=event.app, unit=event.unit + ) + + def _get_relation_alias(self, relation_id: int) -> Optional[str]: + """Returns the relation alias. + + Args: + relation_id: the identifier for a particular relation. + + Returns: + the relation alias or None if the relation was not found. + """ + for relation in self.charm.model.relations[self.relation_name]: + if relation.id == relation_id: + return relation.data[self.local_unit].get("alias") + return None + + def is_postgresql_plugin_enabled(self, plugin: str, relation_index: int = 0) -> bool: + """Returns whether a plugin is enabled in the database. + + Args: + plugin: name of the plugin to check. + relation_index: optional relation index to check the database + (default: 0 - first relation). + + PostgreSQL only. + """ + # Psycopg 3 is imported locally to avoid the need of its package installation + # when relating to a database charm other than PostgreSQL. + import psycopg + + # Return False if no relation is established. + if len(self.relations) == 0: + return False + + relation_data = self.fetch_relation_data()[self.relations[relation_index].id] + host = relation_data.get("endpoints") + + # Return False if there is no endpoint available. + if host is None: + return False + + host = host.split(":")[0] + user = relation_data.get("username") + password = relation_data.get("password") + connection_string = ( + f"host='{host}' dbname='{self.database}' user='{user}' password='{password}'" + ) + try: + with psycopg.connect(connection_string) as connection: + with connection.cursor() as cursor: + cursor.execute(f"SELECT TRUE FROM pg_extension WHERE extname='{plugin}';") + return cursor.fetchone() is not None + except psycopg.Error as e: + logger.exception( + f"failed to check whether {plugin} plugin is enabled in the database: %s", str(e) + ) + return False + + def _on_relation_joined_event(self, event: RelationJoinedEvent) -> None: + """Event emitted when the application joins the database relation.""" + # If relations aliases were provided, assign one to the relation. + self._assign_relation_alias(event.relation.id) + + # Sets both database and extra user roles in the relation + # if the roles are provided. Otherwise, sets only the database. + if self.extra_user_roles: + self._update_relation_data( + event.relation.id, + { + "database": self.database, + "extra-user-roles": self.extra_user_roles, + }, + ) + else: + self._update_relation_data(event.relation.id, {"database": self.database}) + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the database relation has changed.""" + # Check which data has changed to emit customs events. + diff = self._diff(event) + + # Check if the database is created + # (the database charm shared the credentials). + if "username" in diff.added and "password" in diff.added: + # Emit the default event (the one without an alias). + logger.info("database created at %s", datetime.now()) + self.on.database_created.emit(event.relation, app=event.app, unit=event.unit) + + # Emit the aliased event (if any). + self._emit_aliased_event(event, "database_created") + + # To avoid unnecessary application restarts do not trigger + # “endpoints_changed“ event if “database_created“ is triggered. + return + + # Emit an endpoints changed event if the database + # added or changed this info in the relation databag. + if "endpoints" in diff.added or "endpoints" in diff.changed: + # Emit the default event (the one without an alias). + logger.info("endpoints changed on %s", datetime.now()) + self.on.endpoints_changed.emit(event.relation, app=event.app, unit=event.unit) + + # Emit the aliased event (if any). + self._emit_aliased_event(event, "endpoints_changed") + + # To avoid unnecessary application restarts do not trigger + # “read_only_endpoints_changed“ event if “endpoints_changed“ is triggered. + return + + # Emit a read only endpoints changed event if the database + # added or changed this info in the relation databag. + if "read-only-endpoints" in diff.added or "read-only-endpoints" in diff.changed: + # Emit the default event (the one without an alias). + logger.info("read-only-endpoints changed on %s", datetime.now()) + self.on.read_only_endpoints_changed.emit( + event.relation, app=event.app, unit=event.unit + ) + + # Emit the aliased event (if any). + self._emit_aliased_event(event, "read_only_endpoints_changed") + + +# Kafka related events + + +class KafkaProvidesEvent(RelationEvent): + """Base class for Kafka events.""" + + @property + def topic(self) -> Optional[str]: + """Returns the topic that was requested.""" + return self.relation.data[self.relation.app].get("topic") + + @property + def consumer_group_prefix(self) -> Optional[str]: + """Returns the consumer-group-prefix that was requested.""" + return self.relation.data[self.relation.app].get("consumer-group-prefix") + + +class TopicRequestedEvent(KafkaProvidesEvent, ExtraRoleEvent): + """Event emitted when a new topic is requested for use on this relation.""" + + +class KafkaProvidesEvents(CharmEvents): + """Kafka events. + + This class defines the events that the Kafka can emit. + """ + + topic_requested = EventSource(TopicRequestedEvent) + + +class KafkaRequiresEvent(RelationEvent): + """Base class for Kafka events.""" + + @property + def topic(self) -> Optional[str]: + """Returns the topic.""" + return self.relation.data[self.relation.app].get("topic") + + @property + def bootstrap_server(self) -> Optional[str]: + """Returns a comma-separated list of broker uris.""" + return self.relation.data[self.relation.app].get("endpoints") + + @property + def consumer_group_prefix(self) -> Optional[str]: + """Returns the consumer-group-prefix.""" + return self.relation.data[self.relation.app].get("consumer-group-prefix") + + @property + def zookeeper_uris(self) -> Optional[str]: + """Returns a comma separated list of Zookeeper uris.""" + return self.relation.data[self.relation.app].get("zookeeper-uris") + + +class TopicCreatedEvent(AuthenticationEvent, KafkaRequiresEvent): + """Event emitted when a new topic is created for use on this relation.""" + + +class BootstrapServerChangedEvent(AuthenticationEvent, KafkaRequiresEvent): + """Event emitted when the bootstrap server is changed.""" + + +class KafkaRequiresEvents(CharmEvents): + """Kafka events. + + This class defines the events that the Kafka can emit. + """ + + topic_created = EventSource(TopicCreatedEvent) + bootstrap_server_changed = EventSource(BootstrapServerChangedEvent) + + +# Kafka Provides and Requires + + +class KafkaProvides(DataProvides): + """Provider-side of the Kafka relation.""" + + on = KafkaProvidesEvents() + + def __init__(self, charm: CharmBase, relation_name: str) -> None: + super().__init__(charm, relation_name) + + def _on_relation_changed(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation has changed.""" + # Only the leader should handle this event. + if not self.local_unit.is_leader(): + return + + # Check which data has changed to emit customs events. + diff = self._diff(event) + + # Emit a topic requested event if the setup key (topic name and optional + # extra user roles) was added to the relation databag by the application. + if "topic" in diff.added: + self.on.topic_requested.emit(event.relation, app=event.app, unit=event.unit) + + def set_topic(self, relation_id: int, topic: str) -> None: + """Set topic name in the application relation databag. + + Args: + relation_id: the identifier for a particular relation. + topic: the topic name. + """ + self._update_relation_data(relation_id, {"topic": topic}) + + def set_bootstrap_server(self, relation_id: int, bootstrap_server: str) -> None: + """Set the bootstrap server in the application relation databag. + + Args: + relation_id: the identifier for a particular relation. + bootstrap_server: the bootstrap server address. + """ + self._update_relation_data(relation_id, {"endpoints": bootstrap_server}) + + def set_consumer_group_prefix(self, relation_id: int, consumer_group_prefix: str) -> None: + """Set the consumer group prefix in the application relation databag. + + Args: + relation_id: the identifier for a particular relation. + consumer_group_prefix: the consumer group prefix string. + """ + self._update_relation_data(relation_id, {"consumer-group-prefix": consumer_group_prefix}) + + def set_zookeeper_uris(self, relation_id: int, zookeeper_uris: str) -> None: + """Set the zookeeper uris in the application relation databag. + + Args: + relation_id: the identifier for a particular relation. + zookeeper_uris: comma-separated list of ZooKeeper server uris. + """ + self._update_relation_data(relation_id, {"zookeeper-uris": zookeeper_uris}) + + +class KafkaRequires(DataRequires): + """Requires-side of the Kafka relation.""" + + on = KafkaRequiresEvents() + + def __init__( + self, + charm, + relation_name: str, + topic: str, + extra_user_roles: Optional[str] = None, + consumer_group_prefix: Optional[str] = None, + ): + """Manager of Kafka client relations.""" + # super().__init__(charm, relation_name) + super().__init__(charm, relation_name, extra_user_roles) + self.charm = charm + self.topic = topic + self.consumer_group_prefix = consumer_group_prefix or "" + + def _on_relation_joined_event(self, event: RelationJoinedEvent) -> None: + """Event emitted when the application joins the Kafka relation.""" + # Sets topic, extra user roles, and "consumer-group-prefix" in the relation + relation_data = { + f: getattr(self, f.replace("-", "_"), "") + for f in ["consumer-group-prefix", "extra-user-roles", "topic"] + } + + self._update_relation_data(event.relation.id, relation_data) + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the Kafka relation has changed.""" + # Check which data has changed to emit customs events. + diff = self._diff(event) + + # Check if the topic is created + # (the Kafka charm shared the credentials). + if "username" in diff.added and "password" in diff.added: + # Emit the default event (the one without an alias). + logger.info("topic created at %s", datetime.now()) + self.on.topic_created.emit(event.relation, app=event.app, unit=event.unit) + + # To avoid unnecessary application restarts do not trigger + # “endpoints_changed“ event if “topic_created“ is triggered. + return + + # Emit an endpoints (bootstrap-server) changed event if the Kafka endpoints + # added or changed this info in the relation databag. + if "endpoints" in diff.added or "endpoints" in diff.changed: + # Emit the default event (the one without an alias). + logger.info("endpoints changed on %s", datetime.now()) + self.on.bootstrap_server_changed.emit( + event.relation, app=event.app, unit=event.unit + ) # here check if this is the right design + return + + +# Opensearch related events + + +class OpenSearchProvidesEvent(RelationEvent): + """Base class for OpenSearch events.""" + + @property + def index(self) -> Optional[str]: + """Returns the index that was requested.""" + return self.relation.data[self.relation.app].get("index") + + +class IndexRequestedEvent(OpenSearchProvidesEvent, ExtraRoleEvent): + """Event emitted when a new index is requested for use on this relation.""" + + +class OpenSearchProvidesEvents(CharmEvents): + """OpenSearch events. + + This class defines the events that OpenSearch can emit. + """ + + index_requested = EventSource(IndexRequestedEvent) + + +class OpenSearchRequiresEvent(DatabaseRequiresEvent): + """Base class for OpenSearch requirer events.""" + + +class IndexCreatedEvent(AuthenticationEvent, OpenSearchRequiresEvent): + """Event emitted when a new index is created for use on this relation.""" + + +class OpenSearchRequiresEvents(CharmEvents): + """OpenSearch events. + + This class defines the events that the opensearch requirer can emit. + """ + + index_created = EventSource(IndexCreatedEvent) + endpoints_changed = EventSource(DatabaseEndpointsChangedEvent) + authentication_updated = EventSource(AuthenticationEvent) + + +# OpenSearch Provides and Requires Objects + + +class OpenSearchProvides(DataProvides): + """Provider-side of the OpenSearch relation.""" + + on = OpenSearchProvidesEvents() + + def __init__(self, charm: CharmBase, relation_name: str) -> None: + super().__init__(charm, relation_name) + + def _on_relation_changed(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation has changed.""" + # Only the leader should handle this event. + if not self.local_unit.is_leader(): + return + + # Check which data has changed to emit customs events. + diff = self._diff(event) + + # Emit an index requested event if the setup key (index name and optional extra user roles) + # have been added to the relation databag by the application. + if "index" in diff.added: + self.on.index_requested.emit(event.relation, app=event.app, unit=event.unit) + + def set_index(self, relation_id: int, index: str) -> None: + """Set the index in the application relation databag. + + Args: + relation_id: the identifier for a particular relation. + index: the index as it is _created_ on the provider charm. This needn't match the + requested index, and can be used to present a different index name if, for example, + the requested index is invalid. + """ + self._update_relation_data(relation_id, {"index": index}) + + def set_endpoints(self, relation_id: int, endpoints: str) -> None: + """Set the endpoints in the application relation databag. + + Args: + relation_id: the identifier for a particular relation. + endpoints: the endpoint addresses for opensearch nodes. + """ + self._update_relation_data(relation_id, {"endpoints": endpoints}) + + def set_version(self, relation_id: int, version: str) -> None: + """Set the opensearch version in the application relation databag. + + Args: + relation_id: the identifier for a particular relation. + version: database version. + """ + self._update_relation_data(relation_id, {"version": version}) + + +class OpenSearchRequires(DataRequires): + """Requires-side of the OpenSearch relation.""" + + on = OpenSearchRequiresEvents() + + def __init__( + self, charm, relation_name: str, index: str, extra_user_roles: Optional[str] = None + ): + """Manager of OpenSearch client relations.""" + super().__init__(charm, relation_name, extra_user_roles) + self.charm = charm + self.index = index + + def _on_relation_joined_event(self, event: RelationJoinedEvent) -> None: + """Event emitted when the application joins the OpenSearch relation.""" + # Sets both index and extra user roles in the relation if the roles are provided. + # Otherwise, sets only the index. + data = {"index": self.index} + if self.extra_user_roles: + data["extra-user-roles"] = self.extra_user_roles + + self._update_relation_data(event.relation.id, data) + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the OpenSearch relation has changed. + + This event triggers individual custom events depending on the changing relation. + """ + # Check which data has changed to emit customs events. + diff = self._diff(event) + + # Check if authentication has updated, emit event if so + updates = {"username", "password", "tls", "tls-ca"} + if len(set(diff._asdict().keys()) - updates) < len(diff): + logger.info("authentication updated at: %s", datetime.now()) + self.on.authentication_updated.emit(event.relation, app=event.app, unit=event.unit) + + # Check if the index is created + # (the OpenSearch charm shares the credentials). + if "username" in diff.added and "password" in diff.added: + # Emit the default event (the one without an alias). + logger.info("index created at: %s", datetime.now()) + self.on.index_created.emit(event.relation, app=event.app, unit=event.unit) + + # To avoid unnecessary application restarts do not trigger + # “endpoints_changed“ event if “index_created“ is triggered. + return + + # Emit a endpoints changed event if the OpenSearch application added or changed this info + # in the relation databag. + if "endpoints" in diff.added or "endpoints" in diff.changed: + # Emit the default event (the one without an alias). + logger.info("endpoints changed on %s", datetime.now()) + self.on.endpoints_changed.emit( + event.relation, app=event.app, unit=event.unit + ) # here check if this is the right design + return diff --git a/lib/charms/data_platform_libs/v0/database_provides.py b/lib/charms/data_platform_libs/v0/database_provides.py deleted file mode 100644 index 8135da9d..00000000 --- a/lib/charms/data_platform_libs/v0/database_provides.py +++ /dev/null @@ -1,316 +0,0 @@ -# Copyright 2022 Canonical Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Relation provider side abstraction for database relation. - -This library is a uniform interface to a selection of common database -metadata, with added custom events that add convenience to database management, -and methods to set the application related data. - -It can be used as the main library in a database charm to handle relations with -application charms or be extended/used as a template when creating a more complete library -(like one that also handles the database and user creation using database specific APIs). - -Following an example of using the DatabaseRequestedEvent, in the context of the -database charm code: - -```python -from charms.data_platform_libs.v0.database_provides import DatabaseProvides - -class SampleCharm(CharmBase): - - def __init__(self, *args): - super().__init__(*args) - - # Charm events defined in the database provides charm library. - self.provided_database = DatabaseProvides(self, relation_name="database") - self.framework.observe(self.provided_database.on.database_requested, - self._on_database_requested) - - # Database generic helper - self.database = DatabaseHelper() - - def _on_database_requested(self, event: DatabaseRequestedEvent) -> None: - # Handle the event triggered by a new database requested in the relation - - # Retrieve the database name using the charm library. - db_name = event.database - - # generate a new user credential - username = self.database.generate_user() - password = self.database.generate_password() - - # set the credentials for the relation - self.provided_database.set_credentials(event.relation.id, username, password) - - # set other variables for the relation event.set_tls("False") -``` - -As shown above, the library provides a custom event (database_requested) to handle -the situation when an application charm requests a new database to be created. -It's preferred to subscribe to this event instead of relation changed event to avoid -creating a new database when other information other than a database name is -exchanged in the relation databag. -""" -import json -import logging -from collections import namedtuple -from typing import List, Optional - -from ops.charm import CharmBase, CharmEvents, RelationChangedEvent, RelationEvent -from ops.framework import EventSource, Object -from ops.model import Relation - -# The unique Charmhub library identifier, never change it -LIBID = "8eea9ca584d84c7bb357f1946b6f34ce" - -# Increment this major API version when introducing breaking changes -LIBAPI = 0 - -# Increment this PATCH version before using `charmcraft publish-lib` or reset -# to 0 if you are raising the major API version -LIBPATCH = 2 - -logger = logging.getLogger(__name__) - - -class DatabaseEvent(RelationEvent): - """Base class for database events.""" - - @property - def database(self) -> Optional[str]: - """Returns the database that was requested.""" - return self.relation.data[self.relation.app].get("database") - - @property - def extra_user_roles(self) -> Optional[str]: - """Returns the extra user roles that were requested.""" - return self.relation.data[self.relation.app].get("extra-user-roles") - - -class DatabaseRequestedEvent(DatabaseEvent): - """Event emitted when a new database is requested for use on this relation.""" - - -class DatabaseEvents(CharmEvents): - """Database events. - - This class defines the events that the database can emit. - """ - - database_requested = EventSource(DatabaseRequestedEvent) - - -Diff = namedtuple("Diff", "added changed deleted") -Diff.__doc__ = """ -A tuple for storing the diff between two data mappings. - -added - keys that were added -changed - keys that still exist but have new values -deleted - key that were deleted""" - - -class DatabaseProvides(Object): - """Provides-side of the database relation.""" - - on = DatabaseEvents() - - def __init__(self, charm: CharmBase, relation_name: str) -> None: - super().__init__(charm, relation_name) - self.charm = charm - self.local_app = self.charm.model.app - self.local_unit = self.charm.unit - self.relation_name = relation_name - self.framework.observe( - charm.on[relation_name].relation_changed, - self._on_relation_changed, - ) - - def _diff(self, event: RelationChangedEvent) -> Diff: - """Retrieves the diff of the data in the relation changed databag. - - Args: - event: relation changed event. - - Returns: - a Diff instance containing the added, deleted and changed - keys from the event relation databag. - """ - # Retrieve the old data from the data key in the application relation databag. - old_data = json.loads(event.relation.data[self.local_app].get("data", "{}")) - # Retrieve the new data from the event relation databag. - new_data = { - key: value for key, value in event.relation.data[event.app].items() if key != "data" - } - - # These are the keys that were added to the databag and triggered this event. - added = new_data.keys() - old_data.keys() - # These are the keys that were removed from the databag and triggered this event. - deleted = old_data.keys() - new_data.keys() - # These are the keys that already existed in the databag, - # but had their values changed. - changed = { - key for key in old_data.keys() & new_data.keys() if old_data[key] != new_data[key] - } - - # TODO: evaluate the possibility of losing the diff if some error - # happens in the charm before the diff is completely checked (DPE-412). - # Convert the new_data to a serializable format and save it for a next diff check. - event.relation.data[self.local_app].update({"data": json.dumps(new_data)}) - - # Return the diff with all possible changes. - return Diff(added, changed, deleted) - - def _on_relation_changed(self, event: RelationChangedEvent) -> None: - """Event emitted when the database relation has changed.""" - # Only the leader should handle this event. - if not self.local_unit.is_leader(): - return - - # Check which data has changed to emit customs events. - diff = self._diff(event) - - # Emit a database requested event if the setup key (database name and optional - # extra user roles) was added to the relation databag by the application. - if "database" in diff.added: - self.on.database_requested.emit(event.relation, app=event.app, unit=event.unit) - - def fetch_relation_data(self) -> dict: - """Retrieves data from relation. - - This function can be used to retrieve data from a relation - in the charm code when outside an event callback. - - Returns: - a dict of the values stored in the relation data bag - for all relation instances (indexed by the relation id). - """ - data = {} - for relation in self.relations: - data[relation.id] = { - key: value for key, value in relation.data[relation.app].items() if key != "data" - } - return data - - def _update_relation_data(self, relation_id: int, data: dict) -> None: - """Updates a set of key-value pairs in the relation. - - This function writes in the application data bag, therefore, - only the leader unit can call it. - - Args: - relation_id: the identifier for a particular relation. - data: dict containing the key-value pairs - that should be updated in the relation. - """ - if self.local_unit.is_leader(): - relation = self.charm.model.get_relation(self.relation_name, relation_id) - relation.data[self.local_app].update(data) - - @property - def relations(self) -> List[Relation]: - """The list of Relation instances associated with this relation_name.""" - return list(self.charm.model.relations[self.relation_name]) - - def set_credentials(self, relation_id: int, username: str, password: str) -> None: - """Set database primary connections. - - This function writes in the application data bag, therefore, - only the leader unit can call it. - - Args: - relation_id: the identifier for a particular relation. - username: user that was created. - password: password of the created user. - """ - self._update_relation_data( - relation_id, - { - "username": username, - "password": password, - }, - ) - - def set_endpoints(self, relation_id: int, connection_strings: str) -> None: - """Set database primary connections. - - This function writes in the application data bag, therefore, - only the leader unit can call it. - - Args: - relation_id: the identifier for a particular relation. - connection_strings: database hosts and ports comma separated list. - """ - self._update_relation_data(relation_id, {"endpoints": connection_strings}) - - def set_read_only_endpoints(self, relation_id: int, connection_strings: str) -> None: - """Set database replicas connection strings. - - This function writes in the application data bag, therefore, - only the leader unit can call it. - - Args: - relation_id: the identifier for a particular relation. - connection_strings: database hosts and ports comma separated list. - """ - self._update_relation_data(relation_id, {"read-only-endpoints": connection_strings}) - - def set_replset(self, relation_id: int, replset: str) -> None: - """Set replica set name in the application relation databag. - - MongoDB only. - - Args: - relation_id: the identifier for a particular relation. - replset: replica set name. - """ - self._update_relation_data(relation_id, {"replset": replset}) - - def set_tls(self, relation_id: int, tls: str) -> None: - """Set whether TLS is enabled. - - Args: - relation_id: the identifier for a particular relation. - tls: whether tls is enabled (True or False). - """ - self._update_relation_data(relation_id, {"tls": tls}) - - def set_tls_ca(self, relation_id: int, tls_ca: str) -> None: - """Set the TLS CA in the application relation databag. - - Args: - relation_id: the identifier for a particular relation. - tls_ca: TLS certification authority. - """ - self._update_relation_data(relation_id, {"tls_ca": tls_ca}) - - def set_uris(self, relation_id: int, uris: str) -> None: - """Set the database connection URIs in the application relation databag. - - MongoDB, Redis, OpenSearch and Kafka only. - - Args: - relation_id: the identifier for a particular relation. - uris: connection URIs. - """ - self._update_relation_data(relation_id, {"uris": uris}) - - def set_version(self, relation_id: int, version: str) -> None: - """Set the database version in the application relation databag. - - Args: - relation_id: the identifier for a particular relation. - version: database version. - """ - self._update_relation_data(relation_id, {"version": version}) diff --git a/lib/charms/data_platform_libs/v0/database_requires.py b/lib/charms/data_platform_libs/v0/database_requires.py deleted file mode 100644 index 53d61912..00000000 --- a/lib/charms/data_platform_libs/v0/database_requires.py +++ /dev/null @@ -1,496 +0,0 @@ -# Copyright 2022 Canonical Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Relation 'requires' side abstraction for database relation. - -This library is a uniform interface to a selection of common database -metadata, with added custom events that add convenience to database management, -and methods to consume the application related data. - -Following an example of using the DatabaseCreatedEvent, in the context of the -application charm code: - -```python - -from charms.data_platform_libs.v0.database_requires import DatabaseRequires - -class ApplicationCharm(CharmBase): - # Application charm that connects to database charms. - - def __init__(self, *args): - super().__init__(*args) - - # Charm events defined in the database requires charm library. - self.database = DatabaseRequires(self, relation_name="database", database_name="database") - self.framework.observe(self.database.on.database_created, self._on_database_created) - - def _on_database_created(self, event: DatabaseCreatedEvent) -> None: - # Handle the created database - - # Create configuration file for app - config_file = self._render_app_config_file( - event.username, - event.password, - event.endpoints, - ) - - # Start application with rendered configuration - self._start_application(config_file) - - # Set active status - self.unit.status = ActiveStatus("received database credentials") -``` - -As shown above, the library provides some custom events to handle specific situations, -which are listed below: - -— database_created: event emitted when the requested database is created. -— endpoints_changed: event emitted when the read/write endpoints of the database have changed. -— read_only_endpoints_changed: event emitted when the read-only endpoints of the database - have changed. Event is not triggered if read/write endpoints changed too. - -If it is needed to connect multiple database clusters to the same relation endpoint -the application charm can implement the same code as if it would connect to only -one database cluster (like the above code example). - -To differentiate multiple clusters connected to the same relation endpoint -the application charm can use the name of the remote application: - -```python - -def _on_database_created(self, event: DatabaseCreatedEvent) -> None: - # Get the remote app name of the cluster that triggered this event - cluster = event.relation.app.name -``` - -It is also possible to provide an alias for each different database cluster/relation. - -So, it is possible to differentiate the clusters in two ways. -The first is to use the remote application name, i.e., `event.relation.app.name`, as above. - -The second way is to use different event handlers to handle each cluster events. -The implementation would be something like the following code: - -```python - -from charms.data_platform_libs.v0.database_requires import DatabaseRequires - -class ApplicationCharm(CharmBase): - # Application charm that connects to database charms. - - def __init__(self, *args): - super().__init__(*args) - - # Define the cluster aliases and one handler for each cluster database created event. - self.database = DatabaseRequires( - self, - relation_name="database", - database_name="database", - relations_aliases = ["cluster1", "cluster2"], - ) - self.framework.observe( - self.database.on.cluster1_database_created, self._on_cluster1_database_created - ) - self.framework.observe( - self.database.on.cluster2_database_created, self._on_cluster2_database_created - ) - - def _on_cluster1_database_created(self, event: DatabaseCreatedEvent) -> None: - # Handle the created database on the cluster named cluster1 - - # Create configuration file for app - config_file = self._render_app_config_file( - event.username, - event.password, - event.endpoints, - ) - ... - - def _on_cluster2_database_created(self, event: DatabaseCreatedEvent) -> None: - # Handle the created database on the cluster named cluster2 - - # Create configuration file for app - config_file = self._render_app_config_file( - event.username, - event.password, - event.endpoints, - ) - ... - -``` -""" - -import json -import logging -from collections import namedtuple -from datetime import datetime -from typing import List, Optional - -from ops.charm import ( - CharmEvents, - RelationChangedEvent, - RelationEvent, - RelationJoinedEvent, -) -from ops.framework import EventSource, Object -from ops.model import Relation - -# The unique Charmhub library identifier, never change it -LIBID = "0241e088ffa9440fb4e3126349b2fb62" - -# Increment this major API version when introducing breaking changes -LIBAPI = 0 - -# Increment this PATCH version before using `charmcraft publish-lib` or reset -# to 0 if you are raising the major API version. -LIBPATCH = 4 - -logger = logging.getLogger(__name__) - - -class DatabaseEvent(RelationEvent): - """Base class for database events.""" - - @property - def endpoints(self) -> Optional[str]: - """Returns a comma separated list of read/write endpoints.""" - return self.relation.data[self.relation.app].get("endpoints") - - @property - def password(self) -> Optional[str]: - """Returns the password for the created user.""" - return self.relation.data[self.relation.app].get("password") - - @property - def read_only_endpoints(self) -> Optional[str]: - """Returns a comma separated list of read only endpoints.""" - return self.relation.data[self.relation.app].get("read-only-endpoints") - - @property - def replset(self) -> Optional[str]: - """Returns the replicaset name. - - MongoDB only. - """ - return self.relation.data[self.relation.app].get("replset") - - @property - def tls(self) -> Optional[str]: - """Returns whether TLS is configured.""" - return self.relation.data[self.relation.app].get("tls") - - @property - def tls_ca(self) -> Optional[str]: - """Returns TLS CA.""" - return self.relation.data[self.relation.app].get("tls-ca") - - @property - def uris(self) -> Optional[str]: - """Returns the connection URIs. - - MongoDB, Redis, OpenSearch and Kafka only. - """ - return self.relation.data[self.relation.app].get("uris") - - @property - def username(self) -> Optional[str]: - """Returns the created username.""" - return self.relation.data[self.relation.app].get("username") - - @property - def version(self) -> Optional[str]: - """Returns the version of the database. - - Version as informed by the database daemon. - """ - return self.relation.data[self.relation.app].get("version") - - -class DatabaseCreatedEvent(DatabaseEvent): - """Event emitted when a new database is created for use on this relation.""" - - -class DatabaseEndpointsChangedEvent(DatabaseEvent): - """Event emitted when the read/write endpoints are changed.""" - - -class DatabaseReadOnlyEndpointsChangedEvent(DatabaseEvent): - """Event emitted when the read only endpoints are changed.""" - - -class DatabaseEvents(CharmEvents): - """Database events. - - This class defines the events that the database can emit. - """ - - database_created = EventSource(DatabaseCreatedEvent) - endpoints_changed = EventSource(DatabaseEndpointsChangedEvent) - read_only_endpoints_changed = EventSource(DatabaseReadOnlyEndpointsChangedEvent) - - -Diff = namedtuple("Diff", "added changed deleted") -Diff.__doc__ = """ -A tuple for storing the diff between two data mappings. - -— added — keys that were added. -— changed — keys that still exist but have new values. -— deleted — keys that were deleted. -""" - - -class DatabaseRequires(Object): - """Requires-side of the database relation.""" - - on = DatabaseEvents() - - def __init__( - self, - charm, - relation_name: str, - database_name: str, - extra_user_roles: str = None, - relations_aliases: List[str] = None, - ): - """Manager of database client relations.""" - super().__init__(charm, relation_name) - self.charm = charm - self.database = database_name - self.extra_user_roles = extra_user_roles - self.local_app = self.charm.model.app - self.local_unit = self.charm.unit - self.relation_name = relation_name - self.relations_aliases = relations_aliases - self.framework.observe( - self.charm.on[relation_name].relation_joined, self._on_relation_joined_event - ) - self.framework.observe( - self.charm.on[relation_name].relation_changed, self._on_relation_changed_event - ) - - # Define custom event names for each alias. - if relations_aliases: - # Ensure the number of aliases does not exceed the maximum - # of connections allowed in the specific relation. - relation_connection_limit = self.charm.meta.requires[relation_name].limit - if len(relations_aliases) != relation_connection_limit: - raise ValueError( - f"The number of aliases must match the maximum number of connections allowed in the relation. " - f"Expected {relation_connection_limit}, got {len(relations_aliases)}" - ) - - for relation_alias in relations_aliases: - self.on.define_event(f"{relation_alias}_database_created", DatabaseCreatedEvent) - self.on.define_event( - f"{relation_alias}_endpoints_changed", DatabaseEndpointsChangedEvent - ) - self.on.define_event( - f"{relation_alias}_read_only_endpoints_changed", - DatabaseReadOnlyEndpointsChangedEvent, - ) - - def _assign_relation_alias(self, relation_id: int) -> None: - """Assigns an alias to a relation. - - This function writes in the unit data bag. - - Args: - relation_id: the identifier for a particular relation. - """ - # If no aliases were provided, return immediately. - if not self.relations_aliases: - return - - # Return if an alias was already assigned to this relation - # (like when there are more than one unit joining the relation). - if ( - self.charm.model.get_relation(self.relation_name, relation_id) - .data[self.local_unit] - .get("alias") - ): - return - - # Retrieve the available aliases (the ones that weren't assigned to any relation). - available_aliases = self.relations_aliases[:] - for relation in self.charm.model.relations[self.relation_name]: - alias = relation.data[self.local_unit].get("alias") - if alias: - logger.debug("Alias %s was already assigned to relation %d", alias, relation.id) - available_aliases.remove(alias) - - # Set the alias in the unit relation databag of the specific relation. - relation = self.charm.model.get_relation(self.relation_name, relation_id) - relation.data[self.local_unit].update({"alias": available_aliases[0]}) - - def _diff(self, event: RelationChangedEvent) -> Diff: - """Retrieves the diff of the data in the relation changed databag. - - Args: - event: relation changed event. - - Returns: - a Diff instance containing the added, deleted and changed - keys from the event relation databag. - """ - # Retrieve the old data from the data key in the local unit relation databag. - old_data = json.loads(event.relation.data[self.local_unit].get("data", "{}")) - # Retrieve the new data from the event relation databag. - new_data = { - key: value for key, value in event.relation.data[event.app].items() if key != "data" - } - - # These are the keys that were added to the databag and triggered this event. - added = new_data.keys() - old_data.keys() - # These are the keys that were removed from the databag and triggered this event. - deleted = old_data.keys() - new_data.keys() - # These are the keys that already existed in the databag, - # but had their values changed. - changed = { - key for key in old_data.keys() & new_data.keys() if old_data[key] != new_data[key] - } - - # TODO: evaluate the possibility of losing the diff if some error - # happens in the charm before the diff is completely checked (DPE-412). - # Convert the new_data to a serializable format and save it for a next diff check. - event.relation.data[self.local_unit].update({"data": json.dumps(new_data)}) - - # Return the diff with all possible changes. - return Diff(added, changed, deleted) - - def _emit_aliased_event(self, event: RelationChangedEvent, event_name: str) -> None: - """Emit an aliased event to a particular relation if it has an alias. - - Args: - event: the relation changed event that was received. - event_name: the name of the event to emit. - """ - alias = self._get_relation_alias(event.relation.id) - if alias: - getattr(self.on, f"{alias}_{event_name}").emit( - event.relation, app=event.app, unit=event.unit - ) - - def _get_relation_alias(self, relation_id: int) -> Optional[str]: - """Returns the relation alias. - - Args: - relation_id: the identifier for a particular relation. - - Returns: - the relation alias or None if the relation was not found. - """ - for relation in self.charm.model.relations[self.relation_name]: - if relation.id == relation_id: - return relation.data[self.local_unit].get("alias") - return None - - def fetch_relation_data(self) -> dict: - """Retrieves data from relation. - - This function can be used to retrieve data from a relation - in the charm code when outside an event callback. - - Returns: - a dict of the values stored in the relation data bag - for all relation instances (indexed by the relation ID). - """ - data = {} - for relation in self.relations: - data[relation.id] = { - key: value for key, value in relation.data[relation.app].items() if key != "data" - } - return data - - def _update_relation_data(self, relation_id: int, data: dict) -> None: - """Updates a set of key-value pairs in the relation. - - This function writes in the application data bag, therefore, - only the leader unit can call it. - - Args: - relation_id: the identifier for a particular relation. - data: dict containing the key-value pairs - that should be updated in the relation. - """ - if self.local_unit.is_leader(): - relation = self.charm.model.get_relation(self.relation_name, relation_id) - relation.data[self.local_app].update(data) - - def _on_relation_joined_event(self, event: RelationJoinedEvent) -> None: - """Event emitted when the application joins the database relation.""" - # If relations aliases were provided, assign one to the relation. - self._assign_relation_alias(event.relation.id) - - # Sets both database and extra user roles in the relation - # if the roles are provided. Otherwise, sets only the database. - if self.extra_user_roles: - self._update_relation_data( - event.relation.id, - { - "database": self.database, - "extra-user-roles": self.extra_user_roles, - }, - ) - else: - self._update_relation_data(event.relation.id, {"database": self.database}) - - def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: - """Event emitted when the database relation has changed.""" - # Check which data has changed to emit customs events. - diff = self._diff(event) - - # Check if the database is created - # (the database charm shared the credentials). - if "username" in diff.added and "password" in diff.added: - # Emit the default event (the one without an alias). - logger.info("database created at %s", datetime.now()) - self.on.database_created.emit(event.relation, app=event.app, unit=event.unit) - - # Emit the aliased event (if any). - self._emit_aliased_event(event, "database_created") - - # To avoid unnecessary application restarts do not trigger - # “endpoints_changed“ event if “database_created“ is triggered. - return - - # Emit an endpoints changed event if the database - # added or changed this info in the relation databag. - if "endpoints" in diff.added or "endpoints" in diff.changed: - # Emit the default event (the one without an alias). - logger.info("endpoints changed on %s", datetime.now()) - self.on.endpoints_changed.emit(event.relation, app=event.app, unit=event.unit) - - # Emit the aliased event (if any). - self._emit_aliased_event(event, "endpoints_changed") - - # To avoid unnecessary application restarts do not trigger - # “read_only_endpoints_changed“ event if “endpoints_changed“ is triggered. - return - - # Emit a read only endpoints changed event if the database - # added or changed this info in the relation databag. - if "read-only-endpoints" in diff.added or "read-only-endpoints" in diff.changed: - # Emit the default event (the one without an alias). - logger.info("read-only-endpoints changed on %s", datetime.now()) - self.on.read_only_endpoints_changed.emit( - event.relation, app=event.app, unit=event.unit - ) - - # Emit the aliased event (if any). - self._emit_aliased_event(event, "read_only_endpoints_changed") - - @property - def relations(self) -> List[Relation]: - """The list of Relation instances associated with this relation_name.""" - return list(self.charm.model.relations[self.relation_name]) diff --git a/lib/charms/operator_libs_linux/v0/apt.py b/lib/charms/operator_libs_linux/v0/apt.py deleted file mode 100644 index 42090a74..00000000 --- a/lib/charms/operator_libs_linux/v0/apt.py +++ /dev/null @@ -1,1329 +0,0 @@ -# Copyright 2022 Canonical Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Abstractions for the system's Debian/Ubuntu package information and repositories. - -This module contains abstractions and wrappers around Debian/Ubuntu-style repositories and -packages, in order to easily provide an idiomatic and Pythonic mechanism for adding packages and/or -repositories to systems for use in machine charms. - -A sane default configuration is attainable through nothing more than instantiation of the -appropriate classes. `DebianPackage` objects provide information about the architecture, version, -name, and status of a package. - -`DebianPackage` will try to look up a package either from `dpkg -L` or from `apt-cache` when -provided with a string indicating the package name. If it cannot be located, `PackageNotFoundError` -will be returned, as `apt` and `dpkg` otherwise return `100` for all errors, and a meaningful error -message if the package is not known is desirable. - -To install packages with convenience methods: - -```python -try: - # Run `apt-get update` - apt.update() - apt.add_package("zsh") - apt.add_package(["vim", "htop", "wget"]) -except PackageNotFoundError: - logger.error("a specified package not found in package cache or on system") -except PackageError as e: - logger.error("could not install package. Reason: %s", e.message) -```` - -To find details of a specific package: - -```python -try: - vim = apt.DebianPackage.from_system("vim") - - # To find from the apt cache only - # apt.DebianPackage.from_apt_cache("vim") - - # To find from installed packages only - # apt.DebianPackage.from_installed_package("vim") - - vim.ensure(PackageState.Latest) - logger.info("updated vim to version: %s", vim.fullversion) -except PackageNotFoundError: - logger.error("a specified package not found in package cache or on system") -except PackageError as e: - logger.error("could not install package. Reason: %s", e.message) -``` - - -`RepositoryMapping` will return a dict-like object containing enabled system repositories -and their properties (available groups, baseuri. gpg key). This class can add, disable, or -manipulate repositories. Items can be retrieved as `DebianRepository` objects. - -In order add a new repository with explicit details for fields, a new `DebianRepository` can -be added to `RepositoryMapping` - -`RepositoryMapping` provides an abstraction around the existing repositories on the system, -and can be accessed and iterated over like any `Mapping` object, to retrieve values by key, -iterate, or perform other operations. - -Keys are constructed as `{repo_type}-{}-{release}` in order to uniquely identify a repository. - -Repositories can be added with explicit values through a Python constructor. - -Example: - -```python -repositories = apt.RepositoryMapping() - -if "deb-example.com-focal" not in repositories: - repositories.add(DebianRepository(enabled=True, repotype="deb", - uri="https://example.com", release="focal", groups=["universe"])) -``` - -Alternatively, any valid `sources.list` line may be used to construct a new -`DebianRepository`. - -Example: - -```python -repositories = apt.RepositoryMapping() - -if "deb-us.archive.ubuntu.com-xenial" not in repositories: - line = "deb http://us.archive.ubuntu.com/ubuntu xenial main restricted" - repo = DebianRepository.from_repo_line(line) - repositories.add(repo) -``` -""" - -import fileinput -import glob -import logging -import os -import re -import subprocess -from collections.abc import Mapping -from enum import Enum -from subprocess import PIPE, CalledProcessError, check_call, check_output -from typing import Iterable, List, Optional, Tuple, Union -from urllib.parse import urlparse - -logger = logging.getLogger(__name__) - -# The unique Charmhub library identifier, never change it -LIBID = "7c3dbc9c2ad44a47bd6fcb25caa270e5" - -# Increment this major API version when introducing breaking changes -LIBAPI = 0 - -# Increment this PATCH version before using `charmcraft publish-lib` or reset -# to 0 if you are raising the major API version -LIBPATCH = 7 - - -VALID_SOURCE_TYPES = ("deb", "deb-src") -OPTIONS_MATCHER = re.compile(r"\[.*?\]") - - -class Error(Exception): - """Base class of most errors raised by this library.""" - - def __repr__(self): - """String representation of Error.""" - return "<{}.{} {}>".format(type(self).__module__, type(self).__name__, self.args) - - @property - def name(self): - """Return a string representation of the model plus class.""" - return "<{}.{}>".format(type(self).__module__, type(self).__name__) - - @property - def message(self): - """Return the message passed as an argument.""" - return self.args[0] - - -class PackageError(Error): - """Raised when there's an error installing or removing a package.""" - - -class PackageNotFoundError(Error): - """Raised when a requested package is not known to the system.""" - - -class PackageState(Enum): - """A class to represent possible package states.""" - - Present = "present" - Absent = "absent" - Latest = "latest" - Available = "available" - - -class DebianPackage: - """Represents a traditional Debian package and its utility functions. - - `DebianPackage` wraps information and functionality around a known package, whether installed - or available. The version, epoch, name, and architecture can be easily queried and compared - against other `DebianPackage` objects to determine the latest version or to install a specific - version. - - The representation of this object as a string mimics the output from `dpkg` for familiarity. - - Installation and removal of packages is handled through the `state` property or `ensure` - method, with the following options: - - apt.PackageState.Absent - apt.PackageState.Available - apt.PackageState.Present - apt.PackageState.Latest - - When `DebianPackage` is initialized, the state of a given `DebianPackage` object will be set to - `Available`, `Present`, or `Latest`, with `Absent` implemented as a convenience for removal - (though it operates essentially the same as `Available`). - """ - - def __init__( - self, name: str, version: str, epoch: str, arch: str, state: PackageState - ) -> None: - self._name = name - self._arch = arch - self._state = state - self._version = Version(version, epoch) - - def __eq__(self, other) -> bool: - """Equality for comparison. - - Args: - other: a `DebianPackage` object for comparison - - Returns: - A boolean reflecting equality - """ - return isinstance(other, self.__class__) and ( - self._name, - self._version.number, - ) == (other._name, other._version.number) - - def __hash__(self): - """A basic hash so this class can be used in Mappings and dicts.""" - return hash((self._name, self._version.number)) - - def __repr__(self): - """A representation of the package.""" - return "<{}.{}: {}>".format(self.__module__, self.__class__.__name__, self.__dict__) - - def __str__(self): - """A human-readable representation of the package.""" - return "<{}: {}-{}.{} -- {}>".format( - self.__class__.__name__, - self._name, - self._version, - self._arch, - str(self._state), - ) - - @staticmethod - def _apt( - command: str, - package_names: Union[str, List], - optargs: Optional[List[str]] = None, - ) -> None: - """Wrap package management commands for Debian/Ubuntu systems. - - Args: - command: the command given to `apt-get` - package_names: a package name or list of package names to operate on - optargs: an (Optional) list of additioanl arguments - - Raises: - PackageError if an error is encountered - """ - optargs = optargs if optargs is not None else [] - if isinstance(package_names, str): - package_names = [package_names] - _cmd = ["apt-get", "-y", *optargs, command, *package_names] - try: - check_call(_cmd, stderr=PIPE, stdout=PIPE) - except CalledProcessError as e: - raise PackageError( - "Could not {} package(s) [{}]: {}".format(command, [*package_names], e.output) - ) from None - - def _add(self) -> None: - """Add a package to the system.""" - self._apt( - "install", - "{}={}".format(self.name, self.version), - optargs=["--option=Dpkg::Options::=--force-confold"], - ) - - def _remove(self) -> None: - """Removes a package from the system. Implementation-specific.""" - return self._apt("remove", "{}={}".format(self.name, self.version)) - - @property - def name(self) -> str: - """Returns the name of the package.""" - return self._name - - def ensure(self, state: PackageState): - """Ensures that a package is in a given state. - - Args: - state: a `PackageState` to reconcile the package to - - Raises: - PackageError from the underlying call to apt - """ - if self._state is not state: - if state not in (PackageState.Present, PackageState.Latest): - self._remove() - else: - self._add() - self._state = state - - @property - def present(self) -> bool: - """Returns whether or not a package is present.""" - return self._state in (PackageState.Present, PackageState.Latest) - - @property - def latest(self) -> bool: - """Returns whether the package is the most recent version.""" - return self._state is PackageState.Latest - - @property - def state(self) -> PackageState: - """Returns the current package state.""" - return self._state - - @state.setter - def state(self, state: PackageState) -> None: - """Sets the package state to a given value. - - Args: - state: a `PackageState` to reconcile the package to - - Raises: - PackageError from the underlying call to apt - """ - if state in (PackageState.Latest, PackageState.Present): - self._add() - else: - self._remove() - self._state = state - - @property - def version(self) -> "Version": - """Returns the version for a package.""" - return self._version - - @property - def epoch(self) -> str: - """Returns the epoch for a package. May be unset.""" - return self._version.epoch - - @property - def arch(self) -> str: - """Returns the architecture for a package.""" - return self._arch - - @property - def fullversion(self) -> str: - """Returns the name+epoch for a package.""" - return "{}.{}".format(self._version, self._arch) - - @staticmethod - def _get_epoch_from_version(version: str) -> Tuple[str, str]: - """Pull the epoch, if any, out of a version string.""" - epoch_matcher = re.compile(r"^((?P\d+):)?(?P.*)") - matches = epoch_matcher.search(version).groupdict() - return matches.get("epoch", ""), matches.get("version") - - @classmethod - def from_system( - cls, package: str, version: Optional[str] = "", arch: Optional[str] = "" - ) -> "DebianPackage": - """Locates a package, either on the system or known to apt, and serializes the information. - - Args: - package: a string representing the package - version: an optional string if a specific version isr equested - arch: an optional architecture, defaulting to `dpkg --print-architecture`. If an - architecture is not specified, this will be used for selection. - - """ - try: - return DebianPackage.from_installed_package(package, version, arch) - except PackageNotFoundError: - logger.debug( - "package '%s' is not currently installed or has the wrong architecture.", package - ) - - # Ok, try `apt-cache ...` - try: - return DebianPackage.from_apt_cache(package, version, arch) - except (PackageNotFoundError, PackageError): - # If we get here, it's not known to the systems. - # This seems unnecessary, but virtually all `apt` commands have a return code of `100`, - # and providing meaningful error messages without this is ugly. - raise PackageNotFoundError( - "Package '{}{}' could not be found on the system or in the apt cache!".format( - package, ".{}".format(arch) if arch else "" - ) - ) from None - - @classmethod - def from_installed_package( - cls, package: str, version: Optional[str] = "", arch: Optional[str] = "" - ) -> "DebianPackage": - """Check whether the package is already installed and return an instance. - - Args: - package: a string representing the package - version: an optional string if a specific version isr equested - arch: an optional architecture, defaulting to `dpkg --print-architecture`. - If an architecture is not specified, this will be used for selection. - """ - system_arch = check_output( - ["dpkg", "--print-architecture"], universal_newlines=True - ).strip() - arch = arch if arch else system_arch - - # Regexps are a really terrible way to do this. Thanks dpkg - output = "" - try: - output = check_output(["dpkg", "-l", package], stderr=PIPE, universal_newlines=True) - except CalledProcessError: - raise PackageNotFoundError("Package is not installed: {}".format(package)) from None - - # Pop off the output from `dpkg -l' because there's no flag to - # omit it` - lines = str(output).splitlines()[5:] - - dpkg_matcher = re.compile( - r""" - ^(?P\w+?)\s+ - (?P.*?)(?P:\w+?)?\s+ - (?P.*?)\s+ - (?P\w+?)\s+ - (?P.*) - """, - re.VERBOSE, - ) - - for line in lines: - try: - matches = dpkg_matcher.search(line).groupdict() - package_status = matches["package_status"] - - if not package_status.endswith("i"): - logger.debug( - "package '%s' in dpkg output but not installed, status: '%s'", - package, - package_status, - ) - break - - epoch, split_version = DebianPackage._get_epoch_from_version(matches["version"]) - pkg = DebianPackage( - matches["package_name"], - split_version, - epoch, - matches["arch"], - PackageState.Present, - ) - if (pkg.arch == "all" or pkg.arch == arch) and ( - version == "" or str(pkg.version) == version - ): - return pkg - except AttributeError: - logger.warning("dpkg matcher could not parse line: %s", line) - - # If we didn't find it, fail through - raise PackageNotFoundError("Package {}.{} is not installed!".format(package, arch)) - - @classmethod - def from_apt_cache( - cls, package: str, version: Optional[str] = "", arch: Optional[str] = "" - ) -> "DebianPackage": - """Check whether the package is already installed and return an instance. - - Args: - package: a string representing the package - version: an optional string if a specific version isr equested - arch: an optional architecture, defaulting to `dpkg --print-architecture`. - If an architecture is not specified, this will be used for selection. - """ - system_arch = check_output( - ["dpkg", "--print-architecture"], universal_newlines=True - ).strip() - arch = arch if arch else system_arch - - # Regexps are a really terrible way to do this. Thanks dpkg - keys = ("Package", "Architecture", "Version") - - try: - output = check_output( - ["apt-cache", "show", package], stderr=PIPE, universal_newlines=True - ) - except CalledProcessError as e: - raise PackageError( - "Could not list packages in apt-cache: {}".format(e.output) - ) from None - - pkg_groups = output.strip().split("\n\n") - keys = ("Package", "Architecture", "Version") - - for pkg_raw in pkg_groups: - lines = str(pkg_raw).splitlines() - vals = {} - for line in lines: - if line.startswith(keys): - items = line.split(":", 1) - vals[items[0]] = items[1].strip() - else: - continue - - epoch, split_version = DebianPackage._get_epoch_from_version(vals["Version"]) - pkg = DebianPackage( - vals["Package"], - split_version, - epoch, - vals["Architecture"], - PackageState.Available, - ) - - if (pkg.arch == "all" or pkg.arch == arch) and ( - version == "" or str(pkg.version) == version - ): - return pkg - - # If we didn't find it, fail through - raise PackageNotFoundError("Package {}.{} is not in the apt cache!".format(package, arch)) - - -class Version: - """An abstraction around package versions. - - This seems like it should be strictly unnecessary, except that `apt_pkg` is not usable inside a - venv, and wedging version comparisions into `DebianPackage` would overcomplicate it. - - This class implements the algorithm found here: - https://www.debian.org/doc/debian-policy/ch-controlfields.html#version - """ - - def __init__(self, version: str, epoch: str): - self._version = version - self._epoch = epoch or "" - - def __repr__(self): - """A representation of the package.""" - return "<{}.{}: {}>".format(self.__module__, self.__class__.__name__, self.__dict__) - - def __str__(self): - """A human-readable representation of the package.""" - return "{}{}".format("{}:".format(self._epoch) if self._epoch else "", self._version) - - @property - def epoch(self): - """Returns the epoch for a package. May be empty.""" - return self._epoch - - @property - def number(self) -> str: - """Returns the version number for a package.""" - return self._version - - def _get_parts(self, version: str) -> Tuple[str, str]: - """Separate the version into component upstream and Debian pieces.""" - try: - version.rindex("-") - except ValueError: - # No hyphens means no Debian version - return version, "0" - - upstream, debian = version.rsplit("-", 1) - return upstream, debian - - def _listify(self, revision: str) -> List[str]: - """Split a revision string into a listself. - - This list is comprised of alternating between strings and numbers, - padded on either end to always be "str, int, str, int..." and - always be of even length. This allows us to trivially implement the - comparison algorithm described. - """ - result = [] - while revision: - rev_1, remains = self._get_alphas(revision) - rev_2, remains = self._get_digits(remains) - result.extend([rev_1, rev_2]) - revision = remains - return result - - def _get_alphas(self, revision: str) -> Tuple[str, str]: - """Return a tuple of the first non-digit characters of a revision.""" - # get the index of the first digit - for i, char in enumerate(revision): - if char.isdigit(): - if i == 0: - return "", revision - return revision[0:i], revision[i:] - # string is entirely alphas - return revision, "" - - def _get_digits(self, revision: str) -> Tuple[int, str]: - """Return a tuple of the first integer characters of a revision.""" - # If the string is empty, return (0,'') - if not revision: - return 0, "" - # get the index of the first non-digit - for i, char in enumerate(revision): - if not char.isdigit(): - if i == 0: - return 0, revision - return int(revision[0:i]), revision[i:] - # string is entirely digits - return int(revision), "" - - def _dstringcmp(self, a, b): # noqa: C901 - """Debian package version string section lexical sort algorithm. - - The lexical comparison is a comparison of ASCII values modified so - that all the letters sort earlier than all the non-letters and so that - a tilde sorts before anything, even the end of a part. - """ - if a == b: - return 0 - try: - for i, char in enumerate(a): - if char == b[i]: - continue - # "a tilde sorts before anything, even the end of a part" - # (emptyness) - if char == "~": - return -1 - if b[i] == "~": - return 1 - # "all the letters sort earlier than all the non-letters" - if char.isalpha() and not b[i].isalpha(): - return -1 - if not char.isalpha() and b[i].isalpha(): - return 1 - # otherwise lexical sort - if ord(char) > ord(b[i]): - return 1 - if ord(char) < ord(b[i]): - return -1 - except IndexError: - # a is longer than b but otherwise equal, greater unless there are tildes - if char == "~": - return -1 - return 1 - # if we get here, a is shorter than b but otherwise equal, so check for tildes... - if b[len(a)] == "~": - return 1 - return -1 - - def _compare_revision_strings(self, first: str, second: str): # noqa: C901 - """Compare two debian revision strings.""" - if first == second: - return 0 - - # listify pads results so that we will always be comparing ints to ints - # and strings to strings (at least until we fall off the end of a list) - first_list = self._listify(first) - second_list = self._listify(second) - if first_list == second_list: - return 0 - try: - for i, item in enumerate(first_list): - # explicitly raise IndexError if we've fallen off the edge of list2 - if i >= len(second_list): - raise IndexError - # if the items are equal, next - if item == second_list[i]: - continue - # numeric comparison - if isinstance(item, int): - if item > second_list[i]: - return 1 - if item < second_list[i]: - return -1 - else: - # string comparison - return self._dstringcmp(item, second_list[i]) - except IndexError: - # rev1 is longer than rev2 but otherwise equal, hence greater - # ...except for goddamn tildes - if first_list[len(second_list)][0][0] == "~": - return 1 - return 1 - # rev1 is shorter than rev2 but otherwise equal, hence lesser - # ...except for goddamn tildes - if second_list[len(first_list)][0][0] == "~": - return -1 - return -1 - - def _compare_version(self, other) -> int: - if (self.number, self.epoch) == (other.number, other.epoch): - return 0 - - if self.epoch < other.epoch: - return -1 - if self.epoch > other.epoch: - return 1 - - # If none of these are true, follow the algorithm - upstream_version, debian_version = self._get_parts(self.number) - other_upstream_version, other_debian_version = self._get_parts(other.number) - - upstream_cmp = self._compare_revision_strings(upstream_version, other_upstream_version) - if upstream_cmp != 0: - return upstream_cmp - - debian_cmp = self._compare_revision_strings(debian_version, other_debian_version) - if debian_cmp != 0: - return debian_cmp - - return 0 - - def __lt__(self, other) -> bool: - """Less than magic method impl.""" - return self._compare_version(other) < 0 - - def __eq__(self, other) -> bool: - """Equality magic method impl.""" - return self._compare_version(other) == 0 - - def __gt__(self, other) -> bool: - """Greater than magic method impl.""" - return self._compare_version(other) > 0 - - def __le__(self, other) -> bool: - """Less than or equal to magic method impl.""" - return self.__eq__(other) or self.__lt__(other) - - def __ge__(self, other) -> bool: - """Greater than or equal to magic method impl.""" - return self.__gt__(other) or self.__eq__(other) - - def __ne__(self, other) -> bool: - """Not equal to magic method impl.""" - return not self.__eq__(other) - - -def add_package( - package_names: Union[str, List[str]], - version: Optional[str] = "", - arch: Optional[str] = "", - update_cache: Optional[bool] = False, -) -> Union[DebianPackage, List[DebianPackage]]: - """Add a package or list of packages to the system. - - Args: - name: the name(s) of the package(s) - version: an (Optional) version as a string. Defaults to the latest known - arch: an optional architecture for the package - update_cache: whether or not to run `apt-get update` prior to operating - - Raises: - PackageNotFoundError if the package is not in the cache. - """ - cache_refreshed = False - if update_cache: - update() - cache_refreshed = True - - packages = {"success": [], "retry": [], "failed": []} - - package_names = [package_names] if type(package_names) is str else package_names - if not package_names: - raise TypeError("Expected at least one package name to add, received zero!") - - if len(package_names) != 1 and version: - raise TypeError( - "Explicit version should not be set if more than one package is being added!" - ) - - for p in package_names: - pkg, success = _add(p, version, arch) - if success: - packages["success"].append(pkg) - else: - logger.warning("failed to locate and install/update '%s'", pkg) - packages["retry"].append(p) - - if packages["retry"] and not cache_refreshed: - logger.info("updating the apt-cache and retrying installation of failed packages.") - update() - - for p in packages["retry"]: - pkg, success = _add(p, version, arch) - if success: - packages["success"].append(pkg) - else: - packages["failed"].append(p) - - if packages["failed"]: - raise PackageError("Failed to install packages: {}".format(", ".join(packages["failed"]))) - - return packages["success"] if len(packages["success"]) > 1 else packages["success"][0] - - -def _add( - name: str, - version: Optional[str] = "", - arch: Optional[str] = "", -) -> Tuple[Union[DebianPackage, str], bool]: - """Adds a package. - - Args: - name: the name(s) of the package(s) - version: an (Optional) version as a string. Defaults to the latest known - arch: an optional architecture for the package - - Returns: a tuple of `DebianPackage` if found, or a :str: if it is not, and - a boolean indicating success - """ - try: - pkg = DebianPackage.from_system(name, version, arch) - pkg.ensure(state=PackageState.Present) - return pkg, True - except PackageNotFoundError: - return name, False - - -def remove_package( - package_names: Union[str, List[str]] -) -> Union[DebianPackage, List[DebianPackage]]: - """Removes a package from the system. - - Args: - package_names: the name of a package - - Raises: - PackageNotFoundError if the package is not found. - """ - packages = [] - - package_names = [package_names] if type(package_names) is str else package_names - if not package_names: - raise TypeError("Expected at least one package name to add, received zero!") - - for p in package_names: - try: - pkg = DebianPackage.from_installed_package(p) - pkg.ensure(state=PackageState.Absent) - packages.append(pkg) - except PackageNotFoundError: - logger.info("package '%s' was requested for removal, but it was not installed.", p) - - # the list of packages will be empty when no package is removed - logger.debug("packages: '%s'", packages) - return packages[0] if len(packages) == 1 else packages - - -def update() -> None: - """Updates the apt cache via `apt-get update`.""" - check_call(["apt-get", "update"], stderr=PIPE, stdout=PIPE) - - -class InvalidSourceError(Error): - """Exceptions for invalid source entries.""" - - -class GPGKeyError(Error): - """Exceptions for GPG keys.""" - - -class DebianRepository: - """An abstraction to represent a repository.""" - - def __init__( - self, - enabled: bool, - repotype: str, - uri: str, - release: str, - groups: List[str], - filename: Optional[str] = "", - gpg_key_filename: Optional[str] = "", - options: Optional[dict] = None, - ): - self._enabled = enabled - self._repotype = repotype - self._uri = uri - self._release = release - self._groups = groups - self._filename = filename - self._gpg_key_filename = gpg_key_filename - self._options = options - - @property - def enabled(self): - """Return whether or not the repository is enabled.""" - return self._enabled - - @property - def repotype(self): - """Return whether it is binary or source.""" - return self._repotype - - @property - def uri(self): - """Return the URI.""" - return self._uri - - @property - def release(self): - """Return which Debian/Ubuntu releases it is valid for.""" - return self._release - - @property - def groups(self): - """Return the enabled package groups.""" - return self._groups - - @property - def filename(self): - """Returns the filename for a repository.""" - return self._filename - - @filename.setter - def filename(self, fname: str) -> None: - """Sets the filename used when a repo is written back to diskself. - - Args: - fname: a filename to write the repository information to. - """ - if not fname.endswith(".list"): - raise InvalidSourceError("apt source filenames should end in .list!") - - self._filename = fname - - @property - def gpg_key(self): - """Returns the path to the GPG key for this repository.""" - return self._gpg_key_filename - - @property - def options(self): - """Returns any additional repo options which are set.""" - return self._options - - def make_options_string(self) -> str: - """Generate the complete options string for a a repository. - - Combining `gpg_key`, if set, and the rest of the options to find - a complex repo string. - """ - options = self._options if self._options else {} - if self._gpg_key_filename: - options["signed-by"] = self._gpg_key_filename - - return ( - "[{}] ".format(" ".join(["{}={}".format(k, v) for k, v in options.items()])) - if options - else "" - ) - - @staticmethod - def prefix_from_uri(uri: str) -> str: - """Get a repo list prefix from the uri, depending on whether a path is set.""" - uridetails = urlparse(uri) - path = ( - uridetails.path.lstrip("/").replace("/", "-") if uridetails.path else uridetails.netloc - ) - return "/etc/apt/sources.list.d/{}".format(path) - - @staticmethod - def from_repo_line(repo_line: str, write_file: Optional[bool] = True) -> "DebianRepository": - """Instantiate a new `DebianRepository` a `sources.list` entry line. - - Args: - repo_line: a string representing a repository entry - write_file: boolean to enable writing the new repo to disk - """ - repo = RepositoryMapping._parse(repo_line, "UserInput") - fname = "{}-{}.list".format( - DebianRepository.prefix_from_uri(repo.uri), repo.release.replace("/", "-") - ) - repo.filename = fname - - options = repo.options if repo.options else {} - if repo.gpg_key: - options["signed-by"] = repo.gpg_key - - # For Python 3.5 it's required to use sorted in the options dict in order to not have - # different results in the order of the options between executions. - options_str = ( - "[{}] ".format(" ".join(["{}={}".format(k, v) for k, v in sorted(options.items())])) - if options - else "" - ) - - if write_file: - with open(fname, "wb") as f: - f.write( - ( - "{}".format("#" if not repo.enabled else "") - + "{} {}{} ".format(repo.repotype, options_str, repo.uri) - + "{} {}\n".format(repo.release, " ".join(repo.groups)) - ).encode("utf-8") - ) - - return repo - - def disable(self) -> None: - """Remove this repository from consideration. - - Disable it instead of removing from the repository file. - """ - searcher = "{} {}{} {}".format( - self.repotype, self.make_options_string(), self.uri, self.release - ) - for line in fileinput.input(self._filename, inplace=True): - if re.match(r"^{}\s".format(re.escape(searcher)), line): - print("# {}".format(line), end="") - else: - print(line, end="") - - def import_key(self, key: str) -> None: - """Import an ASCII Armor key. - - A Radix64 format keyid is also supported for backwards - compatibility. In this case Ubuntu keyserver will be - queried for a key via HTTPS by its keyid. This method - is less preferrable because https proxy servers may - require traffic decryption which is equivalent to a - man-in-the-middle attack (a proxy server impersonates - keyserver TLS certificates and has to be explicitly - trusted by the system). - - Args: - key: A GPG key in ASCII armor format, - including BEGIN and END markers or a keyid. - - Raises: - GPGKeyError if the key could not be imported - """ - key = key.strip() - if "-" in key or "\n" in key: - # Send everything not obviously a keyid to GPG to import, as - # we trust its validation better than our own. eg. handling - # comments before the key. - logger.debug("PGP key found (looks like ASCII Armor format)") - if ( - "-----BEGIN PGP PUBLIC KEY BLOCK-----" in key - and "-----END PGP PUBLIC KEY BLOCK-----" in key - ): - logger.debug("Writing provided PGP key in the binary format") - key_bytes = key.encode("utf-8") - key_name = self._get_keyid_by_gpg_key(key_bytes) - key_gpg = self._dearmor_gpg_key(key_bytes) - self._gpg_key_filename = "/etc/apt/trusted.gpg.d/{}.gpg".format(key_name) - self._write_apt_gpg_keyfile(key_name=self._gpg_key_filename, key_material=key_gpg) - else: - raise GPGKeyError("ASCII armor markers missing from GPG key") - else: - logger.warning( - "PGP key found (looks like Radix64 format). " - "SECURELY importing PGP key from keyserver; " - "full key not provided." - ) - # as of bionic add-apt-repository uses curl with an HTTPS keyserver URL - # to retrieve GPG keys. `apt-key adv` command is deprecated as is - # apt-key in general as noted in its manpage. See lp:1433761 for more - # history. Instead, /etc/apt/trusted.gpg.d is used directly to drop - # gpg - key_asc = self._get_key_by_keyid(key) - # write the key in GPG format so that apt-key list shows it - key_gpg = self._dearmor_gpg_key(key_asc.encode("utf-8")) - self._gpg_key_filename = "/etc/apt/trusted.gpg.d/{}.gpg".format(key) - self._write_apt_gpg_keyfile(key_name=key, key_material=key_gpg) - - @staticmethod - def _get_keyid_by_gpg_key(key_material: bytes) -> str: - """Get a GPG key fingerprint by GPG key material. - - Gets a GPG key fingerprint (40-digit, 160-bit) by the ASCII armor-encoded - or binary GPG key material. Can be used, for example, to generate file - names for keys passed via charm options. - """ - # Use the same gpg command for both Xenial and Bionic - cmd = ["gpg", "--with-colons", "--with-fingerprint"] - ps = subprocess.run( - cmd, - stdout=PIPE, - stderr=PIPE, - input=key_material, - ) - out, err = ps.stdout.decode(), ps.stderr.decode() - if "gpg: no valid OpenPGP data found." in err: - raise GPGKeyError("Invalid GPG key material provided") - # from gnupg2 docs: fpr :: Fingerprint (fingerprint is in field 10) - return re.search(r"^fpr:{9}([0-9A-F]{40}):$", out, re.MULTILINE).group(1) - - @staticmethod - def _get_key_by_keyid(keyid: str) -> str: - """Get a key via HTTPS from the Ubuntu keyserver. - - Different key ID formats are supported by SKS keyservers (the longer ones - are more secure, see "dead beef attack" and https://evil32.com/). Since - HTTPS is used, if SSLBump-like HTTPS proxies are in place, they will - impersonate keyserver.ubuntu.com and generate a certificate with - keyserver.ubuntu.com in the CN field or in SubjAltName fields of a - certificate. If such proxy behavior is expected it is necessary to add the - CA certificate chain containing the intermediate CA of the SSLBump proxy to - every machine that this code runs on via ca-certs cloud-init directive (via - cloudinit-userdata model-config) or via other means (such as through a - custom charm option). Also note that DNS resolution for the hostname in a - URL is done at a proxy server - not at the client side. - 8-digit (32 bit) key ID - https://keyserver.ubuntu.com/pks/lookup?search=0x4652B4E6 - 16-digit (64 bit) key ID - https://keyserver.ubuntu.com/pks/lookup?search=0x6E85A86E4652B4E6 - 40-digit key ID: - https://keyserver.ubuntu.com/pks/lookup?search=0x35F77D63B5CEC106C577ED856E85A86E4652B4E6 - - Args: - keyid: An 8, 16 or 40 hex digit keyid to find a key for - - Returns: - A string contining key material for the specified GPG key id - - - Raises: - subprocess.CalledProcessError - """ - # options=mr - machine-readable output (disables html wrappers) - keyserver_url = ( - "https://keyserver.ubuntu.com" "/pks/lookup?op=get&options=mr&exact=on&search=0x{}" - ) - curl_cmd = ["curl", keyserver_url.format(keyid)] - # use proxy server settings in order to retrieve the key - return check_output(curl_cmd).decode() - - @staticmethod - def _dearmor_gpg_key(key_asc: bytes) -> bytes: - """Converts a GPG key in the ASCII armor format to the binary format. - - Args: - key_asc: A GPG key in ASCII armor format. - - Returns: - A GPG key in binary format as a string - - Raises: - GPGKeyError - """ - ps = subprocess.run(["gpg", "--dearmor"], stdout=PIPE, stderr=PIPE, input=key_asc) - out, err = ps.stdout, ps.stderr.decode() - if "gpg: no valid OpenPGP data found." in err: - raise GPGKeyError( - "Invalid GPG key material. Check your network setup" - " (MTU, routing, DNS) and/or proxy server settings" - " as well as destination keyserver status." - ) - else: - return out - - @staticmethod - def _write_apt_gpg_keyfile(key_name: str, key_material: bytes) -> None: - """Writes GPG key material into a file at a provided path. - - Args: - key_name: A key name to use for a key file (could be a fingerprint) - key_material: A GPG key material (binary) - """ - with open(key_name, "wb") as keyf: - keyf.write(key_material) - - -class RepositoryMapping(Mapping): - """An representation of known repositories. - - Instantiation of `RepositoryMapping` will iterate through the - filesystem, parse out repository files in `/etc/apt/...`, and create - `DebianRepository` objects in this list. - - Typical usage: - - repositories = apt.RepositoryMapping() - repositories.add(DebianRepository( - enabled=True, repotype="deb", uri="https://example.com", release="focal", - groups=["universe"] - )) - """ - - def __init__(self): - self._repository_map = {} - # Repositories that we're adding -- used to implement mode param - self.default_file = "/etc/apt/sources.list" - - # read sources.list if it exists - if os.path.isfile(self.default_file): - self.load(self.default_file) - - # read sources.list.d - for file in glob.iglob("/etc/apt/sources.list.d/*.list"): - self.load(file) - - def __contains__(self, key: str) -> bool: - """Magic method for checking presence of repo in mapping.""" - return key in self._repository_map - - def __len__(self) -> int: - """Return number of repositories in map.""" - return len(self._repository_map) - - def __iter__(self) -> Iterable[DebianRepository]: - """Iterator magic method for RepositoryMapping.""" - return iter(self._repository_map.values()) - - def __getitem__(self, repository_uri: str) -> DebianRepository: - """Return a given `DebianRepository`.""" - return self._repository_map[repository_uri] - - def __setitem__(self, repository_uri: str, repository: DebianRepository) -> None: - """Add a `DebianRepository` to the cache.""" - self._repository_map[repository_uri] = repository - - def load(self, filename: str): - """Load a repository source file into the cache. - - Args: - filename: the path to the repository file - """ - parsed = [] - skipped = [] - with open(filename, "r") as f: - for n, line in enumerate(f): - try: - repo = self._parse(line, filename) - except InvalidSourceError: - skipped.append(n) - else: - repo_identifier = "{}-{}-{}".format(repo.repotype, repo.uri, repo.release) - self._repository_map[repo_identifier] = repo - parsed.append(n) - logger.debug("parsed repo: '%s'", repo_identifier) - - if skipped: - skip_list = ", ".join(str(s) for s in skipped) - logger.debug("skipped the following lines in file '%s': %s", filename, skip_list) - - if parsed: - logger.info("parsed %d apt package repositories", len(parsed)) - else: - raise InvalidSourceError("all repository lines in '{}' were invalid!".format(filename)) - - @staticmethod - def _parse(line: str, filename: str) -> DebianRepository: - """Parse a line in a sources.list file. - - Args: - line: a single line from `load` to parse - filename: the filename being read - - Raises: - InvalidSourceError if the source type is unknown - """ - enabled = True - repotype = uri = release = gpg_key = "" - options = {} - groups = [] - - line = line.strip() - if line.startswith("#"): - enabled = False - line = line[1:] - - # Check for "#" in the line and treat a part after it as a comment then strip it off. - i = line.find("#") - if i > 0: - line = line[:i] - - # Split a source into substrings to initialize a new repo. - source = line.strip() - if source: - # Match any repo options, and get a dict representation. - for v in re.findall(OPTIONS_MATCHER, source): - opts = dict(o.split("=") for o in v.strip("[]").split()) - # Extract the 'signed-by' option for the gpg_key - gpg_key = opts.pop("signed-by", "") - options = opts - - # Remove any options from the source string and split the string into chunks - source = re.sub(OPTIONS_MATCHER, "", source) - chunks = source.split() - - # Check we've got a valid list of chunks - if len(chunks) < 3 or chunks[0] not in VALID_SOURCE_TYPES: - raise InvalidSourceError("An invalid sources line was found in %s!", filename) - - repotype = chunks[0] - uri = chunks[1] - release = chunks[2] - groups = chunks[3:] - - return DebianRepository( - enabled, repotype, uri, release, groups, filename, gpg_key, options - ) - else: - raise InvalidSourceError("An invalid sources line was found in %s!", filename) - - def add(self, repo: DebianRepository, default_filename: Optional[bool] = False) -> None: - """Add a new repository to the system. - - Args: - repo: a `DebianRepository` object - default_filename: an (Optional) filename if the default is not desirable - """ - new_filename = "{}-{}.list".format( - DebianRepository.prefix_from_uri(repo.uri), repo.release.replace("/", "-") - ) - - fname = repo.filename or new_filename - - options = repo.options if repo.options else {} - if repo.gpg_key: - options["signed-by"] = repo.gpg_key - - with open(fname, "wb") as f: - f.write( - ( - "{}".format("#" if not repo.enabled else "") - + "{} {}{} ".format(repo.repotype, repo.make_options_string(), repo.uri) - + "{} {}\n".format(repo.release, " ".join(repo.groups)) - ).encode("utf-8") - ) - - self._repository_map["{}-{}-{}".format(repo.repotype, repo.uri, repo.release)] = repo - - def disable(self, repo: DebianRepository) -> None: - """Remove a repository. Disable by default. - - Args: - repo: a `DebianRepository` to disable - """ - searcher = "{} {}{} {}".format( - repo.repotype, repo.make_options_string(), repo.uri, repo.release - ) - - for line in fileinput.input(repo.filename, inplace=True): - if re.match(r"^{}\s".format(re.escape(searcher)), line): - print("# {}".format(line), end="") - else: - print(line, end="") - - self._repository_map["{}-{}-{}".format(repo.repotype, repo.uri, repo.release)] = repo diff --git a/lib/charms/operator_libs_linux/v0/passwd.py b/lib/charms/operator_libs_linux/v0/passwd.py deleted file mode 100644 index b692e700..00000000 --- a/lib/charms/operator_libs_linux/v0/passwd.py +++ /dev/null @@ -1,255 +0,0 @@ -# Copyright 2021 Canonical Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Simple library for managing Linux users and groups. - -The `passwd` module provides convenience methods and abstractions around users and groups on a -Linux system, in order to make adding and managing users and groups easy. - -Example of adding a user named 'test': - -```python -import passwd -passwd.add_group(name='special_group') -passwd.add_user(username='test', secondary_groups=['sudo']) - -if passwd.user_exists('some_user'): - do_stuff() -``` -""" - -import grp -import logging -import pwd -from subprocess import STDOUT, check_output -from typing import List, Optional, Union - -logger = logging.getLogger(__name__) - -# The unique Charmhub library identifier, never change it -LIBID = "cf7655b2bf914d67ac963f72b930f6bb" - -# Increment this major API version when introducing breaking changes -LIBAPI = 0 - -# Increment this PATCH version before using `charmcraft publish-lib` or reset -# to 0 if you are raising the major API version -LIBPATCH = 3 - - -def user_exists(user: Union[str, int]) -> Optional[pwd.struct_passwd]: - """Check if a user exists. - - Args: - user: username or gid of user whose existence to check - - Raises: - TypeError: where neither a string or int is passed as the first argument - """ - try: - if type(user) is int: - return pwd.getpwuid(user) - elif type(user) is str: - return pwd.getpwnam(user) - else: - raise TypeError("specified argument '%r' should be a string or int", user) - except KeyError: - logger.info("specified user '%s' doesn't exist", str(user)) - return None - - -def group_exists(group: Union[str, int]) -> Optional[grp.struct_group]: - """Check if a group exists. - - Args: - group: username or gid of user whose existence to check - - Raises: - TypeError: where neither a string or int is passed as the first argument - """ - try: - if type(group) is int: - return grp.getgrgid(group) - elif type(group) is str: - return grp.getgrnam(group) - else: - raise TypeError("specified argument '%r' should be a string or int", group) - except KeyError: - logger.info("specified group '%s' doesn't exist", str(group)) - return None - - -def add_user( - username: str, - password: Optional[str] = None, - shell: str = "/bin/bash", - system_user: bool = False, - primary_group: str = None, - secondary_groups: List[str] = None, - uid: int = None, - home_dir: str = None, -) -> str: - """Add a user to the system. - - Will log but otherwise succeed if the user already exists. - - Arguments: - username: Username to create - password: Password for user; if ``None``, create a system user - shell: The default shell for the user - system_user: Whether to create a login or system user - primary_group: Primary group for user; defaults to username - secondary_groups: Optional list of additional groups - uid: UID for user being created - home_dir: Home directory for user - - Returns: - The password database entry struct, as returned by `pwd.getpwnam` - """ - try: - if uid: - user_info = pwd.getpwuid(int(uid)) - logger.info("user '%d' already exists", uid) - return user_info - user_info = pwd.getpwnam(username) - logger.info("user with uid '%s' already exists", username) - return user_info - except KeyError: - logger.info("creating user '%s'", username) - - cmd = ["useradd", "--shell", shell] - - if uid: - cmd.extend(["--uid", str(uid)]) - if home_dir: - cmd.extend(["--home", str(home_dir)]) - if password: - cmd.extend(["--password", password, "--create-home"]) - if system_user or password is None: - cmd.append("--system") - - if not primary_group: - try: - grp.getgrnam(username) - primary_group = username # avoid "group exists" error - except KeyError: - pass - - if primary_group: - cmd.extend(["-g", primary_group]) - if secondary_groups: - cmd.extend(["-G", ",".join(secondary_groups)]) - - cmd.append(username) - check_output(cmd, stderr=STDOUT) - user_info = pwd.getpwnam(username) - return user_info - - -def add_group(group_name: str, system_group: bool = False, gid: int = None): - """Add a group to the system. - - Will log but otherwise succeed if the group already exists. - - Args: - group_name: group to create - system_group: Create system group - gid: GID for user being created - - Returns: - The group's password database entry struct, as returned by `grp.getgrnam` - """ - try: - group_info = grp.getgrnam(group_name) - logger.info("group '%s' already exists", group_name) - if gid: - group_info = grp.getgrgid(gid) - logger.info("group with gid '%d' already exists", gid) - except KeyError: - logger.info("creating group '%s'", group_name) - cmd = ["addgroup"] - if gid: - cmd.extend(["--gid", str(gid)]) - if system_group: - cmd.append("--system") - else: - cmd.extend(["--group"]) - cmd.append(group_name) - check_output(cmd, stderr=STDOUT) - group_info = grp.getgrnam(group_name) - return group_info - - -def add_user_to_group(username: str, group: str): - """Add a user to a group. - - Args: - username: user to add to specified group - group: name of group to add user to - - Returns: - The group's password database entry struct, as returned by `grp.getgrnam` - """ - if not user_exists(username): - raise ValueError("user '{}' does not exist".format(username)) - if not group_exists(group): - raise ValueError("group '{}' does not exist".format(group)) - - logger.info("adding user '%s' to group '%s'", username, group) - check_output(["gpasswd", "-a", username, group], stderr=STDOUT) - return grp.getgrnam(group) - - -def remove_user(user: Union[str, int], remove_home: bool = False) -> bool: - """Remove a user from the system. - - Args: - user: the username or uid of the user to remove - remove_home: indicates whether the user's home directory should be removed - """ - u = user_exists(user) - if not u: - logger.info("user '%s' does not exist", str(u)) - return True - - cmd = ["userdel"] - if remove_home: - cmd.append("-f") - cmd.append(u.pw_name) - - logger.info("removing user '%s'", u.pw_name) - check_output(cmd, stderr=STDOUT) - return True - - -def remove_group(group: Union[str, int], force: bool = False) -> bool: - """Remove a user from the system. - - Args: - group: the name or gid of the group to remove - force: force group removal even if it's the primary group for a user - """ - g = group_exists(group) - if not g: - logger.info("group '%s' does not exist", str(g)) - return True - - cmd = ["groupdel"] - if force: - cmd.append("-f") - cmd.append(g.gr_name) - - logger.info("removing group '%s'", g.gr_name) - check_output(cmd, stderr=STDOUT) - return True diff --git a/lib/charms/operator_libs_linux/v1/systemd.py b/lib/charms/operator_libs_linux/v1/systemd.py deleted file mode 100644 index 5be34c17..00000000 --- a/lib/charms/operator_libs_linux/v1/systemd.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright 2021 Canonical Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -"""Abstractions for stopping, starting and managing system services via systemd. - -This library assumes that your charm is running on a platform that uses systemd. E.g., -Centos 7 or later, Ubuntu Xenial (16.04) or later. - -For the most part, we transparently provide an interface to a commonly used selection of -systemd commands, with a few shortcuts baked in. For example, service_pause and -service_resume with run the mask/unmask and enable/disable invocations. - -Example usage: -```python -from charms.operator_libs_linux.v0.systemd import service_running, service_reload - -# Start a service -if not service_running("mysql"): - success = service_start("mysql") - -# Attempt to reload a service, restarting if necessary -success = service_reload("nginx", restart_on_failure=True) -``` - -""" - -import logging -import subprocess - -__all__ = [ # Don't export `_systemctl`. (It's not the intended way of using this lib.) - "service_pause", - "service_reload", - "service_restart", - "service_resume", - "service_running", - "service_start", - "service_stop", - "daemon_reload", -] - -logger = logging.getLogger(__name__) - -# The unique Charmhub library identifier, never change it -LIBID = "045b0d179f6b4514a8bb9b48aee9ebaf" - -# Increment this major API version when introducing breaking changes -LIBAPI = 1 - -# Increment this PATCH version before using `charmcraft publish-lib` or reset -# to 0 if you are raising the major API version -LIBPATCH = 0 - - -class SystemdError(Exception): - pass - - -def _popen_kwargs(): - return dict( - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - bufsize=1, - universal_newlines=True, - encoding="utf-8", - ) - - -def _systemctl( - sub_cmd: str, service_name: str = None, now: bool = None, quiet: bool = None -) -> bool: - """Control a system service. - - Args: - sub_cmd: the systemctl subcommand to issue - service_name: the name of the service to perform the action on - now: passes the --now flag to the shell invocation. - quiet: passes the --quiet flag to the shell invocation. - """ - cmd = ["systemctl", sub_cmd] - - if service_name is not None: - cmd.append(service_name) - if now is not None: - cmd.append("--now") - if quiet is not None: - cmd.append("--quiet") - if sub_cmd != "is-active": - logger.debug("Attempting to {} '{}' with command {}.".format(cmd, service_name, cmd)) - else: - logger.debug("Checking if '{}' is active".format(service_name)) - - proc = subprocess.Popen(cmd, **_popen_kwargs()) - last_line = "" - for line in iter(proc.stdout.readline, ""): - last_line = line - logger.debug(line) - - proc.wait() - - if sub_cmd == "is-active": - # If we are just checking whether a service is running, return True/False, rather - # than raising an error. - if proc.returncode < 1: - return True - if proc.returncode == 3: # Code returned when service is not active. - return False - - if proc.returncode < 1: - return True - - raise SystemdError( - "Could not {}{}: systemd output: {}".format( - sub_cmd, " {}".format(service_name) if service_name else "", last_line - ) - ) - - -def service_running(service_name: str) -> bool: - """Determine whether a system service is running. - - Args: - service_name: the name of the service - """ - return _systemctl("is-active", service_name, quiet=True) - - -def service_start(service_name: str) -> bool: - """Start a system service. - - Args: - service_name: the name of the service to stop - """ - return _systemctl("start", service_name) - - -def service_stop(service_name: str) -> bool: - """Stop a system service. - - Args: - service_name: the name of the service to stop - """ - return _systemctl("stop", service_name) - - -def service_restart(service_name: str) -> bool: - """Restart a system service. - - Args: - service_name: the name of the service to restart - """ - return _systemctl("restart", service_name) - - -def service_reload(service_name: str, restart_on_failure: bool = False) -> bool: - """Reload a system service, optionally falling back to restart if reload fails. - - Args: - service_name: the name of the service to reload - restart_on_failure: boolean indicating whether to fallback to a restart if the - reload fails. - """ - try: - return _systemctl("reload", service_name) - except SystemdError: - if restart_on_failure: - return _systemctl("restart", service_name) - else: - raise - - -def service_pause(service_name: str) -> bool: - """Pause a system service. - - Stop it, and prevent it from starting again at boot. - - Args: - service_name: the name of the service to pause - """ - _systemctl("disable", service_name, now=True) - _systemctl("mask", service_name) - - if not service_running(service_name): - return True - - raise SystemdError("Attempted to pause '{}', but it is still running.".format(service_name)) - - -def service_resume(service_name: str) -> bool: - """Resume a system service. - - Re-enable starting again at boot. Start the service. - - Args: - service_name: the name of the service to resume - """ - _systemctl("unmask", service_name) - _systemctl("enable", service_name, now=True) - - if service_running(service_name): - return True - - raise SystemdError("Attempted to resume '{}', but it is not running.".format(service_name)) - - -def daemon_reload() -> bool: - """Reload systemd manager configuration.""" - return _systemctl("daemon-reload") diff --git a/lib/charms/operator_libs_linux/v2/snap.py b/lib/charms/operator_libs_linux/v2/snap.py new file mode 100644 index 00000000..b82024c5 --- /dev/null +++ b/lib/charms/operator_libs_linux/v2/snap.py @@ -0,0 +1,1065 @@ +# Copyright 2021 Canonical Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Representations of the system's Snaps, and abstractions around managing them. + +The `snap` module provides convenience methods for listing, installing, refreshing, and removing +Snap packages, in addition to setting and getting configuration options for them. + +In the `snap` module, `SnapCache` creates a dict-like mapping of `Snap` objects at when +instantiated. Installed snaps are fully populated, and available snaps are lazily-loaded upon +request. This module relies on an installed and running `snapd` daemon to perform operations over +the `snapd` HTTP API. + +`SnapCache` objects can be used to install or modify Snap packages by name in a manner similar to +using the `snap` command from the commandline. + +An example of adding Juju to the system with `SnapCache` and setting a config value: + +```python +try: + cache = snap.SnapCache() + juju = cache["juju"] + + if not juju.present: + juju.ensure(snap.SnapState.Latest, channel="beta") + juju.set({"some.key": "value", "some.key2": "value2"}) +except snap.SnapError as e: + logger.error("An exception occurred when installing charmcraft. Reason: %s", e.message) +``` + +In addition, the `snap` module provides "bare" methods which can act on Snap packages as +simple function calls. :meth:`add`, :meth:`remove`, and :meth:`ensure` are provided, as +well as :meth:`add_local` for installing directly from a local `.snap` file. These return +`Snap` objects. + +As an example of installing several Snaps and checking details: + +```python +try: + nextcloud, charmcraft = snap.add(["nextcloud", "charmcraft"]) + if nextcloud.get("mode") != "production": + nextcloud.set({"mode": "production"}) +except snap.SnapError as e: + logger.error("An exception occurred when installing snaps. Reason: %s" % e.message) +``` +""" + +import http.client +import json +import logging +import os +import re +import socket +import subprocess +import sys +import urllib.error +import urllib.parse +import urllib.request +from collections.abc import Mapping +from datetime import datetime, timedelta, timezone +from enum import Enum +from subprocess import CalledProcessError, CompletedProcess +from typing import Any, Dict, Iterable, List, Optional, Union + +logger = logging.getLogger(__name__) + +# The unique Charmhub library identifier, never change it +LIBID = "05394e5893f94f2d90feb7cbe6b633cd" + +# Increment this major API version when introducing breaking changes +LIBAPI = 2 + +# Increment this PATCH version before using `charmcraft publish-lib` or reset +# to 0 if you are raising the major API version +LIBPATCH = 0 + + +# Regex to locate 7-bit C1 ANSI sequences +ansi_filter = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") + + +def _cache_init(func): + def inner(*args, **kwargs): + if _Cache.cache is None: + _Cache.cache = SnapCache() + return func(*args, **kwargs) + + return inner + + +# recursive hints seems to error out pytest +JSONType = Union[Dict[str, Any], List[Any], str, int, float] + + +class SnapService: + """Data wrapper for snap services.""" + + def __init__( + self, + daemon: Optional[str] = None, + daemon_scope: Optional[str] = None, + enabled: bool = False, + active: bool = False, + activators: List[str] = [], + **kwargs, + ): + self.daemon = daemon + self.daemon_scope = kwargs.get("daemon-scope", None) or daemon_scope + self.enabled = enabled + self.active = active + self.activators = activators + + def as_dict(self) -> Dict: + """Return instance representation as dict.""" + return { + "daemon": self.daemon, + "daemon_scope": self.daemon_scope, + "enabled": self.enabled, + "active": self.active, + "activators": self.activators, + } + + +class MetaCache(type): + """MetaCache class used for initialising the snap cache.""" + + @property + def cache(cls) -> "SnapCache": + """Property for returning the snap cache.""" + return cls._cache + + @cache.setter + def cache(cls, cache: "SnapCache") -> None: + """Setter for the snap cache.""" + cls._cache = cache + + def __getitem__(cls, name) -> "Snap": + """Snap cache getter.""" + return cls._cache[name] + + +class _Cache(object, metaclass=MetaCache): + _cache = None + + +class Error(Exception): + """Base class of most errors raised by this library.""" + + def __repr__(self): + """Represent the Error class.""" + return "<{}.{} {}>".format(type(self).__module__, type(self).__name__, self.args) + + @property + def name(self): + """Return a string representation of the model plus class.""" + return "<{}.{}>".format(type(self).__module__, type(self).__name__) + + @property + def message(self): + """Return the message passed as an argument.""" + return self.args[0] + + +class SnapAPIError(Error): + """Raised when an HTTP API error occurs talking to the Snapd server.""" + + def __init__(self, body: Dict, code: int, status: str, message: str): + super().__init__(message) # Makes str(e) return message + self.body = body + self.code = code + self.status = status + self._message = message + + def __repr__(self): + """Represent the SnapAPIError class.""" + return "APIError({!r}, {!r}, {!r}, {!r})".format( + self.body, self.code, self.status, self._message + ) + + +class SnapState(Enum): + """The state of a snap on the system or in the cache.""" + + Present = "present" + Absent = "absent" + Latest = "latest" + Available = "available" + + +class SnapError(Error): + """Raised when there's an error running snap control commands.""" + + +class SnapNotFoundError(Error): + """Raised when a requested snap is not known to the system.""" + + +class Snap(object): + """Represents a snap package and its properties. + + `Snap` exposes the following properties about a snap: + - name: the name of the snap + - state: a `SnapState` representation of its install status + - channel: "stable", "candidate", "beta", and "edge" are common + - revision: a string representing the snap's revision + - confinement: "classic" or "strict" + """ + + def __init__( + self, + name, + state: SnapState, + channel: str, + revision: str, + confinement: str, + apps: Optional[List[Dict[str, str]]] = None, + cohort: Optional[str] = "", + ) -> None: + self._name = name + self._state = state + self._channel = channel + self._revision = revision + self._confinement = confinement + self._cohort = cohort + self._apps = apps or [] + self._snap_client = SnapClient() + + def __eq__(self, other) -> bool: + """Equality for comparison.""" + return isinstance(other, self.__class__) and ( + self._name, + self._revision, + ) == (other._name, other._revision) + + def __hash__(self): + """Calculate a hash for this snap.""" + return hash((self._name, self._revision)) + + def __repr__(self): + """Represent the object such that it can be reconstructed.""" + return "<{}.{}: {}>".format(self.__module__, self.__class__.__name__, self.__dict__) + + def __str__(self): + """Represent the snap object as a string.""" + return "<{}: {}-{}.{} -- {}>".format( + self.__class__.__name__, + self._name, + self._revision, + self._channel, + str(self._state), + ) + + def _snap(self, command: str, optargs: Optional[Iterable[str]] = None) -> str: + """Perform a snap operation. + + Args: + command: the snap command to execute + optargs: an (optional) list of additional arguments to pass, + commonly confinement or channel + + Raises: + SnapError if there is a problem encountered + """ + optargs = optargs or [] + _cmd = ["snap", command, self._name, *optargs] + try: + return subprocess.check_output(_cmd, universal_newlines=True) + except CalledProcessError as e: + raise SnapError( + "Snap: {!r}; command {!r} failed with output = {!r}".format( + self._name, _cmd, e.output + ) + ) + + def _snap_daemons( + self, + command: List[str], + services: Optional[List[str]] = None, + ) -> CompletedProcess: + """Perform snap app commands. + + Args: + command: the snap command to execute + services: the snap service to execute command on + + Raises: + SnapError if there is a problem encountered + """ + if services: + # an attempt to keep the command constrained to the snap instance's services + services = ["{}.{}".format(self._name, service) for service in services] + else: + services = [self._name] + + _cmd = ["snap", *command, *services] + + try: + return subprocess.run(_cmd, universal_newlines=True, check=True, capture_output=True) + except CalledProcessError as e: + raise SnapError("Could not {} for snap [{}]: {}".format(_cmd, self._name, e.stderr)) + + def get(self, key) -> str: + """Fetch a snap configuration value. + + Args: + key: the key to retrieve + """ + return self._snap("get", [key]).strip() + + def set(self, config: Dict) -> str: + """Set a snap configuration value. + + Args: + config: a dictionary containing keys and values specifying the config to set. + """ + args = ['{}="{}"'.format(key, val) for key, val in config.items()] + + return self._snap("set", [*args]) + + def unset(self, key) -> str: + """Unset a snap configuration value. + + Args: + key: the key to unset + """ + return self._snap("unset", [key]) + + def start(self, services: Optional[List[str]] = None, enable: Optional[bool] = False) -> None: + """Start a snap's services. + + Args: + services (list): (optional) list of individual snap services to start (otherwise all) + enable (bool): (optional) flag to enable snap services on start. Default `false` + """ + args = ["start", "--enable"] if enable else ["start"] + self._snap_daemons(args, services) + + def stop(self, services: Optional[List[str]] = None, disable: Optional[bool] = False) -> None: + """Stop a snap's services. + + Args: + services (list): (optional) list of individual snap services to stop (otherwise all) + disable (bool): (optional) flag to disable snap services on stop. Default `False` + """ + args = ["stop", "--disable"] if disable else ["stop"] + self._snap_daemons(args, services) + + def logs(self, services: Optional[List[str]] = None, num_lines: Optional[int] = 10) -> str: + """Fetch a snap services' logs. + + Args: + services (list): (optional) list of individual snap services to show logs from + (otherwise all) + num_lines (int): (optional) integer number of log lines to return. Default `10` + """ + args = ["logs", "-n={}".format(num_lines)] if num_lines else ["logs"] + return self._snap_daemons(args, services).stdout + + def connect( + self, plug: str, service: Optional[str] = None, slot: Optional[str] = None + ) -> None: + """Connect a plug to a slot. + + Args: + plug (str): the plug to connect + service (str): (optional) the snap service name to plug into + slot (str): (optional) the snap service slot to plug in to + + Raises: + SnapError if there is a problem encountered + """ + command = ["connect", "{}:{}".format(self._name, plug)] + + if service and slot: + command = command + ["{}:{}".format(service, slot)] + elif slot: + command = command + [slot] + + _cmd = ["snap", *command] + try: + subprocess.run(_cmd, universal_newlines=True, check=True, capture_output=True) + except CalledProcessError as e: + raise SnapError("Could not {} for snap [{}]: {}".format(_cmd, self._name, e.stderr)) + + def hold(self, duration: Optional[timedelta] = None) -> None: + """Add a refresh hold to a snap. + + Args: + duration: duration for the hold, or None (the default) to hold this snap indefinitely. + """ + hold_str = "forever" + if duration is not None: + seconds = round(duration.total_seconds()) + hold_str = f"{seconds}s" + self._snap("refresh", [f"--hold={hold_str}"]) + + def unhold(self) -> None: + """Remove the refresh hold of a snap.""" + self._snap("refresh", ["--unhold"]) + + def restart( + self, services: Optional[List[str]] = None, reload: Optional[bool] = False + ) -> None: + """Restarts a snap's services. + + Args: + services (list): (optional) list of individual snap services to restart. + (otherwise all) + reload (bool): (optional) flag to use the service reload command, if available. + Default `False` + """ + args = ["restart", "--reload"] if reload else ["restart"] + self._snap_daemons(args, services) + + def _install( + self, + channel: Optional[str] = "", + cohort: Optional[str] = "", + revision: Optional[str] = None, + ) -> None: + """Add a snap to the system. + + Args: + channel: the channel to install from + cohort: optional, the key of a cohort that this snap belongs to + revision: optional, the revision of the snap to install + """ + cohort = cohort or self._cohort + + args = [] + if self.confinement == "classic": + args.append("--classic") + if channel: + args.append('--channel="{}"'.format(channel)) + if revision: + args.append('--revision="{}"'.format(revision)) + if cohort: + args.append('--cohort="{}"'.format(cohort)) + + self._snap("install", args) + + def _refresh( + self, + channel: Optional[str] = "", + cohort: Optional[str] = "", + revision: Optional[str] = None, + leave_cohort: Optional[bool] = False, + ) -> None: + """Refresh a snap. + + Args: + channel: the channel to install from + cohort: optionally, specify a cohort. + revision: optionally, specify the revision of the snap to refresh + leave_cohort: leave the current cohort. + """ + args = [] + if channel: + args.append('--channel="{}"'.format(channel)) + + if revision: + args.append('--revision="{}"'.format(revision)) + + if not cohort: + cohort = self._cohort + + if leave_cohort: + self._cohort = "" + args.append("--leave-cohort") + elif cohort: + args.append('--cohort="{}"'.format(cohort)) + + self._snap("refresh", args) + + def _remove(self) -> str: + """Remove a snap from the system.""" + return self._snap("remove") + + @property + def name(self) -> str: + """Returns the name of the snap.""" + return self._name + + def ensure( + self, + state: SnapState, + classic: Optional[bool] = False, + channel: Optional[str] = "", + cohort: Optional[str] = "", + revision: Optional[str] = None, + ): + """Ensure that a snap is in a given state. + + Args: + state: a `SnapState` to reconcile to. + classic: an (Optional) boolean indicating whether classic confinement should be used + channel: the channel to install from + cohort: optional. Specify the key of a snap cohort. + revision: optional. the revision of the snap to install/refresh + + While both channel and revision could be specified, the underlying snap install/refresh + command will determine which one takes precedence (revision at this time) + + Raises: + SnapError if an error is encountered + """ + self._confinement = "classic" if classic or self._confinement == "classic" else "" + + if state not in (SnapState.Present, SnapState.Latest): + # We are attempting to remove this snap. + if self._state in (SnapState.Present, SnapState.Latest): + # The snap is installed, so we run _remove. + self._remove() + else: + # The snap is not installed -- no need to do anything. + pass + else: + # We are installing or refreshing a snap. + if self._state not in (SnapState.Present, SnapState.Latest): + # The snap is not installed, so we install it. + self._install(channel, cohort, revision) + else: + # The snap is installed, but we are changing it (e.g., switching channels). + self._refresh(channel, cohort, revision) + + self._update_snap_apps() + self._state = state + + def _update_snap_apps(self) -> None: + """Update a snap's apps after snap changes state.""" + try: + self._apps = self._snap_client.get_installed_snap_apps(self._name) + except SnapAPIError: + logger.debug("Unable to retrieve snap apps for {}".format(self._name)) + self._apps = [] + + @property + def present(self) -> bool: + """Report whether or not a snap is present.""" + return self._state in (SnapState.Present, SnapState.Latest) + + @property + def latest(self) -> bool: + """Report whether the snap is the most recent version.""" + return self._state is SnapState.Latest + + @property + def state(self) -> SnapState: + """Report the current snap state.""" + return self._state + + @state.setter + def state(self, state: SnapState) -> None: + """Set the snap state to a given value. + + Args: + state: a `SnapState` to reconcile the snap to. + + Raises: + SnapError if an error is encountered + """ + if self._state is not state: + self.ensure(state) + self._state = state + + @property + def revision(self) -> str: + """Returns the revision for a snap.""" + return self._revision + + @property + def channel(self) -> str: + """Returns the channel for a snap.""" + return self._channel + + @property + def confinement(self) -> str: + """Returns the confinement for a snap.""" + return self._confinement + + @property + def apps(self) -> List: + """Returns (if any) the installed apps of the snap.""" + self._update_snap_apps() + return self._apps + + @property + def services(self) -> Dict: + """Returns (if any) the installed services of the snap.""" + self._update_snap_apps() + services = {} + for app in self._apps: + if "daemon" in app: + services[app["name"]] = SnapService(**app).as_dict() + + return services + + @property + def held(self) -> bool: + """Report whether the snap has a hold.""" + info = self._snap("info") + return "hold:" in info + + +class _UnixSocketConnection(http.client.HTTPConnection): + """Implementation of HTTPConnection that connects to a named Unix socket.""" + + def __init__(self, host, timeout=None, socket_path=None): + if timeout is None: + super().__init__(host) + else: + super().__init__(host, timeout=timeout) + self.socket_path = socket_path + + def connect(self): + """Override connect to use Unix socket (instead of TCP socket).""" + if not hasattr(socket, "AF_UNIX"): + raise NotImplementedError("Unix sockets not supported on {}".format(sys.platform)) + self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.sock.connect(self.socket_path) + if self.timeout is not None: + self.sock.settimeout(self.timeout) + + +class _UnixSocketHandler(urllib.request.AbstractHTTPHandler): + """Implementation of HTTPHandler that uses a named Unix socket.""" + + def __init__(self, socket_path: str): + super().__init__() + self.socket_path = socket_path + + def http_open(self, req) -> http.client.HTTPResponse: + """Override http_open to use a Unix socket connection (instead of TCP).""" + return self.do_open(_UnixSocketConnection, req, socket_path=self.socket_path) + + +class SnapClient: + """Snapd API client to talk to HTTP over UNIX sockets. + + In order to avoid shelling out and/or involving sudo in calling the snapd API, + use a wrapper based on the Pebble Client, trimmed down to only the utility methods + needed for talking to snapd. + """ + + def __init__( + self, + socket_path: str = "/run/snapd.socket", + opener: Optional[urllib.request.OpenerDirector] = None, + base_url: str = "http://localhost/v2/", + timeout: float = 5.0, + ): + """Initialize a client instance. + + Args: + socket_path: a path to the socket on the filesystem. Defaults to /run/snap/snapd.socket + opener: specifies an opener for unix socket, if unspecified a default is used + base_url: base url for making requests to the snap client. Defaults to + http://localhost/v2/ + timeout: timeout in seconds to use when making requests to the API. Default is 5.0s. + """ + if opener is None: + opener = self._get_default_opener(socket_path) + self.opener = opener + self.base_url = base_url + self.timeout = timeout + + @classmethod + def _get_default_opener(cls, socket_path): + """Build the default opener to use for requests (HTTP over Unix socket).""" + opener = urllib.request.OpenerDirector() + opener.add_handler(_UnixSocketHandler(socket_path)) + opener.add_handler(urllib.request.HTTPDefaultErrorHandler()) + opener.add_handler(urllib.request.HTTPRedirectHandler()) + opener.add_handler(urllib.request.HTTPErrorProcessor()) + return opener + + def _request( + self, + method: str, + path: str, + query: Dict = None, + body: Dict = None, + ) -> JSONType: + """Make a JSON request to the Snapd server with the given HTTP method and path. + + If query dict is provided, it is encoded and appended as a query string + to the URL. If body dict is provided, it is serialied as JSON and used + as the HTTP body (with Content-Type: "application/json"). The resulting + body is decoded from JSON. + """ + headers = {"Accept": "application/json"} + data = None + if body is not None: + data = json.dumps(body).encode("utf-8") + headers["Content-Type"] = "application/json" + + response = self._request_raw(method, path, query, headers, data) + return json.loads(response.read().decode())["result"] + + def _request_raw( + self, + method: str, + path: str, + query: Dict = None, + headers: Dict = None, + data: bytes = None, + ) -> http.client.HTTPResponse: + """Make a request to the Snapd server; return the raw HTTPResponse object.""" + url = self.base_url + path + if query: + url = url + "?" + urllib.parse.urlencode(query) + + if headers is None: + headers = {} + request = urllib.request.Request(url, method=method, data=data, headers=headers) + + try: + response = self.opener.open(request, timeout=self.timeout) + except urllib.error.HTTPError as e: + code = e.code + status = e.reason + message = "" + try: + body = json.loads(e.read().decode())["result"] + except (IOError, ValueError, KeyError) as e2: + # Will only happen on read error or if Pebble sends invalid JSON. + body = {} + message = "{} - {}".format(type(e2).__name__, e2) + raise SnapAPIError(body, code, status, message) + except urllib.error.URLError as e: + raise SnapAPIError({}, 500, "Not found", e.reason) + return response + + def get_installed_snaps(self) -> Dict: + """Get information about currently installed snaps.""" + return self._request("GET", "snaps") + + def get_snap_information(self, name: str) -> Dict: + """Query the snap server for information about single snap.""" + return self._request("GET", "find", {"name": name})[0] + + def get_installed_snap_apps(self, name: str) -> List: + """Query the snap server for apps belonging to a named, currently installed snap.""" + return self._request("GET", "apps", {"names": name, "select": "service"}) + + +class SnapCache(Mapping): + """An abstraction to represent installed/available packages. + + When instantiated, `SnapCache` iterates through the list of installed + snaps using the `snapd` HTTP API, and a list of available snaps by reading + the filesystem to populate the cache. Information about available snaps is lazily-loaded + from the `snapd` API when requested. + """ + + def __init__(self): + if not self.snapd_installed: + raise SnapError("snapd is not installed or not in /usr/bin") from None + self._snap_client = SnapClient() + self._snap_map = {} + if self.snapd_installed: + self._load_available_snaps() + self._load_installed_snaps() + + def __contains__(self, key: str) -> bool: + """Check if a given snap is in the cache.""" + return key in self._snap_map + + def __len__(self) -> int: + """Report number of items in the snap cache.""" + return len(self._snap_map) + + def __iter__(self) -> Iterable["Snap"]: + """Provide iterator for the snap cache.""" + return iter(self._snap_map.values()) + + def __getitem__(self, snap_name: str) -> Snap: + """Return either the installed version or latest version for a given snap.""" + snap = self._snap_map.get(snap_name, None) + if snap is None: + # The snapd cache file may not have existed when _snap_map was + # populated. This is normal. + try: + self._snap_map[snap_name] = self._load_info(snap_name) + except SnapAPIError: + raise SnapNotFoundError("Snap '{}' not found!".format(snap_name)) + + return self._snap_map[snap_name] + + @property + def snapd_installed(self) -> bool: + """Check whether snapd has been installled on the system.""" + return os.path.isfile("/usr/bin/snap") + + def _load_available_snaps(self) -> None: + """Load the list of available snaps from disk. + + Leave them empty and lazily load later if asked for. + """ + if not os.path.isfile("/var/cache/snapd/names"): + # The snap catalog may not be populated yet; this is normal. + # snapd updates the cache infrequently and the cache file may not + # currently exist. + return + + with open("/var/cache/snapd/names", "r") as f: + for line in f: + if line.strip(): + self._snap_map[line.strip()] = None + + def _load_installed_snaps(self) -> None: + """Load the installed snaps into the dict.""" + installed = self._snap_client.get_installed_snaps() + + for i in installed: + snap = Snap( + name=i["name"], + state=SnapState.Latest, + channel=i["channel"], + revision=i["revision"], + confinement=i["confinement"], + apps=i.get("apps", None), + ) + self._snap_map[snap.name] = snap + + def _load_info(self, name) -> Snap: + """Load info for snaps which are not installed if requested. + + Args: + name: a string representing the name of the snap + """ + info = self._snap_client.get_snap_information(name) + + return Snap( + name=info["name"], + state=SnapState.Available, + channel=info["channel"], + revision=info["revision"], + confinement=info["confinement"], + apps=None, + ) + + +@_cache_init +def add( + snap_names: Union[str, List[str]], + state: Union[str, SnapState] = SnapState.Latest, + channel: Optional[str] = "", + classic: Optional[bool] = False, + cohort: Optional[str] = "", + revision: Optional[str] = None, +) -> Union[Snap, List[Snap]]: + """Add a snap to the system. + + Args: + snap_names: the name or names of the snaps to install + state: a string or `SnapState` representation of the desired state, one of + [`Present` or `Latest`] + channel: an (Optional) channel as a string. Defaults to 'latest' + classic: an (Optional) boolean specifying whether it should be added with classic + confinement. Default `False` + cohort: an (Optional) string specifying the snap cohort to use + revision: an (Optional) string specifying the snap revision to use + + Raises: + SnapError if some snaps failed to install or were not found. + """ + if not channel and not revision: + channel = "latest" + + snap_names = [snap_names] if type(snap_names) is str else snap_names + if not snap_names: + raise TypeError("Expected at least one snap to add, received zero!") + + if type(state) is str: + state = SnapState(state) + + return _wrap_snap_operations(snap_names, state, channel, classic, cohort, revision) + + +@_cache_init +def remove(snap_names: Union[str, List[str]]) -> Union[Snap, List[Snap]]: + """Remove specified snap(s) from the system. + + Args: + snap_names: the name or names of the snaps to install + + Raises: + SnapError if some snaps failed to install. + """ + snap_names = [snap_names] if type(snap_names) is str else snap_names + if not snap_names: + raise TypeError("Expected at least one snap to add, received zero!") + + return _wrap_snap_operations(snap_names, SnapState.Absent, "", False) + + +@_cache_init +def ensure( + snap_names: Union[str, List[str]], + state: str, + channel: Optional[str] = "", + classic: Optional[bool] = False, + cohort: Optional[str] = "", + revision: Optional[int] = None, +) -> Union[Snap, List[Snap]]: + """Ensure specified snaps are in a given state on the system. + + Args: + snap_names: the name(s) of the snaps to operate on + state: a string representation of the desired state, from `SnapState` + channel: an (Optional) channel as a string. Defaults to 'latest' + classic: an (Optional) boolean specifying whether it should be added with classic + confinement. Default `False` + cohort: an (Optional) string specifying the snap cohort to use + revision: an (Optional) integer specifying the snap revision to use + + When both channel and revision are specified, the underlying snap install/refresh + command will determine the precedence (revision at the time of adding this) + + Raises: + SnapError if the snap is not in the cache. + """ + if not revision and not channel: + channel = "latest" + + if state in ("present", "latest") or revision: + return add(snap_names, SnapState(state), channel, classic, cohort, revision) + else: + return remove(snap_names) + + +def _wrap_snap_operations( + snap_names: List[str], + state: SnapState, + channel: str, + classic: bool, + cohort: Optional[str] = "", + revision: Optional[str] = None, +) -> Union[Snap, List[Snap]]: + """Wrap common operations for bare commands.""" + snaps = {"success": [], "failed": []} + + op = "remove" if state is SnapState.Absent else "install or refresh" + + for s in snap_names: + try: + snap = _Cache[s] + if state is SnapState.Absent: + snap.ensure(state=SnapState.Absent) + else: + snap.ensure( + state=state, classic=classic, channel=channel, cohort=cohort, revision=revision + ) + snaps["success"].append(snap) + except SnapError as e: + logger.warning("Failed to {} snap {}: {}!".format(op, s, e.message)) + snaps["failed"].append(s) + except SnapNotFoundError: + logger.warning("Snap '{}' not found in cache!".format(s)) + snaps["failed"].append(s) + + if len(snaps["failed"]): + raise SnapError( + "Failed to install or refresh snap(s): {}".format(", ".join(list(snaps["failed"]))) + ) + + return snaps["success"] if len(snaps["success"]) > 1 else snaps["success"][0] + + +def install_local( + filename: str, classic: Optional[bool] = False, dangerous: Optional[bool] = False +) -> Snap: + """Perform a snap operation. + + Args: + filename: the path to a local .snap file to install + classic: whether to use classic confinement + dangerous: whether --dangerous should be passed to install snaps without a signature + + Raises: + SnapError if there is a problem encountered + """ + _cmd = [ + "snap", + "install", + filename, + ] + if classic: + _cmd.append("--classic") + if dangerous: + _cmd.append("--dangerous") + try: + result = subprocess.check_output(_cmd, universal_newlines=True).splitlines()[-1] + snap_name, _ = result.split(" ", 1) + snap_name = ansi_filter.sub("", snap_name) + + c = SnapCache() + + try: + return c[snap_name] + except SnapAPIError as e: + logger.error( + "Could not find snap {} when querying Snapd socket: {}".format(snap_name, e.body) + ) + raise SnapError("Failed to find snap {} in Snap cache".format(snap_name)) + except CalledProcessError as e: + raise SnapError("Could not install snap {}: {}".format(filename, e.output)) + + +def _system_set(config_item: str, value: str) -> None: + """Set system snapd config values. + + Args: + config_item: name of snap system setting. E.g. 'refresh.hold' + value: value to assign + """ + _cmd = ["snap", "set", "system", "{}={}".format(config_item, value)] + try: + subprocess.check_call(_cmd, universal_newlines=True) + except CalledProcessError: + raise SnapError("Failed setting system config '{}' to '{}'".format(config_item, value)) + + +def hold_refresh(days: int = 90, forever: bool = False) -> bool: + """Set the system-wide snap refresh hold. + + Args: + days: number of days to hold system refreshes for. Maximum 90. Set to zero to remove hold. + forever: if True, will set a hold forever. + """ + if not isinstance(forever, bool): + raise TypeError("forever must be a bool") + if not isinstance(days, int): + raise TypeError("days must be an int") + if forever: + _system_set("refresh.hold", "forever") + logger.info("Set system-wide snap refresh hold to: forever") + elif days == 0: + _system_set("refresh.hold", "") + logger.info("Removed system-wide snap refresh hold") + else: + # Currently the snap daemon can only hold for a maximum of 90 days + if not 1 <= days <= 90: + raise ValueError("days must be between 1 and 90") + # Add the number of days to current time + target_date = datetime.now(timezone.utc).astimezone() + timedelta(days=days) + # Format for the correct datetime format + hold_date = target_date.strftime("%Y-%m-%dT%H:%M:%S%z") + # Python dumps the offset in format '+0100', we need '+01:00' + hold_date = "{0}:{1}".format(hold_date[:-2], hold_date[-2:]) + # Actually set the hold date + _system_set("refresh.hold", hold_date) + logger.info("Set system-wide snap refresh hold to: %s", hold_date) diff --git a/src/charm.py b/src/charm.py deleted file mode 100755 index d5420d3f..00000000 --- a/src/charm.py +++ /dev/null @@ -1,212 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Canonical Ltd. -# See LICENSE file for licensing details. -# -# Learn more at: https://juju.is/docs/sdk - -"""MySQL-Router machine charm.""" - -import json -import logging -import subprocess -from typing import Optional - -from charms.operator_libs_linux.v1 import systemd -from ops.charm import CharmBase, RelationChangedEvent -from ops.main import main -from ops.model import ActiveStatus, BlockedStatus, MaintenanceStatus, WaitingStatus - -from constants import ( - LEGACY_SHARED_DB, - MYSQL_ROUTER_LEADER_BOOTSTRAPED, - MYSQL_ROUTER_REQUIRES_DATA, - PEER, -) -from mysql_router_helpers import ( - MySQLRouter, - MySQLRouterBootstrapError, - MySQLRouterInstallAndConfigureError, -) -from relations.database_provides import DatabaseProvidesRelation -from relations.database_requires import DatabaseRequiresRelation -from relations.shared_db import SharedDBRelation - -logger = logging.getLogger(__name__) - - -class MySQLRouterOperatorCharm(CharmBase): - """Charm the service.""" - - def __init__(self, *args): - super().__init__(*args) - - self.framework.observe(self.on.install, self._on_install) - self.framework.observe(self.on.upgrade_charm, self._on_upgrade_charm) - self.framework.observe(self.on[PEER].relation_changed, self._on_peer_relation_changed) - - self.shared_db_relation = SharedDBRelation(self) - self.database_requires_relation = DatabaseRequiresRelation(self) - self.database_provides_relation = DatabaseProvidesRelation(self) - - # ======================= - # Properties - # ======================= - - @property - def _peers(self): - """Retrieve the peer relation.""" - return self.model.get_relation(PEER) - - @property - def app_peer_data(self): - """Application peer data object.""" - if not self._peers: - return {} - - return self._peers.data[self.app] - - @property - def unit_peer_data(self): - """Unit peer data object.""" - if not self._peers: - return {} - - return self._peers.data[self.unit] - - # ======================= - # Helpers - # ======================= - - def _get_secret(self, scope: str, key: str) -> Optional[str]: - """Get secret from the peer relation databag.""" - if scope == "unit": - return self.unit_peer_data.get(key, None) - elif scope == "app": - return self.app_peer_data.get(key, None) - else: - raise RuntimeError("Unknown secret scope.") - - def _set_secret(self, scope: str, key: str, value: Optional[str]) -> None: - """Set secret in the peer relation databag.""" - if scope == "unit": - if not value: - del self.unit_peer_data[key] - return - self.unit_peer_data.update({key: value}) - elif scope == "app": - if not value: - del self.app_peer_data[key] - return - self.app_peer_data.update({key: value}) - else: - raise RuntimeError("Unknown secret scope.") - - # ======================= - # Handlers - # ======================= - - def _on_install(self, _) -> None: - """Install the mysql-router package.""" - self.unit.status = MaintenanceStatus("Installing packages") - - try: - MySQLRouter.install_and_configure_mysql_router() - except MySQLRouterInstallAndConfigureError: - self.unit.status = BlockedStatus("Failed to install mysqlrouter") - return - - for port in [6446, 6447, 6448, 6449]: - try: - subprocess.check_call(["open-port", f"{port}/tcp"]) - except subprocess.CalledProcessError: - logger.exception(f"failed to open port {port}") - - self.unit.status = WaitingStatus("Waiting for relations") - - def _on_upgrade_charm(self, _) -> None: - """Update the mysql-router config on charm upgrade.""" - if isinstance(self.unit.status, ActiveStatus): - self.unit.status = MaintenanceStatus("Upgrading charm") - - requires_data = json.loads(self.app_peer_data.get(MYSQL_ROUTER_REQUIRES_DATA)) - related_app_name = ( - self.shared_db_relation._get_related_app_name() - if self.shared_db_relation._shared_db_relation_exists() - else self.database_provides_relation._get_related_app_name() - ) - - try: - MySQLRouter.bootstrap_and_start_mysql_router( - requires_data["username"], - self._get_secret("app", "database-password"), - related_app_name, - requires_data["endpoints"].split(",")[0].split(":")[0], - "3306", - force=True, - ) - except MySQLRouterBootstrapError: - self.unit.status = BlockedStatus("Failed to bootstrap mysqlrouter") - return - - self.unit.status = ActiveStatus() - - def _on_peer_relation_changed(self, event: RelationChangedEvent) -> None: - """Handle the peer relation changed event. - - If a peer is being joined for the first time, bootstrap mysqlrouter - and share relevant connection data with the related app. - """ - if isinstance(self.unit.status, WaitingStatus) and self.app_peer_data.get( - MYSQL_ROUTER_LEADER_BOOTSTRAPED - ): - try: - mysqlrouter_running = MySQLRouter.is_mysqlrouter_running() - except systemd.SystemdError as e: - logger.exception("Failed to check if mysqlrouter with systemd", exc_info=e) - self.unit.status = BlockedStatus("Failed to bootstrap mysqlrouter") - return - - if not self.unit.is_leader() and not mysqlrouter_running: - # Occasionally, the related unit is not in the relation databag if this handler - # is invoked in short succession after the peer joins the cluster - shared_db_relation_exists = self.shared_db_relation._shared_db_relation_exists() - shared_db_related_unit_name = self.shared_db_relation._get_related_unit_name() - if shared_db_relation_exists and not shared_db_related_unit_name: - event.defer() - return - - requires_data = json.loads(self.app_peer_data.get(MYSQL_ROUTER_REQUIRES_DATA)) - related_app_name = ( - self.shared_db_relation._get_related_app_name() - if self.shared_db_relation._shared_db_relation_exists() - else self.database_provides_relation._get_related_app_name() - ) - - try: - MySQLRouter.bootstrap_and_start_mysql_router( - requires_data["username"], - self._get_secret("app", "database-password"), - related_app_name, - requires_data["endpoints"].split(",")[0].split(":")[0], - "3306", - ) - except MySQLRouterBootstrapError: - self.unit.status = BlockedStatus("Failed to bootstrap mysqlrouter") - return - - if shared_db_relation_exists: - self.model.relations[LEGACY_SHARED_DB][0].data[self.unit].update( - { - "allowed_units": shared_db_related_unit_name, - "db_host": "127.0.0.1", - "db_port": "3306", - "password": self._get_secret("app", "application-password"), - "wait_timeout": "3600", - } - ) - - self.unit.status = ActiveStatus() - - -if __name__ == "__main__": - main(MySQLRouterOperatorCharm) diff --git a/src/constants.py b/src/constants.py deleted file mode 100644 index fe0ddd99..00000000 --- a/src/constants.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2022 Canonical Ltd. -# See LICENSE file for licensing details. - -"""File containing constants to be used in the charm.""" - -MYSQL_ROUTER_APT_PACKAGE = "mysql-router" -MYSQL_ROUTER_GROUP = "mysql" -MYSQL_ROUTER_USER = "mysql" -MYSQL_HOME_DIRECTORY = "/var/lib/mysql" -PEER = "mysql-router-peers" -DATABASE_REQUIRES_RELATION = "backend-database" -DATABASE_PROVIDES_RELATION = "database" -MYSQL_ROUTER_LEADER_BOOTSTRAPED = "mysql-router-leader-bootstraped" -MYSQL_ROUTER_UNIT_TEMPLATE = "templates/mysqlrouter.service.j2" -MYSQL_ROUTER_SERVICE_NAME = "mysqlrouter.service" -MYSQL_ROUTER_SYSTEMD_DIRECTORY = "/etc/systemd/system" -MYSQL_ROUTER_REQUIRES_DATA = "requires-database-data" -MYSQL_ROUTER_PROVIDES_DATA = "provides-database-data" -PASSWORD_LENGTH = 24 -# Constants for legacy relations -LEGACY_SHARED_DB = "shared-db" -LEGACY_SHARED_DB_DATA = "shared-db-data" -LEGACY_SHARED_DB_DATA_FORWARDED = "shared-db-data-forwarded" diff --git a/src/mysql_router_helpers.py b/src/mysql_router_helpers.py deleted file mode 100644 index 9ba39d73..00000000 --- a/src/mysql_router_helpers.py +++ /dev/null @@ -1,222 +0,0 @@ -# Copyright 2022 Canonical Ltd. -# See LICENSE file for licensing details. - -"""Helper class to manage the MySQL Router lifecycle.""" - -import grp -import logging -import os -import pwd -import subprocess - -import jinja2 -import mysql.connector -from charms.operator_libs_linux.v0 import apt, passwd -from charms.operator_libs_linux.v1 import systemd - -from constants import ( - MYSQL_HOME_DIRECTORY, - MYSQL_ROUTER_APT_PACKAGE, - MYSQL_ROUTER_GROUP, - MYSQL_ROUTER_SERVICE_NAME, - MYSQL_ROUTER_SYSTEMD_DIRECTORY, - MYSQL_ROUTER_UNIT_TEMPLATE, - MYSQL_ROUTER_USER, -) - -logger = logging.getLogger(__name__) - - -class Error(Exception): - """Base class for exceptions in this module.""" - - def __repr__(self): - """String representation of the Error class.""" - return "<{}.{} {}>".format(type(self).__module__, type(self).__name__, self.args) - - @property - def name(self): - """Return a string representation of the model plus class.""" - return "<{}.{}>".format(type(self).__module__, type(self).__name__) - - @property - def message(self): - """Return the message passed as an argument.""" - return self.args[0] - - -class MySQLRouterInstallAndConfigureError(Error): - """Exception raised when there is an issue installing MySQLRouter.""" - - -class MySQLRouterBootstrapError(Error): - """Exception raised when there is an issue bootstrapping MySQLRouter.""" - - -class MySQLRouterCreateUserWithDatabasePrivilegesError(Error): - """Exception raised when there is an issue creating a database scoped user.""" - - -class MySQLRouter: - """Class to encapsulate all operations related to MySQLRouter.""" - - @staticmethod - def install_and_configure_mysql_router() -> None: - """Install and configure MySQLRouter.""" - try: - apt.update() - apt.add_package(MYSQL_ROUTER_APT_PACKAGE) - - if not passwd.group_exists(MYSQL_ROUTER_GROUP): - passwd.add_group(MYSQL_ROUTER_GROUP, system_group=True) - - if not passwd.user_exists(MYSQL_ROUTER_USER): - passwd.add_user( - MYSQL_ROUTER_USER, - shell="/usr/sbin/nologin", - system_user=True, - primary_group=MYSQL_ROUTER_GROUP, - home_dir=MYSQL_HOME_DIRECTORY, - ) - - if not os.path.exists(MYSQL_HOME_DIRECTORY): - os.makedirs(MYSQL_HOME_DIRECTORY, mode=0o755, exist_ok=True) - - user_id = pwd.getpwnam(MYSQL_ROUTER_USER).pw_uid - group_id = grp.getgrnam(MYSQL_ROUTER_GROUP).gr_gid - - os.chown(MYSQL_HOME_DIRECTORY, user_id, group_id) - except Exception as e: - logger.exception(f"Failed to install the {MYSQL_ROUTER_APT_PACKAGE} apt package.") - raise MySQLRouterInstallAndConfigureError(e.stderr) - - @staticmethod - def _render_and_copy_mysqlrouter_systemd_unit_file(app_name): - with open(MYSQL_ROUTER_UNIT_TEMPLATE, "r") as file: - template = jinja2.Template(file.read()) - - rendered_template = template.render(charm_app_name=app_name) - systemd_file_path = f"{MYSQL_ROUTER_SYSTEMD_DIRECTORY}/mysqlrouter.service" - - with open(systemd_file_path, "w+") as file: - file.write(rendered_template) - - os.chmod(systemd_file_path, 0o644) - mysql_user = pwd.getpwnam(MYSQL_ROUTER_USER) - os.chown(systemd_file_path, uid=mysql_user.pw_uid, gid=mysql_user.pw_gid) - - @staticmethod - def bootstrap_and_start_mysql_router( - user, - password, - name, - db_host, - port, - force=False, - ) -> None: - """Bootstrap MySQLRouter and register the service with systemd. - - Args: - user: The user to connect to the database with - password: The password to connect to the database with - name: The name of application that will use mysqlrouter - db_host: The hostname of the database to connect to - port: The port at which to bootstrap mysqlrouter to - force: Overwrite existing config if any - - Raises: - MySQLRouterBootstrapError - if there is an issue bootstrapping MySQLRouter - """ - # server_ssl_mode is set to enforce unix_socket connections to be established - # via encryption (see more at - # https://dev.mysql.com/doc/refman/8.0/en/caching-sha2-pluggable-authentication.html) - bootstrap_mysqlrouter_command = [ - "sudo", - "/usr/bin/mysqlrouter", - "--user", - MYSQL_ROUTER_USER, - "--name", - name, - "--bootstrap", - f"{user}:{password}@{db_host}", - "--directory", - f"{MYSQL_HOME_DIRECTORY}/{name}", - "--conf-use-sockets", - "--conf-bind-address", - "127.0.0.1", - "--conf-base-port", - f"{port}", - "--conf-set-option", - "DEFAULT.server_ssl_mode=PREFERRED", - "--conf-set-option", - "http_server.bind_address=127.0.0.1", - "--conf-use-gr-notifications", - ] - - if force: - bootstrap_mysqlrouter_command.append("--force") - - try: - subprocess.run(bootstrap_mysqlrouter_command) - - subprocess.run(f"sudo chmod 755 {MYSQL_HOME_DIRECTORY}/{name}".split()) - - MySQLRouter._render_and_copy_mysqlrouter_systemd_unit_file(name) - - if not systemd.daemon_reload(): - error_message = "Failed to load the mysqlrouter systemd service" - logger.exception(error_message) - raise MySQLRouterBootstrapError(error_message) - - systemd.service_start(MYSQL_ROUTER_SERVICE_NAME) - if not MySQLRouter.is_mysqlrouter_running(): - error_message = "Failed to start the mysqlrouter systemd service" - logger.exception(error_message) - raise MySQLRouterBootstrapError(error_message) - except subprocess.CalledProcessError as e: - logger.exception("Failed to bootstrap mysqlrouter") - raise MySQLRouterBootstrapError(e.stderr) - except systemd.SystemdError: - error_message = "Failed to set up mysqlrouter as a systemd service" - logger.exception(error_message) - raise MySQLRouterBootstrapError(error_message) - - @staticmethod - def is_mysqlrouter_running() -> bool: - """Indicates whether MySQLRouter is running as a systemd service.""" - return systemd.service_running(MYSQL_ROUTER_SERVICE_NAME) - - @staticmethod - def create_user_with_database_privileges( - username, password, hostname, database, db_username, db_password, db_host, db_port - ) -> None: - """Create a database scope mysql user. - - Args: - username: Username of the user to create - password: Password of the user to create - hostname: Hostname of the user to create - database: Database that the user should be restricted to - db_username: The user to connect to the database with - db_password: The password to use to connect to the database - db_host: The host name of the database - db_port: The port for the database - - Raises: - MySQLRouterCreateUserWithDatabasePrivilegesError - - when there is an issue creating a database scoped user - """ - try: - connection = mysql.connector.connect( - user=db_username, password=db_password, host=db_host, port=db_port - ) - cursor = connection.cursor() - - cursor.execute(f"CREATE USER `{username}`@`{hostname}` IDENTIFIED BY '{password}'") - cursor.execute(f"GRANT ALL PRIVILEGES ON `{database}`.* TO `{username}`@`{hostname}`") - - cursor.close() - connection.close() - except mysql.connector.Error as e: - logger.exception("Failed to create user scoped to a database") - raise MySQLRouterCreateUserWithDatabasePrivilegesError(e.msg) diff --git a/src/relations/database_provides.py b/src/relations/database_provides.py deleted file mode 100644 index 13b2b43f..00000000 --- a/src/relations/database_provides.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright 2022 Canonical Ltd. -# See LICENSE file for licensing details. - -"""Library containing the implementation of the database provides relation.""" - -import json -import logging - -from charms.data_platform_libs.v0.database_provides import ( - DatabaseProvides, - DatabaseRequestedEvent, -) -from ops.framework import Object -from ops.model import Application, BlockedStatus - -from constants import ( - DATABASE_PROVIDES_RELATION, - MYSQL_ROUTER_LEADER_BOOTSTRAPED, - MYSQL_ROUTER_PROVIDES_DATA, - MYSQL_ROUTER_REQUIRES_DATA, - PASSWORD_LENGTH, - PEER, -) -from mysql_router_helpers import ( - MySQLRouter, - MySQLRouterBootstrapError, - MySQLRouterCreateUserWithDatabasePrivilegesError, -) -from utils import generate_random_password - -logger = logging.getLogger(__name__) - - -class DatabaseProvidesRelation(Object): - """Encapsulation of the relation between mysqlrouter and the consumer application.""" - - def __init__(self, charm): - super().__init__(charm, DATABASE_PROVIDES_RELATION) - - self.charm = charm - self.database = DatabaseProvides(self.charm, relation_name=DATABASE_PROVIDES_RELATION) - - self.framework.observe(self.database.on.database_requested, self._on_database_requested) - - self.framework.observe( - self.charm.on[PEER].relation_changed, self._on_peer_relation_changed - ) - - # ======================= - # Helpers - # ======================= - - def _database_provides_relation_exists(self) -> bool: - database_provides_relations = self.charm.model.relations.get(DATABASE_PROVIDES_RELATION) - return bool(database_provides_relations) - - def _get_related_app_name(self) -> str: - """Helper to get the name of the related `database-provides` application.""" - if not self._database_provides_relation_exists(): - return None - - for key in self.charm.model.relations[DATABASE_PROVIDES_RELATION][0].data: - if type(key) == Application and key.name != self.charm.app.name: - return key.name - - return None - - # ======================= - # Handlers - # ======================= - - def _on_database_requested(self, event: DatabaseRequestedEvent) -> None: - """Handle the database requested event.""" - if not self.charm.unit.is_leader(): - return - - self.charm.app_peer_data[MYSQL_ROUTER_PROVIDES_DATA] = json.dumps( - { - "database": event.database, - "extra_user_roles": event.extra_user_roles, - } - ) - - def _on_peer_relation_changed(self, _) -> None: - """Handle the peer relation changed event.""" - if not self.charm.unit.is_leader(): - return - - if self.charm.app_peer_data.get(MYSQL_ROUTER_LEADER_BOOTSTRAPED): - return - - if not self.charm.app_peer_data.get(MYSQL_ROUTER_REQUIRES_DATA): - return - - database_provides_relations = self.charm.model.relations.get(DATABASE_PROVIDES_RELATION) - if not database_provides_relations: - return - - parsed_database_requires_data = json.loads( - self.charm.app_peer_data[MYSQL_ROUTER_REQUIRES_DATA] - ) - parsed_database_provides_data = json.loads( - self.charm.app_peer_data[MYSQL_ROUTER_PROVIDES_DATA] - ) - - db_host = parsed_database_requires_data["endpoints"].split(",")[0].split(":")[0] - mysqlrouter_username = parsed_database_requires_data["username"] - mysqlrouter_user_password = self.charm._get_secret("app", "database-password") - related_app_name = self._get_related_app_name() - - try: - MySQLRouter.bootstrap_and_start_mysql_router( - mysqlrouter_username, - mysqlrouter_user_password, - related_app_name, - db_host, - "3306", - ) - except MySQLRouterBootstrapError: - self.charm.unit.status = BlockedStatus("Failed to bootstrap mysqlrouter") - return - - provides_relation_id = database_provides_relations[0].id - application_username = f"application-user-{provides_relation_id}" - application_password = generate_random_password(PASSWORD_LENGTH) - - try: - MySQLRouter.create_user_with_database_privileges( - application_username, - application_password, - "%", - parsed_database_provides_data["database"], - mysqlrouter_username, - mysqlrouter_user_password, - db_host, - "3306", - ) - except MySQLRouterCreateUserWithDatabasePrivilegesError: - self.charm.unit.status = BlockedStatus("Failed to create application user") - return - - self.charm._set_secret( - "app", f"application-user-{provides_relation_id}-password", application_password - ) - - self.database.set_credentials( - provides_relation_id, application_username, application_password - ) - self.database.set_endpoints( - provides_relation_id, f"file:///var/lib/mysql/{related_app_name}/mysql.sock" - ) - self.database.set_read_only_endpoints( - provides_relation_id, f"file:///var/lib/mysql/{related_app_name}/mysqlro.sock" - ) - - self.charm.app_peer_data[MYSQL_ROUTER_LEADER_BOOTSTRAPED] = "true" diff --git a/src/relations/database_requires.py b/src/relations/database_requires.py deleted file mode 100644 index 3b31782a..00000000 --- a/src/relations/database_requires.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2022 Canonical Ltd. -# See LICENSE file for licensing details. - -"""Library containing the implementation of the database requires relation.""" - -import json -import logging -from typing import Dict - -from charms.data_platform_libs.v0.database_requires import ( - DatabaseCreatedEvent, - DatabaseEndpointsChangedEvent, - DatabaseRequires, -) -from ops.charm import RelationJoinedEvent -from ops.framework import Object -from ops.model import BlockedStatus - -from constants import ( - DATABASE_REQUIRES_RELATION, - LEGACY_SHARED_DB_DATA, - MYSQL_ROUTER_PROVIDES_DATA, - MYSQL_ROUTER_REQUIRES_DATA, -) - -logger = logging.getLogger(__name__) - - -class DatabaseRequiresRelation(Object): - """Encapsulation of the relation between mysqlrouter and mysql database.""" - - def __init__(self, charm): - super().__init__(charm, DATABASE_REQUIRES_RELATION) - - self.charm = charm - - self.framework.observe( - self.charm.on[DATABASE_REQUIRES_RELATION].relation_joined, - self._on_database_requires_relation_joined, - ) - - shared_db_data = self._get_shared_db_data() - provides_data = self._get_provides_data() - - if provides_data and shared_db_data: - logger.error("Both shared-db and database relations created") - self.charm.unit.status = BlockedStatus("Both shared-db and database relations exists") - return - - if not shared_db_data and not provides_data: - return - - database_name = shared_db_data["database"] if shared_db_data else provides_data["database"] - - self.database_requires_relation = DatabaseRequires( - self.charm, - relation_name=DATABASE_REQUIRES_RELATION, - database_name=database_name, - extra_user_roles="mysqlrouter", - ) - self.framework.observe( - self.database_requires_relation.on.database_created, self._on_database_created - ) - self.framework.observe( - self.database_requires_relation.on.endpoints_changed, self._on_endpoints_changed - ) - - # ======================= - # Helpers - # ======================= - - def _get_shared_db_data(self) -> Dict: - """Helper to get the `shared-db` relation data from the app peer databag.""" - peers = self.charm._peers - if not peers: - return None - - shared_db_data = self.charm.app_peer_data.get(LEGACY_SHARED_DB_DATA) - if not shared_db_data: - return None - - return json.loads(shared_db_data) - - def _get_provides_data(self) -> Dict: - """Helper to get the provides relation data from the app peer databag.""" - peers = self.charm._peers - if not peers: - return None - - provides_data = self.charm.app_peer_data.get(MYSQL_ROUTER_PROVIDES_DATA) - if not provides_data: - return None - - return json.loads(provides_data) - - # ======================= - # Handlers - # ======================= - - def _on_database_requires_relation_joined(self, event: RelationJoinedEvent) -> None: - """Handle the backend-database relation joined event. - - Waits until the database (provides) relation with the application is formed - before triggering the database_requires relation joined event (which will - request the database). - """ - if not self.charm.unit.is_leader(): - return - - provides_data = self._get_provides_data() - if not provides_data: - event.defer() - return - - self.database_requires_relation._on_relation_joined_event(event) - - def _on_database_created(self, event: DatabaseCreatedEvent) -> None: - """Handle the database created event. - - Set the relation data in the app peer databag for the `shared-db`/`database-provides` - code to be able to bootstrap mysqlrouter, create an application - user and relay the application user credentials to the consumer application. - """ - if not self.charm.unit.is_leader(): - return - - self.charm.app_peer_data[MYSQL_ROUTER_REQUIRES_DATA] = json.dumps( - { - "username": event.username, - "endpoints": event.endpoints, - } - ) - - self.charm._set_secret("app", "database-password", event.password) - - def _on_endpoints_changed(self, event: DatabaseEndpointsChangedEvent) -> None: - """Handle the database endpoints changed event. - - Update the MYSQL_ROUTER_REQUIRES_DATA in the app peer databag so that - bootstraps of future units work. - """ - if not self.charm.unit.is_leader(): - return - - if self.charm.app_peer_data.get(MYSQL_ROUTER_REQUIRES_DATA): - requires_data = json.loads(self.charm.app_peer_data[MYSQL_ROUTER_REQUIRES_DATA]) - - requires_data["endpoints"] = event.endpoints - - self.charm.app_peer_data[MYSQL_ROUTER_REQUIRES_DATA] = json.dumps(requires_data) diff --git a/src/relations/shared_db.py b/src/relations/shared_db.py deleted file mode 100644 index b1ea0995..00000000 --- a/src/relations/shared_db.py +++ /dev/null @@ -1,181 +0,0 @@ -# Copyright 2022 Canonical Ltd. -# See LICENSE file for licensing details. - -"""Library containing the implementation of the legacy shared-db relation.""" - -import json -import logging - -from ops.charm import RelationChangedEvent -from ops.framework import Object -from ops.model import Application, BlockedStatus, Unit - -from constants import ( - LEGACY_SHARED_DB, - LEGACY_SHARED_DB_DATA, - LEGACY_SHARED_DB_DATA_FORWARDED, - MYSQL_ROUTER_LEADER_BOOTSTRAPED, - MYSQL_ROUTER_REQUIRES_DATA, - PASSWORD_LENGTH, - PEER, -) -from mysql_router_helpers import ( - MySQLRouter, - MySQLRouterBootstrapError, - MySQLRouterCreateUserWithDatabasePrivilegesError, -) -from utils import generate_random_password - -logger = logging.getLogger(__name__) - - -class SharedDBRelation(Object): - """Legacy `shared-db` relation implementation.""" - - def __init__(self, charm): - super().__init__(charm, LEGACY_SHARED_DB) - - self.charm = charm - - self.framework.observe( - self.charm.on[LEGACY_SHARED_DB].relation_changed, self._on_shared_db_relation_changed - ) - self.framework.observe( - self.charm.on[PEER].relation_changed, self._on_peer_relation_changed - ) - - # ======================= - # Helpers - # ======================= - - def _shared_db_relation_exists(self) -> bool: - """Indicates whether a shared-db relation exists.""" - shared_db_relations = self.charm.model.relations.get(LEGACY_SHARED_DB) - return bool(shared_db_relations) - - def _get_related_app_name(self) -> str: - """Helper to get the name of the related `shared-db` application.""" - if not self._shared_db_relation_exists(): - return None - - for key in self.charm.model.relations[LEGACY_SHARED_DB][0].data: - if type(key) == Application and key.name != self.charm.app.name: - return key.name - - return None - - def _get_related_unit_name(self) -> str: - """Helper to get the name of the related `shared-db` unit.""" - if not self._shared_db_relation_exists(): - return None - - for key in self.charm.model.relations[LEGACY_SHARED_DB][0].data: - if type(key) == Unit and key.app.name != self.charm.app.name: - return key.name - - return None - - # ======================= - # Handlers - # ======================= - - def _on_shared_db_relation_changed(self, event: RelationChangedEvent) -> None: - """Handle the shared-db relation changed event.""" - if not self.charm.unit.is_leader(): - return - - # Forward incoming relation data into the app peer databag - # (so that the relation with the database can be formed with the appropriate parameters) - if not self.charm.app_peer_data.get(LEGACY_SHARED_DB_DATA_FORWARDED): - changed_unit_databag = event.relation.data[event.unit] - - database = changed_unit_databag.get("database") - hostname = changed_unit_databag.get("hostname") - username = changed_unit_databag.get("username") - - if not (database and hostname and username): - logger.debug( - "Waiting for `shared-db` databag to be populated by client application" - ) - event.defer() - return - - logger.warning("DEPRECATION WARNING - `shared-db` is a legacy interface") - - self.charm.app_peer_data[LEGACY_SHARED_DB_DATA] = json.dumps( - { - "database": database, - "hostname": hostname, - "username": username, - } - ) - self.charm.app_peer_data[LEGACY_SHARED_DB_DATA_FORWARDED] = "true" - - def _on_peer_relation_changed(self, _) -> None: - """Handler the peer relation changed event. - - Once the `database` relation has been formed, the appropriate database - credentials will be stored in the app peer databag. These credentials - can be used to bootstrap mysqlrouter and create the application user. - """ - if not self.charm.unit.is_leader(): - return - - # Only execute if mysqlrouter has not already been bootstrapped - if self.charm.app_peer_data.get(MYSQL_ROUTER_LEADER_BOOTSTRAPED): - return - - if not self.charm.app_peer_data.get(MYSQL_ROUTER_REQUIRES_DATA): - return - - if not self._shared_db_relation_exists(): - return - - parsed_requires_data = json.loads(self.charm.app_peer_data[MYSQL_ROUTER_REQUIRES_DATA]) - database_password = self.charm._get_secret("app", "database-password") - parsed_shared_db_data = json.loads(self.charm.app_peer_data[LEGACY_SHARED_DB_DATA]) - - db_host = parsed_requires_data["endpoints"].split(",")[0].split(":")[0] - related_app_name = self._get_related_app_name() - application_password = generate_random_password(PASSWORD_LENGTH) - - try: - MySQLRouter.bootstrap_and_start_mysql_router( - parsed_requires_data["username"], - database_password, - related_app_name, - db_host, - "3306", - ) - except MySQLRouterBootstrapError: - self.charm.unit.status = BlockedStatus("Failed to bootstrap mysqlrouter") - return - - try: - MySQLRouter.create_user_with_database_privileges( - parsed_shared_db_data["username"], - application_password, - "%", - parsed_shared_db_data["database"], - parsed_requires_data["username"], - database_password, - db_host, - "3306", - ) - except MySQLRouterCreateUserWithDatabasePrivilegesError: - self.charm.unit.status = BlockedStatus("Failed to create application user") - return - - self.charm._set_secret("app", "application-password", application_password) - - unit_databag = self.charm.model.relations[LEGACY_SHARED_DB][0].data[self.charm.unit] - updates = { - "allowed_units": self._get_related_unit_name(), - "db_host": "127.0.0.1", - "db_port": "3306", - "password": application_password, - "wait_timeout": "3600", - } - unit_databag.update(updates) - - self.charm.app_peer_data[MYSQL_ROUTER_LEADER_BOOTSTRAPED] = "true" diff --git a/src/utils.py b/src/utils.py deleted file mode 100644 index 9428b840..00000000 --- a/src/utils.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright 2022 Canonical Ltd. -# See LICENSE file for licensing details. - -"""A collection of utility functions that are used in the charm.""" - -import secrets -import string - - -def generate_random_password(length: int) -> str: - """Randomly generate a string intended to be used as a password. - - Args: - length: length of the randomly generated string to be returned - - Returns: - a string with random letters and digits of length specified - """ - choices = string.ascii_letters + string.digits - return "".join([secrets.choice(choices) for i in range(length)]) diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py deleted file mode 100644 index 4fb6309c..00000000 --- a/tests/unit/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright 2022 Canonical Ltd. -# See LICENSE file for licensing details. - -import ops.testing - -# Since ops>=1.4 this enables better connection tracking. -# See: More at https://juju.is/docs/sdk/testing#heading--simulate-can-connect -ops.testing.SIMULATE_CAN_CONNECT = True diff --git a/tests/unit/test_charm.py b/tests/unit/test_charm.py deleted file mode 100644 index a95660d5..00000000 --- a/tests/unit/test_charm.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright 2022 Canonical Ltd. -# See LICENSE file for licensing details. - -import json -import unittest -from unittest.mock import patch - -from ops.model import ActiveStatus, BlockedStatus, WaitingStatus -from ops.testing import Harness - -from charm import MySQLRouterOperatorCharm -from constants import MYSQL_ROUTER_REQUIRES_DATA, PEER -from mysql_router_helpers import ( - MySQLRouterBootstrapError, - MySQLRouterInstallAndConfigureError, -) - - -class TestCharm(unittest.TestCase): - def setUp(self): - self.harness = Harness(MySQLRouterOperatorCharm) - self.addCleanup(self.harness.cleanup) - self.peer_relation_id = self.harness.add_relation(f"{PEER}", f"{PEER}") - self.harness.begin() - self.charm = self.harness.charm - - def test_get_secret(self): - self.harness.set_leader() - - # Test application scope - self.assertIsNone(self.charm._get_secret("app", "password")) - self.harness.update_relation_data( - self.peer_relation_id, self.charm.app.name, {"password": "test-password"} - ) - self.assertEqual(self.charm._get_secret("app", "password"), "test-password") - - # Test unit scope - self.assertIsNone(self.charm._get_secret("unit", "password")) - self.harness.update_relation_data( - self.peer_relation_id, self.charm.unit.name, {"password": "test-password"} - ) - self.assertEqual(self.charm._get_secret("unit", "password"), "test-password") - - def test_set_secret(self): - self.harness.set_leader() - - # Test application scope - self.assertNotIn( - "password", self.harness.get_relation_data(self.peer_relation_id, self.charm.app.name) - ) - self.charm._set_secret("app", "password", "test-password") - self.assertEqual( - self.harness.get_relation_data(self.peer_relation_id, self.charm.app.name)["password"], - "test-password", - ) - - # Test unit scope - self.assertNotIn( - "password", self.harness.get_relation_data(self.peer_relation_id, self.charm.unit.name) - ) - self.charm._set_secret("unit", "password", "test-password") - self.assertEqual( - self.harness.get_relation_data(self.peer_relation_id, self.charm.unit.name)[ - "password" - ], - "test-password", - ) - - @patch("subprocess.check_call") - @patch("mysql_router_helpers.MySQLRouter.install_and_configure_mysql_router") - def test_on_install(self, _install_and_configure_mysql_router, _check_call): - self.charm.on.install.emit() - - self.assertTrue(isinstance(self.harness.model.unit.status, WaitingStatus)) - - @patch( - "mysql_router_helpers.MySQLRouter.install_and_configure_mysql_router", - side_effect=MySQLRouterInstallAndConfigureError(), - ) - def test_on_install_exception(self, _install_and_configure_mysql_router): - self.charm.on.install.emit() - - self.assertTrue(isinstance(self.harness.model.unit.status, BlockedStatus)) - - @patch("charm.DatabaseProvidesRelation._get_related_app_name") - @patch("charm.MySQLRouterOperatorCharm._get_secret") - @patch("mysql_router_helpers.MySQLRouter.bootstrap_and_start_mysql_router") - def test_on_upgrade_charm( - self, bootstrap_and_start_mysql_router, get_secret, get_related_app_name - ): - self.charm.unit.status = ActiveStatus() - get_secret.return_value = "s3kr1t" - get_related_app_name.return_value = "testapp" - self.charm.app_peer_data[MYSQL_ROUTER_REQUIRES_DATA] = json.dumps( - { - "username": "test_user", - "endpoints": "10.10.0.1:3306,10.10.0.2:3306", - } - ) - self.charm.on.upgrade_charm.emit() - - self.assertTrue(isinstance(self.harness.model.unit.status, ActiveStatus)) - bootstrap_and_start_mysql_router.assert_called_with( - "test_user", "s3kr1t", "testapp", "10.10.0.1", "3306", force=True - ) - - @patch("mysql_router_helpers.MySQLRouter.bootstrap_and_start_mysql_router") - def test_on_upgrade_charm_waiting(self, bootstrap_and_start_mysql_router): - self.charm.unit.status = WaitingStatus() - self.charm.on.upgrade_charm.emit() - - self.assertTrue(isinstance(self.harness.model.unit.status, WaitingStatus)) - bootstrap_and_start_mysql_router.assert_not_called() - - @patch("charm.DatabaseProvidesRelation._get_related_app_name") - @patch("charm.MySQLRouterOperatorCharm._get_secret") - @patch("mysql_router_helpers.MySQLRouter.bootstrap_and_start_mysql_router") - def test_on_upgrade_charm_error( - self, bootstrap_and_start_mysql_router, get_secret, get_related_app_name - ): - bootstrap_and_start_mysql_router.side_effect = MySQLRouterBootstrapError() - get_secret.return_value = "s3kr1t" - get_related_app_name.return_value = "testapp" - self.charm.unit.status = ActiveStatus() - self.charm.app_peer_data[MYSQL_ROUTER_REQUIRES_DATA] = json.dumps( - { - "username": "test_user", - "endpoints": "10.10.0.1:3306,10.10.0.2:3306", - } - ) - self.charm.on.upgrade_charm.emit() - - self.assertTrue(isinstance(self.harness.model.unit.status, BlockedStatus)) - bootstrap_and_start_mysql_router.assert_called_with( - "test_user", "s3kr1t", "testapp", "10.10.0.1", "3306", force=True - ) diff --git a/tests/unit/test_mysql_router_helpers.py b/tests/unit/test_mysql_router_helpers.py deleted file mode 100644 index 9d08cd2c..00000000 --- a/tests/unit/test_mysql_router_helpers.py +++ /dev/null @@ -1,189 +0,0 @@ -# Copyright 2022 Canonical Ltd. -# See LICENSE file for licensing details. - -import unittest -from subprocess import CalledProcessError -from unittest.mock import call, patch - -from charms.operator_libs_linux.v1.systemd import SystemdError - -from constants import MYSQL_ROUTER_SERVICE_NAME -from mysql_router_helpers import MySQLRouter, MySQLRouterBootstrapError - -bootstrap_cmd = [ - "sudo", - "/usr/bin/mysqlrouter", - "--user", - "mysql", - "--name", - "testapp", - "--bootstrap", - "test_user:qweqwe@10.10.0.1", - "--directory", - "/var/lib/mysql/testapp", - "--conf-use-sockets", - "--conf-bind-address", - "127.0.0.1", - "--conf-base-port", - "3306", - "--conf-set-option", - "DEFAULT.server_ssl_mode=PREFERRED", - "--conf-set-option", - "http_server.bind_address=127.0.0.1", - "--conf-use-gr-notifications", -] -chmod_cmd = [ - "sudo", - "chmod", - "755", - "/var/lib/mysql/testapp", -] - - -class TestMysqlRouterHelpers(unittest.TestCase): - @patch("mysql_router_helpers.MySQLRouter._render_and_copy_mysqlrouter_systemd_unit_file") - @patch("mysql_router_helpers.systemd") - @patch("mysql_router_helpers.subprocess.run") - def test_bootstrap_and_start_mysql_router(self, run, systemd, render_and_copy): - MySQLRouter.bootstrap_and_start_mysql_router( - "test_user", "qweqwe", "testapp", "10.10.0.1", "3306" - ) - - self.assertEqual( - sorted(run.mock_calls), - sorted( - [ - call(bootstrap_cmd), - call(chmod_cmd), - ] - ), - ) - render_and_copy.assert_called_with("testapp") - systemd.daemon_reload.assert_called_with() - systemd.service_start.assert_called_with(MYSQL_ROUTER_SERVICE_NAME) - - @patch("mysql_router_helpers.MySQLRouter._render_and_copy_mysqlrouter_systemd_unit_file") - @patch("mysql_router_helpers.systemd") - @patch("mysql_router_helpers.subprocess.run") - def test_bootstrap_and_start_mysql_router_force(self, run, systemd, render_and_copy): - MySQLRouter.bootstrap_and_start_mysql_router( - "test_user", "qweqwe", "testapp", "10.10.0.1", "3306", force=True - ) - - self.assertEqual( - sorted(run.mock_calls), - sorted( - [ - call(bootstrap_cmd + ["--force"]), - call(chmod_cmd), - ] - ), - ) - render_and_copy.assert_called_with("testapp") - systemd.daemon_reload.assert_called_with() - systemd.service_start.assert_called_with(MYSQL_ROUTER_SERVICE_NAME) - - @patch("mysql_router_helpers.logger") - @patch("mysql_router_helpers.MySQLRouter._render_and_copy_mysqlrouter_systemd_unit_file") - @patch("mysql_router_helpers.systemd") - @patch("mysql_router_helpers.subprocess.run") - def test_bootstrap_and_start_mysql_router_subprocess_error( - self, run, systemd, render_and_copy, logger - ): - e = CalledProcessError(1, bootstrap_cmd) - run.side_effect = e - with self.assertRaises(MySQLRouterBootstrapError): - MySQLRouter.bootstrap_and_start_mysql_router( - "test_user", "qweqwe", "testapp", "10.10.0.1", "3306" - ) - - run.assert_called_with(bootstrap_cmd) - render_and_copy.assert_not_called() - systemd.daemon_reload.assert_not_called() - systemd.service_start.assert_not_called() - logger.exception.assert_called_with("Failed to bootstrap mysqlrouter") - - @patch("mysql_router_helpers.logger") - @patch("mysql_router_helpers.MySQLRouter._render_and_copy_mysqlrouter_systemd_unit_file") - @patch("mysql_router_helpers.systemd.service_start") - @patch("mysql_router_helpers.systemd.daemon_reload") - @patch("mysql_router_helpers.subprocess.run") - def test_bootstrap_and_start_mysql_router_systemd_error( - self, run, daemon_reload, service_start, render_and_copy, logger - ): - e = SystemdError() - daemon_reload.side_effect = e - with self.assertRaises(MySQLRouterBootstrapError): - MySQLRouter.bootstrap_and_start_mysql_router( - "test_user", "qweqwe", "testapp", "10.10.0.1", "3306" - ) - - self.assertEqual( - sorted(run.mock_calls), - sorted( - [ - call(bootstrap_cmd), - call(chmod_cmd), - ] - ), - ) - render_and_copy.assert_called_with("testapp") - daemon_reload.assert_called_with() - service_start.assert_not_called() - logger.exception.assert_called_with("Failed to set up mysqlrouter as a systemd service") - - @patch("mysql_router_helpers.logger") - @patch("mysql_router_helpers.MySQLRouter._render_and_copy_mysqlrouter_systemd_unit_file") - @patch("mysql_router_helpers.systemd.service_start") - @patch("mysql_router_helpers.systemd.daemon_reload") - @patch("mysql_router_helpers.subprocess.run") - def test_bootstrap_and_start_mysql_router_no_daemon_reload( - self, run, daemon_reload, service_start, render_and_copy, logger - ): - daemon_reload.return_value = False - with self.assertRaises(MySQLRouterBootstrapError): - MySQLRouter.bootstrap_and_start_mysql_router( - "test_user", "qweqwe", "testapp", "10.10.0.1", "3306" - ) - - self.assertEqual( - sorted(run.mock_calls), - sorted( - [ - call(bootstrap_cmd), - call(chmod_cmd), - ] - ), - ) - render_and_copy.assert_called_with("testapp") - daemon_reload.assert_called_with() - service_start.assert_not_called() - logger.exception.assert_called_with("Failed to load the mysqlrouter systemd service") - - @patch("mysql_router_helpers.logger") - @patch("mysql_router_helpers.MySQLRouter._render_and_copy_mysqlrouter_systemd_unit_file") - @patch("mysql_router_helpers.systemd.service_start") - @patch("mysql_router_helpers.systemd.daemon_reload") - @patch("mysql_router_helpers.subprocess.run") - def test_bootstrap_and_start_mysql_router_no_service_start( - self, run, daemon_reload, service_start, render_and_copy, logger - ): - service_start.return_value = False - with self.assertRaises(MySQLRouterBootstrapError): - MySQLRouter.bootstrap_and_start_mysql_router( - "test_user", "qweqwe", "testapp", "10.10.0.1", "3306" - ) - - self.assertEqual( - sorted(run.mock_calls), - sorted( - [ - call(bootstrap_cmd), - call(chmod_cmd), - ] - ), - ) - render_and_copy.assert_called_with("testapp") - daemon_reload.assert_called_with() - service_start.assert_called_with(MYSQL_ROUTER_SERVICE_NAME) - logger.exception.assert_called_with("Failed to start the mysqlrouter systemd service") From 0dfee4ac670d56387381cb27850fd41cdd9615e1 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Thu, 1 Jun 2023 16:36:04 +0000 Subject: [PATCH 02/57] Copy /src/ from k8s charm --- src/charm.py | 255 +++++++++++++++++++++++ src/mysql_shell.py | 142 +++++++++++++ src/relations/database_provides.py | 235 +++++++++++++++++++++ src/relations/database_requires.py | 117 +++++++++++ src/relations/remote_databag.py | 48 +++++ src/relations/tls.py | 304 +++++++++++++++++++++++++++ src/status_exception.py | 14 ++ src/workload.py | 324 +++++++++++++++++++++++++++++ 8 files changed, 1439 insertions(+) create mode 100755 src/charm.py create mode 100644 src/mysql_shell.py create mode 100644 src/relations/database_provides.py create mode 100644 src/relations/database_requires.py create mode 100644 src/relations/remote_databag.py create mode 100644 src/relations/tls.py create mode 100644 src/status_exception.py create mode 100644 src/workload.py diff --git a/src/charm.py b/src/charm.py new file mode 100755 index 00000000..d31327a6 --- /dev/null +++ b/src/charm.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python3 +# Copyright 2022 Canonical Ltd. +# See LICENSE file for licensing details. +# +# Learn more at: https://juju.is/docs/sdk + +"""MySQL Router kubernetes (k8s) charm""" + +import logging +import socket + +import lightkube +import lightkube.models.core_v1 +import lightkube.models.meta_v1 +import lightkube.resources.core_v1 +import ops +import tenacity + +import relations.database_provides +import relations.database_requires +import relations.tls +import workload + +logger = logging.getLogger(__name__) + + +class MySQLRouterOperatorCharm(ops.CharmBase): + """Operator charm for MySQL Router""" + + def __init__(self, *args) -> None: + super().__init__(*args) + + self.database_requires = relations.database_requires.RelationEndpoint(self) + + self.database_provides = relations.database_provides.RelationEndpoint(self) + + self.framework.observe(self.on.install, self._on_install) + self.framework.observe(self.on.start, self._on_start) + self.framework.observe( + getattr(self.on, "mysql_router_pebble_ready"), self._on_mysql_router_pebble_ready + ) + self.framework.observe(self.on.leader_elected, self._on_leader_elected) + + # Start workload after pod restart + self.framework.observe(self.on.upgrade_charm, self.reconcile_database_relations) + + self.tls = relations.tls.RelationEndpoint(self) + + def get_workload(self, *, event): + """MySQL Router workload""" + container = self.unit.get_container(workload.Workload.CONTAINER_NAME) + if connection_info := self.database_requires.get_connection_info(event=event): + return workload.AuthenticatedWorkload( + _container=container, + _connection_info=connection_info, + _charm=self, + ) + return workload.Workload(_container=container) + + @property + def model_service_domain(self): + """K8s service domain for Juju model""" + # Example: "mysql-router-k8s-0.mysql-router-k8s-endpoints.my-model.svc.cluster.local" + fqdn = socket.getfqdn() + # Example: "mysql-router-k8s-0.mysql-router-k8s-endpoints." + prefix = f"{self.unit.name.replace('/', '-')}.{self.app.name}-endpoints." + assert fqdn.startswith(f"{prefix}{self.model.name}.") + # Example: my-model.svc.cluster.local + return fqdn.removeprefix(prefix) + + @property + def _endpoint(self) -> str: + """K8s endpoint for MySQL Router""" + # Example: mysql-router-k8s.my-model.svc.cluster.local + return f"{self.app.name}.{self.model_service_domain}" + + @staticmethod + def _prioritize_statuses(statuses: list[ops.StatusBase]) -> ops.StatusBase: + """Report the highest priority status. + + (Statuses of the same type are reported in the order they were added to `statuses`) + """ + status_priority = ( + ops.BlockedStatus, + ops.WaitingStatus, + ops.MaintenanceStatus, + # Catch any unknown status type + ops.StatusBase, + ) + for status_type in status_priority: + for status in statuses: + if isinstance(status, status_type): + return status + return ops.ActiveStatus() + + def _determine_app_status(self, *, event) -> ops.StatusBase: + """Report app status.""" + statuses = [] + for endpoint in (self.database_requires, self.database_provides): + if status := endpoint.get_status(event): + statuses.append(status) + return self._prioritize_statuses(statuses) + + def _determine_unit_status(self, *, event) -> ops.StatusBase: + """Report unit status.""" + statuses = [] + if not self.get_workload(event=event).container_ready: + statuses.append(ops.MaintenanceStatus("Waiting for container")) + return self._prioritize_statuses(statuses) + + def set_status(self, *, event) -> None: + """Set charm status.""" + if self.unit.is_leader(): + self.app.status = self._determine_app_status(event=event) + logger.debug(f"Set app status to {self.app.status}") + self.unit.status = self._determine_unit_status(event=event) + logger.debug(f"Set unit status to {self.unit.status}") + + def wait_until_mysql_router_ready(self) -> None: + """Wait until a connection to MySQL Router is possible. + + Retry every 5 seconds for up to 30 seconds. + """ + logger.debug("Waiting until MySQL Router is ready") + self.unit.status = ops.WaitingStatus("MySQL Router starting") + try: + for attempt in tenacity.Retrying( + reraise=True, + stop=tenacity.stop_after_delay(30), + wait=tenacity.wait_fixed(5), + ): + with attempt: + for port in (6446, 6447): + with socket.socket() as s: + assert s.connect_ex(("localhost", port)) == 0 + except AssertionError: + logger.exception("Unable to connect to MySQL Router") + raise + else: + logger.debug("MySQL Router is ready") + + def _patch_service(self, *, name: str, ro_port: int, rw_port: int) -> None: + """Patch Juju-created k8s service. + + The k8s service will be tied to pod-0 so that the service is auto cleaned by + k8s when the last pod is scaled down. + + Args: + name: The name of the service. + ro_port: The read only port. + rw_port: The read write port. + """ + logger.debug(f"Patching k8s service {name=}, {ro_port=}, {rw_port=}") + client = lightkube.Client() + pod0 = client.get( + res=lightkube.resources.core_v1.Pod, + name=self.app.name + "-0", + namespace=self.model.name, + ) + service = lightkube.resources.core_v1.Service( + metadata=lightkube.models.meta_v1.ObjectMeta( + name=name, + namespace=self.model.name, + ownerReferences=pod0.metadata.ownerReferences, + labels={ + "app.kubernetes.io/name": self.app.name, + }, + ), + spec=lightkube.models.core_v1.ServiceSpec( + ports=[ + lightkube.models.core_v1.ServicePort( + name="mysql-ro", + port=ro_port, + targetPort=ro_port, + ), + lightkube.models.core_v1.ServicePort( + name="mysql-rw", + port=rw_port, + targetPort=rw_port, + ), + ], + selector={"app.kubernetes.io/name": self.app.name}, + ), + ) + client.patch( + res=lightkube.resources.core_v1.Service, + obj=service, + name=service.metadata.name, + namespace=service.metadata.namespace, + force=True, + field_manager=self.model.app.name, + ) + logger.debug(f"Patched k8s service {name=}, {ro_port=}, {rw_port=}") + + # ======================= + # Handlers + # ======================= + + def reconcile_database_relations(self, event=None) -> None: + """Handle database requires/provides events.""" + workload_ = self.get_workload(event=event) + logger.debug( + "State of reconcile " + f"{self.unit.is_leader()=}, " + f"{isinstance(workload_, workload.AuthenticatedWorkload)=}, " + f"{workload_.container_ready=}, " + f"{self.database_requires.is_relation_breaking(event)=}, " + f"{isinstance(event, ops.UpgradeCharmEvent)=}" + ) + if self.unit.is_leader() and self.database_requires.is_relation_breaking(event): + self.database_provides.delete_all_databags() + elif ( + self.unit.is_leader() + and isinstance(workload_, workload.AuthenticatedWorkload) + and workload_.container_ready + ): + self.database_provides.reconcile_users( + event=event, + router_endpoint=self._endpoint, + shell=workload_.shell, + ) + if isinstance(workload_, workload.AuthenticatedWorkload) and workload_.container_ready: + if isinstance(event, ops.UpgradeCharmEvent): + # Pod restart (https://juju.is/docs/sdk/start-event#heading--emission-sequence) + workload_.cleanup_after_pod_restart() + workload_.enable(tls=self.tls.certificate_saved, unit_name=self.unit.name) + elif workload_.container_ready: + workload_.disable() + self.set_status(event=event) + + def _on_install(self, _) -> None: + """Patch existing k8s service to include read-write and read-only services.""" + if not self.unit.is_leader(): + return + try: + self._patch_service(name=self.app.name, ro_port=6447, rw_port=6446) + except lightkube.ApiError: + logger.exception("Failed to patch k8s service") + raise + + def _on_start(self, _) -> None: + # Set status on first start if no relations active + self.set_status(event=None) + + def _on_mysql_router_pebble_ready(self, _) -> None: + self.unit.set_workload_version(self.get_workload(event=None).version) + self.reconcile_database_relations() + + def _on_leader_elected(self, _) -> None: + # Update app status + self.set_status(event=None) + + +if __name__ == "__main__": + ops.main.main(MySQLRouterOperatorCharm) diff --git a/src/mysql_shell.py b/src/mysql_shell.py new file mode 100644 index 00000000..c152cb6d --- /dev/null +++ b/src/mysql_shell.py @@ -0,0 +1,142 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +"""MySQL Shell in Python execution mode + +https://dev.mysql.com/doc/mysql-shell/8.0/en/ +""" + +import dataclasses +import json +import logging +import secrets +import string + +import ops + +_PASSWORD_LENGTH = 24 +logger = logging.getLogger(__name__) + + +# TODO python3.10 min version: Add `(kw_only=True)` +@dataclasses.dataclass +class Shell: + """MySQL Shell connected to MySQL cluster""" + + _container: ops.Container + username: str + _password: str + _host: str + _port: str + + _TEMPORARY_SCRIPT_FILE = "/tmp/script.py" + + def _run_commands(self, commands: list[str]) -> None: + """Connect to MySQL cluster and run commands.""" + # Redact password from log + logged_commands = commands.copy() + logged_commands.insert( + 0, f"shell.connect('{self.username}:***@{self._host}:{self._port}')" + ) + + commands.insert( + 0, f"shell.connect('{self.username}:{self._password}@{self._host}:{self._port}')" + ) + self._container.push(self._TEMPORARY_SCRIPT_FILE, "\n".join(commands)) + try: + process = self._container.exec( + ["mysqlsh", "--no-wizard", "--python", "--file", self._TEMPORARY_SCRIPT_FILE] + ) + process.wait_output() + except ops.pebble.ExecError as e: + logger.exception(f"Failed to run {logged_commands=}\nstderr:\n{e.stderr}\n") + raise + finally: + self._container.remove_path(self._TEMPORARY_SCRIPT_FILE) + + def _run_sql(self, sql_statements: list[str]) -> None: + """Connect to MySQL cluster and execute SQL.""" + commands = [] + for statement in sql_statements: + # Escape double quote (") characters in statement + statement = statement.replace('"', r"\"") + commands.append('session.run_sql("' + statement + '")') + self._run_commands(commands) + + @staticmethod + def _generate_password() -> str: + choices = string.ascii_letters + string.digits + return "".join(secrets.choice(choices) for _ in range(_PASSWORD_LENGTH)) + + def _get_attributes(self, additional_attributes: dict = None) -> str: + """Attributes for (MySQL) users created by this charm + + If the relation with the MySQL charm is broken, the MySQL charm will use this attribute + to delete all users created by this charm. + """ + attributes = {"created_by_user": self.username} + if additional_attributes: + attributes.update(additional_attributes) + return json.dumps(attributes) + + def create_application_database_and_user(self, *, username: str, database: str) -> str: + """Create database and user for related database_provides application.""" + attributes = self._get_attributes() + logger.debug(f"Creating {database=} and {username=} with {attributes=}") + password = self._generate_password() + self._run_sql( + [ + f"CREATE DATABASE IF NOT EXISTS `{database}`", + f"CREATE USER `{username}` IDENTIFIED BY '{password}' ATTRIBUTE '{attributes}'", + f"GRANT ALL PRIVILEGES ON `{database}`.* TO `{username}`", + ] + ) + logger.debug(f"Created {database=} and {username=} with {attributes=}") + return password + + def add_attributes_to_mysql_router_user( + self, *, username: str, router_id: str, unit_name: str + ) -> None: + """Add attributes to user created during MySQL Router bootstrap.""" + attributes = self._get_attributes( + {"router_id": router_id, "created_by_juju_unit": unit_name} + ) + logger.debug(f"Adding {attributes=} to {username=}") + self._run_sql([f"ALTER USER `{username}` ATTRIBUTE '{attributes}'"]) + logger.debug(f"Added {attributes=} to {username=}") + + def delete_user(self, username: str) -> None: + """Delete user.""" + logger.debug(f"Deleting {username=}") + self._run_sql([f"DROP USER `{username}`"]) + logger.debug(f"Deleted {username=}") + + def delete_router_user_after_pod_restart(self, router_id: str) -> None: + """Delete MySQL Router user created by a previous instance of this unit. + + Before pod restart, the charm does not have an opportunity to delete the MySQL Router user. + During MySQL Router bootstrap, a new user is created. Before bootstrap, the old user + should be deleted. + """ + logger.debug(f"Deleting MySQL Router user {router_id=} created by {self.username=}") + self._run_sql( + [ + f"SELECT CONCAT('DROP USER ', GROUP_CONCAT(QUOTE(USER), '@', QUOTE(HOST))) INTO @sql FROM INFORMATION_SCHEMA.USER_ATTRIBUTES WHERE ATTRIBUTE->'$.created_by_user'='{self.username}' AND ATTRIBUTE->'$.router_id'='{router_id}'", + "PREPARE stmt FROM @sql", + "EXECUTE stmt", + "DEALLOCATE PREPARE stmt", + ] + ) + logger.debug(f"Deleted MySQL Router user {router_id=} created by {self.username=}") + + def remove_router_from_cluster_metadata(self, router_id: str) -> None: + """Remove MySQL Router from InnoDB Cluster metadata. + + On pod restart, MySQL Router bootstrap will fail without `--force` if cluster metadata + already exists for the router ID. + """ + logger.debug(f"Removing {router_id=} from cluster metadata") + self._run_commands( + ["cluster = dba.get_cluster()", f'cluster.remove_router_metadata("{router_id}")'] + ) + logger.debug(f"Removed {router_id=} from cluster metadata") diff --git a/src/relations/database_provides.py b/src/relations/database_provides.py new file mode 100644 index 00000000..7788a504 --- /dev/null +++ b/src/relations/database_provides.py @@ -0,0 +1,235 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Relation(s) to one or more application charms""" + +import logging +import typing + +import charms.data_platform_libs.v0.data_interfaces as data_interfaces +import ops + +import mysql_shell +import relations.remote_databag as remote_databag +import status_exception + +if typing.TYPE_CHECKING: + import charm + +logger = logging.getLogger(__name__) + + +class _RelationBreaking(Exception): + """Relation will be broken after the current event is handled""" + + +class _UnsupportedExtraUserRole(status_exception.StatusException): + """Application charm requested unsupported extra user role""" + + def __init__(self, *, app_name: str, endpoint_name: str) -> None: + message = ( + f"{app_name} app requested unsupported extra user role on {endpoint_name} endpoint" + ) + logger.warning(message) + super().__init__(ops.BlockedStatus(message)) + + +class _Relation: + """Relation to one application charm""" + + def __init__(self, *, relation: ops.Relation) -> None: + self._id = relation.id + + def __eq__(self, other) -> bool: + if not isinstance(other, _Relation): + return False + return self._id == other._id + + def _get_username(self, database_requires_username: str) -> str: + """Database username""" + # Prefix username with username from database requires relation. + # This ensures a unique username if MySQL Router is deployed in a different Juju model + # from MySQL. + # (Relation IDs are only unique within a Juju model.) + return f"{database_requires_username}-{self._id}" + + +class _RelationThatRequestedUser(_Relation): + """Related application charm that has requested a database & user""" + + def __init__( + self, *, relation: ops.Relation, interface: data_interfaces.DatabaseProvides, event + ) -> None: + super().__init__(relation=relation) + self._interface = interface + if isinstance(event, ops.RelationBrokenEvent) and event.relation.id == self._id: + raise _RelationBreaking + # Application charm databag + databag = remote_databag.RemoteDatabag(interface=interface, relation=relation) + self._database: str = databag["database"] + if databag.get("extra-user-roles"): + raise _UnsupportedExtraUserRole( + app_name=relation.app.name, endpoint_name=relation.name + ) + + def _set_databag(self, *, username: str, password: str, router_endpoint: str) -> None: + """Share connection information with application charm.""" + read_write_endpoint = f"{router_endpoint}:6446" + read_only_endpoint = f"{router_endpoint}:6447" + logger.debug( + f"Setting databag {self._id=} {self._database=}, {username=}, {read_write_endpoint=}, {read_only_endpoint=}" + ) + self._interface.set_database(self._id, self._database) + self._interface.set_credentials(self._id, username, password) + self._interface.set_endpoints(self._id, read_write_endpoint) + self._interface.set_read_only_endpoints(self._id, read_only_endpoint) + logger.debug( + f"Set databag {self._id=} {self._database=}, {username=}, {read_write_endpoint=}, {read_only_endpoint=}" + ) + + def create_database_and_user(self, *, router_endpoint: str, shell: mysql_shell.Shell) -> None: + """Create database & user and update databag.""" + username = self._get_username(shell.username) + password = shell.create_application_database_and_user( + username=username, database=self._database + ) + self._set_databag(username=username, password=password, router_endpoint=router_endpoint) + + +class _UserNotCreated(Exception): + """Database & user has not been provided to related application charm""" + + +class _RelationWithCreatedUser(_Relation): + """Related application charm that has been provided with a database & user""" + + def __init__( + self, *, relation: ops.Relation, interface: data_interfaces.DatabaseProvides + ) -> None: + super().__init__(relation=relation) + self._local_databag = relation.data[interface.local_app] + for key in ("database", "username", "password", "endpoints"): + if key not in self._local_databag: + raise _UserNotCreated + + def delete_databag(self) -> None: + """Remove connection information from databag.""" + logger.debug(f"Deleting databag {self._id=}") + self._local_databag.clear() + logger.debug(f"Deleted databag {self._id=}") + + def delete_user(self, *, shell: mysql_shell.Shell) -> None: + """Delete user and update databag.""" + self.delete_databag() + shell.delete_user(self._get_username(shell.username)) + + +class RelationEndpoint: + """Relation endpoint for application charm(s)""" + + NAME = "database" + + def __init__(self, charm_: "charm.MySQLRouterOperatorCharm") -> None: + self._interface = data_interfaces.DatabaseProvides(charm_, relation_name=self.NAME) + charm_.framework.observe( + charm_.on[self.NAME].relation_joined, + charm_.reconcile_database_relations, + ) + charm_.framework.observe( + self._interface.on.database_requested, + charm_.reconcile_database_relations, + ) + charm_.framework.observe( + charm_.on[self.NAME].relation_broken, + charm_.reconcile_database_relations, + ) + + @property + def _created_users(self) -> list[_RelationWithCreatedUser]: + created_users = [] + for relation in self._interface.relations: + try: + created_users.append( + _RelationWithCreatedUser(relation=relation, interface=self._interface) + ) + except _UserNotCreated: + pass + return created_users + + def reconcile_users( + self, + *, + event, + router_endpoint: str, + shell: mysql_shell.Shell, + ) -> None: + """Create requested users and delete inactive users. + + When the relation to the MySQL charm is broken, the MySQL charm will delete all users + created by this charm. Therefore, this charm does not need to delete users when that + relation is broken. + """ + logger.debug(f"Reconciling users {event=}, {router_endpoint=}") + requested_users = [] + for relation in self._interface.relations: + try: + requested_users.append( + _RelationThatRequestedUser( + relation=relation, interface=self._interface, event=event + ) + ) + except ( + _RelationBreaking, + remote_databag.IncompleteDatabag, + _UnsupportedExtraUserRole, + ): + pass + logger.debug(f"State of reconcile users {requested_users=}, {self._created_users=}") + for relation in requested_users: + if relation not in self._created_users: + relation.create_database_and_user(router_endpoint=router_endpoint, shell=shell) + for relation in self._created_users: + if relation not in requested_users: + relation.delete_user(shell=shell) + logger.debug(f"Reconciled users {event=}, {router_endpoint=}") + + def delete_all_databags(self) -> None: + """Remove connection information from all databags. + + Called when relation with MySQL is breaking + + When the MySQL relation is re-established, it could be a different MySQL cluster—new users + will need to be created. + """ + logger.debug("Deleting all application databags") + for relation in self._created_users: + # MySQL charm will delete user; just delete databag + relation.delete_databag() + logger.debug("Deleted all application databags") + + def get_status(self, event) -> typing.Optional[ops.StatusBase]: + """Report non-active status.""" + requested_users = [] + exceptions: list[status_exception.StatusException] = [] + for relation in self._interface.relations: + try: + requested_users.append( + _RelationThatRequestedUser( + relation=relation, interface=self._interface, event=event + ) + ) + except _RelationBreaking: + pass + except (remote_databag.IncompleteDatabag, _UnsupportedExtraUserRole) as exception: + exceptions.append(exception) + # Always report unsupported extra user role + for exception in exceptions: + if isinstance(exception, _UnsupportedExtraUserRole): + return exception.status + if requested_users: + # At least one relation is active—do not report about inactive relations + return + for exception in exceptions: + if isinstance(exception, remote_databag.IncompleteDatabag): + return exception.status + return ops.BlockedStatus(f"Missing relation: {self.NAME}") diff --git a/src/relations/database_requires.py b/src/relations/database_requires.py new file mode 100644 index 00000000..307d73ee --- /dev/null +++ b/src/relations/database_requires.py @@ -0,0 +1,117 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Relation to MySQL charm""" + +import logging +import typing + +import charms.data_platform_libs.v0.data_interfaces as data_interfaces +import ops + +import relations.remote_databag as remote_databag +import status_exception + +if typing.TYPE_CHECKING: + import charm + +logger = logging.getLogger(__name__) + + +class _MissingRelation(status_exception.StatusException): + """Relation to MySQL charm does (or will) not exist""" + + def __init__(self, *, endpoint_name: str) -> None: + super().__init__(ops.BlockedStatus(f"Missing relation: {endpoint_name}")) + + +class _RelationBreaking(_MissingRelation): + """Relation to MySQL charm will be broken after the current event is handled + + Relation currently exists + """ + + +class ConnectionInformation: + """Information for connection to MySQL cluster + + User has permission to: + - Create databases & users + - Grant all privileges on a database to a user + (Different from user that MySQL Router runs with after bootstrap.) + """ + + def __init__(self, *, interface: data_interfaces.DatabaseRequires, event) -> None: + relations = interface.relations + endpoint_name = interface.relation_name + if not relations: + raise _MissingRelation(endpoint_name=endpoint_name) + assert len(relations) == 1 + relation = relations[0] + if isinstance(event, ops.RelationBrokenEvent) and event.relation.id == relation.id: + # Relation will be broken after the current event is handled + raise _RelationBreaking(endpoint_name=endpoint_name) + # MySQL charm databag + databag = remote_databag.RemoteDatabag(interface=interface, relation=relation) + endpoints = databag["endpoints"].split(",") + assert len(endpoints) == 1 + endpoint = endpoints[0] + self.host: str = endpoint.split(":")[0] + self.port: str = endpoint.split(":")[1] + self.username: str = databag["username"] + self.password: str = databag["password"] + + +class RelationEndpoint: + """Relation endpoint for MySQL charm""" + + NAME = "backend-database" + + def __init__(self, charm_: "charm.MySQLRouterOperatorCharm") -> None: + self._interface = data_interfaces.DatabaseRequires( + charm_, + relation_name=self.NAME, + # Database name disregarded by MySQL charm if "mysqlrouter" extra user role requested + database_name="mysql_innodb_cluster_metadata", + extra_user_roles="mysqlrouter", + ) + charm_.framework.observe( + charm_.on[self.NAME].relation_created, + charm_.reconcile_database_relations, + ) + charm_.framework.observe( + self._interface.on.database_created, + charm_.reconcile_database_relations, + ) + charm_.framework.observe( + self._interface.on.endpoints_changed, + charm_.reconcile_database_relations, + ) + charm_.framework.observe( + charm_.on[self.NAME].relation_broken, + charm_.reconcile_database_relations, + ) + + def get_connection_info(self, *, event) -> typing.Optional[ConnectionInformation]: + """Information for connection to MySQL cluster""" + try: + return ConnectionInformation(interface=self._interface, event=event) + except (_MissingRelation, remote_databag.IncompleteDatabag): + return + + def is_relation_breaking(self, event) -> bool: + """Whether relation will be broken after the current event is handled""" + try: + ConnectionInformation(interface=self._interface, event=event) + except _RelationBreaking: + return True + except (_MissingRelation, remote_databag.IncompleteDatabag): + pass + return False + + def get_status(self, event) -> typing.Optional[ops.StatusBase]: + """Report non-active status.""" + try: + ConnectionInformation(interface=self._interface, event=event) + except (_MissingRelation, remote_databag.IncompleteDatabag) as exception: + return exception.status diff --git a/src/relations/remote_databag.py b/src/relations/remote_databag.py new file mode 100644 index 00000000..6d6cb491 --- /dev/null +++ b/src/relations/remote_databag.py @@ -0,0 +1,48 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Relation databag for remote application""" + +import logging +import typing + +import charms.data_platform_libs.v0.data_interfaces as data_interfaces +import ops + +import status_exception + +logger = logging.getLogger(__name__) + + +class IncompleteDatabag(status_exception.StatusException): + """Databag is missing required key""" + + def __init__(self, *, app_name: str, endpoint_name: str) -> None: + super().__init__( + ops.WaitingStatus(f"Waiting for {app_name} app on {endpoint_name} endpoint") + ) + + +class RemoteDatabag(dict): + """Relation databag for remote application""" + + def __init__( + self, + # TODO python3.10 min version: Use `|` instead of `typing.Union` + interface: typing.Union[ + data_interfaces.DatabaseRequires, data_interfaces.DatabaseProvides + ], + relation: ops.Relation, + ) -> None: + super().__init__(interface.fetch_relation_data()[relation.id]) + self._app_name = relation.app.name + self._endpoint_name = relation.name + + def __getitem__(self, key): + try: + return super().__getitem__(key) + except KeyError: + logger.debug( + f"Required {key=} missing from databag for {self._app_name=} on {self._endpoint_name=}" + ) + raise IncompleteDatabag(app_name=self._app_name, endpoint_name=self._endpoint_name) diff --git a/src/relations/tls.py b/src/relations/tls.py new file mode 100644 index 00000000..c0c749ce --- /dev/null +++ b/src/relations/tls.py @@ -0,0 +1,304 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Relation to TLS certificate provider""" + +import base64 +import dataclasses +import inspect +import json +import logging +import re +import socket +import typing + +import charms.tls_certificates_interface.v1.tls_certificates as tls_certificates +import ops + +if typing.TYPE_CHECKING: + import charm + +_PEER_RELATION_ENDPOINT_NAME = "mysql-router-peers" +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class _UnitSecrets: + """Secrets for charm unit + + Stored in peer unit databag (to support Juju 2.9) + """ + + _peer_unit_databag: ops.RelationDataContent + + @staticmethod + def generate_private_key() -> str: + """Generate TLS private key.""" + return tls_certificates.generate_private_key().decode("utf-8") + + @property + def private_key(self) -> str: + """TLS private key + + Generate & save key if it doesn't exist. + """ + return self._peer_unit_databag.setdefault( + "secrets.tls_private_key", self.generate_private_key() + ) + + @private_key.setter + def private_key(self, value: str) -> None: + self._peer_unit_databag["secrets.tls_private_key"] = value + + +class _PeerUnitDatabag: + """Peer relation unit databag""" + + # CSR stands for certificate signing request + requested_csr: str + active_csr: str + certificate: str + ca: str # Certificate authority + chain: str + + def __init__(self, databag: ops.RelationDataContent) -> None: + # Cannot use `self._databag =` since this class overrides `__setattr__()` + super().__setattr__("_databag", databag) + + @staticmethod + def _get_key(key: str) -> str: + """Create databag key by adding a 'tls_' prefix.""" + return f"tls_{key}" + + @property + def _attribute_names(self) -> typing.Iterable[str]: + """Class attributes with type annotation""" + return (name for name in inspect.get_annotations(type(self))) + + def __getattr__(self, name: str) -> typing.Optional[str]: + assert name in self._attribute_names, f"Invalid attribute {name=}" + return self._databag.get(self._get_key(name)) + + def __setattr__(self, name: str, value: str) -> None: + assert name in self._attribute_names, f"Invalid attribute {name=}" + self._databag[self._get_key(name)] = value + + def __delattr__(self, name: str) -> None: + assert name in self._attribute_names, f"Invalid attribute {name=}" + self._databag.pop(self._get_key(name), None) + + def clear(self) -> None: + """Delete all items in databag.""" + for name in self._attribute_names: + delattr(self, name) + + +@dataclasses.dataclass(kw_only=True) +class _Relation: + """Relation to TLS certificate provider""" + + _charm: "charm.MySQLRouterOperatorCharm" + _interface: tls_certificates.TLSCertificatesRequiresV1 + _peer_unit_databag: _PeerUnitDatabag + _unit_secrets: _UnitSecrets + + @property + def certificate_saved(self) -> bool: + """Whether a TLS certificate is available to use""" + for value in (self._peer_unit_databag.certificate, self._peer_unit_databag.ca): + if not value: + return False + return True + + def save_certificate(self, event: tls_certificates.CertificateAvailableEvent) -> None: + """Save TLS certificate in peer relation unit databag.""" + if ( + event.certificate_signing_request.strip() + != self._peer_unit_databag.requested_csr.strip() + ): + logger.warning("Unknown certificate received. Ignoring.") + return + if ( + self.certificate_saved + and event.certificate_signing_request.strip() + == self._peer_unit_databag.active_csr.strip() + ): + # Workaround for https://github.com/canonical/tls-certificates-operator/issues/34 + logger.debug("TLS certificate already saved.") + return + logger.debug(f"Saving TLS certificate {event=}") + self._peer_unit_databag.certificate = event.certificate + self._peer_unit_databag.ca = event.ca + self._peer_unit_databag.chain = json.dumps(event.chain) + self._peer_unit_databag.active_csr = self._peer_unit_databag.requested_csr + logger.debug(f"Saved TLS certificate {event=}") + self._charm.get_workload(event=None).enable_tls( + key=self._unit_secrets.private_key, + certificate=self._peer_unit_databag.certificate, + ) + + def _generate_csr(self, key: bytes) -> bytes: + """Generate certificate signing request (CSR).""" + unit_name = self._charm.unit.name.replace("/", "-") + return tls_certificates.generate_csr( + private_key=key, + subject=socket.getfqdn(), + organization=self._charm.app.name, + sans_dns=[ + unit_name, + f"{unit_name}.{self._charm.app.name}-endpoints", + f"{unit_name}.{self._charm.app.name}-endpoints.{self._charm.model_service_domain}", + f"{self._charm.app.name}-endpoints", + f"{self._charm.app.name}-endpoints.{self._charm.model_service_domain}", + f"{unit_name}.{self._charm.app.name}", + f"{unit_name}.{self._charm.app.name}.{self._charm.model_service_domain}", + self._charm.app.name, + f"{self._charm.app.name}.{self._charm.model_service_domain}", + ], + sans_ip=[ + str(self._charm.model.get_binding("juju-info").network.bind_address), + ], + ) + + def request_certificate_creation(self): + """Request new TLS certificate from related provider charm.""" + logger.debug("Requesting TLS certificate creation") + csr = self._generate_csr(self._unit_secrets.private_key.encode("utf-8")) + self._interface.request_certificate_creation(certificate_signing_request=csr) + self._peer_unit_databag.requested_csr = csr.decode("utf-8") + logger.debug( + f"Requested TLS certificate creation {self._peer_unit_databag.requested_csr=}" + ) + + def request_certificate_renewal(self): + """Request TLS certificate renewal from related provider charm.""" + logger.debug(f"Requesting TLS certificate renewal {self._peer_unit_databag.active_csr=}") + old_csr = self._peer_unit_databag.active_csr.encode("utf-8") + new_csr = self._generate_csr(self._unit_secrets.private_key.encode("utf-8")) + self._interface.request_certificate_renewal( + old_certificate_signing_request=old_csr, new_certificate_signing_request=new_csr + ) + self._peer_unit_databag.requested_csr = new_csr.decode("utf-8") + logger.debug(f"Requested TLS certificate renewal {self._peer_unit_databag.requested_csr=}") + + +class RelationEndpoint(ops.Object): + """Relation endpoint and handlers for TLS certificate provider""" + + NAME = "certificates" + + def __init__(self, charm_: "charm.MySQLRouterOperatorCharm") -> None: + super().__init__(charm_, self.NAME) + self._charm = charm_ + self._interface = tls_certificates.TLSCertificatesRequiresV1(self._charm, self.NAME) + + self.framework.observe( + self._charm.on.set_tls_private_key_action, + self._on_set_tls_private_key, + ) + self.framework.observe( + self._charm.on[self.NAME].relation_joined, self._on_tls_relation_joined + ) + self.framework.observe( + self._charm.on[self.NAME].relation_broken, self._on_tls_relation_broken + ) + + self.framework.observe( + self._interface.on.certificate_available, self._on_certificate_available + ) + self.framework.observe( + self._interface.on.certificate_expiring, self._on_certificate_expiring + ) + + @property + def _peer_unit_raw_databag(self) -> ops.RelationDataContent: + peer_relation = self._charm.model.get_relation(_PEER_RELATION_ENDPOINT_NAME) + return peer_relation.data[self._charm.unit] + + @property + def _peer_unit_databag(self) -> _PeerUnitDatabag: + return _PeerUnitDatabag(self._peer_unit_raw_databag) + + @property + def _unit_secrets(self) -> _UnitSecrets: + return _UnitSecrets(self._peer_unit_raw_databag) + + @property + def _relation(self) -> typing.Optional[_Relation]: + if not self._charm.model.get_relation(self.NAME): + return + return _Relation( + _charm=self._charm, + _interface=self._interface, + _peer_unit_databag=self._peer_unit_databag, + _unit_secrets=self._unit_secrets, + ) + + @property + def certificate_saved(self) -> bool: + """Whether a TLS certificate is available to use""" + if self._relation is None: + return False + return self._relation.certificate_saved + + @staticmethod + def _parse_tls_key(raw_content: str) -> str: + """Parse TLS key from plain text or base64 format.""" + if re.match(r"(-+(BEGIN|END) [A-Z ]+-+)", raw_content): + return re.sub( + r"(-+(BEGIN|END) [A-Z ]+-+)", + "\n\\1\n", + raw_content, + ) + return base64.b64decode(raw_content).decode("utf-8") + + def _on_set_tls_private_key(self, event: ops.ActionEvent) -> None: + """Handle action to set unit TLS private key.""" + logger.debug("Handling set TLS private key action") + if key := event.params.get("internal-key"): + key = self._parse_tls_key(key) + else: + key = self._unit_secrets.generate_private_key() + event.log("No key provided. Generated new key.") + logger.debug("No TLS key provided via action. Generated new key.") + self._unit_secrets.private_key = key + event.log("Saved TLS private key") + logger.debug("Saved TLS private key") + if self._relation is None: + event.log( + "No TLS certificate relation active. Relate a certificate provider charm to enable TLS." + ) + logger.debug("No TLS certificate relation active. Skipped certificate request") + else: + try: + self._relation.request_certificate_creation() + except Exception as e: + event.fail(f"Failed to request certificate: {e}") + logger.exception( + "Failed to request certificate after TLS private key set via action" + ) + raise + logger.debug("Handled set TLS private key action") + + def _on_tls_relation_joined(self, _) -> None: + """Request certificate when TLS relation joined.""" + self._relation.request_certificate_creation() + + def _on_tls_relation_broken(self, _) -> None: + """Delete TLS certificate.""" + logger.debug("Deleting TLS certificate") + self._peer_unit_databag.clear() + self._charm.get_workload(event=None).disable_tls() + logger.debug("Deleted TLS certificate") + + def _on_certificate_available(self, event: tls_certificates.CertificateAvailableEvent) -> None: + """Save TLS certificate.""" + self._relation.save_certificate(event) + + def _on_certificate_expiring(self, event: tls_certificates.CertificateExpiringEvent) -> None: + """Request the new certificate when old certificate is expiring.""" + if event.certificate != self._peer_unit_databag.certificate: + logger.warning("Unknown certificate expiring") + return + + self._relation.request_certificate_renewal() diff --git a/src/status_exception.py b/src/status_exception.py new file mode 100644 index 00000000..f3cc6dd7 --- /dev/null +++ b/src/status_exception.py @@ -0,0 +1,14 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Exception with ops status""" + +import ops + + +class StatusException(Exception): + """Exception with ops status""" + + def __init__(self, status: ops.StatusBase) -> None: + super().__init__(status.message) + self.status = status diff --git a/src/workload.py b/src/workload.py new file mode 100644 index 00000000..5590af61 --- /dev/null +++ b/src/workload.py @@ -0,0 +1,324 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +"""MySQL Router workload""" + +import configparser +import dataclasses +import logging +import pathlib +import socket +import string +import typing + +import ops + +import mysql_shell + +if typing.TYPE_CHECKING: + import charm + import relations.database_requires + +logger = logging.getLogger(__name__) + + +# TODO python3.10 min version: Add `(kw_only=True)` +@dataclasses.dataclass +class Workload: + """MySQL Router workload""" + + _container: ops.Container + + CONTAINER_NAME = "mysql-router" + _SERVICE_NAME = "mysql_router" + _UNIX_USERNAME = "mysql" + _ROUTER_CONFIG_DIRECTORY = pathlib.Path("/etc/mysqlrouter") + _ROUTER_DATA_DIRECTORY = pathlib.Path("/var/lib/mysqlrouter") + _ROUTER_CONFIG_FILE = "mysqlrouter.conf" + _TLS_CONFIG_FILE = "tls.conf" + + @property + def container_ready(self) -> bool: + """Whether container is ready""" + return self._container.can_connect() + + @property + def _enabled(self) -> bool: + """Service status""" + service = self._container.get_services(self._SERVICE_NAME).get(self._SERVICE_NAME) + if service is None: + return False + return service.startup == ops.pebble.ServiceStartup.ENABLED + + @property + def version(self) -> str: + """MySQL Router version""" + process = self._container.exec(["mysqlrouter", "--version"]) + raw_version, _ = process.wait_output() + for version in raw_version.split(): + if version.startswith("8"): + return version + return "" + + def _update_layer(self, *, enabled: bool, tls: bool = None) -> None: + """Update and restart services. + + Args: + enabled: Whether MySQL Router service is enabled + tls: Whether TLS is enabled. Required if enabled=True + """ + if enabled: + assert tls is not None, "`tls` argument required when enabled=True" + command = ( + f"mysqlrouter --config {self._ROUTER_CONFIG_DIRECTORY / self._ROUTER_CONFIG_FILE}" + ) + if tls: + command = ( + f"{command} --extra-config {self._ROUTER_CONFIG_DIRECTORY / self._TLS_CONFIG_FILE}" + ) + if enabled: + startup = ops.pebble.ServiceStartup.ENABLED.value + else: + startup = ops.pebble.ServiceStartup.DISABLED.value + layer = ops.pebble.Layer( + { + "summary": "mysql router layer", + "description": "the pebble config layer for mysql router", + "services": { + self._SERVICE_NAME: { + "override": "replace", + "summary": "mysql router", + "command": command, + "startup": startup, + "user": self._UNIX_USERNAME, + "group": self._UNIX_USERNAME, + }, + }, + } + ) + self._container.add_layer(self._SERVICE_NAME, layer, combine=True) + self._container.replan() + + def _create_directory(self, path: pathlib.Path) -> None: + """Create directory. + + Args: + path: Full filesystem path + """ + path = str(path) + self._container.make_dir(path, user=self._UNIX_USERNAME, group=self._UNIX_USERNAME) + + def _delete_directory(self, path: pathlib.Path) -> None: + """Delete directory. + + Args: + path: Full filesystem path + """ + path = str(path) + self._container.remove_path(path, recursive=True) + + def disable(self) -> None: + """Stop and disable MySQL Router service.""" + if not self._enabled: + return + logger.debug("Disabling MySQL Router service") + self._update_layer(enabled=False) + self._delete_directory(self._ROUTER_CONFIG_DIRECTORY) + self._create_directory(self._ROUTER_CONFIG_DIRECTORY) + self._delete_directory(self._ROUTER_DATA_DIRECTORY) + logger.debug("Disabled MySQL Router service") + + +# TODO python3.10 min version: Add `(kw_only=True)` +@dataclasses.dataclass +class AuthenticatedWorkload(Workload): + """Workload with connection to MySQL cluster""" + + _connection_info: "relations.database_requires.ConnectionInformation" + _charm: "charm.MySQLRouterOperatorCharm" + + _TLS_KEY_FILE = "custom-key.pem" + _TLS_CERTIFICATE_FILE = "custom-certificate.pem" + + @property + def shell(self) -> mysql_shell.Shell: + """MySQL Shell""" + return mysql_shell.Shell( + _container=self._container, + username=self._connection_info.username, + _password=self._connection_info.password, + _host=self._connection_info.host, + _port=self._connection_info.port, + ) + + @property + def _router_id(self) -> str: + """MySQL Router ID in InnoDB Cluster metadata + + Used to remove MySQL Router metadata from InnoDB cluster + """ + # MySQL Router is bootstrapped without `--directory`—there is one system-wide instance. + return f"{socket.getfqdn()}::system" + + def cleanup_after_pod_restart(self) -> None: + """Remove MySQL Router cluster metadata & user after pod restart. + + (Storage is not persisted on pod restart—MySQL Router's config file is deleted. + Therefore, MySQL Router needs to be bootstrapped again.) + """ + self.shell.remove_router_from_cluster_metadata(self._router_id) + self.shell.delete_router_user_after_pod_restart(self._router_id) + + def _bootstrap_router(self, *, tls: bool) -> None: + """Bootstrap MySQL Router and enable service.""" + logger.debug( + f"Bootstrapping router {tls=}, {self._connection_info.host=}, {self._connection_info.port=}" + ) + + def _get_command(password: str): + return [ + "mysqlrouter", + "--bootstrap", + self._connection_info.username + + ":" + + password + + "@" + + self._connection_info.host + + ":" + + self._connection_info.port, + "--strict", + "--user", + self._UNIX_USERNAME, + "--conf-set-option", + "http_server.bind_address=127.0.0.1", + "--conf-use-gr-notifications", + ] + + # Redact password from log + logged_command = _get_command("***") + + command = _get_command(self._connection_info.password) + try: + # Bootstrap MySQL Router + process = self._container.exec( + command, + timeout=30, + ) + process.wait_output() + except ops.pebble.ExecError as e: + # Use `logger.error` instead of `logger.exception` so password isn't logged + logger.error(f"Failed to bootstrap router\n{logged_command=}\nstderr:\n{e.stderr}\n") + # Original exception contains password + # Re-raising would log the password to Juju's debug log + # Raise new exception + # `from None` disables exception chaining so that the original exception is not + # included in the traceback + raise Exception("Failed to bootstrap router") from None + # Enable service + self._update_layer(enabled=True, tls=tls) + + logger.debug( + f"Bootstrapped router {tls=}, {self._connection_info.host=}, {self._connection_info.port=}" + ) + + @property + def _router_username(self) -> str: + """Read MySQL Router username from config file. + + During bootstrap, MySQL Router creates a config file at + `/etc/mysqlrouter/mysqlrouter.conf`. This file contains the username that was created + during bootstrap. + """ + config = configparser.ConfigParser() + config.read_file( + self._container.pull(self._ROUTER_CONFIG_DIRECTORY / self._ROUTER_CONFIG_FILE) + ) + return config["metadata_cache:bootstrap"]["user"] + + def enable(self, *, tls: bool, unit_name: str) -> None: + """Start and enable MySQL Router service.""" + if self._enabled: + # If the host or port changes, MySQL Router will receive topology change + # notifications from MySQL. + # Therefore, if the host or port changes, we do not need to restart MySQL Router. + return + logger.debug("Enabling MySQL Router service") + self._bootstrap_router(tls=tls) + self.shell.add_attributes_to_mysql_router_user( + username=self._router_username, router_id=self._router_id, unit_name=unit_name + ) + logger.debug("Enabled MySQL Router service") + self._charm.wait_until_mysql_router_ready() + + def _restart(self, *, tls: bool) -> None: + """Restart MySQL Router to enable or disable TLS.""" + logger.debug("Restarting MySQL Router") + assert self._enabled is True + self._bootstrap_router(tls=tls) + logger.debug("Restarted MySQL Router") + self._charm.wait_until_mysql_router_ready() + # wait_until_mysql_router_ready will set WaitingStatus—override it with current charm + # status + self._charm.set_status(event=None) + + def _write_file(self, path: pathlib.Path, content: str) -> None: + """Write content to file. + + Args: + path: Full filesystem path (with filename) + content: File content + """ + self._container.push( + str(path), + content, + permissions=0o600, + user=self._UNIX_USERNAME, + group=self._UNIX_USERNAME, + ) + logger.debug(f"Wrote file {path=}") + + def _delete_file(self, path: pathlib.Path) -> None: + """Delete file. + + Args: + path: Full filesystem path (with filename) + """ + path = str(path) + if self._container.exists(path): + self._container.remove_path(path) + logger.debug(f"Deleted file {path=}") + + @property + def _tls_config_file(self) -> str: + """Render config file template to string. + + Config file enables TLS on MySQL Router. + """ + with open("templates/tls.cnf", "r") as template_file: + template = string.Template(template_file.read()) + config_string = template.substitute( + tls_ssl_key_file=self._ROUTER_CONFIG_DIRECTORY / self._TLS_KEY_FILE, + tls_ssl_cert_file=self._ROUTER_CONFIG_DIRECTORY / self._TLS_CERTIFICATE_FILE, + ) + return config_string + + def enable_tls(self, *, key: str, certificate: str): + """Enable TLS and restart MySQL Router.""" + logger.debug("Enabling TLS") + self._write_file( + self._ROUTER_CONFIG_DIRECTORY / self._TLS_CONFIG_FILE, self._tls_config_file + ) + self._write_file(self._ROUTER_CONFIG_DIRECTORY / self._TLS_KEY_FILE, key) + self._write_file(self._ROUTER_CONFIG_DIRECTORY / self._TLS_CERTIFICATE_FILE, certificate) + if self._enabled: + self._restart(tls=True) + logger.debug("Enabled TLS") + + def disable_tls(self) -> None: + """Disable TLS and restart MySQL Router.""" + logger.debug("Disabling TLS") + for file in (self._TLS_CONFIG_FILE, self._TLS_KEY_FILE, self._TLS_CERTIFICATE_FILE): + self._delete_file(self._ROUTER_CONFIG_DIRECTORY / file) + if self._enabled: + self._restart(tls=False) + logger.debug("Disabled TLS") From f3908a718807d25ecdc2591ee6e6d5022aa04e86 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Thu, 1 Jun 2023 16:40:17 +0000 Subject: [PATCH 03/57] Update lint ignore --- pyproject.toml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4e9d36cf..ba93d4a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,11 @@ exclude = [".git", "__pycache__", ".tox", "build", "dist", "*.egg_info", "venv"] select = ["E", "W", "F", "C", "N", "R", "D", "H"] # Ignore W503, E501 because using black creates errors with this # Ignore D107 Missing docstring in __init__ -ignore = ["W503", "E501", "D107"] +# Ignore D105 Missing docstring in magic method +# Ignore D415 Docstring first line punctuation (doesn't make sense for properties) +# Ignore D403 First word of the first line should be properly capitalized (false positive on "MySQL") +# Ignore N818 Exception should be named with an Error suffix +ignore = ["W503", "E501", "D107", "D105", "D415", "D403", "N818"] # D100, D101, D102, D103: Ignore missing docstrings in tests per-file-ignores = ["tests/*:D100,D101,D102,D103,D104"] docstring-convention = "google" From c2b59d3fb86b0b723c4b75dbd41d63f7b6f931b4 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Thu, 1 Jun 2023 16:40:44 +0000 Subject: [PATCH 04/57] Remove TLS --- src/charm.py | 5 +- src/relations/tls.py | 304 ------------------------------------------- src/workload.py | 99 +------------- 3 files changed, 8 insertions(+), 400 deletions(-) delete mode 100644 src/relations/tls.py diff --git a/src/charm.py b/src/charm.py index d31327a6..ce55aacd 100755 --- a/src/charm.py +++ b/src/charm.py @@ -18,7 +18,6 @@ import relations.database_provides import relations.database_requires -import relations.tls import workload logger = logging.getLogger(__name__) @@ -44,8 +43,6 @@ def __init__(self, *args) -> None: # Start workload after pod restart self.framework.observe(self.on.upgrade_charm, self.reconcile_database_relations) - self.tls = relations.tls.RelationEndpoint(self) - def get_workload(self, *, event): """MySQL Router workload""" container = self.unit.get_container(workload.Workload.CONTAINER_NAME) @@ -223,7 +220,7 @@ def reconcile_database_relations(self, event=None) -> None: if isinstance(event, ops.UpgradeCharmEvent): # Pod restart (https://juju.is/docs/sdk/start-event#heading--emission-sequence) workload_.cleanup_after_pod_restart() - workload_.enable(tls=self.tls.certificate_saved, unit_name=self.unit.name) + workload_.enable(unit_name=self.unit.name) elif workload_.container_ready: workload_.disable() self.set_status(event=event) diff --git a/src/relations/tls.py b/src/relations/tls.py deleted file mode 100644 index c0c749ce..00000000 --- a/src/relations/tls.py +++ /dev/null @@ -1,304 +0,0 @@ -# Copyright 2023 Canonical Ltd. -# See LICENSE file for licensing details. - -"""Relation to TLS certificate provider""" - -import base64 -import dataclasses -import inspect -import json -import logging -import re -import socket -import typing - -import charms.tls_certificates_interface.v1.tls_certificates as tls_certificates -import ops - -if typing.TYPE_CHECKING: - import charm - -_PEER_RELATION_ENDPOINT_NAME = "mysql-router-peers" -logger = logging.getLogger(__name__) - - -@dataclasses.dataclass -class _UnitSecrets: - """Secrets for charm unit - - Stored in peer unit databag (to support Juju 2.9) - """ - - _peer_unit_databag: ops.RelationDataContent - - @staticmethod - def generate_private_key() -> str: - """Generate TLS private key.""" - return tls_certificates.generate_private_key().decode("utf-8") - - @property - def private_key(self) -> str: - """TLS private key - - Generate & save key if it doesn't exist. - """ - return self._peer_unit_databag.setdefault( - "secrets.tls_private_key", self.generate_private_key() - ) - - @private_key.setter - def private_key(self, value: str) -> None: - self._peer_unit_databag["secrets.tls_private_key"] = value - - -class _PeerUnitDatabag: - """Peer relation unit databag""" - - # CSR stands for certificate signing request - requested_csr: str - active_csr: str - certificate: str - ca: str # Certificate authority - chain: str - - def __init__(self, databag: ops.RelationDataContent) -> None: - # Cannot use `self._databag =` since this class overrides `__setattr__()` - super().__setattr__("_databag", databag) - - @staticmethod - def _get_key(key: str) -> str: - """Create databag key by adding a 'tls_' prefix.""" - return f"tls_{key}" - - @property - def _attribute_names(self) -> typing.Iterable[str]: - """Class attributes with type annotation""" - return (name for name in inspect.get_annotations(type(self))) - - def __getattr__(self, name: str) -> typing.Optional[str]: - assert name in self._attribute_names, f"Invalid attribute {name=}" - return self._databag.get(self._get_key(name)) - - def __setattr__(self, name: str, value: str) -> None: - assert name in self._attribute_names, f"Invalid attribute {name=}" - self._databag[self._get_key(name)] = value - - def __delattr__(self, name: str) -> None: - assert name in self._attribute_names, f"Invalid attribute {name=}" - self._databag.pop(self._get_key(name), None) - - def clear(self) -> None: - """Delete all items in databag.""" - for name in self._attribute_names: - delattr(self, name) - - -@dataclasses.dataclass(kw_only=True) -class _Relation: - """Relation to TLS certificate provider""" - - _charm: "charm.MySQLRouterOperatorCharm" - _interface: tls_certificates.TLSCertificatesRequiresV1 - _peer_unit_databag: _PeerUnitDatabag - _unit_secrets: _UnitSecrets - - @property - def certificate_saved(self) -> bool: - """Whether a TLS certificate is available to use""" - for value in (self._peer_unit_databag.certificate, self._peer_unit_databag.ca): - if not value: - return False - return True - - def save_certificate(self, event: tls_certificates.CertificateAvailableEvent) -> None: - """Save TLS certificate in peer relation unit databag.""" - if ( - event.certificate_signing_request.strip() - != self._peer_unit_databag.requested_csr.strip() - ): - logger.warning("Unknown certificate received. Ignoring.") - return - if ( - self.certificate_saved - and event.certificate_signing_request.strip() - == self._peer_unit_databag.active_csr.strip() - ): - # Workaround for https://github.com/canonical/tls-certificates-operator/issues/34 - logger.debug("TLS certificate already saved.") - return - logger.debug(f"Saving TLS certificate {event=}") - self._peer_unit_databag.certificate = event.certificate - self._peer_unit_databag.ca = event.ca - self._peer_unit_databag.chain = json.dumps(event.chain) - self._peer_unit_databag.active_csr = self._peer_unit_databag.requested_csr - logger.debug(f"Saved TLS certificate {event=}") - self._charm.get_workload(event=None).enable_tls( - key=self._unit_secrets.private_key, - certificate=self._peer_unit_databag.certificate, - ) - - def _generate_csr(self, key: bytes) -> bytes: - """Generate certificate signing request (CSR).""" - unit_name = self._charm.unit.name.replace("/", "-") - return tls_certificates.generate_csr( - private_key=key, - subject=socket.getfqdn(), - organization=self._charm.app.name, - sans_dns=[ - unit_name, - f"{unit_name}.{self._charm.app.name}-endpoints", - f"{unit_name}.{self._charm.app.name}-endpoints.{self._charm.model_service_domain}", - f"{self._charm.app.name}-endpoints", - f"{self._charm.app.name}-endpoints.{self._charm.model_service_domain}", - f"{unit_name}.{self._charm.app.name}", - f"{unit_name}.{self._charm.app.name}.{self._charm.model_service_domain}", - self._charm.app.name, - f"{self._charm.app.name}.{self._charm.model_service_domain}", - ], - sans_ip=[ - str(self._charm.model.get_binding("juju-info").network.bind_address), - ], - ) - - def request_certificate_creation(self): - """Request new TLS certificate from related provider charm.""" - logger.debug("Requesting TLS certificate creation") - csr = self._generate_csr(self._unit_secrets.private_key.encode("utf-8")) - self._interface.request_certificate_creation(certificate_signing_request=csr) - self._peer_unit_databag.requested_csr = csr.decode("utf-8") - logger.debug( - f"Requested TLS certificate creation {self._peer_unit_databag.requested_csr=}" - ) - - def request_certificate_renewal(self): - """Request TLS certificate renewal from related provider charm.""" - logger.debug(f"Requesting TLS certificate renewal {self._peer_unit_databag.active_csr=}") - old_csr = self._peer_unit_databag.active_csr.encode("utf-8") - new_csr = self._generate_csr(self._unit_secrets.private_key.encode("utf-8")) - self._interface.request_certificate_renewal( - old_certificate_signing_request=old_csr, new_certificate_signing_request=new_csr - ) - self._peer_unit_databag.requested_csr = new_csr.decode("utf-8") - logger.debug(f"Requested TLS certificate renewal {self._peer_unit_databag.requested_csr=}") - - -class RelationEndpoint(ops.Object): - """Relation endpoint and handlers for TLS certificate provider""" - - NAME = "certificates" - - def __init__(self, charm_: "charm.MySQLRouterOperatorCharm") -> None: - super().__init__(charm_, self.NAME) - self._charm = charm_ - self._interface = tls_certificates.TLSCertificatesRequiresV1(self._charm, self.NAME) - - self.framework.observe( - self._charm.on.set_tls_private_key_action, - self._on_set_tls_private_key, - ) - self.framework.observe( - self._charm.on[self.NAME].relation_joined, self._on_tls_relation_joined - ) - self.framework.observe( - self._charm.on[self.NAME].relation_broken, self._on_tls_relation_broken - ) - - self.framework.observe( - self._interface.on.certificate_available, self._on_certificate_available - ) - self.framework.observe( - self._interface.on.certificate_expiring, self._on_certificate_expiring - ) - - @property - def _peer_unit_raw_databag(self) -> ops.RelationDataContent: - peer_relation = self._charm.model.get_relation(_PEER_RELATION_ENDPOINT_NAME) - return peer_relation.data[self._charm.unit] - - @property - def _peer_unit_databag(self) -> _PeerUnitDatabag: - return _PeerUnitDatabag(self._peer_unit_raw_databag) - - @property - def _unit_secrets(self) -> _UnitSecrets: - return _UnitSecrets(self._peer_unit_raw_databag) - - @property - def _relation(self) -> typing.Optional[_Relation]: - if not self._charm.model.get_relation(self.NAME): - return - return _Relation( - _charm=self._charm, - _interface=self._interface, - _peer_unit_databag=self._peer_unit_databag, - _unit_secrets=self._unit_secrets, - ) - - @property - def certificate_saved(self) -> bool: - """Whether a TLS certificate is available to use""" - if self._relation is None: - return False - return self._relation.certificate_saved - - @staticmethod - def _parse_tls_key(raw_content: str) -> str: - """Parse TLS key from plain text or base64 format.""" - if re.match(r"(-+(BEGIN|END) [A-Z ]+-+)", raw_content): - return re.sub( - r"(-+(BEGIN|END) [A-Z ]+-+)", - "\n\\1\n", - raw_content, - ) - return base64.b64decode(raw_content).decode("utf-8") - - def _on_set_tls_private_key(self, event: ops.ActionEvent) -> None: - """Handle action to set unit TLS private key.""" - logger.debug("Handling set TLS private key action") - if key := event.params.get("internal-key"): - key = self._parse_tls_key(key) - else: - key = self._unit_secrets.generate_private_key() - event.log("No key provided. Generated new key.") - logger.debug("No TLS key provided via action. Generated new key.") - self._unit_secrets.private_key = key - event.log("Saved TLS private key") - logger.debug("Saved TLS private key") - if self._relation is None: - event.log( - "No TLS certificate relation active. Relate a certificate provider charm to enable TLS." - ) - logger.debug("No TLS certificate relation active. Skipped certificate request") - else: - try: - self._relation.request_certificate_creation() - except Exception as e: - event.fail(f"Failed to request certificate: {e}") - logger.exception( - "Failed to request certificate after TLS private key set via action" - ) - raise - logger.debug("Handled set TLS private key action") - - def _on_tls_relation_joined(self, _) -> None: - """Request certificate when TLS relation joined.""" - self._relation.request_certificate_creation() - - def _on_tls_relation_broken(self, _) -> None: - """Delete TLS certificate.""" - logger.debug("Deleting TLS certificate") - self._peer_unit_databag.clear() - self._charm.get_workload(event=None).disable_tls() - logger.debug("Deleted TLS certificate") - - def _on_certificate_available(self, event: tls_certificates.CertificateAvailableEvent) -> None: - """Save TLS certificate.""" - self._relation.save_certificate(event) - - def _on_certificate_expiring(self, event: tls_certificates.CertificateExpiringEvent) -> None: - """Request the new certificate when old certificate is expiring.""" - if event.certificate != self._peer_unit_databag.certificate: - logger.warning("Unknown certificate expiring") - return - - self._relation.request_certificate_renewal() diff --git a/src/workload.py b/src/workload.py index 5590af61..9d35246d 100644 --- a/src/workload.py +++ b/src/workload.py @@ -8,7 +8,6 @@ import logging import pathlib import socket -import string import typing import ops @@ -35,7 +34,6 @@ class Workload: _ROUTER_CONFIG_DIRECTORY = pathlib.Path("/etc/mysqlrouter") _ROUTER_DATA_DIRECTORY = pathlib.Path("/var/lib/mysqlrouter") _ROUTER_CONFIG_FILE = "mysqlrouter.conf" - _TLS_CONFIG_FILE = "tls.conf" @property def container_ready(self) -> bool: @@ -60,22 +58,15 @@ def version(self) -> str: return version return "" - def _update_layer(self, *, enabled: bool, tls: bool = None) -> None: + def _update_layer(self, *, enabled: bool) -> None: """Update and restart services. Args: enabled: Whether MySQL Router service is enabled - tls: Whether TLS is enabled. Required if enabled=True """ - if enabled: - assert tls is not None, "`tls` argument required when enabled=True" command = ( f"mysqlrouter --config {self._ROUTER_CONFIG_DIRECTORY / self._ROUTER_CONFIG_FILE}" ) - if tls: - command = ( - f"{command} --extra-config {self._ROUTER_CONFIG_DIRECTORY / self._TLS_CONFIG_FILE}" - ) if enabled: startup = ops.pebble.ServiceStartup.ENABLED.value else: @@ -137,9 +128,6 @@ class AuthenticatedWorkload(Workload): _connection_info: "relations.database_requires.ConnectionInformation" _charm: "charm.MySQLRouterOperatorCharm" - _TLS_KEY_FILE = "custom-key.pem" - _TLS_CERTIFICATE_FILE = "custom-certificate.pem" - @property def shell(self) -> mysql_shell.Shell: """MySQL Shell""" @@ -169,10 +157,10 @@ def cleanup_after_pod_restart(self) -> None: self.shell.remove_router_from_cluster_metadata(self._router_id) self.shell.delete_router_user_after_pod_restart(self._router_id) - def _bootstrap_router(self, *, tls: bool) -> None: + def _bootstrap_router(self) -> None: """Bootstrap MySQL Router and enable service.""" logger.debug( - f"Bootstrapping router {tls=}, {self._connection_info.host=}, {self._connection_info.port=}" + f"Bootstrapping router {self._connection_info.host=}, {self._connection_info.port=}" ) def _get_command(password: str): @@ -215,10 +203,10 @@ def _get_command(password: str): # included in the traceback raise Exception("Failed to bootstrap router") from None # Enable service - self._update_layer(enabled=True, tls=tls) + self._update_layer(enabled=True) logger.debug( - f"Bootstrapped router {tls=}, {self._connection_info.host=}, {self._connection_info.port=}" + f"Bootstrapped router {self._connection_info.host=}, {self._connection_info.port=}" ) @property @@ -235,7 +223,7 @@ def _router_username(self) -> str: ) return config["metadata_cache:bootstrap"]["user"] - def enable(self, *, tls: bool, unit_name: str) -> None: + def enable(self, *, unit_name: str) -> None: """Start and enable MySQL Router service.""" if self._enabled: # If the host or port changes, MySQL Router will receive topology change @@ -243,82 +231,9 @@ def enable(self, *, tls: bool, unit_name: str) -> None: # Therefore, if the host or port changes, we do not need to restart MySQL Router. return logger.debug("Enabling MySQL Router service") - self._bootstrap_router(tls=tls) + self._bootstrap_router() self.shell.add_attributes_to_mysql_router_user( username=self._router_username, router_id=self._router_id, unit_name=unit_name ) logger.debug("Enabled MySQL Router service") self._charm.wait_until_mysql_router_ready() - - def _restart(self, *, tls: bool) -> None: - """Restart MySQL Router to enable or disable TLS.""" - logger.debug("Restarting MySQL Router") - assert self._enabled is True - self._bootstrap_router(tls=tls) - logger.debug("Restarted MySQL Router") - self._charm.wait_until_mysql_router_ready() - # wait_until_mysql_router_ready will set WaitingStatus—override it with current charm - # status - self._charm.set_status(event=None) - - def _write_file(self, path: pathlib.Path, content: str) -> None: - """Write content to file. - - Args: - path: Full filesystem path (with filename) - content: File content - """ - self._container.push( - str(path), - content, - permissions=0o600, - user=self._UNIX_USERNAME, - group=self._UNIX_USERNAME, - ) - logger.debug(f"Wrote file {path=}") - - def _delete_file(self, path: pathlib.Path) -> None: - """Delete file. - - Args: - path: Full filesystem path (with filename) - """ - path = str(path) - if self._container.exists(path): - self._container.remove_path(path) - logger.debug(f"Deleted file {path=}") - - @property - def _tls_config_file(self) -> str: - """Render config file template to string. - - Config file enables TLS on MySQL Router. - """ - with open("templates/tls.cnf", "r") as template_file: - template = string.Template(template_file.read()) - config_string = template.substitute( - tls_ssl_key_file=self._ROUTER_CONFIG_DIRECTORY / self._TLS_KEY_FILE, - tls_ssl_cert_file=self._ROUTER_CONFIG_DIRECTORY / self._TLS_CERTIFICATE_FILE, - ) - return config_string - - def enable_tls(self, *, key: str, certificate: str): - """Enable TLS and restart MySQL Router.""" - logger.debug("Enabling TLS") - self._write_file( - self._ROUTER_CONFIG_DIRECTORY / self._TLS_CONFIG_FILE, self._tls_config_file - ) - self._write_file(self._ROUTER_CONFIG_DIRECTORY / self._TLS_KEY_FILE, key) - self._write_file(self._ROUTER_CONFIG_DIRECTORY / self._TLS_CERTIFICATE_FILE, certificate) - if self._enabled: - self._restart(tls=True) - logger.debug("Enabled TLS") - - def disable_tls(self) -> None: - """Disable TLS and restart MySQL Router.""" - logger.debug("Disabling TLS") - for file in (self._TLS_CONFIG_FILE, self._TLS_KEY_FILE, self._TLS_CERTIFICATE_FILE): - self._delete_file(self._ROUTER_CONFIG_DIRECTORY / file) - if self._enabled: - self._restart(tls=False) - logger.debug("Disabled TLS") From ba4b08fc02cafbf2c088e29fd8b5950e1000aec0 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Wed, 7 Jun 2023 12:57:21 +0000 Subject: [PATCH 05/57] Remove pod restart handling --- src/charm.py | 4 ---- src/mysql_shell.py | 30 ------------------------------ src/workload.py | 9 --------- 3 files changed, 43 deletions(-) diff --git a/src/charm.py b/src/charm.py index ce55aacd..ba3b6aee 100755 --- a/src/charm.py +++ b/src/charm.py @@ -202,7 +202,6 @@ def reconcile_database_relations(self, event=None) -> None: f"{isinstance(workload_, workload.AuthenticatedWorkload)=}, " f"{workload_.container_ready=}, " f"{self.database_requires.is_relation_breaking(event)=}, " - f"{isinstance(event, ops.UpgradeCharmEvent)=}" ) if self.unit.is_leader() and self.database_requires.is_relation_breaking(event): self.database_provides.delete_all_databags() @@ -217,9 +216,6 @@ def reconcile_database_relations(self, event=None) -> None: shell=workload_.shell, ) if isinstance(workload_, workload.AuthenticatedWorkload) and workload_.container_ready: - if isinstance(event, ops.UpgradeCharmEvent): - # Pod restart (https://juju.is/docs/sdk/start-event#heading--emission-sequence) - workload_.cleanup_after_pod_restart() workload_.enable(unit_name=self.unit.name) elif workload_.container_ready: workload_.disable() diff --git a/src/mysql_shell.py b/src/mysql_shell.py index c152cb6d..1a105424 100644 --- a/src/mysql_shell.py +++ b/src/mysql_shell.py @@ -110,33 +110,3 @@ def delete_user(self, username: str) -> None: logger.debug(f"Deleting {username=}") self._run_sql([f"DROP USER `{username}`"]) logger.debug(f"Deleted {username=}") - - def delete_router_user_after_pod_restart(self, router_id: str) -> None: - """Delete MySQL Router user created by a previous instance of this unit. - - Before pod restart, the charm does not have an opportunity to delete the MySQL Router user. - During MySQL Router bootstrap, a new user is created. Before bootstrap, the old user - should be deleted. - """ - logger.debug(f"Deleting MySQL Router user {router_id=} created by {self.username=}") - self._run_sql( - [ - f"SELECT CONCAT('DROP USER ', GROUP_CONCAT(QUOTE(USER), '@', QUOTE(HOST))) INTO @sql FROM INFORMATION_SCHEMA.USER_ATTRIBUTES WHERE ATTRIBUTE->'$.created_by_user'='{self.username}' AND ATTRIBUTE->'$.router_id'='{router_id}'", - "PREPARE stmt FROM @sql", - "EXECUTE stmt", - "DEALLOCATE PREPARE stmt", - ] - ) - logger.debug(f"Deleted MySQL Router user {router_id=} created by {self.username=}") - - def remove_router_from_cluster_metadata(self, router_id: str) -> None: - """Remove MySQL Router from InnoDB Cluster metadata. - - On pod restart, MySQL Router bootstrap will fail without `--force` if cluster metadata - already exists for the router ID. - """ - logger.debug(f"Removing {router_id=} from cluster metadata") - self._run_commands( - ["cluster = dba.get_cluster()", f'cluster.remove_router_metadata("{router_id}")'] - ) - logger.debug(f"Removed {router_id=} from cluster metadata") diff --git a/src/workload.py b/src/workload.py index 9d35246d..34afb291 100644 --- a/src/workload.py +++ b/src/workload.py @@ -148,15 +148,6 @@ def _router_id(self) -> str: # MySQL Router is bootstrapped without `--directory`—there is one system-wide instance. return f"{socket.getfqdn()}::system" - def cleanup_after_pod_restart(self) -> None: - """Remove MySQL Router cluster metadata & user after pod restart. - - (Storage is not persisted on pod restart—MySQL Router's config file is deleted. - Therefore, MySQL Router needs to be bootstrapped again.) - """ - self.shell.remove_router_from_cluster_metadata(self._router_id) - self.shell.delete_router_user_after_pod_restart(self._router_id) - def _bootstrap_router(self) -> None: """Bootstrap MySQL Router and enable service.""" logger.debug( From 9b51fdcf7109b59ae8f63904130f3325b3ae2d3f Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Tue, 6 Jun 2023 18:46:58 +0000 Subject: [PATCH 06/57] temp disable lint --- .github/workflows/ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 1f34c20c..8f380b42 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -81,7 +81,7 @@ jobs: bases-index: 1 name: ${{ matrix.tox-environments }} | ${{ matrix.ubuntu-versions.series }} needs: - - lint +# - lint # TODO: re-enable after adding unit tests # - unit-test - build From 567e981d33b8242aa59eddf3ad7604c541221e94 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 9 Jun 2023 18:32:57 +0000 Subject: [PATCH 07/57] temp --- requirements.txt | 1 + src/charm.py | 125 ++++---------- src/container.py | 125 ++++++++++++++ src/mysql_shell.py | 74 +++++++-- src/relations/database_provides.py | 10 +- src/snap.py | 97 +++++++++++ src/socket_workload.py | 43 +++++ src/workload.py | 258 +++++++++++++++-------------- 8 files changed, 500 insertions(+), 233 deletions(-) create mode 100644 src/container.py create mode 100644 src/snap.py create mode 100644 src/socket_workload.py diff --git a/requirements.txt b/requirements.txt index 56f5f642..ed5bedd6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ ops >= 1.5.0 +tenacity \ No newline at end of file diff --git a/src/charm.py b/src/charm.py index ba3b6aee..89ede81e 100755 --- a/src/charm.py +++ b/src/charm.py @@ -9,15 +9,14 @@ import logging import socket -import lightkube -import lightkube.models.core_v1 -import lightkube.models.meta_v1 -import lightkube.resources.core_v1 +import charms.operator_libs_linux.v2.snap as snap_lib import ops import tenacity import relations.database_provides import relations.database_requires +import snap +import socket_workload import workload logger = logging.getLogger(__name__) @@ -34,42 +33,28 @@ def __init__(self, *args) -> None: self.database_provides = relations.database_provides.RelationEndpoint(self) self.framework.observe(self.on.install, self._on_install) + self.framework.observe(self.on.remove, self._on_remove) self.framework.observe(self.on.start, self._on_start) - self.framework.observe( - getattr(self.on, "mysql_router_pebble_ready"), self._on_mysql_router_pebble_ready - ) self.framework.observe(self.on.leader_elected, self._on_leader_elected) - # Start workload after pod restart - self.framework.observe(self.on.upgrade_charm, self.reconcile_database_relations) - def get_workload(self, *, event): """MySQL Router workload""" - container = self.unit.get_container(workload.Workload.CONTAINER_NAME) + container = snap.Snap() if connection_info := self.database_requires.get_connection_info(event=event): - return workload.AuthenticatedWorkload( - _container=container, - _connection_info=connection_info, - _charm=self, + return socket_workload.AuthenticatedSocketWorkload( + container_=container, + connection_info=connection_info, + charm_=self, ) - return workload.Workload(_container=container) - - @property - def model_service_domain(self): - """K8s service domain for Juju model""" - # Example: "mysql-router-k8s-0.mysql-router-k8s-endpoints.my-model.svc.cluster.local" - fqdn = socket.getfqdn() - # Example: "mysql-router-k8s-0.mysql-router-k8s-endpoints." - prefix = f"{self.unit.name.replace('/', '-')}.{self.app.name}-endpoints." - assert fqdn.startswith(f"{prefix}{self.model.name}.") - # Example: my-model.svc.cluster.local - return fqdn.removeprefix(prefix) + return socket_workload.SocketWorkload(container_=container) @property def _endpoint(self) -> str: """K8s endpoint for MySQL Router""" + # TODO: remove # Example: mysql-router-k8s.my-model.svc.cluster.local - return f"{self.app.name}.{self.model_service_domain}" + return "foo" + # return f"{self.app.name}.{self.model_service_domain}" @staticmethod def _prioritize_statuses(statuses: list[ops.StatusBase]) -> ops.StatusBase: @@ -136,59 +121,6 @@ def wait_until_mysql_router_ready(self) -> None: else: logger.debug("MySQL Router is ready") - def _patch_service(self, *, name: str, ro_port: int, rw_port: int) -> None: - """Patch Juju-created k8s service. - - The k8s service will be tied to pod-0 so that the service is auto cleaned by - k8s when the last pod is scaled down. - - Args: - name: The name of the service. - ro_port: The read only port. - rw_port: The read write port. - """ - logger.debug(f"Patching k8s service {name=}, {ro_port=}, {rw_port=}") - client = lightkube.Client() - pod0 = client.get( - res=lightkube.resources.core_v1.Pod, - name=self.app.name + "-0", - namespace=self.model.name, - ) - service = lightkube.resources.core_v1.Service( - metadata=lightkube.models.meta_v1.ObjectMeta( - name=name, - namespace=self.model.name, - ownerReferences=pod0.metadata.ownerReferences, - labels={ - "app.kubernetes.io/name": self.app.name, - }, - ), - spec=lightkube.models.core_v1.ServiceSpec( - ports=[ - lightkube.models.core_v1.ServicePort( - name="mysql-ro", - port=ro_port, - targetPort=ro_port, - ), - lightkube.models.core_v1.ServicePort( - name="mysql-rw", - port=rw_port, - targetPort=rw_port, - ), - ], - selector={"app.kubernetes.io/name": self.app.name}, - ), - ) - client.patch( - res=lightkube.resources.core_v1.Service, - obj=service, - name=service.metadata.name, - namespace=service.metadata.namespace, - force=True, - field_manager=self.model.app.name, - ) - logger.debug(f"Patched k8s service {name=}, {ro_port=}, {rw_port=}") - # ======================= # Handlers # ======================= @@ -223,22 +155,31 @@ def reconcile_database_relations(self, event=None) -> None: def _on_install(self, _) -> None: """Patch existing k8s service to include read-write and read-only services.""" - if not self.unit.is_leader(): - return - try: - self._patch_service(name=self.app.name, ro_port=6447, rw_port=6446) - except lightkube.ApiError: - logger.exception("Failed to patch k8s service") - raise + # TODO update docstring + # TODO: move to workload.py? + # TODO set workload version + _SNAP_NAME = "charmed-mysql" + _SNAP_REVISION = "51" + mysql_snap = snap_lib.SnapCache()[_SNAP_NAME] + if mysql_snap.present: + logger.error(f"{_SNAP_NAME} snap already installed on machine. Installation aborted") + raise Exception(f"Multiple {_SNAP_NAME} snap installs not supported on one machine") + logger.debug(f"Installing {_SNAP_NAME=}, {_SNAP_REVISION=}") + # TODO: set status + # TODO catch/retry on error? + mysql_snap.ensure(snap_lib.SnapState.Present, revision=_SNAP_REVISION) + logger.debug(f"Installed {_SNAP_NAME=}, {_SNAP_REVISION=}") + self.unit.set_workload_version(self.get_workload(event=None).version) + + def _on_remove(self, _) -> None: + _SNAP_NAME = "charmed-mysql" + mysql_snap = snap_lib.SnapCache()[_SNAP_NAME] + mysql_snap.ensure(snap_lib.SnapState.Absent) def _on_start(self, _) -> None: # Set status on first start if no relations active self.set_status(event=None) - def _on_mysql_router_pebble_ready(self, _) -> None: - self.unit.set_workload_version(self.get_workload(event=None).version) - self.reconcile_database_relations() - def _on_leader_elected(self, _) -> None: # Update app status self.set_status(event=None) diff --git a/src/container.py b/src/container.py new file mode 100644 index 00000000..7fd99dfa --- /dev/null +++ b/src/container.py @@ -0,0 +1,125 @@ +import abc +import pathlib +import subprocess +import typing + + +class Installer(abc.ABC): + @abc.abstractmethod + def install(self): + pass + + @abc.abstractmethod + def uninstall(self): + pass + + +class Path(pathlib.PurePosixPath, abc.ABC): + @property + @abc.abstractmethod + def _UNIX_USERNAME(self) -> str: + pass + + @abc.abstractmethod + def read_text(self) -> str: + """Open the file in text mode, read it, and close the file.""" + + @abc.abstractmethod + def write_text(self, data: str): + """Open the file in text mode, write to it, and close the file.""" + + @abc.abstractmethod + def unlink(self): + """Remove this file or link.""" + + @abc.abstractmethod + def mkdir(self): + """Create a new directory at this path.""" + + @abc.abstractmethod + def rmtree(self): + """Recursively delete the directory tree at this path.""" + + +class CalledProcessError(subprocess.CalledProcessError): + """Command returned non-zero exit code""" + + def __init__(self, *, returncode: int, cmd: list[str], output: str, stderr: str) -> None: + super().__init__(returncode=returncode, cmd=cmd, output=output, stderr=stderr) + + +class Container(abc.ABC): + @property + @abc.abstractmethod + def UNIX_USERNAME(self) -> str: + pass + + @property + def router_config_directory(self) -> Path: + return self.path("/etc/mysqlrouter") + + @property + def router_config_file(self) -> Path: + return self.router_config_directory / "mysqlrouter.conf" + + @property + def tls_config_file(self) -> Path: + return self.router_config_directory / "tls.conf" + + def __init__(self, *, mysql_router_command: str, mysql_shell_command: str) -> None: + self._mysql_router_command = mysql_router_command + self._mysql_shell_command = mysql_shell_command + + @property + @abc.abstractmethod + def ready(self) -> bool: + """Whether container is ready + + Only applies to Kubernetes charm + """ + + @property + @abc.abstractmethod + def mysql_router_service_enabled(self) -> bool: + """MySQL Router service status""" + + @abc.abstractmethod + def update_mysql_router_service(self, *, enabled: bool, tls: bool = None) -> None: + """Update and restart MySQL Router service. + + Args: + enabled: Whether MySQL Router service is enabled + tls: Whether TLS is enabled. Required if enabled=True + """ + if enabled: + assert tls is not None, "`tls` argument required when enabled=True" + + @abc.abstractmethod + def _run_command(self, command: list[str], *, timeout: typing.Optional[int]) -> str: + """Run command in container. + + Raises: + CalledProcessError: Command returns non-zero exit code + """ + + def run_mysql_router(self, args: list[str], *, timeout: int = None) -> str: + """Run MySQL Router command. + + Raises: + CalledProcessError: Command returns non-zero exit code + """ + args.insert(0, self._mysql_router_command) + return self._run_command(args, timeout=timeout) + + def run_mysql_shell(self, args: list[str], *, timeout: int = None) -> str: + """Run MySQL Shell command. + + Raises: + CalledProcessError: Command returns non-zero exit code + """ + args.insert(0, self._mysql_shell_command) + return self._run_command(args, timeout=timeout) + + @abc.abstractmethod + def path(self, *args) -> Path: + pass diff --git a/src/mysql_shell.py b/src/mysql_shell.py index 1a105424..dd0786f1 100644 --- a/src/mysql_shell.py +++ b/src/mysql_shell.py @@ -11,30 +11,39 @@ import logging import secrets import string +import typing -import ops +import container _PASSWORD_LENGTH = 24 logger = logging.getLogger(__name__) +# TODO python3.10 min version: Add `(kw_only=True)` +@dataclasses.dataclass +class RouterUserInformation: + """MySQL Router user information""" + + username: str + router_id: str + + # TODO python3.10 min version: Add `(kw_only=True)` @dataclasses.dataclass class Shell: """MySQL Shell connected to MySQL cluster""" - _container: ops.Container + _container: container.Container username: str _password: str _host: str _port: str - _TEMPORARY_SCRIPT_FILE = "/tmp/script.py" - - def _run_commands(self, commands: list[str]) -> None: + def _run_commands(self, commands: list[str]) -> str: """Connect to MySQL cluster and run commands.""" # Redact password from log logged_commands = commands.copy() + # TODO: Password is still logged on user creation logged_commands.insert( 0, f"shell.connect('{self.username}:***@{self._host}:{self._port}')" ) @@ -42,17 +51,18 @@ def _run_commands(self, commands: list[str]) -> None: commands.insert( 0, f"shell.connect('{self.username}:{self._password}@{self._host}:{self._port}')" ) - self._container.push(self._TEMPORARY_SCRIPT_FILE, "\n".join(commands)) + temporary_script_file = self._container.path("/tmp/script.py") + temporary_script_file.write_text("\n".join(commands)) try: - process = self._container.exec( - ["mysqlsh", "--no-wizard", "--python", "--file", self._TEMPORARY_SCRIPT_FILE] + output = self._container.run_mysql_shell( + ["--no-wizard", "--python", "--file", str(temporary_script_file)] ) - process.wait_output() - except ops.pebble.ExecError as e: + except container.CalledProcessError as e: logger.exception(f"Failed to run {logged_commands=}\nstderr:\n{e.stderr}\n") raise finally: - self._container.remove_path(self._TEMPORARY_SCRIPT_FILE) + temporary_script_file.unlink() + return output def _run_sql(self, sql_statements: list[str]) -> None: """Connect to MySQL cluster and execute SQL.""" @@ -105,6 +115,48 @@ def add_attributes_to_mysql_router_user( self._run_sql([f"ALTER USER `{username}` ATTRIBUTE '{attributes}'"]) logger.debug(f"Added {attributes=} to {username=}") + def get_mysql_router_user_for_unit( + self, unit_name: str + ) -> typing.Optional[RouterUserInformation]: + """Get MySQL Router user created by a previous instance of the unit. + + Get username & router ID attribute. + + Before container restart, the charm does not have an opportunity to delete the MySQL + Router user or cluster metadata created during MySQL Router bootstrap. After container + restart, the user and cluster metadata should be deleted before bootstrapping MySQL Router + again. + """ + logger.debug(f"Getting MySQL Router user for {unit_name=}") + rows = json.loads( + self._run_commands( + [ + f"result = session.run_sql(\"SELECT USER, ATTRIBUTE->>'$.router_id' FROM INFORMATION_SCHEMA.USER_ATTRIBUTES WHERE ATTRIBUTE->'$.created_by_user'='{self.username}' AND ATTRIBUTE->'$.created_by_juju_unit'='{unit_name}'\")", + "print(result.fetch_all())", + ] + ) + ) + if not rows: + logger.debug(f"No MySQL Router user found for {unit_name=}") + return + assert len(rows) == 1 + username, router_id = rows[0] + user_info = RouterUserInformation(username=username, router_id=router_id) + logger.debug(f"MySQL Router user found for {unit_name=}: {user_info}") + return user_info + + def remove_router_from_cluster_metadata(self, router_id: str) -> None: + """Remove MySQL Router from InnoDB Cluster metadata. + + On container restart, MySQL Router bootstrap will fail without `--force` if cluster + metadata already exists for the router ID. + """ + logger.debug(f"Removing {router_id=} from cluster metadata") + self._run_commands( + ["cluster = dba.get_cluster()", f'cluster.remove_router_metadata("{router_id}")'] + ) + logger.debug(f"Removed {router_id=} from cluster metadata") + def delete_user(self, username: str) -> None: """Delete user.""" logger.debug(f"Deleting {username=}") diff --git a/src/relations/database_provides.py b/src/relations/database_provides.py index 7788a504..bbeda7bc 100644 --- a/src/relations/database_provides.py +++ b/src/relations/database_provides.py @@ -74,8 +74,14 @@ def __init__( def _set_databag(self, *, username: str, password: str, router_endpoint: str) -> None: """Share connection information with application charm.""" - read_write_endpoint = f"{router_endpoint}:6446" - read_only_endpoint = f"{router_endpoint}:6447" + # TODO: remove `file://`? + # TODO: get socket path from variable + read_write_endpoint = ( + "file:///var/snap/charmed-mysql/common/var/run/mysqlrouter/mysql.sock" + ) + read_only_endpoint = ( + "file:///var/snap/charmed-mysql/common/var/run/mysqlrouter/mysqlro.sock" + ) logger.debug( f"Setting databag {self._id=} {self._database=}, {username=}, {read_write_endpoint=}, {read_only_endpoint=}" ) diff --git a/src/snap.py b/src/snap.py new file mode 100644 index 00000000..df803096 --- /dev/null +++ b/src/snap.py @@ -0,0 +1,97 @@ +import pathlib +import shutil +import subprocess +import typing + +import charms.operator_libs_linux.v2.snap as snap_lib + +import container + +_UNIX_USERNAME = None # TODO +_SNAP_NAME = "charmed-mysql" + + +class _SnapPath(pathlib.PosixPath): + def __new__(cls, *args, **kwargs): + path = super().__new__(cls, *args, **kwargs) + if str(path).startswith("/etc/mysqlrouter") or str(path).startswith( + "/var/lib/mysqlrouter" + ): + parent = f"/var/snap/{_SNAP_NAME}/current" + elif str(path).startswith("/run"): + parent = f"/var/snap/{_SNAP_NAME}/common" + elif str(path).startswith("/tmp"): + parent = f"/tmp/snap-private-tmp/snap.{_SNAP_NAME}" + else: + return path + assert str(path).startswith("/") + return parent / path.relative_to("/") + + def __rtruediv__(self, other): + return type(self)(other, self) + + +class _Path(_SnapPath, container.Path): + _UNIX_USERNAME = _UNIX_USERNAME + + def read_text(self, encoding="utf-8", *args) -> str: + return super().read_text(encoding, *args) + + def write_text(self, data: str, encoding="utf-8", *args): + return super().write_text(data, encoding, *args) + + # TODO: override unlink with not exists no fail? + + def rmtree(self): + shutil.rmtree(self) + + +class Snap(container.Container): + UNIX_USERNAME = _UNIX_USERNAME + _SNAP_REVISION = "51" + _SERVICE_NAME = "mysqlrouter-service" + + def __init__(self) -> None: + super().__init__( + mysql_router_command=f"{_SNAP_NAME}.mysqlrouter", + mysql_shell_command=f"{_SNAP_NAME}.mysqlsh", + ) + + def ready(self) -> bool: + return True + + @property + def _snap(self) -> snap_lib.Snap: + return snap_lib.SnapCache()[_SNAP_NAME] + + @property + def mysql_router_service_enabled(self) -> bool: + return self._snap.services[self._SERVICE_NAME]["active"] + + def update_mysql_router_service(self, *, enabled: bool, tls: bool = None) -> None: + # TODO: uncomment when TLS is implemented + # super().update_mysql_router_service(enabled=enabled, tls=tls) + if tls is not None: + raise NotImplementedError + if enabled: + self._snap.start([self._SERVICE_NAME], enable=True) + else: + self._snap.stop([self._SERVICE_NAME], disable=True) + + def _run_command(self, command: list[str], *, timeout: typing.Optional[int]) -> str: + try: + output = subprocess.run( + command, + capture_output=True, + timeout=timeout, + check=True, + encoding="utf-8", + ).stdout + except subprocess.CalledProcessError as e: + raise container.CalledProcessError( + returncode=e.returncode, cmd=e.cmd, output=e.output, stderr=e.stderr + ) + return output + + def path(self, *args, **kwargs) -> _Path: + return _Path(*args, **kwargs) diff --git a/src/socket_workload.py b/src/socket_workload.py new file mode 100644 index 00000000..eba8d672 --- /dev/null +++ b/src/socket_workload.py @@ -0,0 +1,43 @@ +import configparser +import io +import pathlib + +import workload + + +# TODO: rename to Workload? +class SocketWorkload(workload.Workload): + pass + + +class AuthenticatedSocketWorkload(workload.AuthenticatedWorkload, SocketWorkload): + def _get_bootstrap_command(self, password: str): + command = super()._get_bootstrap_command(password) + command.extend( + [ + "--conf-use-sockets", + # For unix sockets, authentication fails on first connection if this option is not + # set. Workaround for https://bugs.mysql.com/bug.php?id=107291 + "--conf-set-option", + "DEFAULT.server_ssl_mode=PREFERRED", + ] + ) + return command + + def _change_socket_file_locations(self) -> None: + # TODO: rename + config = configparser.ConfigParser() + config.read_string(self._container.router_config_file.read_text()) + for section_name, section in config.items(): + if not section_name.startswith("routing:"): + continue + section["socket"] = str( + self._container.path("/run/mysqlrouter") / pathlib.PurePath(section["socket"]).name + ) + with io.StringIO() as output: + config.write(output) + self._container.router_config_file.write_text(output.getvalue()) + + def _bootstrap_router(self, *, tls: bool) -> None: + super()._bootstrap_router(tls=tls) + self._change_socket_file_locations() diff --git a/src/workload.py b/src/workload.py index 34afb291..88898608 100644 --- a/src/workload.py +++ b/src/workload.py @@ -4,14 +4,12 @@ """MySQL Router workload""" import configparser -import dataclasses import logging -import pathlib import socket +import string import typing -import ops - +import container import mysql_shell if typing.TYPE_CHECKING: @@ -21,112 +19,91 @@ logger = logging.getLogger(__name__) -# TODO python3.10 min version: Add `(kw_only=True)` -@dataclasses.dataclass class Workload: """MySQL Router workload""" - _container: ops.Container - - CONTAINER_NAME = "mysql-router" - _SERVICE_NAME = "mysql_router" - _UNIX_USERNAME = "mysql" - _ROUTER_CONFIG_DIRECTORY = pathlib.Path("/etc/mysqlrouter") - _ROUTER_DATA_DIRECTORY = pathlib.Path("/var/lib/mysqlrouter") - _ROUTER_CONFIG_FILE = "mysqlrouter.conf" + def __init__(self, container_: container.Container) -> None: + self._container = container_ + self._router_data_directory = self._container.path("/var/lib/mysqlrouter") + self._tls_key_file = self._container.router_config_directory / "custom-key.pem" + self._tls_certificate_file = ( + self._container.router_config_directory / "custom-certificate.pem" + ) @property def container_ready(self) -> bool: - """Whether container is ready""" - return self._container.can_connect() + """Whether container is ready - @property - def _enabled(self) -> bool: - """Service status""" - service = self._container.get_services(self._SERVICE_NAME).get(self._SERVICE_NAME) - if service is None: - return False - return service.startup == ops.pebble.ServiceStartup.ENABLED + Only applies to Kubernetes charm + """ + return self._container.ready @property def version(self) -> str: """MySQL Router version""" - process = self._container.exec(["mysqlrouter", "--version"]) - raw_version, _ = process.wait_output() - for version in raw_version.split(): - if version.startswith("8"): - return version + version = self._container.run_mysql_router(["--version"]) + for component in version.split(): + if component.startswith("8"): + return component return "" - def _update_layer(self, *, enabled: bool) -> None: - """Update and restart services. - - Args: - enabled: Whether MySQL Router service is enabled - """ - command = ( - f"mysqlrouter --config {self._ROUTER_CONFIG_DIRECTORY / self._ROUTER_CONFIG_FILE}" - ) - if enabled: - startup = ops.pebble.ServiceStartup.ENABLED.value - else: - startup = ops.pebble.ServiceStartup.DISABLED.value - layer = ops.pebble.Layer( - { - "summary": "mysql router layer", - "description": "the pebble config layer for mysql router", - "services": { - self._SERVICE_NAME: { - "override": "replace", - "summary": "mysql router", - "command": command, - "startup": startup, - "user": self._UNIX_USERNAME, - "group": self._UNIX_USERNAME, - }, - }, - } - ) - self._container.add_layer(self._SERVICE_NAME, layer, combine=True) - self._container.replan() - - def _create_directory(self, path: pathlib.Path) -> None: - """Create directory. - - Args: - path: Full filesystem path - """ - path = str(path) - self._container.make_dir(path, user=self._UNIX_USERNAME, group=self._UNIX_USERNAME) - - def _delete_directory(self, path: pathlib.Path) -> None: - """Delete directory. - - Args: - path: Full filesystem path - """ - path = str(path) - self._container.remove_path(path, recursive=True) - def disable(self) -> None: """Stop and disable MySQL Router service.""" - if not self._enabled: + if not self._container.mysql_router_service_enabled: return logger.debug("Disabling MySQL Router service") - self._update_layer(enabled=False) - self._delete_directory(self._ROUTER_CONFIG_DIRECTORY) - self._create_directory(self._ROUTER_CONFIG_DIRECTORY) - self._delete_directory(self._ROUTER_DATA_DIRECTORY) + self._container.update_mysql_router_service(enabled=False) + self._container.router_config_directory.rmtree() + self._container.router_config_directory.mkdir() + self._router_data_directory.rmtree() logger.debug("Disabled MySQL Router service") + @property + def _tls_config_file_data(self) -> str: + """Render config file template to string. + + Config file enables TLS on MySQL Router. + """ + with open("templates/tls.cnf", "r") as template_file: + template = string.Template(template_file.read()) + config_string = template.substitute( + tls_ssl_key_file=self._tls_key_file, + tls_ssl_cert_file=self._tls_certificate_file, + ) + return config_string + + def enable_tls(self, *, key: str, certificate: str): + """Enable TLS.""" + logger.debug("Enabling TLS") + self._container.tls_config_file.write_text(self._tls_config_file_data) + self._tls_key_file.write_text(key) + self._tls_certificate_file.write_text(certificate) + logger.debug("Enabled TLS") + + def disable_tls(self) -> None: + """Disable TLS.""" + logger.debug("Disabling TLS") + for file in ( + self._container.tls_config_file, + self._tls_key_file, + self._tls_certificate_file, + ): + file.unlink() + logger.debug("Disabled TLS") + -# TODO python3.10 min version: Add `(kw_only=True)` -@dataclasses.dataclass class AuthenticatedWorkload(Workload): """Workload with connection to MySQL cluster""" - _connection_info: "relations.database_requires.ConnectionInformation" - _charm: "charm.MySQLRouterOperatorCharm" + def __init__( + self, + container_: container.Container, + connection_info: "relations.database_requires.ConnectionInformation", + charm_: "charm.MySQLRouterOperatorCharm", + ) -> None: + super().__init__(container_) + self._connection_info = connection_info + self._charm = charm_ @property def shell(self) -> mysql_shell.Shell: @@ -148,43 +125,46 @@ def _router_id(self) -> str: # MySQL Router is bootstrapped without `--directory`—there is one system-wide instance. return f"{socket.getfqdn()}::system" - def _bootstrap_router(self) -> None: + def cleanup_after_potential_container_restart(self, *, unit_name: str) -> None: + """Remove MySQL Router cluster metadata & user after (potential) container restart. + + (Storage is not persisted on container restart—MySQL Router's config file is deleted. + Therefore, MySQL Router needs to be bootstrapped again.) + """ + if user_info := self.shell.get_mysql_router_user_for_unit(unit_name): + self.shell.remove_router_from_cluster_metadata(user_info.router_id) + self.shell.delete_user(user_info.username) + + def _get_bootstrap_command(self, password: str) -> list[str]: + return [ + "--bootstrap", + self._connection_info.username + + ":" + + password + + "@" + + self._connection_info.host + + ":" + + self._connection_info.port, + "--strict", + "--user", + self._container.UNIX_USERNAME, + "--conf-set-option", + "http_server.bind_address=127.0.0.1", + "--conf-use-gr-notifications", + ] + + def _bootstrap_router(self, *, tls: bool) -> None: """Bootstrap MySQL Router and enable service.""" logger.debug( - f"Bootstrapping router {self._connection_info.host=}, {self._connection_info.port=}" + f"Bootstrapping router {tls=}, {self._connection_info.host=}, {self._connection_info.port=}" ) - - def _get_command(password: str): - return [ - "mysqlrouter", - "--bootstrap", - self._connection_info.username - + ":" - + password - + "@" - + self._connection_info.host - + ":" - + self._connection_info.port, - "--strict", - "--user", - self._UNIX_USERNAME, - "--conf-set-option", - "http_server.bind_address=127.0.0.1", - "--conf-use-gr-notifications", - ] - # Redact password from log - logged_command = _get_command("***") + logged_command = self._get_bootstrap_command("***") - command = _get_command(self._connection_info.password) + command = self._get_bootstrap_command(self._connection_info.password) try: - # Bootstrap MySQL Router - process = self._container.exec( - command, - timeout=30, - ) - process.wait_output() - except ops.pebble.ExecError as e: + self._container.run_mysql_router(command, timeout=30) + except container.CalledProcessError as e: # Use `logger.error` instead of `logger.exception` so password isn't logged logger.error(f"Failed to bootstrap router\n{logged_command=}\nstderr:\n{e.stderr}\n") # Original exception contains password @@ -193,11 +173,8 @@ def _get_command(password: str): # `from None` disables exception chaining so that the original exception is not # included in the traceback raise Exception("Failed to bootstrap router") from None - # Enable service - self._update_layer(enabled=True) - logger.debug( - f"Bootstrapped router {self._connection_info.host=}, {self._connection_info.port=}" + f"Bootstrapped router {tls=}, {self._connection_info.host=}, {self._connection_info.port=}" ) @property @@ -208,23 +185,48 @@ def _router_username(self) -> str: `/etc/mysqlrouter/mysqlrouter.conf`. This file contains the username that was created during bootstrap. """ + # TODO: remove path from docstring config = configparser.ConfigParser() - config.read_file( - self._container.pull(self._ROUTER_CONFIG_DIRECTORY / self._ROUTER_CONFIG_FILE) - ) + config.read_string(self._container.router_config_file.read_text()) return config["metadata_cache:bootstrap"]["user"] - def enable(self, *, unit_name: str) -> None: + def enable(self, *, tls: bool, unit_name: str) -> None: """Start and enable MySQL Router service.""" - if self._enabled: + if self._container.mysql_router_service_enabled: # If the host or port changes, MySQL Router will receive topology change # notifications from MySQL. # Therefore, if the host or port changes, we do not need to restart MySQL Router. return logger.debug("Enabling MySQL Router service") - self._bootstrap_router() + self._bootstrap_router(tls=tls) + self._container.update_mysql_router_service(enabled=True, tls=tls) + # TODO: move before enable service self.shell.add_attributes_to_mysql_router_user( username=self._router_username, router_id=self._router_id, unit_name=unit_name ) logger.debug("Enabled MySQL Router service") self._charm.wait_until_mysql_router_ready() + + def _restart(self, *, tls: bool) -> None: + """Restart MySQL Router to enable or disable TLS.""" + logger.debug("Restarting MySQL Router") + assert self._container.mysql_router_service_enabled is True + self._bootstrap_router(tls=tls) + self._container.update_mysql_router_service(enabled=True, tls=tls) + logger.debug("Restarted MySQL Router") + self._charm.wait_until_mysql_router_ready() + # wait_until_mysql_router_ready will set WaitingStatus—override it with current charm + # status + self._charm.set_status(event=None) + + def enable_tls(self, *, key: str, certificate: str): + """Enable TLS and restart MySQL Router.""" + super().enable_tls(key=key, certificate=certificate) + if self._container.mysql_router_service_enabled: + self._restart(tls=True) + + def disable_tls(self) -> None: + """Disable TLS and restart MySQL Router.""" + super().disable_tls() + if self._container.mysql_router_service_enabled: + self._restart(tls=False) From f94160a07ae1e6483a36772eb589386ff485e5fd Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Mon, 12 Jun 2023 11:45:40 +0000 Subject: [PATCH 08/57] Fix recursion error --- src/snap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snap.py b/src/snap.py index df803096..b176e048 100644 --- a/src/snap.py +++ b/src/snap.py @@ -25,7 +25,7 @@ def __new__(cls, *args, **kwargs): else: return path assert str(path).startswith("/") - return parent / path.relative_to("/") + return super().__new__(cls, parent, path.relative_to("/"), **kwargs) def __rtruediv__(self, other): return type(self)(other, self) From fcdb6751bf827bdf9f0face4c0a08d8fbf6fc8be Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Mon, 12 Jun 2023 12:07:04 +0000 Subject: [PATCH 09/57] temp fix for /tmp path in snap --- src/mysql_shell.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mysql_shell.py b/src/mysql_shell.py index dd0786f1..c59ba18a 100644 --- a/src/mysql_shell.py +++ b/src/mysql_shell.py @@ -55,7 +55,7 @@ def _run_commands(self, commands: list[str]) -> str: temporary_script_file.write_text("\n".join(commands)) try: output = self._container.run_mysql_shell( - ["--no-wizard", "--python", "--file", str(temporary_script_file)] + ["--no-wizard", "--python", "--file", "/tmp/script.py"] ) except container.CalledProcessError as e: logger.exception(f"Failed to run {logged_commands=}\nstderr:\n{e.stderr}\n") From f90b96ed008d25cfaf47ec07958ad65ea19ea73a Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Mon, 12 Jun 2023 12:30:49 +0000 Subject: [PATCH 10/57] fix arg --- src/charm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/charm.py b/src/charm.py index 89ede81e..e2b92a52 100755 --- a/src/charm.py +++ b/src/charm.py @@ -148,7 +148,10 @@ def reconcile_database_relations(self, event=None) -> None: shell=workload_.shell, ) if isinstance(workload_, workload.AuthenticatedWorkload) and workload_.container_ready: - workload_.enable(unit_name=self.unit.name) + workload_.enable( + unit_name=self.unit.name, + tls=False, # TODO + ) elif workload_.container_ready: workload_.disable() self.set_status(event=event) From 8e8654cab3e2f92273f9164e6b7f984e5bc2c1f6 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Mon, 12 Jun 2023 13:00:22 +0000 Subject: [PATCH 11/57] add relative_to_container method to Path --- src/charm.py | 2 +- src/container.py | 8 ++++++++ src/mysql_shell.py | 7 ++++++- src/snap.py | 21 ++++++++++++++------- 4 files changed, 29 insertions(+), 9 deletions(-) diff --git a/src/charm.py b/src/charm.py index e2b92a52..e7c8e42c 100755 --- a/src/charm.py +++ b/src/charm.py @@ -150,7 +150,7 @@ def reconcile_database_relations(self, event=None) -> None: if isinstance(workload_, workload.AuthenticatedWorkload) and workload_.container_ready: workload_.enable( unit_name=self.unit.name, - tls=False, # TODO + tls=False, # TODO ) elif workload_.container_ready: workload_.disable() diff --git a/src/container.py b/src/container.py index 7fd99dfa..bd95c3d1 100644 --- a/src/container.py +++ b/src/container.py @@ -20,6 +20,14 @@ class Path(pathlib.PurePosixPath, abc.ABC): def _UNIX_USERNAME(self) -> str: pass + @property + @abc.abstractmethod + def relative_to_container(self) -> pathlib.PurePosixPath: + """Path from container root (instead of machine root) + + Only differs from `self` on machine charm + """ + @abc.abstractmethod def read_text(self) -> str: """Open the file in text mode, read it, and close the file.""" diff --git a/src/mysql_shell.py b/src/mysql_shell.py index c59ba18a..5940980a 100644 --- a/src/mysql_shell.py +++ b/src/mysql_shell.py @@ -55,7 +55,12 @@ def _run_commands(self, commands: list[str]) -> str: temporary_script_file.write_text("\n".join(commands)) try: output = self._container.run_mysql_shell( - ["--no-wizard", "--python", "--file", "/tmp/script.py"] + [ + "--no-wizard", + "--python", + "--file", + str(temporary_script_file.relative_to_container), + ] ) except container.CalledProcessError as e: logger.exception(f"Failed to run {logged_commands=}\nstderr:\n{e.stderr}\n") diff --git a/src/snap.py b/src/snap.py index b176e048..b80857f8 100644 --- a/src/snap.py +++ b/src/snap.py @@ -11,7 +11,9 @@ _SNAP_NAME = "charmed-mysql" -class _SnapPath(pathlib.PosixPath): +class _Path(pathlib.PosixPath, container.Path): + _UNIX_USERNAME = _UNIX_USERNAME + def __new__(cls, *args, **kwargs): path = super().__new__(cls, *args, **kwargs) if str(path).startswith("/etc/mysqlrouter") or str(path).startswith( @@ -23,16 +25,21 @@ def __new__(cls, *args, **kwargs): elif str(path).startswith("/tmp"): parent = f"/tmp/snap-private-tmp/snap.{_SNAP_NAME}" else: - return path - assert str(path).startswith("/") - return super().__new__(cls, parent, path.relative_to("/"), **kwargs) + parent = None + if parent: + assert str(path).startswith("/") + path = super().__new__(cls, parent, path.relative_to("/"), **kwargs) + path._container_parent = parent + return path def __rtruediv__(self, other): return type(self)(other, self) - -class _Path(_SnapPath, container.Path): - _UNIX_USERNAME = _UNIX_USERNAME + @property + def relative_to_container(self) -> pathlib.PurePosixPath: + if parent := self._container_parent: + return self.relative_to(parent) + return self def read_text(self, encoding="utf-8", *args) -> str: return super().read_text(encoding, *args) From 5f04da202353e5fcf5870699871d7ff76f698a1e Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Mon, 12 Jun 2023 13:12:32 +0000 Subject: [PATCH 12/57] cast to purepath before getting relative path --- src/snap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snap.py b/src/snap.py index b80857f8..56628509 100644 --- a/src/snap.py +++ b/src/snap.py @@ -38,7 +38,7 @@ def __rtruediv__(self, other): @property def relative_to_container(self) -> pathlib.PurePosixPath: if parent := self._container_parent: - return self.relative_to(parent) + return pathlib.PurePosixPath(self).relative_to(parent) return self def read_text(self, encoding="utf-8", *args) -> str: From 7af5af243fe2711fd0c0624c995baf98077c0988 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Mon, 12 Jun 2023 13:20:47 +0000 Subject: [PATCH 13/57] fix path --- src/snap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snap.py b/src/snap.py index 56628509..9699a649 100644 --- a/src/snap.py +++ b/src/snap.py @@ -38,7 +38,7 @@ def __rtruediv__(self, other): @property def relative_to_container(self) -> pathlib.PurePosixPath: if parent := self._container_parent: - return pathlib.PurePosixPath(self).relative_to(parent) + return pathlib.PurePosixPath("/", self.relative_to(parent)) return self def read_text(self, encoding="utf-8", *args) -> str: From aadadbcd93f20f9fe8e7e107b273b3307d76d912 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Mon, 12 Jun 2023 16:18:06 +0000 Subject: [PATCH 14/57] fix username --- src/snap.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/snap.py b/src/snap.py index 9699a649..1f7d4d3c 100644 --- a/src/snap.py +++ b/src/snap.py @@ -7,7 +7,7 @@ import container -_UNIX_USERNAME = None # TODO +_UNIX_USERNAME = "mysql" # TODO _SNAP_NAME = "charmed-mysql" @@ -60,10 +60,11 @@ class Snap(container.Container): def __init__(self) -> None: super().__init__( - mysql_router_command=f"{_SNAP_NAME}.mysqlrouter", - mysql_shell_command=f"{_SNAP_NAME}.mysqlsh", + mysql_router_command=f"/snap/bin/{_SNAP_NAME}.mysqlrouter", + mysql_shell_command=f"/snap/bin/{_SNAP_NAME}.mysqlsh", ) + @property def ready(self) -> bool: return True From 473d7f8cd5fc71d6af7c6827c39cb08b168620e9 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Mon, 12 Jun 2023 16:22:40 +0000 Subject: [PATCH 15/57] change username --- src/snap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snap.py b/src/snap.py index 1f7d4d3c..00e6d053 100644 --- a/src/snap.py +++ b/src/snap.py @@ -7,7 +7,7 @@ import container -_UNIX_USERNAME = "mysql" # TODO +_UNIX_USERNAME = "snap_daemon" # TODO _SNAP_NAME = "charmed-mysql" From d1362eb9a7e180ef829aa9e4d504439532cc3aa6 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Mon, 12 Jun 2023 16:29:32 +0000 Subject: [PATCH 16/57] fix --- src/charm.py | 2 +- src/snap.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/charm.py b/src/charm.py index e7c8e42c..5da75f15 100755 --- a/src/charm.py +++ b/src/charm.py @@ -150,7 +150,7 @@ def reconcile_database_relations(self, event=None) -> None: if isinstance(workload_, workload.AuthenticatedWorkload) and workload_.container_ready: workload_.enable( unit_name=self.unit.name, - tls=False, # TODO + tls=None, # TODO ) elif workload_.container_ready: workload_.disable() diff --git a/src/snap.py b/src/snap.py index 00e6d053..cecd8033 100644 --- a/src/snap.py +++ b/src/snap.py @@ -60,8 +60,8 @@ class Snap(container.Container): def __init__(self) -> None: super().__init__( - mysql_router_command=f"/snap/bin/{_SNAP_NAME}.mysqlrouter", - mysql_shell_command=f"/snap/bin/{_SNAP_NAME}.mysqlsh", + mysql_router_command=f"{_SNAP_NAME}.mysqlrouter", + mysql_shell_command=f"{_SNAP_NAME}.mysqlsh", ) @property From 3a6c42bf88afbde60d2500d334ea4813da02ea10 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Mon, 12 Jun 2023 16:59:25 +0000 Subject: [PATCH 17/57] fix socket path --- src/snap.py | 2 +- src/socket_workload.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/snap.py b/src/snap.py index cecd8033..e958023f 100644 --- a/src/snap.py +++ b/src/snap.py @@ -20,7 +20,7 @@ def __new__(cls, *args, **kwargs): "/var/lib/mysqlrouter" ): parent = f"/var/snap/{_SNAP_NAME}/current" - elif str(path).startswith("/run"): + elif str(path).startswith("/var/run"): # TODO: user /run instead of /var/run? parent = f"/var/snap/{_SNAP_NAME}/common" elif str(path).startswith("/tmp"): parent = f"/tmp/snap-private-tmp/snap.{_SNAP_NAME}" diff --git a/src/socket_workload.py b/src/socket_workload.py index eba8d672..ee47255d 100644 --- a/src/socket_workload.py +++ b/src/socket_workload.py @@ -32,7 +32,9 @@ def _change_socket_file_locations(self) -> None: if not section_name.startswith("routing:"): continue section["socket"] = str( - self._container.path("/run/mysqlrouter") / pathlib.PurePath(section["socket"]).name + # TODO use /run instead of /var/run? + self._container.path("/var/run/mysqlrouter") + / pathlib.PurePath(section["socket"]).name ) with io.StringIO() as output: config.write(output) From 6d077e457a21b5f590f509dbf0f1326d06b169cb Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Mon, 12 Jun 2023 17:55:51 +0000 Subject: [PATCH 18/57] remove username --- src/container.py | 10 ---------- src/snap.py | 4 ---- src/workload.py | 2 -- 3 files changed, 16 deletions(-) diff --git a/src/container.py b/src/container.py index bd95c3d1..a6139501 100644 --- a/src/container.py +++ b/src/container.py @@ -15,11 +15,6 @@ def uninstall(self): class Path(pathlib.PurePosixPath, abc.ABC): - @property - @abc.abstractmethod - def _UNIX_USERNAME(self) -> str: - pass - @property @abc.abstractmethod def relative_to_container(self) -> pathlib.PurePosixPath: @@ -57,11 +52,6 @@ def __init__(self, *, returncode: int, cmd: list[str], output: str, stderr: str) class Container(abc.ABC): - @property - @abc.abstractmethod - def UNIX_USERNAME(self) -> str: - pass - @property def router_config_directory(self) -> Path: return self.path("/etc/mysqlrouter") diff --git a/src/snap.py b/src/snap.py index e958023f..6ab0bf11 100644 --- a/src/snap.py +++ b/src/snap.py @@ -7,13 +7,10 @@ import container -_UNIX_USERNAME = "snap_daemon" # TODO _SNAP_NAME = "charmed-mysql" class _Path(pathlib.PosixPath, container.Path): - _UNIX_USERNAME = _UNIX_USERNAME - def __new__(cls, *args, **kwargs): path = super().__new__(cls, *args, **kwargs) if str(path).startswith("/etc/mysqlrouter") or str(path).startswith( @@ -54,7 +51,6 @@ def rmtree(self): class Snap(container.Container): - UNIX_USERNAME = _UNIX_USERNAME _SNAP_REVISION = "51" _SERVICE_NAME = "mysqlrouter-service" diff --git a/src/workload.py b/src/workload.py index 88898608..ed94703b 100644 --- a/src/workload.py +++ b/src/workload.py @@ -146,8 +146,6 @@ def _get_bootstrap_command(self, password: str) -> list[str]: + ":" + self._connection_info.port, "--strict", - "--user", - self._container.UNIX_USERNAME, "--conf-set-option", "http_server.bind_address=127.0.0.1", "--conf-use-gr-notifications", From 2558aae6173cac95277760df2c64a8688a6a86c3 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Thu, 15 Jun 2023 17:28:47 +0000 Subject: [PATCH 19/57] remove installer --- src/container.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/container.py b/src/container.py index a6139501..7011ec2f 100644 --- a/src/container.py +++ b/src/container.py @@ -4,16 +4,6 @@ import typing -class Installer(abc.ABC): - @abc.abstractmethod - def install(self): - pass - - @abc.abstractmethod - def uninstall(self): - pass - - class Path(pathlib.PurePosixPath, abc.ABC): @property @abc.abstractmethod From 6d8ab1c86e1617ec510752e9e4f4f4e0c83c478e Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 16 Jun 2023 15:58:18 +0000 Subject: [PATCH 20/57] sync --- src/mysql_shell.py | 1 - src/workload.py | 5 +---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/mysql_shell.py b/src/mysql_shell.py index 5940980a..21545a33 100644 --- a/src/mysql_shell.py +++ b/src/mysql_shell.py @@ -43,7 +43,6 @@ def _run_commands(self, commands: list[str]) -> str: """Connect to MySQL cluster and run commands.""" # Redact password from log logged_commands = commands.copy() - # TODO: Password is still logged on user creation logged_commands.insert( 0, f"shell.connect('{self.username}:***@{self._host}:{self._port}')" ) diff --git a/src/workload.py b/src/workload.py index ed94703b..ae9da15c 100644 --- a/src/workload.py +++ b/src/workload.py @@ -179,11 +179,8 @@ def _bootstrap_router(self, *, tls: bool) -> None: def _router_username(self) -> str: """Read MySQL Router username from config file. - During bootstrap, MySQL Router creates a config file at - `/etc/mysqlrouter/mysqlrouter.conf`. This file contains the username that was created - during bootstrap. + During bootstrap, MySQL Router creates a config file which includes a generated username. """ - # TODO: remove path from docstring config = configparser.ConfigParser() config.read_string(self._container.router_config_file.read_text()) return config["metadata_cache:bootstrap"]["user"] From 27ea0de3df150425d967a488d6bfa16632dc037e Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 16 Jun 2023 18:44:14 +0000 Subject: [PATCH 21/57] sync & socket endpoints --- src/charm.py | 12 ++----- src/relations/database_provides.py | 57 ++++++++++++++++++++---------- src/socket_workload.py | 8 +++++ src/workload.py | 15 ++++++-- 4 files changed, 62 insertions(+), 30 deletions(-) diff --git a/src/charm.py b/src/charm.py index 5da75f15..3377d2f5 100755 --- a/src/charm.py +++ b/src/charm.py @@ -45,17 +45,10 @@ def get_workload(self, *, event): container_=container, connection_info=connection_info, charm_=self, + host="", # TODO: replace with IP address when enabling TCP ) return socket_workload.SocketWorkload(container_=container) - @property - def _endpoint(self) -> str: - """K8s endpoint for MySQL Router""" - # TODO: remove - # Example: mysql-router-k8s.my-model.svc.cluster.local - return "foo" - # return f"{self.app.name}.{self.model_service_domain}" - @staticmethod def _prioritize_statuses(statuses: list[ops.StatusBase]) -> ops.StatusBase: """Report the highest priority status. @@ -144,7 +137,8 @@ def reconcile_database_relations(self, event=None) -> None: ): self.database_provides.reconcile_users( event=event, - router_endpoint=self._endpoint, + router_read_write_endpoint=workload_.read_write_endpoint, + router_read_only_endpoint=workload_.read_only_endpoint, shell=workload_.shell, ) if isinstance(workload_, workload.AuthenticatedWorkload) and workload_.container_ready: diff --git a/src/relations/database_provides.py b/src/relations/database_provides.py index bbeda7bc..1dc6f0fe 100644 --- a/src/relations/database_provides.py +++ b/src/relations/database_provides.py @@ -72,34 +72,44 @@ def __init__( app_name=relation.app.name, endpoint_name=relation.name ) - def _set_databag(self, *, username: str, password: str, router_endpoint: str) -> None: + def _set_databag( + self, + *, + username: str, + password: str, + router_read_write_endpoint: str, + router_read_only_endpoint: str, + ) -> None: """Share connection information with application charm.""" - # TODO: remove `file://`? - # TODO: get socket path from variable - read_write_endpoint = ( - "file:///var/snap/charmed-mysql/common/var/run/mysqlrouter/mysql.sock" - ) - read_only_endpoint = ( - "file:///var/snap/charmed-mysql/common/var/run/mysqlrouter/mysqlro.sock" - ) logger.debug( - f"Setting databag {self._id=} {self._database=}, {username=}, {read_write_endpoint=}, {read_only_endpoint=}" + f"Setting databag {self._id=} {self._database=}, {username=}, {router_read_write_endpoint=}, {router_read_only_endpoint=}" ) self._interface.set_database(self._id, self._database) self._interface.set_credentials(self._id, username, password) - self._interface.set_endpoints(self._id, read_write_endpoint) - self._interface.set_read_only_endpoints(self._id, read_only_endpoint) + self._interface.set_endpoints(self._id, router_read_write_endpoint) + self._interface.set_read_only_endpoints(self._id, router_read_only_endpoint) logger.debug( - f"Set databag {self._id=} {self._database=}, {username=}, {read_write_endpoint=}, {read_only_endpoint=}" + f"Set databag {self._id=} {self._database=}, {username=}, {router_read_write_endpoint=}, {router_read_only_endpoint=}" ) - def create_database_and_user(self, *, router_endpoint: str, shell: mysql_shell.Shell) -> None: + def create_database_and_user( + self, + *, + router_read_write_endpoint: str, + router_read_only_endpoint: str, + shell: mysql_shell.Shell, + ) -> None: """Create database & user and update databag.""" username = self._get_username(shell.username) password = shell.create_application_database_and_user( username=username, database=self._database ) - self._set_databag(username=username, password=password, router_endpoint=router_endpoint) + self._set_databag( + username=username, + password=password, + router_read_write_endpoint=router_read_write_endpoint, + router_read_only_endpoint=router_read_only_endpoint, + ) class _UserNotCreated(Exception): @@ -166,7 +176,8 @@ def reconcile_users( self, *, event, - router_endpoint: str, + router_read_write_endpoint: str, + router_read_only_endpoint: str, shell: mysql_shell.Shell, ) -> None: """Create requested users and delete inactive users. @@ -175,7 +186,9 @@ def reconcile_users( created by this charm. Therefore, this charm does not need to delete users when that relation is broken. """ - logger.debug(f"Reconciling users {event=}, {router_endpoint=}") + logger.debug( + f"Reconciling users {event=}, {router_read_write_endpoint=}, {router_read_only_endpoint=}" + ) requested_users = [] for relation in self._interface.relations: try: @@ -193,11 +206,17 @@ def reconcile_users( logger.debug(f"State of reconcile users {requested_users=}, {self._created_users=}") for relation in requested_users: if relation not in self._created_users: - relation.create_database_and_user(router_endpoint=router_endpoint, shell=shell) + relation.create_database_and_user( + router_read_write_endpoint=router_read_write_endpoint, + router_read_only_endpoint=router_read_only_endpoint, + shell=shell, + ) for relation in self._created_users: if relation not in requested_users: relation.delete_user(shell=shell) - logger.debug(f"Reconciled users {event=}, {router_endpoint=}") + logger.debug( + f"Reconciled users {event=}, {router_read_write_endpoint=}, {router_read_only_endpoint=}" + ) def delete_all_databags(self) -> None: """Remove connection information from all databags. diff --git a/src/socket_workload.py b/src/socket_workload.py index ee47255d..b8b90340 100644 --- a/src/socket_workload.py +++ b/src/socket_workload.py @@ -11,6 +11,14 @@ class SocketWorkload(workload.Workload): class AuthenticatedSocketWorkload(workload.AuthenticatedWorkload, SocketWorkload): + @property + def read_write_endpoint(self) -> str: + return f'file://{self._container.path("/run/mysqlrouter/mysql.sock")}' + + @property + def read_only_endpoint(self) -> str: + return f'file://{self._container.path("/run/mysqlrouter/mysqlro.sock")}' + def _get_bootstrap_command(self, password: str): command = super()._get_bootstrap_command(password) command.extend( diff --git a/src/workload.py b/src/workload.py index ae9da15c..f7b68773 100644 --- a/src/workload.py +++ b/src/workload.py @@ -22,7 +22,7 @@ class Workload: """MySQL Router workload""" - def __init__(self, container_: container.Container) -> None: + def __init__(self, *, container_: container.Container) -> None: self._container = container_ self._router_data_directory = self._container.path("/var/lib/mysqlrouter") self._tls_key_file = self._container.router_config_directory / "custom-key.pem" @@ -97,14 +97,25 @@ class AuthenticatedWorkload(Workload): def __init__( self, + *, container_: container.Container, connection_info: "relations.database_requires.ConnectionInformation", + host: str, charm_: "charm.MySQLRouterOperatorCharm", ) -> None: - super().__init__(container_) + super().__init__(container_=container_) self._connection_info = connection_info + self._host = host self._charm = charm_ + @property + def read_write_endpoint(self) -> str: + return f"{self._host}:6446" + + @property + def read_only_endpoint(self) -> str: + return f"{self._host}:6447" + @property def shell(self) -> mysql_shell.Shell: """MySQL Shell""" From 4f47eea2523e0b85f79faa814a4ca577dc119222 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 16 Jun 2023 18:44:35 +0000 Subject: [PATCH 22/57] /run instead of /var/run --- src/snap.py | 2 +- src/socket_workload.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/snap.py b/src/snap.py index 6ab0bf11..3dcd28eb 100644 --- a/src/snap.py +++ b/src/snap.py @@ -17,7 +17,7 @@ def __new__(cls, *args, **kwargs): "/var/lib/mysqlrouter" ): parent = f"/var/snap/{_SNAP_NAME}/current" - elif str(path).startswith("/var/run"): # TODO: user /run instead of /var/run? + elif str(path).startswith("/run"): parent = f"/var/snap/{_SNAP_NAME}/common" elif str(path).startswith("/tmp"): parent = f"/tmp/snap-private-tmp/snap.{_SNAP_NAME}" diff --git a/src/socket_workload.py b/src/socket_workload.py index b8b90340..0eb0edf5 100644 --- a/src/socket_workload.py +++ b/src/socket_workload.py @@ -40,9 +40,7 @@ def _change_socket_file_locations(self) -> None: if not section_name.startswith("routing:"): continue section["socket"] = str( - # TODO use /run instead of /var/run? - self._container.path("/var/run/mysqlrouter") - / pathlib.PurePath(section["socket"]).name + self._container.path("/run/mysqlrouter") / pathlib.PurePath(section["socket"]).name ) with io.StringIO() as output: config.write(output) From fb8e8404b99dfe4d41240fe371ec59e1945fb5ca Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 16 Jun 2023 18:52:33 +0000 Subject: [PATCH 23/57] Add missing_ok to unlink() --- src/container.py | 2 +- src/snap.py | 2 -- src/workload.py | 2 +- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/container.py b/src/container.py index 7011ec2f..1bafa7b9 100644 --- a/src/container.py +++ b/src/container.py @@ -22,7 +22,7 @@ def write_text(self, data: str): """Open the file in text mode, write to it, and close the file.""" @abc.abstractmethod - def unlink(self): + def unlink(self, *, missing_ok=False): """Remove this file or link.""" @abc.abstractmethod diff --git a/src/snap.py b/src/snap.py index 3dcd28eb..c7f8d06f 100644 --- a/src/snap.py +++ b/src/snap.py @@ -44,8 +44,6 @@ def read_text(self, encoding="utf-8", *args) -> str: def write_text(self, data: str, encoding="utf-8", *args): return super().write_text(data, encoding, *args) - # TODO: override unlink with not exists no fail? - def rmtree(self): shutil.rmtree(self) diff --git a/src/workload.py b/src/workload.py index f7b68773..b039fe09 100644 --- a/src/workload.py +++ b/src/workload.py @@ -88,7 +88,7 @@ def disable_tls(self) -> None: self._tls_key_file, self._tls_certificate_file, ): - file.unlink() + file.unlink(missing_ok=True) logger.debug("Disabled TLS") From 9ce5d195d6664a08163ba8dcc8521b2bae26a71c Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 16 Jun 2023 18:54:47 +0000 Subject: [PATCH 24/57] remove comment --- src/socket_workload.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/socket_workload.py b/src/socket_workload.py index 0eb0edf5..3b45ff15 100644 --- a/src/socket_workload.py +++ b/src/socket_workload.py @@ -5,7 +5,6 @@ import workload -# TODO: rename to Workload? class SocketWorkload(workload.Workload): pass From b039fad0fd4e591965acf95f89b7c916d81677e7 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 16 Jun 2023 19:06:23 +0000 Subject: [PATCH 25/57] Add docstring to _update_configured_socket_file_locations() --- src/socket_workload.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/socket_workload.py b/src/socket_workload.py index 3b45ff15..cde9b584 100644 --- a/src/socket_workload.py +++ b/src/socket_workload.py @@ -31,8 +31,17 @@ def _get_bootstrap_command(self, password: str): ) return command - def _change_socket_file_locations(self) -> None: - # TODO: rename + def _update_configured_socket_file_locations(self) -> None: + """Update configured socket file locations from `/tmp` to `/run/mysqlrouter`. + + Called after MySQL Router bootstrap & before MySQL Router service is enabled + + Change configured location of socket files before socket files are created by MySQL Router + service. + + Needed since `/tmp` inside a snap is not accessible to non-root users. The socket files + must be accessible to applications related via database_provides endpoint. + """ config = configparser.ConfigParser() config.read_string(self._container.router_config_file.read_text()) for section_name, section in config.items(): @@ -47,4 +56,4 @@ def _change_socket_file_locations(self) -> None: def _bootstrap_router(self, *, tls: bool) -> None: super()._bootstrap_router(tls=tls) - self._change_socket_file_locations() + self._update_configured_socket_file_locations() From dd433b03d0a5e310870c94f262f2cf66d0f44005 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Tue, 20 Jun 2023 15:09:37 +0000 Subject: [PATCH 26/57] docstring sync --- src/container.py | 17 ++++++++++++++++- src/workload.py | 2 ++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/container.py b/src/container.py index 1bafa7b9..b680eaef 100644 --- a/src/container.py +++ b/src/container.py @@ -1,3 +1,8 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Workload container (snap or ROCK/OCI)""" + import abc import pathlib import subprocess @@ -5,6 +10,8 @@ class Path(pathlib.PurePosixPath, abc.ABC): + """Workload container (snap or ROCK) filesystem path""" + @property @abc.abstractmethod def relative_to_container(self) -> pathlib.PurePosixPath: @@ -42,16 +49,24 @@ def __init__(self, *, returncode: int, cmd: list[str], output: str, stderr: str) class Container(abc.ABC): + """Workload container (snap or ROCK)""" + @property def router_config_directory(self) -> Path: + """MySQL Router configuration directory""" return self.path("/etc/mysqlrouter") @property def router_config_file(self) -> Path: + """MySQL Router configuration file + + Automatically generated by MySQL Router bootstrap + """ return self.router_config_directory / "mysqlrouter.conf" @property def tls_config_file(self) -> Path: + """Extra MySQL Router configuration file to enable TLS""" return self.router_config_directory / "tls.conf" def __init__(self, *, mysql_router_command: str, mysql_shell_command: str) -> None: @@ -110,4 +125,4 @@ def run_mysql_shell(self, args: list[str], *, timeout: int = None) -> str: @abc.abstractmethod def path(self, *args) -> Path: - pass + """Container filesystem path""" diff --git a/src/workload.py b/src/workload.py index b039fe09..94e84ddc 100644 --- a/src/workload.py +++ b/src/workload.py @@ -110,10 +110,12 @@ def __init__( @property def read_write_endpoint(self) -> str: + """MySQL Router read-write endpoint""" return f"{self._host}:6446" @property def read_only_endpoint(self) -> str: + """MySQL Router read-only endpoint""" return f"{self._host}:6447" @property From 21c11bcb2e858c40bb25fdc02a7c848445a59997 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Tue, 20 Jun 2023 15:12:08 +0000 Subject: [PATCH 27/57] tls todo --- src/charm.py | 4 ++-- src/snap.py | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/charm.py b/src/charm.py index 3377d2f5..869e4cbb 100755 --- a/src/charm.py +++ b/src/charm.py @@ -45,7 +45,7 @@ def get_workload(self, *, event): container_=container, connection_info=connection_info, charm_=self, - host="", # TODO: replace with IP address when enabling TCP + host="", # TODO TLS: replace with IP address when enabling TCP ) return socket_workload.SocketWorkload(container_=container) @@ -144,7 +144,7 @@ def reconcile_database_relations(self, event=None) -> None: if isinstance(workload_, workload.AuthenticatedWorkload) and workload_.container_ready: workload_.enable( unit_name=self.unit.name, - tls=None, # TODO + tls=False, # TODO TLS ) elif workload_.container_ready: workload_.disable() diff --git a/src/snap.py b/src/snap.py index c7f8d06f..407cfd66 100644 --- a/src/snap.py +++ b/src/snap.py @@ -71,10 +71,9 @@ def mysql_router_service_enabled(self) -> bool: return self._snap.services[self._SERVICE_NAME]["active"] def update_mysql_router_service(self, *, enabled: bool, tls: bool = None) -> None: - # TODO: uncomment when TLS is implemented - # super().update_mysql_router_service(enabled=enabled, tls=tls) - if tls is not None: - raise NotImplementedError + super().update_mysql_router_service(enabled=enabled, tls=tls) + if tls: + raise NotImplementedError # TODO TLS if enabled: self._snap.start([self._SERVICE_NAME], enable=True) else: From ff2681bc72cd0b3e144ee79e0b0c346782bc72ba Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Tue, 20 Jun 2023 15:14:55 +0000 Subject: [PATCH 28/57] update snap --- src/charm.py | 2 +- src/snap.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/charm.py b/src/charm.py index 869e4cbb..7c5792b7 100755 --- a/src/charm.py +++ b/src/charm.py @@ -156,7 +156,7 @@ def _on_install(self, _) -> None: # TODO: move to workload.py? # TODO set workload version _SNAP_NAME = "charmed-mysql" - _SNAP_REVISION = "51" + _SNAP_REVISION = "57" mysql_snap = snap_lib.SnapCache()[_SNAP_NAME] if mysql_snap.present: logger.error(f"{_SNAP_NAME} snap already installed on machine. Installation aborted") diff --git a/src/snap.py b/src/snap.py index 407cfd66..2f0b8c2a 100644 --- a/src/snap.py +++ b/src/snap.py @@ -49,7 +49,7 @@ def rmtree(self): class Snap(container.Container): - _SNAP_REVISION = "51" + _SNAP_REVISION = "57" _SERVICE_NAME = "mysqlrouter-service" def __init__(self) -> None: From 6ad27560c4cece0d2af9bf620389e3f7282898f3 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Tue, 20 Jun 2023 19:02:01 +0000 Subject: [PATCH 29/57] sync add attributes before enable --- src/workload.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/workload.py b/src/workload.py index 94e84ddc..aba0bf27 100644 --- a/src/workload.py +++ b/src/workload.py @@ -207,11 +207,10 @@ def enable(self, *, tls: bool, unit_name: str) -> None: return logger.debug("Enabling MySQL Router service") self._bootstrap_router(tls=tls) - self._container.update_mysql_router_service(enabled=True, tls=tls) - # TODO: move before enable service self.shell.add_attributes_to_mysql_router_user( username=self._router_username, router_id=self._router_id, unit_name=unit_name ) + self._container.update_mysql_router_service(enabled=True, tls=tls) logger.debug("Enabled MySQL Router service") self._charm.wait_until_mysql_router_ready() From 97f74c3e58abc98ea9cf3b2f7c03165ecb5b6d8d Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Tue, 20 Jun 2023 19:47:27 +0000 Subject: [PATCH 30/57] add installer to container --- src/charm.py | 28 +++++++--------------------- src/container.py | 12 ++++++++++++ src/snap.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 21 deletions(-) diff --git a/src/charm.py b/src/charm.py index 7c5792b7..c3c7f81a 100755 --- a/src/charm.py +++ b/src/charm.py @@ -9,7 +9,6 @@ import logging import socket -import charms.operator_libs_linux.v2.snap as snap_lib import ops import tenacity @@ -100,9 +99,10 @@ def wait_until_mysql_router_ready(self) -> None: self.unit.status = ops.WaitingStatus("MySQL Router starting") try: for attempt in tenacity.Retrying( - reraise=True, stop=tenacity.stop_after_delay(30), wait=tenacity.wait_fixed(5), + retry=tenacity.retry_if_exception_type(AssertionError), + reraise=True, ): with attempt: for port in (6446, 6447): @@ -151,27 +151,13 @@ def reconcile_database_relations(self, event=None) -> None: self.set_status(event=event) def _on_install(self, _) -> None: - """Patch existing k8s service to include read-write and read-only services.""" - # TODO update docstring - # TODO: move to workload.py? - # TODO set workload version - _SNAP_NAME = "charmed-mysql" - _SNAP_REVISION = "57" - mysql_snap = snap_lib.SnapCache()[_SNAP_NAME] - if mysql_snap.present: - logger.error(f"{_SNAP_NAME} snap already installed on machine. Installation aborted") - raise Exception(f"Multiple {_SNAP_NAME} snap installs not supported on one machine") - logger.debug(f"Installing {_SNAP_NAME=}, {_SNAP_REVISION=}") - # TODO: set status - # TODO catch/retry on error? - mysql_snap.ensure(snap_lib.SnapState.Present, revision=_SNAP_REVISION) - logger.debug(f"Installed {_SNAP_NAME=}, {_SNAP_REVISION=}") - self.unit.set_workload_version(self.get_workload(event=None).version) + snap.Installer().install(unit=self.unit) + workload_ = self.get_workload(event=None) + if workload_.container_ready: # check for VM instead? + self.unit.set_workload_version(workload_.version) def _on_remove(self, _) -> None: - _SNAP_NAME = "charmed-mysql" - mysql_snap = snap_lib.SnapCache()[_SNAP_NAME] - mysql_snap.ensure(snap_lib.SnapState.Absent) + snap.Installer().uninstall() def _on_start(self, _) -> None: # Set status on first start if no relations active diff --git a/src/container.py b/src/container.py index b680eaef..7ee5f922 100644 --- a/src/container.py +++ b/src/container.py @@ -8,6 +8,18 @@ import subprocess import typing +import ops + + +class Installer(abc.ABC): + @abc.abstractmethod + def install(self, *, unit: ops.Unit, model_name: str, app_name: str): + pass + + @abc.abstractmethod + def uninstall(self): + pass + class Path(pathlib.PurePosixPath, abc.ABC): """Workload container (snap or ROCK) filesystem path""" diff --git a/src/snap.py b/src/snap.py index 2f0b8c2a..e4a1a939 100644 --- a/src/snap.py +++ b/src/snap.py @@ -1,14 +1,56 @@ +import logging import pathlib import shutil import subprocess import typing import charms.operator_libs_linux.v2.snap as snap_lib +import ops +import tenacity import container _SNAP_NAME = "charmed-mysql" +logger = logging.getLogger(__name__) + + +class Installer(container.Installer): + _SNAP_REVISION = "57" + + @property + def _snap(self) -> snap_lib.Snap: + return snap_lib.SnapCache()[_SNAP_NAME] + + def install(self, *, unit: ops.Unit, **_): + if self._snap.present: + logger.error(f"{_SNAP_NAME} snap already installed on machine. Installation aborted") + raise Exception(f"Multiple {_SNAP_NAME} snap installs not supported on one machine") + logger.debug(f"Installing {_SNAP_NAME=}, {self._SNAP_REVISION=}") + unit.status = ops.MaintenanceStatus("Installing snap") + + def _set_retry_status(_) -> None: + unit.status = ops.MaintenanceStatus("Snap install failed. Retrying...") + + try: + for attempt in tenacity.Retrying( + stop=tenacity.stop_after_delay(60 * 5), + wait=tenacity.wait_exponential(multiplier=10), + retry=tenacity.retry_if_exception_type(snap_lib.SnapError), + after=_set_retry_status, + reraise=True, + ): + with attempt: + self._snap.ensure( + state=snap_lib.SnapState.Present, revision=self._SNAP_REVISION + ) + except snap_lib.SnapError: + raise + logger.debug(f"Installed {_SNAP_NAME=}, {self._SNAP_REVISION=}") + + def uninstall(self): + self._snap.ensure(state=snap_lib.SnapState.Absent) + class _Path(pathlib.PosixPath, container.Path): def __new__(cls, *args, **kwargs): From 15ad29a9121d624c0411f112b124078cde5abf6a Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Wed, 21 Jun 2023 18:31:51 +0000 Subject: [PATCH 31/57] Add abstract_charm.py --- charmcraft.yaml | 1 + src/{charm.py => abstract_charm.py} | 90 ++++++++++++++--------------- src/container.py | 12 ---- src/machine_charm.py | 53 +++++++++++++++++ src/relations/database_provides.py | 4 +- src/relations/database_requires.py | 4 +- src/snap.py | 6 +- src/socket_workload.py | 8 --- src/workload.py | 24 +++----- 9 files changed, 112 insertions(+), 90 deletions(-) rename src/{charm.py => abstract_charm.py} (67%) mode change 100755 => 100644 create mode 100755 src/machine_charm.py diff --git a/charmcraft.yaml b/charmcraft.yaml index 5f2b42d3..17fd6108 100644 --- a/charmcraft.yaml +++ b/charmcraft.yaml @@ -20,5 +20,6 @@ bases: channel: "22.04" parts: charm: + charm-entrypoint: src/machine_charm.py charm-binary-python-packages: - mysql-connector-python==8.0.32 diff --git a/src/charm.py b/src/abstract_charm.py old mode 100755 new mode 100644 similarity index 67% rename from src/charm.py rename to src/abstract_charm.py index c3c7f81a..5740ba5c --- a/src/charm.py +++ b/src/abstract_charm.py @@ -1,52 +1,65 @@ -#!/usr/bin/env python3 -# Copyright 2022 Canonical Ltd. +# Copyright 2023 Canonical Ltd. # See LICENSE file for licensing details. -# -# Learn more at: https://juju.is/docs/sdk -"""MySQL Router kubernetes (k8s) charm""" +"""MySQL Router charm""" +import abc import logging import socket import ops import tenacity +import container import relations.database_provides import relations.database_requires -import snap -import socket_workload import workload logger = logging.getLogger(__name__) -class MySQLRouterOperatorCharm(ops.CharmBase): - """Operator charm for MySQL Router""" +class MySQLRouterCharm(ops.CharmBase, abc.ABC): + """MySQL Router charm""" def __init__(self, *args) -> None: super().__init__(*args) + self._workload_type = workload.Workload + self._authenticated_workload_type = workload.AuthenticatedWorkload + self._database_requires = relations.database_requires.RelationEndpoint(self) + self._database_provides = relations.database_provides.RelationEndpoint(self) + self.framework.observe(self.on.start, self._on_start) + self.framework.observe(self.on.leader_elected, self._on_leader_elected) - self.database_requires = relations.database_requires.RelationEndpoint(self) + @property + def _tls_certificate_saved(self) -> bool: + """Whether a TLS certificate is available to use""" + # TODO VM TLS: Remove property after implementing TLS on machine charm + return False - self.database_provides = relations.database_provides.RelationEndpoint(self) + @property + @abc.abstractmethod + def _container(self) -> container.Container: + """Workload container (snap or ROCK)""" - self.framework.observe(self.on.install, self._on_install) - self.framework.observe(self.on.remove, self._on_remove) - self.framework.observe(self.on.start, self._on_start) - self.framework.observe(self.on.leader_elected, self._on_leader_elected) + @property + @abc.abstractmethod + def _read_write_endpoint(self) -> str: + """MySQL Router read-write endpoint""" + + @property + @abc.abstractmethod + def _read_only_endpoint(self) -> str: + """MySQL Router read-only endpoint""" def get_workload(self, *, event): """MySQL Router workload""" - container = snap.Snap() - if connection_info := self.database_requires.get_connection_info(event=event): - return socket_workload.AuthenticatedSocketWorkload( - container_=container, + if connection_info := self._database_requires.get_connection_info(event=event): + return self._authenticated_workload_type( + container_=self._container, connection_info=connection_info, charm_=self, - host="", # TODO TLS: replace with IP address when enabling TCP ) - return socket_workload.SocketWorkload(container_=container) + return self._workload_type(container_=self._container) @staticmethod def _prioritize_statuses(statuses: list[ops.StatusBase]) -> ops.StatusBase: @@ -70,7 +83,7 @@ def _prioritize_statuses(statuses: list[ops.StatusBase]) -> ops.StatusBase: def _determine_app_status(self, *, event) -> ops.StatusBase: """Report app status.""" statuses = [] - for endpoint in (self.database_requires, self.database_provides): + for endpoint in (self._database_requires, self._database_provides): if status := endpoint.get_status(event): statuses.append(status) return self._prioritize_statuses(statuses) @@ -99,10 +112,9 @@ def wait_until_mysql_router_ready(self) -> None: self.unit.status = ops.WaitingStatus("MySQL Router starting") try: for attempt in tenacity.Retrying( + reraise=True, stop=tenacity.stop_after_delay(30), wait=tenacity.wait_fixed(5), - retry=tenacity.retry_if_exception_type(AssertionError), - reraise=True, ): with attempt: for port in (6446, 6447): @@ -126,39 +138,27 @@ def reconcile_database_relations(self, event=None) -> None: f"{self.unit.is_leader()=}, " f"{isinstance(workload_, workload.AuthenticatedWorkload)=}, " f"{workload_.container_ready=}, " - f"{self.database_requires.is_relation_breaking(event)=}, " + f"{self._database_requires.is_relation_breaking(event)=}" ) - if self.unit.is_leader() and self.database_requires.is_relation_breaking(event): - self.database_provides.delete_all_databags() + if self.unit.is_leader() and self._database_requires.is_relation_breaking(event): + self._database_provides.delete_all_databags() elif ( self.unit.is_leader() and isinstance(workload_, workload.AuthenticatedWorkload) and workload_.container_ready ): - self.database_provides.reconcile_users( + self._database_provides.reconcile_users( event=event, - router_read_write_endpoint=workload_.read_write_endpoint, - router_read_only_endpoint=workload_.read_only_endpoint, + router_read_write_endpoint=self._read_write_endpoint, + router_read_only_endpoint=self._read_only_endpoint, shell=workload_.shell, ) if isinstance(workload_, workload.AuthenticatedWorkload) and workload_.container_ready: - workload_.enable( - unit_name=self.unit.name, - tls=False, # TODO TLS - ) + workload_.enable(tls=self._tls_certificate_saved, unit_name=self.unit.name) elif workload_.container_ready: workload_.disable() self.set_status(event=event) - def _on_install(self, _) -> None: - snap.Installer().install(unit=self.unit) - workload_ = self.get_workload(event=None) - if workload_.container_ready: # check for VM instead? - self.unit.set_workload_version(workload_.version) - - def _on_remove(self, _) -> None: - snap.Installer().uninstall() - def _on_start(self, _) -> None: # Set status on first start if no relations active self.set_status(event=None) @@ -166,7 +166,3 @@ def _on_start(self, _) -> None: def _on_leader_elected(self, _) -> None: # Update app status self.set_status(event=None) - - -if __name__ == "__main__": - ops.main.main(MySQLRouterOperatorCharm) diff --git a/src/container.py b/src/container.py index 7ee5f922..b680eaef 100644 --- a/src/container.py +++ b/src/container.py @@ -8,18 +8,6 @@ import subprocess import typing -import ops - - -class Installer(abc.ABC): - @abc.abstractmethod - def install(self, *, unit: ops.Unit, model_name: str, app_name: str): - pass - - @abc.abstractmethod - def uninstall(self): - pass - class Path(pathlib.PurePosixPath, abc.ABC): """Workload container (snap or ROCK) filesystem path""" diff --git a/src/machine_charm.py b/src/machine_charm.py new file mode 100755 index 00000000..f70daa42 --- /dev/null +++ b/src/machine_charm.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. +# +# Learn more at: https://juju.is/docs/sdk + +"""MySQL Router machine charm""" + +import logging + +import ops + +import abstract_charm +import snap +import socket_workload + +logger = logging.getLogger(__name__) + + +class MachineRouterCharm(abstract_charm.MySQLRouterCharm): + def __init__(self, *args) -> None: + super().__init__(*args) + self._workload_type = socket_workload.SocketWorkload + self._authenticated_workload_type = socket_workload.AuthenticatedSocketWorkload + self.framework.observe(self.on.install, self._on_install) + self.framework.observe(self.on.remove, self._on_remove) + + @property + def _container(self) -> snap.Snap: + return snap.Snap() + + @property + def _read_write_endpoint(self) -> str: + return f'file://{self._container.path("/run/mysqlrouter/mysql.sock")}' + + @property + def _read_only_endpoint(self) -> str: + return f'file://{self._container.path("/run/mysqlrouter/mysqlro.sock")}' + + # ======================= + # Handlers + # ======================= + + def _on_install(self, _) -> None: + snap.Installer().install(unit=self.unit) + + @staticmethod + def _on_remove(_) -> None: + snap.Installer().uninstall() + + +if __name__ == "__main__": + ops.main.main(MachineRouterCharm) diff --git a/src/relations/database_provides.py b/src/relations/database_provides.py index 1dc6f0fe..3d9f4e8f 100644 --- a/src/relations/database_provides.py +++ b/src/relations/database_provides.py @@ -14,7 +14,7 @@ import status_exception if typing.TYPE_CHECKING: - import charm + import abstract_charm logger = logging.getLogger(__name__) @@ -145,7 +145,7 @@ class RelationEndpoint: NAME = "database" - def __init__(self, charm_: "charm.MySQLRouterOperatorCharm") -> None: + def __init__(self, charm_: "abstract_charm.MySQLRouterCharm") -> None: self._interface = data_interfaces.DatabaseProvides(charm_, relation_name=self.NAME) charm_.framework.observe( charm_.on[self.NAME].relation_joined, diff --git a/src/relations/database_requires.py b/src/relations/database_requires.py index 307d73ee..f26e014f 100644 --- a/src/relations/database_requires.py +++ b/src/relations/database_requires.py @@ -13,7 +13,7 @@ import status_exception if typing.TYPE_CHECKING: - import charm + import abstract_charm logger = logging.getLogger(__name__) @@ -67,7 +67,7 @@ class RelationEndpoint: NAME = "backend-database" - def __init__(self, charm_: "charm.MySQLRouterOperatorCharm") -> None: + def __init__(self, charm_: "abstract_charm.MySQLRouterCharm") -> None: self._interface = data_interfaces.DatabaseRequires( charm_, relation_name=self.NAME, diff --git a/src/snap.py b/src/snap.py index e4a1a939..915488d8 100644 --- a/src/snap.py +++ b/src/snap.py @@ -15,14 +15,14 @@ logger = logging.getLogger(__name__) -class Installer(container.Installer): +class Installer: _SNAP_REVISION = "57" @property def _snap(self) -> snap_lib.Snap: return snap_lib.SnapCache()[_SNAP_NAME] - def install(self, *, unit: ops.Unit, **_): + def install(self, *, unit: ops.Unit): if self._snap.present: logger.error(f"{_SNAP_NAME} snap already installed on machine. Installation aborted") raise Exception(f"Multiple {_SNAP_NAME} snap installs not supported on one machine") @@ -115,7 +115,7 @@ def mysql_router_service_enabled(self) -> bool: def update_mysql_router_service(self, *, enabled: bool, tls: bool = None) -> None: super().update_mysql_router_service(enabled=enabled, tls=tls) if tls: - raise NotImplementedError # TODO TLS + raise NotImplementedError # TODO VM TLS if enabled: self._snap.start([self._SERVICE_NAME], enable=True) else: diff --git a/src/socket_workload.py b/src/socket_workload.py index cde9b584..1d2bb680 100644 --- a/src/socket_workload.py +++ b/src/socket_workload.py @@ -10,14 +10,6 @@ class SocketWorkload(workload.Workload): class AuthenticatedSocketWorkload(workload.AuthenticatedWorkload, SocketWorkload): - @property - def read_write_endpoint(self) -> str: - return f'file://{self._container.path("/run/mysqlrouter/mysql.sock")}' - - @property - def read_only_endpoint(self) -> str: - return f'file://{self._container.path("/run/mysqlrouter/mysqlro.sock")}' - def _get_bootstrap_command(self, password: str): command = super()._get_bootstrap_command(password) command.extend( diff --git a/src/workload.py b/src/workload.py index aba0bf27..88892169 100644 --- a/src/workload.py +++ b/src/workload.py @@ -13,7 +13,7 @@ import mysql_shell if typing.TYPE_CHECKING: - import charm + import abstract_charm import relations.database_requires logger = logging.getLogger(__name__) @@ -100,24 +100,12 @@ def __init__( *, container_: container.Container, connection_info: "relations.database_requires.ConnectionInformation", - host: str, - charm_: "charm.MySQLRouterOperatorCharm", + charm_: "abstract_charm.MySQLRouterCharm", ) -> None: super().__init__(container_=container_) self._connection_info = connection_info - self._host = host self._charm = charm_ - @property - def read_write_endpoint(self) -> str: - """MySQL Router read-write endpoint""" - return f"{self._host}:6446" - - @property - def read_only_endpoint(self) -> str: - """MySQL Router read-only endpoint""" - return f"{self._host}:6447" - @property def shell(self) -> mysql_shell.Shell: """MySQL Shell""" @@ -138,15 +126,17 @@ def _router_id(self) -> str: # MySQL Router is bootstrapped without `--directory`—there is one system-wide instance. return f"{socket.getfqdn()}::system" - def cleanup_after_potential_container_restart(self, *, unit_name: str) -> None: + def _cleanup_after_potential_container_restart(self) -> None: """Remove MySQL Router cluster metadata & user after (potential) container restart. (Storage is not persisted on container restart—MySQL Router's config file is deleted. Therefore, MySQL Router needs to be bootstrapped again.) """ - if user_info := self.shell.get_mysql_router_user_for_unit(unit_name): + if user_info := self.shell.get_mysql_router_user_for_unit(self._charm.unit.name): + logger.debug("Cleaning up after container restart") self.shell.remove_router_from_cluster_metadata(user_info.router_id) self.shell.delete_user(user_info.username) + logger.debug("Cleaned up after container restart") def _get_bootstrap_command(self, password: str) -> list[str]: return [ @@ -206,6 +196,8 @@ def enable(self, *, tls: bool, unit_name: str) -> None: # Therefore, if the host or port changes, we do not need to restart MySQL Router. return logger.debug("Enabling MySQL Router service") + # TODO: VM? + self._cleanup_after_potential_container_restart() self._bootstrap_router(tls=tls) self.shell.add_attributes_to_mysql_router_user( username=self._router_username, router_id=self._router_id, unit_name=unit_name From bbd54970127efa34f1307095be9b64f2e3b3ed85 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Wed, 21 Jun 2023 18:56:16 +0000 Subject: [PATCH 32/57] update metadata.yaml, format --- metadata.yaml | 21 ++++++++++++--------- pyproject.toml | 4 +++- src/machine_charm.py | 2 ++ src/snap.py | 13 +++++++++++++ src/socket_workload.py | 9 ++++++++- 5 files changed, 38 insertions(+), 11 deletions(-) diff --git a/metadata.yaml b/metadata.yaml index 759ae30d..0bea6b2e 100644 --- a/metadata.yaml +++ b/metadata.yaml @@ -3,8 +3,9 @@ name: mysql-router display-name: MySQL Router maintainers: - - Shayan Patel + - Carl Csaposs - Paulo Machado + - Shayan Patel description: | Machine charmed operator for mysql-router. summary: | @@ -12,20 +13,22 @@ summary: | Enables effective access to a MySQL cluster with Group Replication. subordinate: true provides: - shared-db: - interface: mysql-shared - scope: container +# TODO: re-implement legacy relation +# shared-db: +# interface: mysql-shared +# scope: container database: interface: mysql_client scope: container requires: backend-database: interface: mysql_client + # Workaround: Subordinate charms are required to have at least one `requires` endpoint with + # `scope: container` juju-info: interface: juju-info scope: container -peers: - mysql-router-peers: - interface: mysql_router_peers -series: - - focal +# TODO TLS VM: re-enable peer relation +#peers: +# mysql-router-peers: +# interface: mysql_router_peers diff --git a/pyproject.toml b/pyproject.toml index ba93d4a8..3217885c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,9 @@ select = ["E", "W", "F", "C", "N", "R", "D", "H"] # Ignore D415 Docstring first line punctuation (doesn't make sense for properties) # Ignore D403 First word of the first line should be properly capitalized (false positive on "MySQL") # Ignore N818 Exception should be named with an Error suffix -ignore = ["W503", "E501", "D107", "D105", "D415", "D403", "N818"] +# Ignore D102 Missing docstring in public method (pydocstyle doesn't look for docstrings in super class +# https://github.com/PyCQA/pydocstyle/issues/309) TODO: add pylint check? https://github.com/PyCQA/pydocstyle/issues/309#issuecomment-1284142716 +ignore = ["W503", "E501", "D107", "D105", "D415", "D403", "N818", "D102"] # D100, D101, D102, D103: Ignore missing docstrings in tests per-file-ignores = ["tests/*:D100,D101,D102,D103,D104"] docstring-convention = "google" diff --git a/src/machine_charm.py b/src/machine_charm.py index f70daa42..105b37a3 100755 --- a/src/machine_charm.py +++ b/src/machine_charm.py @@ -18,6 +18,8 @@ class MachineRouterCharm(abstract_charm.MySQLRouterCharm): + """MySQL Router machine charm""" + def __init__(self, *args) -> None: super().__init__(*args) self._workload_type = socket_workload.SocketWorkload diff --git a/src/snap.py b/src/snap.py index 915488d8..34d96970 100644 --- a/src/snap.py +++ b/src/snap.py @@ -1,3 +1,8 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Workload snap container & installer""" + import logging import pathlib import shutil @@ -16,6 +21,8 @@ class Installer: + """Workload snap installer""" + _SNAP_REVISION = "57" @property @@ -23,6 +30,7 @@ def _snap(self) -> snap_lib.Snap: return snap_lib.SnapCache()[_SNAP_NAME] def install(self, *, unit: ops.Unit): + """Install snap.""" if self._snap.present: logger.error(f"{_SNAP_NAME} snap already installed on machine. Installation aborted") raise Exception(f"Multiple {_SNAP_NAME} snap installs not supported on one machine") @@ -49,7 +57,10 @@ def _set_retry_status(_) -> None: logger.debug(f"Installed {_SNAP_NAME=}, {self._SNAP_REVISION=}") def uninstall(self): + """Uninstall snap.""" + logger.debug(f"Uninstalling {_SNAP_NAME=}") self._snap.ensure(state=snap_lib.SnapState.Absent) + logger.debug(f"Uninstalled {_SNAP_NAME=}") class _Path(pathlib.PosixPath, container.Path): @@ -91,6 +102,8 @@ def rmtree(self): class Snap(container.Container): + """Workload snap container""" + _SNAP_REVISION = "57" _SERVICE_NAME = "mysqlrouter-service" diff --git a/src/socket_workload.py b/src/socket_workload.py index 1d2bb680..98f49938 100644 --- a/src/socket_workload.py +++ b/src/socket_workload.py @@ -1,3 +1,8 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +"""MySQl Router workload with Unix sockets enabled""" + import configparser import io import pathlib @@ -6,10 +11,12 @@ class SocketWorkload(workload.Workload): - pass + """MySQl Router workload with Unix sockets enabled""" class AuthenticatedSocketWorkload(workload.AuthenticatedWorkload, SocketWorkload): + """Workload with connection to MySQL cluster and with Unix sockets enabled""" + def _get_bootstrap_command(self, password: str): command = super()._get_bootstrap_command(password) command.extend( From d507f2c7ed9aa3953d8053daf76eb436c3c4055b Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Wed, 21 Jun 2023 18:59:01 +0000 Subject: [PATCH 33/57] remove pytest-order from integration tests --- tests/integration/test_database.py | 1 - tests/integration/test_shared_db.py | 1 - tests/integration/test_tls.py | 2 -- tox.ini | 3 --- 4 files changed, 7 deletions(-) diff --git a/tests/integration/test_database.py b/tests/integration/test_database.py index 758c321b..cf3f684f 100644 --- a/tests/integration/test_database.py +++ b/tests/integration/test_database.py @@ -24,7 +24,6 @@ SLOW_TIMEOUT = 15 * 60 -@pytest.mark.order(1) @pytest.mark.abort_on_fail async def test_database_relation(ops_test: OpsTest, mysql_router_charm_series: str) -> None: """Test the database relation.""" diff --git a/tests/integration/test_shared_db.py b/tests/integration/test_shared_db.py index 4a0798e8..08c6e24e 100644 --- a/tests/integration/test_shared_db.py +++ b/tests/integration/test_shared_db.py @@ -18,7 +18,6 @@ TIMEOUT = 15 * 60 -@pytest.mark.order(1) @pytest.mark.abort_on_fail async def test_shared_db(ops_test: OpsTest, mysql_router_charm_series: str): """Test the shared-db legacy relation.""" diff --git a/tests/integration/test_tls.py b/tests/integration/test_tls.py index 98cf443a..705e5d1f 100644 --- a/tests/integration/test_tls.py +++ b/tests/integration/test_tls.py @@ -18,7 +18,6 @@ @pytest.mark.abort_on_fail -@pytest.mark.order(1) async def test_build_deploy_and_relate(ops_test: OpsTest, mysql_router_charm_series: str) -> None: """Test encryption when backend database is using TLS.""" # Deploy TLS Certificates operator. @@ -61,7 +60,6 @@ async def test_build_deploy_and_relate(ops_test: OpsTest, mysql_router_charm_ser ops_test.model.wait_for_idle(TEST_APP_NAME, status="active", timeout=15 * 60) -@pytest.mark.order(2) async def test_connected_encryption(ops_test: OpsTest) -> None: """Test encryption when backend database is using TLS.""" test_app_unit = ops_test.model.applications[TEST_APP_NAME].units[0] diff --git a/tox.ini b/tox.ini index 7b390070..f3f862c8 100644 --- a/tox.ini +++ b/tox.ini @@ -79,7 +79,6 @@ deps = mysql-connector-python pytest pytest-operator - pytest-order -r {tox_root}/requirements.txt commands = pytest -v --tb native --log-cli-level=INFO -s {posargs} {[vars]tests_path}/integration/test_shared_db.py @@ -95,7 +94,6 @@ deps = mysql-connector-python pytest pytest-operator - pytest-order -r {tox_root}/requirements.txt commands = pytest -v --tb native --log-cli-level=INFO -s {posargs} {[vars]tests_path}/integration/test_database.py @@ -111,7 +109,6 @@ deps = mysql-connector-python pytest pytest-operator - pytest-order -r {tox_root}/requirements.txt commands = pytest -v --tb native --log-cli-level=INFO -s {posargs} {[vars]tests_path}/integration/test_tls.py From 46de3fc0e42e4f1982f5aa3d0b3daa8d74bc9ee7 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Wed, 21 Jun 2023 19:13:12 +0000 Subject: [PATCH 34/57] don't use staticmethod for _on_remove --- src/machine_charm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/machine_charm.py b/src/machine_charm.py index 105b37a3..7f08066b 100755 --- a/src/machine_charm.py +++ b/src/machine_charm.py @@ -46,8 +46,7 @@ def _read_only_endpoint(self) -> str: def _on_install(self, _) -> None: snap.Installer().install(unit=self.unit) - @staticmethod - def _on_remove(_) -> None: + def _on_remove(self, _) -> None: snap.Installer().uninstall() From b17ee864d6d4d718d8b62b5bf84f12c4b071a0f1 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 23 Jun 2023 11:02:16 +0000 Subject: [PATCH 35/57] remove unused snap revision var --- src/snap.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/snap.py b/src/snap.py index 34d96970..ee1774f6 100644 --- a/src/snap.py +++ b/src/snap.py @@ -104,7 +104,6 @@ def rmtree(self): class Snap(container.Container): """Workload snap container""" - _SNAP_REVISION = "57" _SERVICE_NAME = "mysqlrouter-service" def __init__(self) -> None: From fb4cfe2ceb4541262125869f728ba1ca2e59924e Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 23 Jun 2023 11:54:46 +0000 Subject: [PATCH 36/57] chown directories --- src/snap.py | 19 ++++++++++++++----- src/workload.py | 1 + 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/snap.py b/src/snap.py index ee1774f6..4980bfda 100644 --- a/src/snap.py +++ b/src/snap.py @@ -23,7 +23,7 @@ class Installer: """Workload snap installer""" - _SNAP_REVISION = "57" + _SNAP_REVISION = "61" @property def _snap(self) -> snap_lib.Snap: @@ -39,6 +39,7 @@ def install(self, *, unit: ops.Unit): def _set_retry_status(_) -> None: unit.status = ops.MaintenanceStatus("Snap install failed. Retrying...") + logger.debug("Snap install failed. Retrying...") try: for attempt in tenacity.Retrying( @@ -64,6 +65,8 @@ def uninstall(self): class _Path(pathlib.PosixPath, container.Path): + _UNIX_USERNAME = "snap_daemon" + def __new__(cls, *args, **kwargs): path = super().__new__(cls, *args, **kwargs) if str(path).startswith("/etc/mysqlrouter") or str(path).startswith( @@ -91,11 +94,17 @@ def relative_to_container(self) -> pathlib.PurePosixPath: return pathlib.PurePosixPath("/", self.relative_to(parent)) return self - def read_text(self, encoding="utf-8", *args) -> str: - return super().read_text(encoding, *args) + def read_text(self, encoding="utf-8", *args, **kwargs) -> str: + return super().read_text(encoding, *args, **kwargs) + + def write_text(self, data: str, encoding="utf-8", *args, **kwargs): + output = super().write_text(data, encoding, *args, **kwargs) + shutil.chown(self, user=self._UNIX_USERNAME, group=self._UNIX_USERNAME) + return output - def write_text(self, data: str, encoding="utf-8", *args): - return super().write_text(data, encoding, *args) + def mkdir(self, *args, **kwargs) -> None: + super().mkdir(*args, **kwargs) + shutil.chown(self, user=self._UNIX_USERNAME, group=self._UNIX_USERNAME) def rmtree(self): shutil.rmtree(self) diff --git a/src/workload.py b/src/workload.py index 88892169..957222e4 100644 --- a/src/workload.py +++ b/src/workload.py @@ -56,6 +56,7 @@ def disable(self) -> None: self._container.router_config_directory.rmtree() self._container.router_config_directory.mkdir() self._router_data_directory.rmtree() + self._router_data_directory.mkdir() logger.debug("Disabled MySQL Router service") @property From f7a339932b2ed23ad7718400f708ecc461ea00cb Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 23 Jun 2023 12:01:47 +0000 Subject: [PATCH 37/57] update docstring --- src/workload.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/workload.py b/src/workload.py index 957222e4..7bf92b3f 100644 --- a/src/workload.py +++ b/src/workload.py @@ -130,6 +130,8 @@ def _router_id(self) -> str: def _cleanup_after_potential_container_restart(self) -> None: """Remove MySQL Router cluster metadata & user after (potential) container restart. + Only applies to Kubernetes charm + (Storage is not persisted on container restart—MySQL Router's config file is deleted. Therefore, MySQL Router needs to be bootstrapped again.) """ @@ -197,7 +199,6 @@ def enable(self, *, tls: bool, unit_name: str) -> None: # Therefore, if the host or port changes, we do not need to restart MySQL Router. return logger.debug("Enabling MySQL Router service") - # TODO: VM? self._cleanup_after_potential_container_restart() self._bootstrap_router(tls=tls) self.shell.add_attributes_to_mysql_router_user( From 8f589e3b98a5f4eea38779002ebfa8a0f66e3a5b Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 23 Jun 2023 12:16:41 +0000 Subject: [PATCH 38/57] lint --- .github/workflows/ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8f380b42..1f34c20c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -81,7 +81,7 @@ jobs: bases-index: 1 name: ${{ matrix.tox-environments }} | ${{ matrix.ubuntu-versions.series }} needs: -# - lint + - lint # TODO: re-enable after adding unit tests # - unit-test - build From 2297e0982a3596c45c335ba08c3686c636733e08 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 23 Jun 2023 12:25:44 +0000 Subject: [PATCH 39/57] subordinate in class name --- src/machine_charm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/machine_charm.py b/src/machine_charm.py index 7f08066b..674227ca 100755 --- a/src/machine_charm.py +++ b/src/machine_charm.py @@ -17,8 +17,8 @@ logger = logging.getLogger(__name__) -class MachineRouterCharm(abstract_charm.MySQLRouterCharm): - """MySQL Router machine charm""" +class MachineSubordinateRouterCharm(abstract_charm.MySQLRouterCharm): + """MySQL Router machine subordinate charm""" def __init__(self, *args) -> None: super().__init__(*args) @@ -51,4 +51,4 @@ def _on_remove(self, _) -> None: if __name__ == "__main__": - ops.main.main(MachineRouterCharm) + ops.main.main(MachineSubordinateRouterCharm) From c8132ea46c1b24c6711b6ec84bbe9cb05f1a3236 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 23 Jun 2023 12:35:14 +0000 Subject: [PATCH 40/57] bind address localhost (socket) --- src/socket_workload.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/socket_workload.py b/src/socket_workload.py index 98f49938..d92c3db6 100644 --- a/src/socket_workload.py +++ b/src/socket_workload.py @@ -21,6 +21,8 @@ def _get_bootstrap_command(self, password: str): command = super()._get_bootstrap_command(password) command.extend( [ + "--conf-bind-address", + "127.0.0.1", "--conf-use-sockets", # For unix sockets, authentication fails on first connection if this option is not # set. Workaround for https://bugs.mysql.com/bug.php?id=107291 From e384b52ced0fb9b13eab122f8e554c23f1ef7ac4 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 23 Jun 2023 12:39:00 +0000 Subject: [PATCH 41/57] logger --- src/mysql_shell.py | 3 ++- src/snap.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/mysql_shell.py b/src/mysql_shell.py index 21545a33..8b7763e7 100644 --- a/src/mysql_shell.py +++ b/src/mysql_shell.py @@ -15,9 +15,10 @@ import container -_PASSWORD_LENGTH = 24 logger = logging.getLogger(__name__) +_PASSWORD_LENGTH = 24 + # TODO python3.10 min version: Add `(kw_only=True)` @dataclasses.dataclass diff --git a/src/snap.py b/src/snap.py index 4980bfda..2e37b8c8 100644 --- a/src/snap.py +++ b/src/snap.py @@ -15,10 +15,10 @@ import container -_SNAP_NAME = "charmed-mysql" - logger = logging.getLogger(__name__) +_SNAP_NAME = "charmed-mysql" + class Installer: """Workload snap installer""" From 15633f6443f1ffe95045ae036e8502c118f2dd36 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 23 Jun 2023 12:46:38 +0000 Subject: [PATCH 42/57] private constant --- src/relations/database_provides.py | 10 +++++----- src/relations/database_requires.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/relations/database_provides.py b/src/relations/database_provides.py index 3d9f4e8f..0bdd10be 100644 --- a/src/relations/database_provides.py +++ b/src/relations/database_provides.py @@ -143,12 +143,12 @@ def delete_user(self, *, shell: mysql_shell.Shell) -> None: class RelationEndpoint: """Relation endpoint for application charm(s)""" - NAME = "database" + _NAME = "database" def __init__(self, charm_: "abstract_charm.MySQLRouterCharm") -> None: - self._interface = data_interfaces.DatabaseProvides(charm_, relation_name=self.NAME) + self._interface = data_interfaces.DatabaseProvides(charm_, relation_name=self._NAME) charm_.framework.observe( - charm_.on[self.NAME].relation_joined, + charm_.on[self._NAME].relation_joined, charm_.reconcile_database_relations, ) charm_.framework.observe( @@ -156,7 +156,7 @@ def __init__(self, charm_: "abstract_charm.MySQLRouterCharm") -> None: charm_.reconcile_database_relations, ) charm_.framework.observe( - charm_.on[self.NAME].relation_broken, + charm_.on[self._NAME].relation_broken, charm_.reconcile_database_relations, ) @@ -257,4 +257,4 @@ def get_status(self, event) -> typing.Optional[ops.StatusBase]: for exception in exceptions: if isinstance(exception, remote_databag.IncompleteDatabag): return exception.status - return ops.BlockedStatus(f"Missing relation: {self.NAME}") + return ops.BlockedStatus(f"Missing relation: {self._NAME}") diff --git a/src/relations/database_requires.py b/src/relations/database_requires.py index f26e014f..6c702a56 100644 --- a/src/relations/database_requires.py +++ b/src/relations/database_requires.py @@ -65,18 +65,18 @@ def __init__(self, *, interface: data_interfaces.DatabaseRequires, event) -> Non class RelationEndpoint: """Relation endpoint for MySQL charm""" - NAME = "backend-database" + _NAME = "backend-database" def __init__(self, charm_: "abstract_charm.MySQLRouterCharm") -> None: self._interface = data_interfaces.DatabaseRequires( charm_, - relation_name=self.NAME, + relation_name=self._NAME, # Database name disregarded by MySQL charm if "mysqlrouter" extra user role requested database_name="mysql_innodb_cluster_metadata", extra_user_roles="mysqlrouter", ) charm_.framework.observe( - charm_.on[self.NAME].relation_created, + charm_.on[self._NAME].relation_created, charm_.reconcile_database_relations, ) charm_.framework.observe( @@ -88,7 +88,7 @@ def __init__(self, charm_: "abstract_charm.MySQLRouterCharm") -> None: charm_.reconcile_database_relations, ) charm_.framework.observe( - charm_.on[self.NAME].relation_broken, + charm_.on[self._NAME].relation_broken, charm_.reconcile_database_relations, ) From eefb5429baa7ece40de377f6386f1ebb665d1ddc Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 23 Jun 2023 12:51:11 +0000 Subject: [PATCH 43/57] rename var --- src/snap.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/snap.py b/src/snap.py index 2e37b8c8..22783601 100644 --- a/src/snap.py +++ b/src/snap.py @@ -98,9 +98,9 @@ def read_text(self, encoding="utf-8", *args, **kwargs) -> str: return super().read_text(encoding, *args, **kwargs) def write_text(self, data: str, encoding="utf-8", *args, **kwargs): - output = super().write_text(data, encoding, *args, **kwargs) + return_value = super().write_text(data, encoding, *args, **kwargs) shutil.chown(self, user=self._UNIX_USERNAME, group=self._UNIX_USERNAME) - return output + return return_value def mkdir(self, *args, **kwargs) -> None: super().mkdir(*args, **kwargs) From 61ad7e5aad9ce04510a65b001f53604b20840345 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 23 Jun 2023 12:53:32 +0000 Subject: [PATCH 44/57] type annotation --- src/socket_workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/socket_workload.py b/src/socket_workload.py index d92c3db6..1c8e7852 100644 --- a/src/socket_workload.py +++ b/src/socket_workload.py @@ -17,7 +17,7 @@ class SocketWorkload(workload.Workload): class AuthenticatedSocketWorkload(workload.AuthenticatedWorkload, SocketWorkload): """Workload with connection to MySQL cluster and with Unix sockets enabled""" - def _get_bootstrap_command(self, password: str): + def _get_bootstrap_command(self, password: str) -> list[str]: command = super()._get_bootstrap_command(password) command.extend( [ From 42f44290a84e105560808f2a9f48ce2f88f27e17 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 23 Jun 2023 12:57:33 +0000 Subject: [PATCH 45/57] set workload version --- src/machine_charm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/machine_charm.py b/src/machine_charm.py index 674227ca..42e5d603 100755 --- a/src/machine_charm.py +++ b/src/machine_charm.py @@ -45,6 +45,7 @@ def _read_only_endpoint(self) -> str: def _on_install(self, _) -> None: snap.Installer().install(unit=self.unit) + self.unit.set_workload_version(self.get_workload(event=None).version) def _on_remove(self, _) -> None: snap.Installer().uninstall() From a23bc2f6655b6529a7c743953222b136118ca51c Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 23 Jun 2023 13:02:29 +0000 Subject: [PATCH 46/57] use pathlib to read file --- src/workload.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/workload.py b/src/workload.py index 7bf92b3f..af0fa564 100644 --- a/src/workload.py +++ b/src/workload.py @@ -5,6 +5,7 @@ import configparser import logging +import pathlib import socket import string import typing @@ -65,8 +66,7 @@ def _tls_config_file_data(self) -> str: Config file enables TLS on MySQL Router. """ - with open("templates/tls.cnf", "r") as template_file: - template = string.Template(template_file.read()) + template = string.Template(pathlib.Path("templates/tls.cnf").read_text(encoding="utf-8")) config_string = template.substitute( tls_ssl_key_file=self._tls_key_file, tls_ssl_cert_file=self._tls_certificate_file, From 3c62bc16ff35a1a3c072d48268b48778c86f102b Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 23 Jun 2023 13:08:52 +0000 Subject: [PATCH 47/57] update docstring --- src/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/workload.py b/src/workload.py index af0fa564..45999fcc 100644 --- a/src/workload.py +++ b/src/workload.py @@ -158,7 +158,7 @@ def _get_bootstrap_command(self, password: str) -> list[str]: ] def _bootstrap_router(self, *, tls: bool) -> None: - """Bootstrap MySQL Router and enable service.""" + """Bootstrap MySQL Router.""" logger.debug( f"Bootstrapping router {tls=}, {self._connection_info.host=}, {self._connection_info.port=}" ) From 6d7c0df53fefc5885bc45b1c287d5f1f4b3aea58 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 23 Jun 2023 13:26:50 +0000 Subject: [PATCH 48/57] store string in var --- src/snap.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/snap.py b/src/snap.py index 22783601..e0b9c92d 100644 --- a/src/snap.py +++ b/src/snap.py @@ -38,8 +38,9 @@ def install(self, *, unit: ops.Unit): unit.status = ops.MaintenanceStatus("Installing snap") def _set_retry_status(_) -> None: - unit.status = ops.MaintenanceStatus("Snap install failed. Retrying...") - logger.debug("Snap install failed. Retrying...") + message = "Snap install failed. Retrying..." + unit.status = ops.MaintenanceStatus(message) + logger.debug(message) try: for attempt in tenacity.Retrying( From b9dab7d72363cdd0248bf334f80af3b21cb421fb Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 23 Jun 2023 13:31:57 +0000 Subject: [PATCH 49/57] add logging --- src/socket_workload.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/socket_workload.py b/src/socket_workload.py index 1c8e7852..516f26ff 100644 --- a/src/socket_workload.py +++ b/src/socket_workload.py @@ -5,10 +5,12 @@ import configparser import io +import logging import pathlib import workload +logger = logging.getLogger(__name__) class SocketWorkload(workload.Workload): """MySQl Router workload with Unix sockets enabled""" @@ -43,6 +45,7 @@ def _update_configured_socket_file_locations(self) -> None: Needed since `/tmp` inside a snap is not accessible to non-root users. The socket files must be accessible to applications related via database_provides endpoint. """ + logger.debug("Updating configured socket file locations") config = configparser.ConfigParser() config.read_string(self._container.router_config_file.read_text()) for section_name, section in config.items(): @@ -54,6 +57,7 @@ def _update_configured_socket_file_locations(self) -> None: with io.StringIO() as output: config.write(output) self._container.router_config_file.write_text(output.getvalue()) + logger.debug("Updated configured socket file locations") def _bootstrap_router(self, *, tls: bool) -> None: super()._bootstrap_router(tls=tls) From 65bb4454a8225f8c89321a9e7778702ad6647c0c Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 23 Jun 2023 13:51:14 +0000 Subject: [PATCH 50/57] Update snap revision --- src/snap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snap.py b/src/snap.py index e0b9c92d..c9217a73 100644 --- a/src/snap.py +++ b/src/snap.py @@ -23,7 +23,7 @@ class Installer: """Workload snap installer""" - _SNAP_REVISION = "61" + _SNAP_REVISION = "64" @property def _snap(self) -> snap_lib.Snap: From 6293698df4b0ffbce8538e6bd8b8f89f65020abf Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 23 Jun 2023 13:52:43 +0000 Subject: [PATCH 51/57] format --- src/socket_workload.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/socket_workload.py b/src/socket_workload.py index 516f26ff..4ef0ac78 100644 --- a/src/socket_workload.py +++ b/src/socket_workload.py @@ -12,6 +12,7 @@ logger = logging.getLogger(__name__) + class SocketWorkload(workload.Workload): """MySQl Router workload with Unix sockets enabled""" From 1dc44817599c6361487be6238ce567a75acdb0ce Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 23 Jun 2023 14:33:06 +0000 Subject: [PATCH 52/57] add todo comment --- src/machine_charm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/machine_charm.py b/src/machine_charm.py index 42e5d603..b51b7c94 100755 --- a/src/machine_charm.py +++ b/src/machine_charm.py @@ -15,6 +15,7 @@ import socket_workload logger = logging.getLogger(__name__) +# TODO VM TLS: open ports for `juju expose` class MachineSubordinateRouterCharm(abstract_charm.MySQLRouterCharm): From 3650fddf9a2e1f05d47eafe6c840cdf7acb4b1fd Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 23 Jun 2023 16:14:13 +0000 Subject: [PATCH 53/57] add missing limit to backend-database endpoint --- metadata.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/metadata.yaml b/metadata.yaml index 0bea6b2e..a8577650 100644 --- a/metadata.yaml +++ b/metadata.yaml @@ -23,6 +23,7 @@ provides: requires: backend-database: interface: mysql_client + limit: 1 # Workaround: Subordinate charms are required to have at least one `requires` endpoint with # `scope: container` juju-info: From eb1ec2c7c2d3ad175332143bf9f7a5b4407b3b5d Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Fri, 23 Jun 2023 16:15:03 +0000 Subject: [PATCH 54/57] remove installer class --- src/machine_charm.py | 4 +- src/snap.py | 88 +++++++++++++++++++------------------------- 2 files changed, 40 insertions(+), 52 deletions(-) diff --git a/src/machine_charm.py b/src/machine_charm.py index b51b7c94..ea607092 100755 --- a/src/machine_charm.py +++ b/src/machine_charm.py @@ -45,11 +45,11 @@ def _read_only_endpoint(self) -> str: # ======================= def _on_install(self, _) -> None: - snap.Installer().install(unit=self.unit) + snap.install(unit=self.unit) self.unit.set_workload_version(self.get_workload(event=None).version) def _on_remove(self, _) -> None: - snap.Installer().uninstall() + snap.uninstall() if __name__ == "__main__": diff --git a/src/snap.py b/src/snap.py index c9217a73..6d454793 100644 --- a/src/snap.py +++ b/src/snap.py @@ -18,51 +18,43 @@ logger = logging.getLogger(__name__) _SNAP_NAME = "charmed-mysql" +_REVISION = "64" +_snap = snap_lib.SnapCache()[_SNAP_NAME] + + +def install(*, unit: ops.Unit): + """Install snap.""" + if _snap.present: + logger.error(f"{_SNAP_NAME} snap already installed on machine. Installation aborted") + raise Exception(f"Multiple {_SNAP_NAME} snap installs not supported on one machine") + logger.debug(f"Installing {_SNAP_NAME=}, {_REVISION=}") + unit.status = ops.MaintenanceStatus("Installing snap") + + def _set_retry_status(_) -> None: + message = "Snap install failed. Retrying..." + unit.status = ops.MaintenanceStatus(message) + logger.debug(message) + + try: + for attempt in tenacity.Retrying( + stop=tenacity.stop_after_delay(60 * 5), + wait=tenacity.wait_exponential(multiplier=10), + retry=tenacity.retry_if_exception_type(snap_lib.SnapError), + after=_set_retry_status, + reraise=True, + ): + with attempt: + _snap.ensure(state=snap_lib.SnapState.Present, revision=_REVISION) + except snap_lib.SnapError: + raise + logger.debug(f"Installed {_SNAP_NAME=}, {_REVISION=}") -class Installer: - """Workload snap installer""" - - _SNAP_REVISION = "64" - - @property - def _snap(self) -> snap_lib.Snap: - return snap_lib.SnapCache()[_SNAP_NAME] - - def install(self, *, unit: ops.Unit): - """Install snap.""" - if self._snap.present: - logger.error(f"{_SNAP_NAME} snap already installed on machine. Installation aborted") - raise Exception(f"Multiple {_SNAP_NAME} snap installs not supported on one machine") - logger.debug(f"Installing {_SNAP_NAME=}, {self._SNAP_REVISION=}") - unit.status = ops.MaintenanceStatus("Installing snap") - - def _set_retry_status(_) -> None: - message = "Snap install failed. Retrying..." - unit.status = ops.MaintenanceStatus(message) - logger.debug(message) - - try: - for attempt in tenacity.Retrying( - stop=tenacity.stop_after_delay(60 * 5), - wait=tenacity.wait_exponential(multiplier=10), - retry=tenacity.retry_if_exception_type(snap_lib.SnapError), - after=_set_retry_status, - reraise=True, - ): - with attempt: - self._snap.ensure( - state=snap_lib.SnapState.Present, revision=self._SNAP_REVISION - ) - except snap_lib.SnapError: - raise - logger.debug(f"Installed {_SNAP_NAME=}, {self._SNAP_REVISION=}") - - def uninstall(self): - """Uninstall snap.""" - logger.debug(f"Uninstalling {_SNAP_NAME=}") - self._snap.ensure(state=snap_lib.SnapState.Absent) - logger.debug(f"Uninstalled {_SNAP_NAME=}") +def uninstall(): + """Uninstall snap.""" + logger.debug(f"Uninstalling {_SNAP_NAME=}") + _snap.ensure(state=snap_lib.SnapState.Absent) + logger.debug(f"Uninstalled {_SNAP_NAME=}") class _Path(pathlib.PosixPath, container.Path): @@ -126,22 +118,18 @@ def __init__(self) -> None: def ready(self) -> bool: return True - @property - def _snap(self) -> snap_lib.Snap: - return snap_lib.SnapCache()[_SNAP_NAME] - @property def mysql_router_service_enabled(self) -> bool: - return self._snap.services[self._SERVICE_NAME]["active"] + return _snap.services[self._SERVICE_NAME]["active"] def update_mysql_router_service(self, *, enabled: bool, tls: bool = None) -> None: super().update_mysql_router_service(enabled=enabled, tls=tls) if tls: raise NotImplementedError # TODO VM TLS if enabled: - self._snap.start([self._SERVICE_NAME], enable=True) + _snap.start([self._SERVICE_NAME], enable=True) else: - self._snap.stop([self._SERVICE_NAME], disable=True) + _snap.stop([self._SERVICE_NAME], disable=True) def _run_command(self, command: list[str], *, timeout: typing.Optional[int]) -> str: try: From 12a58a83878e3dde83f2b564fc016a541e2f91a1 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Mon, 26 Jun 2023 11:47:16 +0000 Subject: [PATCH 55/57] Fix type annotation for python 3.8 compatability --- src/abstract_charm.py | 4 +++- src/container.py | 14 ++++++++++---- src/mysql_shell.py | 6 ++++-- src/relations/database_provides.py | 6 ++++-- src/snap.py | 3 ++- src/socket_workload.py | 4 +++- src/workload.py | 3 ++- 7 files changed, 28 insertions(+), 12 deletions(-) diff --git a/src/abstract_charm.py b/src/abstract_charm.py index 5740ba5c..2d19d30c 100644 --- a/src/abstract_charm.py +++ b/src/abstract_charm.py @@ -6,6 +6,7 @@ import abc import logging import socket +import typing import ops import tenacity @@ -62,7 +63,8 @@ def get_workload(self, *, event): return self._workload_type(container_=self._container) @staticmethod - def _prioritize_statuses(statuses: list[ops.StatusBase]) -> ops.StatusBase: + # TODO python3.10 min version: Use `list` instead of `typing.List` + def _prioritize_statuses(statuses: typing.List[ops.StatusBase]) -> ops.StatusBase: """Report the highest priority status. (Statuses of the same type are reported in the order they were added to `statuses`) diff --git a/src/container.py b/src/container.py index b680eaef..44f10445 100644 --- a/src/container.py +++ b/src/container.py @@ -44,7 +44,10 @@ def rmtree(self): class CalledProcessError(subprocess.CalledProcessError): """Command returned non-zero exit code""" - def __init__(self, *, returncode: int, cmd: list[str], output: str, stderr: str) -> None: + # TODO python3.10 min version: Use `list` instead of `typing.List` + def __init__( + self, *, returncode: int, cmd: typing.List[str], output: str, stderr: str + ) -> None: super().__init__(returncode=returncode, cmd=cmd, output=output, stderr=stderr) @@ -98,14 +101,16 @@ def update_mysql_router_service(self, *, enabled: bool, tls: bool = None) -> Non assert tls is not None, "`tls` argument required when enabled=True" @abc.abstractmethod - def _run_command(self, command: list[str], *, timeout: typing.Optional[int]) -> str: + # TODO python3.10 min version: Use `list` instead of `typing.List` + def _run_command(self, command: typing.List[str], *, timeout: typing.Optional[int]) -> str: """Run command in container. Raises: CalledProcessError: Command returns non-zero exit code """ - def run_mysql_router(self, args: list[str], *, timeout: int = None) -> str: + # TODO python3.10 min version: Use `list` instead of `typing.List` + def run_mysql_router(self, args: typing.List[str], *, timeout: int = None) -> str: """Run MySQL Router command. Raises: @@ -114,7 +119,8 @@ def run_mysql_router(self, args: list[str], *, timeout: int = None) -> str: args.insert(0, self._mysql_router_command) return self._run_command(args, timeout=timeout) - def run_mysql_shell(self, args: list[str], *, timeout: int = None) -> str: + # TODO python3.10 min version: Use `list` instead of `typing.List` + def run_mysql_shell(self, args: typing.List[str], *, timeout: int = None) -> str: """Run MySQL Shell command. Raises: diff --git a/src/mysql_shell.py b/src/mysql_shell.py index 8b7763e7..3dc4cf46 100644 --- a/src/mysql_shell.py +++ b/src/mysql_shell.py @@ -40,7 +40,8 @@ class Shell: _host: str _port: str - def _run_commands(self, commands: list[str]) -> str: + # TODO python3.10 min version: Use `list` instead of `typing.List` + def _run_commands(self, commands: typing.List[str]) -> str: """Connect to MySQL cluster and run commands.""" # Redact password from log logged_commands = commands.copy() @@ -69,7 +70,8 @@ def _run_commands(self, commands: list[str]) -> str: temporary_script_file.unlink() return output - def _run_sql(self, sql_statements: list[str]) -> None: + # TODO python3.10 min version: Use `list` instead of `typing.List` + def _run_sql(self, sql_statements: typing.List[str]) -> None: """Connect to MySQL cluster and execute SQL.""" commands = [] for statement in sql_statements: diff --git a/src/relations/database_provides.py b/src/relations/database_provides.py index 0bdd10be..9c039cc9 100644 --- a/src/relations/database_provides.py +++ b/src/relations/database_provides.py @@ -161,7 +161,8 @@ def __init__(self, charm_: "abstract_charm.MySQLRouterCharm") -> None: ) @property - def _created_users(self) -> list[_RelationWithCreatedUser]: + # TODO python3.10 min version: Use `list` instead of `typing.List` + def _created_users(self) -> typing.List[_RelationWithCreatedUser]: created_users = [] for relation in self._interface.relations: try: @@ -235,7 +236,8 @@ def delete_all_databags(self) -> None: def get_status(self, event) -> typing.Optional[ops.StatusBase]: """Report non-active status.""" requested_users = [] - exceptions: list[status_exception.StatusException] = [] + # TODO python3.10 min version: Use `list` instead of `typing.List` + exceptions: typing.List[status_exception.StatusException] = [] for relation in self._interface.relations: try: requested_users.append( diff --git a/src/snap.py b/src/snap.py index 6d454793..22012773 100644 --- a/src/snap.py +++ b/src/snap.py @@ -131,7 +131,8 @@ def update_mysql_router_service(self, *, enabled: bool, tls: bool = None) -> Non else: _snap.stop([self._SERVICE_NAME], disable=True) - def _run_command(self, command: list[str], *, timeout: typing.Optional[int]) -> str: + # TODO python3.10 min version: Use `list` instead of `typing.List` + def _run_command(self, command: typing.List[str], *, timeout: typing.Optional[int]) -> str: try: output = subprocess.run( command, diff --git a/src/socket_workload.py b/src/socket_workload.py index 4ef0ac78..d09e3178 100644 --- a/src/socket_workload.py +++ b/src/socket_workload.py @@ -7,6 +7,7 @@ import io import logging import pathlib +import typing import workload @@ -20,7 +21,8 @@ class SocketWorkload(workload.Workload): class AuthenticatedSocketWorkload(workload.AuthenticatedWorkload, SocketWorkload): """Workload with connection to MySQL cluster and with Unix sockets enabled""" - def _get_bootstrap_command(self, password: str) -> list[str]: + # TODO python3.10 min version: Use `list` instead of `typing.List` + def _get_bootstrap_command(self, password: str) -> typing.List[str]: command = super()._get_bootstrap_command(password) command.extend( [ diff --git a/src/workload.py b/src/workload.py index 45999fcc..05eab117 100644 --- a/src/workload.py +++ b/src/workload.py @@ -141,7 +141,8 @@ def _cleanup_after_potential_container_restart(self) -> None: self.shell.delete_user(user_info.username) logger.debug("Cleaned up after container restart") - def _get_bootstrap_command(self, password: str) -> list[str]: + # TODO python3.10 min version: Use `list` instead of `typing.List` + def _get_bootstrap_command(self, password: str) -> typing.List[str]: return [ "--bootstrap", self._connection_info.username From 7855afa646dfcc2a49fee90b20d2a4925ead048f Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Mon, 26 Jun 2023 17:21:09 +0000 Subject: [PATCH 56/57] Remove SocketWorkload --- src/machine_charm.py | 1 - src/socket_workload.py | 4 ---- 2 files changed, 5 deletions(-) diff --git a/src/machine_charm.py b/src/machine_charm.py index ea607092..2adbf608 100755 --- a/src/machine_charm.py +++ b/src/machine_charm.py @@ -23,7 +23,6 @@ class MachineSubordinateRouterCharm(abstract_charm.MySQLRouterCharm): def __init__(self, *args) -> None: super().__init__(*args) - self._workload_type = socket_workload.SocketWorkload self._authenticated_workload_type = socket_workload.AuthenticatedSocketWorkload self.framework.observe(self.on.install, self._on_install) self.framework.observe(self.on.remove, self._on_remove) diff --git a/src/socket_workload.py b/src/socket_workload.py index d09e3178..2ba7544a 100644 --- a/src/socket_workload.py +++ b/src/socket_workload.py @@ -14,10 +14,6 @@ logger = logging.getLogger(__name__) -class SocketWorkload(workload.Workload): - """MySQl Router workload with Unix sockets enabled""" - - class AuthenticatedSocketWorkload(workload.AuthenticatedWorkload, SocketWorkload): """Workload with connection to MySQL cluster and with Unix sockets enabled""" From ea7a788373252949a2343ba595a2df308bd45ab5 Mon Sep 17 00:00:00 2001 From: Carl Csaposs Date: Mon, 26 Jun 2023 17:25:46 +0000 Subject: [PATCH 57/57] fixup --- src/socket_workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/socket_workload.py b/src/socket_workload.py index 2ba7544a..411fd8fe 100644 --- a/src/socket_workload.py +++ b/src/socket_workload.py @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) -class AuthenticatedSocketWorkload(workload.AuthenticatedWorkload, SocketWorkload): +class AuthenticatedSocketWorkload(workload.AuthenticatedWorkload): """Workload with connection to MySQL cluster and with Unix sockets enabled""" # TODO python3.10 min version: Use `list` instead of `typing.List`