diff --git a/lib/charms/data_platform_libs/v0/database_provides.py b/lib/charms/data_platform_libs/v0/database_provides.py index 8135da9d..fae004ab 100644 --- a/lib/charms/data_platform_libs/v0/database_provides.py +++ b/lib/charms/data_platform_libs/v0/database_provides.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Relation provider side abstraction for database relation. +"""[DEPRECATED] Relation provider side abstraction for database relation. This library is a uniform interface to a selection of common database metadata, with added custom events that add convenience to database management, @@ -80,7 +80,7 @@ def _on_database_requested(self, event: DatabaseRequestedEvent) -> None: # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 2 +LIBPATCH = 3 logger = logging.getLogger(__name__) diff --git a/lib/charms/data_platform_libs/v0/database_requires.py b/lib/charms/data_platform_libs/v0/database_requires.py index 53d61912..6f425e71 100644 --- a/lib/charms/data_platform_libs/v0/database_requires.py +++ b/lib/charms/data_platform_libs/v0/database_requires.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Relation 'requires' side abstraction for database relation. +"""[DEPRECATED] Relation 'requires' side abstraction for database relation. This library is a uniform interface to a selection of common database metadata, with added custom events that add convenience to database management, @@ -23,7 +23,10 @@ ```python -from charms.data_platform_libs.v0.database_requires import DatabaseRequires +from charms.data_platform_libs.v0.database_requires import ( + DatabaseCreatedEvent, + DatabaseRequires, +) class ApplicationCharm(CharmBase): # Application charm that connects to database charms. @@ -84,7 +87,10 @@ def _on_database_created(self, event: DatabaseCreatedEvent) -> None: ```python -from charms.data_platform_libs.v0.database_requires import DatabaseRequires +from charms.data_platform_libs.v0.database_requires import ( + DatabaseCreatedEvent, + DatabaseRequires, +) class ApplicationCharm(CharmBase): # Application charm that connects to database charms. @@ -154,7 +160,7 @@ def _on_cluster2_database_created(self, event: DatabaseCreatedEvent) -> None: # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version. -LIBPATCH = 4 +LIBPATCH = 5 logger = logging.getLogger(__name__) diff --git a/lib/charms/operator_libs_linux/v0/apt.py b/lib/charms/operator_libs_linux/v0/apt.py deleted file mode 100644 index 42090a74..00000000 --- a/lib/charms/operator_libs_linux/v0/apt.py +++ /dev/null @@ -1,1329 +0,0 @@ -# Copyright 2022 Canonical Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Abstractions for the system's Debian/Ubuntu package information and repositories. - -This module contains abstractions and wrappers around Debian/Ubuntu-style repositories and -packages, in order to easily provide an idiomatic and Pythonic mechanism for adding packages and/or -repositories to systems for use in machine charms. - -A sane default configuration is attainable through nothing more than instantiation of the -appropriate classes. `DebianPackage` objects provide information about the architecture, version, -name, and status of a package. - -`DebianPackage` will try to look up a package either from `dpkg -L` or from `apt-cache` when -provided with a string indicating the package name. If it cannot be located, `PackageNotFoundError` -will be returned, as `apt` and `dpkg` otherwise return `100` for all errors, and a meaningful error -message if the package is not known is desirable. - -To install packages with convenience methods: - -```python -try: - # Run `apt-get update` - apt.update() - apt.add_package("zsh") - apt.add_package(["vim", "htop", "wget"]) -except PackageNotFoundError: - logger.error("a specified package not found in package cache or on system") -except PackageError as e: - logger.error("could not install package. Reason: %s", e.message) -```` - -To find details of a specific package: - -```python -try: - vim = apt.DebianPackage.from_system("vim") - - # To find from the apt cache only - # apt.DebianPackage.from_apt_cache("vim") - - # To find from installed packages only - # apt.DebianPackage.from_installed_package("vim") - - vim.ensure(PackageState.Latest) - logger.info("updated vim to version: %s", vim.fullversion) -except PackageNotFoundError: - logger.error("a specified package not found in package cache or on system") -except PackageError as e: - logger.error("could not install package. Reason: %s", e.message) -``` - - -`RepositoryMapping` will return a dict-like object containing enabled system repositories -and their properties (available groups, baseuri. gpg key). This class can add, disable, or -manipulate repositories. Items can be retrieved as `DebianRepository` objects. - -In order add a new repository with explicit details for fields, a new `DebianRepository` can -be added to `RepositoryMapping` - -`RepositoryMapping` provides an abstraction around the existing repositories on the system, -and can be accessed and iterated over like any `Mapping` object, to retrieve values by key, -iterate, or perform other operations. - -Keys are constructed as `{repo_type}-{}-{release}` in order to uniquely identify a repository. - -Repositories can be added with explicit values through a Python constructor. - -Example: - -```python -repositories = apt.RepositoryMapping() - -if "deb-example.com-focal" not in repositories: - repositories.add(DebianRepository(enabled=True, repotype="deb", - uri="https://example.com", release="focal", groups=["universe"])) -``` - -Alternatively, any valid `sources.list` line may be used to construct a new -`DebianRepository`. - -Example: - -```python -repositories = apt.RepositoryMapping() - -if "deb-us.archive.ubuntu.com-xenial" not in repositories: - line = "deb http://us.archive.ubuntu.com/ubuntu xenial main restricted" - repo = DebianRepository.from_repo_line(line) - repositories.add(repo) -``` -""" - -import fileinput -import glob -import logging -import os -import re -import subprocess -from collections.abc import Mapping -from enum import Enum -from subprocess import PIPE, CalledProcessError, check_call, check_output -from typing import Iterable, List, Optional, Tuple, Union -from urllib.parse import urlparse - -logger = logging.getLogger(__name__) - -# The unique Charmhub library identifier, never change it -LIBID = "7c3dbc9c2ad44a47bd6fcb25caa270e5" - -# Increment this major API version when introducing breaking changes -LIBAPI = 0 - -# Increment this PATCH version before using `charmcraft publish-lib` or reset -# to 0 if you are raising the major API version -LIBPATCH = 7 - - -VALID_SOURCE_TYPES = ("deb", "deb-src") -OPTIONS_MATCHER = re.compile(r"\[.*?\]") - - -class Error(Exception): - """Base class of most errors raised by this library.""" - - def __repr__(self): - """String representation of Error.""" - return "<{}.{} {}>".format(type(self).__module__, type(self).__name__, self.args) - - @property - def name(self): - """Return a string representation of the model plus class.""" - return "<{}.{}>".format(type(self).__module__, type(self).__name__) - - @property - def message(self): - """Return the message passed as an argument.""" - return self.args[0] - - -class PackageError(Error): - """Raised when there's an error installing or removing a package.""" - - -class PackageNotFoundError(Error): - """Raised when a requested package is not known to the system.""" - - -class PackageState(Enum): - """A class to represent possible package states.""" - - Present = "present" - Absent = "absent" - Latest = "latest" - Available = "available" - - -class DebianPackage: - """Represents a traditional Debian package and its utility functions. - - `DebianPackage` wraps information and functionality around a known package, whether installed - or available. The version, epoch, name, and architecture can be easily queried and compared - against other `DebianPackage` objects to determine the latest version or to install a specific - version. - - The representation of this object as a string mimics the output from `dpkg` for familiarity. - - Installation and removal of packages is handled through the `state` property or `ensure` - method, with the following options: - - apt.PackageState.Absent - apt.PackageState.Available - apt.PackageState.Present - apt.PackageState.Latest - - When `DebianPackage` is initialized, the state of a given `DebianPackage` object will be set to - `Available`, `Present`, or `Latest`, with `Absent` implemented as a convenience for removal - (though it operates essentially the same as `Available`). - """ - - def __init__( - self, name: str, version: str, epoch: str, arch: str, state: PackageState - ) -> None: - self._name = name - self._arch = arch - self._state = state - self._version = Version(version, epoch) - - def __eq__(self, other) -> bool: - """Equality for comparison. - - Args: - other: a `DebianPackage` object for comparison - - Returns: - A boolean reflecting equality - """ - return isinstance(other, self.__class__) and ( - self._name, - self._version.number, - ) == (other._name, other._version.number) - - def __hash__(self): - """A basic hash so this class can be used in Mappings and dicts.""" - return hash((self._name, self._version.number)) - - def __repr__(self): - """A representation of the package.""" - return "<{}.{}: {}>".format(self.__module__, self.__class__.__name__, self.__dict__) - - def __str__(self): - """A human-readable representation of the package.""" - return "<{}: {}-{}.{} -- {}>".format( - self.__class__.__name__, - self._name, - self._version, - self._arch, - str(self._state), - ) - - @staticmethod - def _apt( - command: str, - package_names: Union[str, List], - optargs: Optional[List[str]] = None, - ) -> None: - """Wrap package management commands for Debian/Ubuntu systems. - - Args: - command: the command given to `apt-get` - package_names: a package name or list of package names to operate on - optargs: an (Optional) list of additioanl arguments - - Raises: - PackageError if an error is encountered - """ - optargs = optargs if optargs is not None else [] - if isinstance(package_names, str): - package_names = [package_names] - _cmd = ["apt-get", "-y", *optargs, command, *package_names] - try: - check_call(_cmd, stderr=PIPE, stdout=PIPE) - except CalledProcessError as e: - raise PackageError( - "Could not {} package(s) [{}]: {}".format(command, [*package_names], e.output) - ) from None - - def _add(self) -> None: - """Add a package to the system.""" - self._apt( - "install", - "{}={}".format(self.name, self.version), - optargs=["--option=Dpkg::Options::=--force-confold"], - ) - - def _remove(self) -> None: - """Removes a package from the system. Implementation-specific.""" - return self._apt("remove", "{}={}".format(self.name, self.version)) - - @property - def name(self) -> str: - """Returns the name of the package.""" - return self._name - - def ensure(self, state: PackageState): - """Ensures that a package is in a given state. - - Args: - state: a `PackageState` to reconcile the package to - - Raises: - PackageError from the underlying call to apt - """ - if self._state is not state: - if state not in (PackageState.Present, PackageState.Latest): - self._remove() - else: - self._add() - self._state = state - - @property - def present(self) -> bool: - """Returns whether or not a package is present.""" - return self._state in (PackageState.Present, PackageState.Latest) - - @property - def latest(self) -> bool: - """Returns whether the package is the most recent version.""" - return self._state is PackageState.Latest - - @property - def state(self) -> PackageState: - """Returns the current package state.""" - return self._state - - @state.setter - def state(self, state: PackageState) -> None: - """Sets the package state to a given value. - - Args: - state: a `PackageState` to reconcile the package to - - Raises: - PackageError from the underlying call to apt - """ - if state in (PackageState.Latest, PackageState.Present): - self._add() - else: - self._remove() - self._state = state - - @property - def version(self) -> "Version": - """Returns the version for a package.""" - return self._version - - @property - def epoch(self) -> str: - """Returns the epoch for a package. May be unset.""" - return self._version.epoch - - @property - def arch(self) -> str: - """Returns the architecture for a package.""" - return self._arch - - @property - def fullversion(self) -> str: - """Returns the name+epoch for a package.""" - return "{}.{}".format(self._version, self._arch) - - @staticmethod - def _get_epoch_from_version(version: str) -> Tuple[str, str]: - """Pull the epoch, if any, out of a version string.""" - epoch_matcher = re.compile(r"^((?P\d+):)?(?P.*)") - matches = epoch_matcher.search(version).groupdict() - return matches.get("epoch", ""), matches.get("version") - - @classmethod - def from_system( - cls, package: str, version: Optional[str] = "", arch: Optional[str] = "" - ) -> "DebianPackage": - """Locates a package, either on the system or known to apt, and serializes the information. - - Args: - package: a string representing the package - version: an optional string if a specific version isr equested - arch: an optional architecture, defaulting to `dpkg --print-architecture`. If an - architecture is not specified, this will be used for selection. - - """ - try: - return DebianPackage.from_installed_package(package, version, arch) - except PackageNotFoundError: - logger.debug( - "package '%s' is not currently installed or has the wrong architecture.", package - ) - - # Ok, try `apt-cache ...` - try: - return DebianPackage.from_apt_cache(package, version, arch) - except (PackageNotFoundError, PackageError): - # If we get here, it's not known to the systems. - # This seems unnecessary, but virtually all `apt` commands have a return code of `100`, - # and providing meaningful error messages without this is ugly. - raise PackageNotFoundError( - "Package '{}{}' could not be found on the system or in the apt cache!".format( - package, ".{}".format(arch) if arch else "" - ) - ) from None - - @classmethod - def from_installed_package( - cls, package: str, version: Optional[str] = "", arch: Optional[str] = "" - ) -> "DebianPackage": - """Check whether the package is already installed and return an instance. - - Args: - package: a string representing the package - version: an optional string if a specific version isr equested - arch: an optional architecture, defaulting to `dpkg --print-architecture`. - If an architecture is not specified, this will be used for selection. - """ - system_arch = check_output( - ["dpkg", "--print-architecture"], universal_newlines=True - ).strip() - arch = arch if arch else system_arch - - # Regexps are a really terrible way to do this. Thanks dpkg - output = "" - try: - output = check_output(["dpkg", "-l", package], stderr=PIPE, universal_newlines=True) - except CalledProcessError: - raise PackageNotFoundError("Package is not installed: {}".format(package)) from None - - # Pop off the output from `dpkg -l' because there's no flag to - # omit it` - lines = str(output).splitlines()[5:] - - dpkg_matcher = re.compile( - r""" - ^(?P\w+?)\s+ - (?P.*?)(?P:\w+?)?\s+ - (?P.*?)\s+ - (?P\w+?)\s+ - (?P.*) - """, - re.VERBOSE, - ) - - for line in lines: - try: - matches = dpkg_matcher.search(line).groupdict() - package_status = matches["package_status"] - - if not package_status.endswith("i"): - logger.debug( - "package '%s' in dpkg output but not installed, status: '%s'", - package, - package_status, - ) - break - - epoch, split_version = DebianPackage._get_epoch_from_version(matches["version"]) - pkg = DebianPackage( - matches["package_name"], - split_version, - epoch, - matches["arch"], - PackageState.Present, - ) - if (pkg.arch == "all" or pkg.arch == arch) and ( - version == "" or str(pkg.version) == version - ): - return pkg - except AttributeError: - logger.warning("dpkg matcher could not parse line: %s", line) - - # If we didn't find it, fail through - raise PackageNotFoundError("Package {}.{} is not installed!".format(package, arch)) - - @classmethod - def from_apt_cache( - cls, package: str, version: Optional[str] = "", arch: Optional[str] = "" - ) -> "DebianPackage": - """Check whether the package is already installed and return an instance. - - Args: - package: a string representing the package - version: an optional string if a specific version isr equested - arch: an optional architecture, defaulting to `dpkg --print-architecture`. - If an architecture is not specified, this will be used for selection. - """ - system_arch = check_output( - ["dpkg", "--print-architecture"], universal_newlines=True - ).strip() - arch = arch if arch else system_arch - - # Regexps are a really terrible way to do this. Thanks dpkg - keys = ("Package", "Architecture", "Version") - - try: - output = check_output( - ["apt-cache", "show", package], stderr=PIPE, universal_newlines=True - ) - except CalledProcessError as e: - raise PackageError( - "Could not list packages in apt-cache: {}".format(e.output) - ) from None - - pkg_groups = output.strip().split("\n\n") - keys = ("Package", "Architecture", "Version") - - for pkg_raw in pkg_groups: - lines = str(pkg_raw).splitlines() - vals = {} - for line in lines: - if line.startswith(keys): - items = line.split(":", 1) - vals[items[0]] = items[1].strip() - else: - continue - - epoch, split_version = DebianPackage._get_epoch_from_version(vals["Version"]) - pkg = DebianPackage( - vals["Package"], - split_version, - epoch, - vals["Architecture"], - PackageState.Available, - ) - - if (pkg.arch == "all" or pkg.arch == arch) and ( - version == "" or str(pkg.version) == version - ): - return pkg - - # If we didn't find it, fail through - raise PackageNotFoundError("Package {}.{} is not in the apt cache!".format(package, arch)) - - -class Version: - """An abstraction around package versions. - - This seems like it should be strictly unnecessary, except that `apt_pkg` is not usable inside a - venv, and wedging version comparisions into `DebianPackage` would overcomplicate it. - - This class implements the algorithm found here: - https://www.debian.org/doc/debian-policy/ch-controlfields.html#version - """ - - def __init__(self, version: str, epoch: str): - self._version = version - self._epoch = epoch or "" - - def __repr__(self): - """A representation of the package.""" - return "<{}.{}: {}>".format(self.__module__, self.__class__.__name__, self.__dict__) - - def __str__(self): - """A human-readable representation of the package.""" - return "{}{}".format("{}:".format(self._epoch) if self._epoch else "", self._version) - - @property - def epoch(self): - """Returns the epoch for a package. May be empty.""" - return self._epoch - - @property - def number(self) -> str: - """Returns the version number for a package.""" - return self._version - - def _get_parts(self, version: str) -> Tuple[str, str]: - """Separate the version into component upstream and Debian pieces.""" - try: - version.rindex("-") - except ValueError: - # No hyphens means no Debian version - return version, "0" - - upstream, debian = version.rsplit("-", 1) - return upstream, debian - - def _listify(self, revision: str) -> List[str]: - """Split a revision string into a listself. - - This list is comprised of alternating between strings and numbers, - padded on either end to always be "str, int, str, int..." and - always be of even length. This allows us to trivially implement the - comparison algorithm described. - """ - result = [] - while revision: - rev_1, remains = self._get_alphas(revision) - rev_2, remains = self._get_digits(remains) - result.extend([rev_1, rev_2]) - revision = remains - return result - - def _get_alphas(self, revision: str) -> Tuple[str, str]: - """Return a tuple of the first non-digit characters of a revision.""" - # get the index of the first digit - for i, char in enumerate(revision): - if char.isdigit(): - if i == 0: - return "", revision - return revision[0:i], revision[i:] - # string is entirely alphas - return revision, "" - - def _get_digits(self, revision: str) -> Tuple[int, str]: - """Return a tuple of the first integer characters of a revision.""" - # If the string is empty, return (0,'') - if not revision: - return 0, "" - # get the index of the first non-digit - for i, char in enumerate(revision): - if not char.isdigit(): - if i == 0: - return 0, revision - return int(revision[0:i]), revision[i:] - # string is entirely digits - return int(revision), "" - - def _dstringcmp(self, a, b): # noqa: C901 - """Debian package version string section lexical sort algorithm. - - The lexical comparison is a comparison of ASCII values modified so - that all the letters sort earlier than all the non-letters and so that - a tilde sorts before anything, even the end of a part. - """ - if a == b: - return 0 - try: - for i, char in enumerate(a): - if char == b[i]: - continue - # "a tilde sorts before anything, even the end of a part" - # (emptyness) - if char == "~": - return -1 - if b[i] == "~": - return 1 - # "all the letters sort earlier than all the non-letters" - if char.isalpha() and not b[i].isalpha(): - return -1 - if not char.isalpha() and b[i].isalpha(): - return 1 - # otherwise lexical sort - if ord(char) > ord(b[i]): - return 1 - if ord(char) < ord(b[i]): - return -1 - except IndexError: - # a is longer than b but otherwise equal, greater unless there are tildes - if char == "~": - return -1 - return 1 - # if we get here, a is shorter than b but otherwise equal, so check for tildes... - if b[len(a)] == "~": - return 1 - return -1 - - def _compare_revision_strings(self, first: str, second: str): # noqa: C901 - """Compare two debian revision strings.""" - if first == second: - return 0 - - # listify pads results so that we will always be comparing ints to ints - # and strings to strings (at least until we fall off the end of a list) - first_list = self._listify(first) - second_list = self._listify(second) - if first_list == second_list: - return 0 - try: - for i, item in enumerate(first_list): - # explicitly raise IndexError if we've fallen off the edge of list2 - if i >= len(second_list): - raise IndexError - # if the items are equal, next - if item == second_list[i]: - continue - # numeric comparison - if isinstance(item, int): - if item > second_list[i]: - return 1 - if item < second_list[i]: - return -1 - else: - # string comparison - return self._dstringcmp(item, second_list[i]) - except IndexError: - # rev1 is longer than rev2 but otherwise equal, hence greater - # ...except for goddamn tildes - if first_list[len(second_list)][0][0] == "~": - return 1 - return 1 - # rev1 is shorter than rev2 but otherwise equal, hence lesser - # ...except for goddamn tildes - if second_list[len(first_list)][0][0] == "~": - return -1 - return -1 - - def _compare_version(self, other) -> int: - if (self.number, self.epoch) == (other.number, other.epoch): - return 0 - - if self.epoch < other.epoch: - return -1 - if self.epoch > other.epoch: - return 1 - - # If none of these are true, follow the algorithm - upstream_version, debian_version = self._get_parts(self.number) - other_upstream_version, other_debian_version = self._get_parts(other.number) - - upstream_cmp = self._compare_revision_strings(upstream_version, other_upstream_version) - if upstream_cmp != 0: - return upstream_cmp - - debian_cmp = self._compare_revision_strings(debian_version, other_debian_version) - if debian_cmp != 0: - return debian_cmp - - return 0 - - def __lt__(self, other) -> bool: - """Less than magic method impl.""" - return self._compare_version(other) < 0 - - def __eq__(self, other) -> bool: - """Equality magic method impl.""" - return self._compare_version(other) == 0 - - def __gt__(self, other) -> bool: - """Greater than magic method impl.""" - return self._compare_version(other) > 0 - - def __le__(self, other) -> bool: - """Less than or equal to magic method impl.""" - return self.__eq__(other) or self.__lt__(other) - - def __ge__(self, other) -> bool: - """Greater than or equal to magic method impl.""" - return self.__gt__(other) or self.__eq__(other) - - def __ne__(self, other) -> bool: - """Not equal to magic method impl.""" - return not self.__eq__(other) - - -def add_package( - package_names: Union[str, List[str]], - version: Optional[str] = "", - arch: Optional[str] = "", - update_cache: Optional[bool] = False, -) -> Union[DebianPackage, List[DebianPackage]]: - """Add a package or list of packages to the system. - - Args: - name: the name(s) of the package(s) - version: an (Optional) version as a string. Defaults to the latest known - arch: an optional architecture for the package - update_cache: whether or not to run `apt-get update` prior to operating - - Raises: - PackageNotFoundError if the package is not in the cache. - """ - cache_refreshed = False - if update_cache: - update() - cache_refreshed = True - - packages = {"success": [], "retry": [], "failed": []} - - package_names = [package_names] if type(package_names) is str else package_names - if not package_names: - raise TypeError("Expected at least one package name to add, received zero!") - - if len(package_names) != 1 and version: - raise TypeError( - "Explicit version should not be set if more than one package is being added!" - ) - - for p in package_names: - pkg, success = _add(p, version, arch) - if success: - packages["success"].append(pkg) - else: - logger.warning("failed to locate and install/update '%s'", pkg) - packages["retry"].append(p) - - if packages["retry"] and not cache_refreshed: - logger.info("updating the apt-cache and retrying installation of failed packages.") - update() - - for p in packages["retry"]: - pkg, success = _add(p, version, arch) - if success: - packages["success"].append(pkg) - else: - packages["failed"].append(p) - - if packages["failed"]: - raise PackageError("Failed to install packages: {}".format(", ".join(packages["failed"]))) - - return packages["success"] if len(packages["success"]) > 1 else packages["success"][0] - - -def _add( - name: str, - version: Optional[str] = "", - arch: Optional[str] = "", -) -> Tuple[Union[DebianPackage, str], bool]: - """Adds a package. - - Args: - name: the name(s) of the package(s) - version: an (Optional) version as a string. Defaults to the latest known - arch: an optional architecture for the package - - Returns: a tuple of `DebianPackage` if found, or a :str: if it is not, and - a boolean indicating success - """ - try: - pkg = DebianPackage.from_system(name, version, arch) - pkg.ensure(state=PackageState.Present) - return pkg, True - except PackageNotFoundError: - return name, False - - -def remove_package( - package_names: Union[str, List[str]] -) -> Union[DebianPackage, List[DebianPackage]]: - """Removes a package from the system. - - Args: - package_names: the name of a package - - Raises: - PackageNotFoundError if the package is not found. - """ - packages = [] - - package_names = [package_names] if type(package_names) is str else package_names - if not package_names: - raise TypeError("Expected at least one package name to add, received zero!") - - for p in package_names: - try: - pkg = DebianPackage.from_installed_package(p) - pkg.ensure(state=PackageState.Absent) - packages.append(pkg) - except PackageNotFoundError: - logger.info("package '%s' was requested for removal, but it was not installed.", p) - - # the list of packages will be empty when no package is removed - logger.debug("packages: '%s'", packages) - return packages[0] if len(packages) == 1 else packages - - -def update() -> None: - """Updates the apt cache via `apt-get update`.""" - check_call(["apt-get", "update"], stderr=PIPE, stdout=PIPE) - - -class InvalidSourceError(Error): - """Exceptions for invalid source entries.""" - - -class GPGKeyError(Error): - """Exceptions for GPG keys.""" - - -class DebianRepository: - """An abstraction to represent a repository.""" - - def __init__( - self, - enabled: bool, - repotype: str, - uri: str, - release: str, - groups: List[str], - filename: Optional[str] = "", - gpg_key_filename: Optional[str] = "", - options: Optional[dict] = None, - ): - self._enabled = enabled - self._repotype = repotype - self._uri = uri - self._release = release - self._groups = groups - self._filename = filename - self._gpg_key_filename = gpg_key_filename - self._options = options - - @property - def enabled(self): - """Return whether or not the repository is enabled.""" - return self._enabled - - @property - def repotype(self): - """Return whether it is binary or source.""" - return self._repotype - - @property - def uri(self): - """Return the URI.""" - return self._uri - - @property - def release(self): - """Return which Debian/Ubuntu releases it is valid for.""" - return self._release - - @property - def groups(self): - """Return the enabled package groups.""" - return self._groups - - @property - def filename(self): - """Returns the filename for a repository.""" - return self._filename - - @filename.setter - def filename(self, fname: str) -> None: - """Sets the filename used when a repo is written back to diskself. - - Args: - fname: a filename to write the repository information to. - """ - if not fname.endswith(".list"): - raise InvalidSourceError("apt source filenames should end in .list!") - - self._filename = fname - - @property - def gpg_key(self): - """Returns the path to the GPG key for this repository.""" - return self._gpg_key_filename - - @property - def options(self): - """Returns any additional repo options which are set.""" - return self._options - - def make_options_string(self) -> str: - """Generate the complete options string for a a repository. - - Combining `gpg_key`, if set, and the rest of the options to find - a complex repo string. - """ - options = self._options if self._options else {} - if self._gpg_key_filename: - options["signed-by"] = self._gpg_key_filename - - return ( - "[{}] ".format(" ".join(["{}={}".format(k, v) for k, v in options.items()])) - if options - else "" - ) - - @staticmethod - def prefix_from_uri(uri: str) -> str: - """Get a repo list prefix from the uri, depending on whether a path is set.""" - uridetails = urlparse(uri) - path = ( - uridetails.path.lstrip("/").replace("/", "-") if uridetails.path else uridetails.netloc - ) - return "/etc/apt/sources.list.d/{}".format(path) - - @staticmethod - def from_repo_line(repo_line: str, write_file: Optional[bool] = True) -> "DebianRepository": - """Instantiate a new `DebianRepository` a `sources.list` entry line. - - Args: - repo_line: a string representing a repository entry - write_file: boolean to enable writing the new repo to disk - """ - repo = RepositoryMapping._parse(repo_line, "UserInput") - fname = "{}-{}.list".format( - DebianRepository.prefix_from_uri(repo.uri), repo.release.replace("/", "-") - ) - repo.filename = fname - - options = repo.options if repo.options else {} - if repo.gpg_key: - options["signed-by"] = repo.gpg_key - - # For Python 3.5 it's required to use sorted in the options dict in order to not have - # different results in the order of the options between executions. - options_str = ( - "[{}] ".format(" ".join(["{}={}".format(k, v) for k, v in sorted(options.items())])) - if options - else "" - ) - - if write_file: - with open(fname, "wb") as f: - f.write( - ( - "{}".format("#" if not repo.enabled else "") - + "{} {}{} ".format(repo.repotype, options_str, repo.uri) - + "{} {}\n".format(repo.release, " ".join(repo.groups)) - ).encode("utf-8") - ) - - return repo - - def disable(self) -> None: - """Remove this repository from consideration. - - Disable it instead of removing from the repository file. - """ - searcher = "{} {}{} {}".format( - self.repotype, self.make_options_string(), self.uri, self.release - ) - for line in fileinput.input(self._filename, inplace=True): - if re.match(r"^{}\s".format(re.escape(searcher)), line): - print("# {}".format(line), end="") - else: - print(line, end="") - - def import_key(self, key: str) -> None: - """Import an ASCII Armor key. - - A Radix64 format keyid is also supported for backwards - compatibility. In this case Ubuntu keyserver will be - queried for a key via HTTPS by its keyid. This method - is less preferrable because https proxy servers may - require traffic decryption which is equivalent to a - man-in-the-middle attack (a proxy server impersonates - keyserver TLS certificates and has to be explicitly - trusted by the system). - - Args: - key: A GPG key in ASCII armor format, - including BEGIN and END markers or a keyid. - - Raises: - GPGKeyError if the key could not be imported - """ - key = key.strip() - if "-" in key or "\n" in key: - # Send everything not obviously a keyid to GPG to import, as - # we trust its validation better than our own. eg. handling - # comments before the key. - logger.debug("PGP key found (looks like ASCII Armor format)") - if ( - "-----BEGIN PGP PUBLIC KEY BLOCK-----" in key - and "-----END PGP PUBLIC KEY BLOCK-----" in key - ): - logger.debug("Writing provided PGP key in the binary format") - key_bytes = key.encode("utf-8") - key_name = self._get_keyid_by_gpg_key(key_bytes) - key_gpg = self._dearmor_gpg_key(key_bytes) - self._gpg_key_filename = "/etc/apt/trusted.gpg.d/{}.gpg".format(key_name) - self._write_apt_gpg_keyfile(key_name=self._gpg_key_filename, key_material=key_gpg) - else: - raise GPGKeyError("ASCII armor markers missing from GPG key") - else: - logger.warning( - "PGP key found (looks like Radix64 format). " - "SECURELY importing PGP key from keyserver; " - "full key not provided." - ) - # as of bionic add-apt-repository uses curl with an HTTPS keyserver URL - # to retrieve GPG keys. `apt-key adv` command is deprecated as is - # apt-key in general as noted in its manpage. See lp:1433761 for more - # history. Instead, /etc/apt/trusted.gpg.d is used directly to drop - # gpg - key_asc = self._get_key_by_keyid(key) - # write the key in GPG format so that apt-key list shows it - key_gpg = self._dearmor_gpg_key(key_asc.encode("utf-8")) - self._gpg_key_filename = "/etc/apt/trusted.gpg.d/{}.gpg".format(key) - self._write_apt_gpg_keyfile(key_name=key, key_material=key_gpg) - - @staticmethod - def _get_keyid_by_gpg_key(key_material: bytes) -> str: - """Get a GPG key fingerprint by GPG key material. - - Gets a GPG key fingerprint (40-digit, 160-bit) by the ASCII armor-encoded - or binary GPG key material. Can be used, for example, to generate file - names for keys passed via charm options. - """ - # Use the same gpg command for both Xenial and Bionic - cmd = ["gpg", "--with-colons", "--with-fingerprint"] - ps = subprocess.run( - cmd, - stdout=PIPE, - stderr=PIPE, - input=key_material, - ) - out, err = ps.stdout.decode(), ps.stderr.decode() - if "gpg: no valid OpenPGP data found." in err: - raise GPGKeyError("Invalid GPG key material provided") - # from gnupg2 docs: fpr :: Fingerprint (fingerprint is in field 10) - return re.search(r"^fpr:{9}([0-9A-F]{40}):$", out, re.MULTILINE).group(1) - - @staticmethod - def _get_key_by_keyid(keyid: str) -> str: - """Get a key via HTTPS from the Ubuntu keyserver. - - Different key ID formats are supported by SKS keyservers (the longer ones - are more secure, see "dead beef attack" and https://evil32.com/). Since - HTTPS is used, if SSLBump-like HTTPS proxies are in place, they will - impersonate keyserver.ubuntu.com and generate a certificate with - keyserver.ubuntu.com in the CN field or in SubjAltName fields of a - certificate. If such proxy behavior is expected it is necessary to add the - CA certificate chain containing the intermediate CA of the SSLBump proxy to - every machine that this code runs on via ca-certs cloud-init directive (via - cloudinit-userdata model-config) or via other means (such as through a - custom charm option). Also note that DNS resolution for the hostname in a - URL is done at a proxy server - not at the client side. - 8-digit (32 bit) key ID - https://keyserver.ubuntu.com/pks/lookup?search=0x4652B4E6 - 16-digit (64 bit) key ID - https://keyserver.ubuntu.com/pks/lookup?search=0x6E85A86E4652B4E6 - 40-digit key ID: - https://keyserver.ubuntu.com/pks/lookup?search=0x35F77D63B5CEC106C577ED856E85A86E4652B4E6 - - Args: - keyid: An 8, 16 or 40 hex digit keyid to find a key for - - Returns: - A string contining key material for the specified GPG key id - - - Raises: - subprocess.CalledProcessError - """ - # options=mr - machine-readable output (disables html wrappers) - keyserver_url = ( - "https://keyserver.ubuntu.com" "/pks/lookup?op=get&options=mr&exact=on&search=0x{}" - ) - curl_cmd = ["curl", keyserver_url.format(keyid)] - # use proxy server settings in order to retrieve the key - return check_output(curl_cmd).decode() - - @staticmethod - def _dearmor_gpg_key(key_asc: bytes) -> bytes: - """Converts a GPG key in the ASCII armor format to the binary format. - - Args: - key_asc: A GPG key in ASCII armor format. - - Returns: - A GPG key in binary format as a string - - Raises: - GPGKeyError - """ - ps = subprocess.run(["gpg", "--dearmor"], stdout=PIPE, stderr=PIPE, input=key_asc) - out, err = ps.stdout, ps.stderr.decode() - if "gpg: no valid OpenPGP data found." in err: - raise GPGKeyError( - "Invalid GPG key material. Check your network setup" - " (MTU, routing, DNS) and/or proxy server settings" - " as well as destination keyserver status." - ) - else: - return out - - @staticmethod - def _write_apt_gpg_keyfile(key_name: str, key_material: bytes) -> None: - """Writes GPG key material into a file at a provided path. - - Args: - key_name: A key name to use for a key file (could be a fingerprint) - key_material: A GPG key material (binary) - """ - with open(key_name, "wb") as keyf: - keyf.write(key_material) - - -class RepositoryMapping(Mapping): - """An representation of known repositories. - - Instantiation of `RepositoryMapping` will iterate through the - filesystem, parse out repository files in `/etc/apt/...`, and create - `DebianRepository` objects in this list. - - Typical usage: - - repositories = apt.RepositoryMapping() - repositories.add(DebianRepository( - enabled=True, repotype="deb", uri="https://example.com", release="focal", - groups=["universe"] - )) - """ - - def __init__(self): - self._repository_map = {} - # Repositories that we're adding -- used to implement mode param - self.default_file = "/etc/apt/sources.list" - - # read sources.list if it exists - if os.path.isfile(self.default_file): - self.load(self.default_file) - - # read sources.list.d - for file in glob.iglob("/etc/apt/sources.list.d/*.list"): - self.load(file) - - def __contains__(self, key: str) -> bool: - """Magic method for checking presence of repo in mapping.""" - return key in self._repository_map - - def __len__(self) -> int: - """Return number of repositories in map.""" - return len(self._repository_map) - - def __iter__(self) -> Iterable[DebianRepository]: - """Iterator magic method for RepositoryMapping.""" - return iter(self._repository_map.values()) - - def __getitem__(self, repository_uri: str) -> DebianRepository: - """Return a given `DebianRepository`.""" - return self._repository_map[repository_uri] - - def __setitem__(self, repository_uri: str, repository: DebianRepository) -> None: - """Add a `DebianRepository` to the cache.""" - self._repository_map[repository_uri] = repository - - def load(self, filename: str): - """Load a repository source file into the cache. - - Args: - filename: the path to the repository file - """ - parsed = [] - skipped = [] - with open(filename, "r") as f: - for n, line in enumerate(f): - try: - repo = self._parse(line, filename) - except InvalidSourceError: - skipped.append(n) - else: - repo_identifier = "{}-{}-{}".format(repo.repotype, repo.uri, repo.release) - self._repository_map[repo_identifier] = repo - parsed.append(n) - logger.debug("parsed repo: '%s'", repo_identifier) - - if skipped: - skip_list = ", ".join(str(s) for s in skipped) - logger.debug("skipped the following lines in file '%s': %s", filename, skip_list) - - if parsed: - logger.info("parsed %d apt package repositories", len(parsed)) - else: - raise InvalidSourceError("all repository lines in '{}' were invalid!".format(filename)) - - @staticmethod - def _parse(line: str, filename: str) -> DebianRepository: - """Parse a line in a sources.list file. - - Args: - line: a single line from `load` to parse - filename: the filename being read - - Raises: - InvalidSourceError if the source type is unknown - """ - enabled = True - repotype = uri = release = gpg_key = "" - options = {} - groups = [] - - line = line.strip() - if line.startswith("#"): - enabled = False - line = line[1:] - - # Check for "#" in the line and treat a part after it as a comment then strip it off. - i = line.find("#") - if i > 0: - line = line[:i] - - # Split a source into substrings to initialize a new repo. - source = line.strip() - if source: - # Match any repo options, and get a dict representation. - for v in re.findall(OPTIONS_MATCHER, source): - opts = dict(o.split("=") for o in v.strip("[]").split()) - # Extract the 'signed-by' option for the gpg_key - gpg_key = opts.pop("signed-by", "") - options = opts - - # Remove any options from the source string and split the string into chunks - source = re.sub(OPTIONS_MATCHER, "", source) - chunks = source.split() - - # Check we've got a valid list of chunks - if len(chunks) < 3 or chunks[0] not in VALID_SOURCE_TYPES: - raise InvalidSourceError("An invalid sources line was found in %s!", filename) - - repotype = chunks[0] - uri = chunks[1] - release = chunks[2] - groups = chunks[3:] - - return DebianRepository( - enabled, repotype, uri, release, groups, filename, gpg_key, options - ) - else: - raise InvalidSourceError("An invalid sources line was found in %s!", filename) - - def add(self, repo: DebianRepository, default_filename: Optional[bool] = False) -> None: - """Add a new repository to the system. - - Args: - repo: a `DebianRepository` object - default_filename: an (Optional) filename if the default is not desirable - """ - new_filename = "{}-{}.list".format( - DebianRepository.prefix_from_uri(repo.uri), repo.release.replace("/", "-") - ) - - fname = repo.filename or new_filename - - options = repo.options if repo.options else {} - if repo.gpg_key: - options["signed-by"] = repo.gpg_key - - with open(fname, "wb") as f: - f.write( - ( - "{}".format("#" if not repo.enabled else "") - + "{} {}{} ".format(repo.repotype, repo.make_options_string(), repo.uri) - + "{} {}\n".format(repo.release, " ".join(repo.groups)) - ).encode("utf-8") - ) - - self._repository_map["{}-{}-{}".format(repo.repotype, repo.uri, repo.release)] = repo - - def disable(self, repo: DebianRepository) -> None: - """Remove a repository. Disable by default. - - Args: - repo: a `DebianRepository` to disable - """ - searcher = "{} {}{} {}".format( - repo.repotype, repo.make_options_string(), repo.uri, repo.release - ) - - for line in fileinput.input(repo.filename, inplace=True): - if re.match(r"^{}\s".format(re.escape(searcher)), line): - print("# {}".format(line), end="") - else: - print(line, end="") - - self._repository_map["{}-{}-{}".format(repo.repotype, repo.uri, repo.release)] = repo diff --git a/lib/charms/operator_libs_linux/v0/passwd.py b/lib/charms/operator_libs_linux/v0/passwd.py deleted file mode 100644 index b692e700..00000000 --- a/lib/charms/operator_libs_linux/v0/passwd.py +++ /dev/null @@ -1,255 +0,0 @@ -# Copyright 2021 Canonical Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Simple library for managing Linux users and groups. - -The `passwd` module provides convenience methods and abstractions around users and groups on a -Linux system, in order to make adding and managing users and groups easy. - -Example of adding a user named 'test': - -```python -import passwd -passwd.add_group(name='special_group') -passwd.add_user(username='test', secondary_groups=['sudo']) - -if passwd.user_exists('some_user'): - do_stuff() -``` -""" - -import grp -import logging -import pwd -from subprocess import STDOUT, check_output -from typing import List, Optional, Union - -logger = logging.getLogger(__name__) - -# The unique Charmhub library identifier, never change it -LIBID = "cf7655b2bf914d67ac963f72b930f6bb" - -# Increment this major API version when introducing breaking changes -LIBAPI = 0 - -# Increment this PATCH version before using `charmcraft publish-lib` or reset -# to 0 if you are raising the major API version -LIBPATCH = 3 - - -def user_exists(user: Union[str, int]) -> Optional[pwd.struct_passwd]: - """Check if a user exists. - - Args: - user: username or gid of user whose existence to check - - Raises: - TypeError: where neither a string or int is passed as the first argument - """ - try: - if type(user) is int: - return pwd.getpwuid(user) - elif type(user) is str: - return pwd.getpwnam(user) - else: - raise TypeError("specified argument '%r' should be a string or int", user) - except KeyError: - logger.info("specified user '%s' doesn't exist", str(user)) - return None - - -def group_exists(group: Union[str, int]) -> Optional[grp.struct_group]: - """Check if a group exists. - - Args: - group: username or gid of user whose existence to check - - Raises: - TypeError: where neither a string or int is passed as the first argument - """ - try: - if type(group) is int: - return grp.getgrgid(group) - elif type(group) is str: - return grp.getgrnam(group) - else: - raise TypeError("specified argument '%r' should be a string or int", group) - except KeyError: - logger.info("specified group '%s' doesn't exist", str(group)) - return None - - -def add_user( - username: str, - password: Optional[str] = None, - shell: str = "/bin/bash", - system_user: bool = False, - primary_group: str = None, - secondary_groups: List[str] = None, - uid: int = None, - home_dir: str = None, -) -> str: - """Add a user to the system. - - Will log but otherwise succeed if the user already exists. - - Arguments: - username: Username to create - password: Password for user; if ``None``, create a system user - shell: The default shell for the user - system_user: Whether to create a login or system user - primary_group: Primary group for user; defaults to username - secondary_groups: Optional list of additional groups - uid: UID for user being created - home_dir: Home directory for user - - Returns: - The password database entry struct, as returned by `pwd.getpwnam` - """ - try: - if uid: - user_info = pwd.getpwuid(int(uid)) - logger.info("user '%d' already exists", uid) - return user_info - user_info = pwd.getpwnam(username) - logger.info("user with uid '%s' already exists", username) - return user_info - except KeyError: - logger.info("creating user '%s'", username) - - cmd = ["useradd", "--shell", shell] - - if uid: - cmd.extend(["--uid", str(uid)]) - if home_dir: - cmd.extend(["--home", str(home_dir)]) - if password: - cmd.extend(["--password", password, "--create-home"]) - if system_user or password is None: - cmd.append("--system") - - if not primary_group: - try: - grp.getgrnam(username) - primary_group = username # avoid "group exists" error - except KeyError: - pass - - if primary_group: - cmd.extend(["-g", primary_group]) - if secondary_groups: - cmd.extend(["-G", ",".join(secondary_groups)]) - - cmd.append(username) - check_output(cmd, stderr=STDOUT) - user_info = pwd.getpwnam(username) - return user_info - - -def add_group(group_name: str, system_group: bool = False, gid: int = None): - """Add a group to the system. - - Will log but otherwise succeed if the group already exists. - - Args: - group_name: group to create - system_group: Create system group - gid: GID for user being created - - Returns: - The group's password database entry struct, as returned by `grp.getgrnam` - """ - try: - group_info = grp.getgrnam(group_name) - logger.info("group '%s' already exists", group_name) - if gid: - group_info = grp.getgrgid(gid) - logger.info("group with gid '%d' already exists", gid) - except KeyError: - logger.info("creating group '%s'", group_name) - cmd = ["addgroup"] - if gid: - cmd.extend(["--gid", str(gid)]) - if system_group: - cmd.append("--system") - else: - cmd.extend(["--group"]) - cmd.append(group_name) - check_output(cmd, stderr=STDOUT) - group_info = grp.getgrnam(group_name) - return group_info - - -def add_user_to_group(username: str, group: str): - """Add a user to a group. - - Args: - username: user to add to specified group - group: name of group to add user to - - Returns: - The group's password database entry struct, as returned by `grp.getgrnam` - """ - if not user_exists(username): - raise ValueError("user '{}' does not exist".format(username)) - if not group_exists(group): - raise ValueError("group '{}' does not exist".format(group)) - - logger.info("adding user '%s' to group '%s'", username, group) - check_output(["gpasswd", "-a", username, group], stderr=STDOUT) - return grp.getgrnam(group) - - -def remove_user(user: Union[str, int], remove_home: bool = False) -> bool: - """Remove a user from the system. - - Args: - user: the username or uid of the user to remove - remove_home: indicates whether the user's home directory should be removed - """ - u = user_exists(user) - if not u: - logger.info("user '%s' does not exist", str(u)) - return True - - cmd = ["userdel"] - if remove_home: - cmd.append("-f") - cmd.append(u.pw_name) - - logger.info("removing user '%s'", u.pw_name) - check_output(cmd, stderr=STDOUT) - return True - - -def remove_group(group: Union[str, int], force: bool = False) -> bool: - """Remove a user from the system. - - Args: - group: the name or gid of the group to remove - force: force group removal even if it's the primary group for a user - """ - g = group_exists(group) - if not g: - logger.info("group '%s' does not exist", str(g)) - return True - - cmd = ["groupdel"] - if force: - cmd.append("-f") - cmd.append(g.gr_name) - - logger.info("removing group '%s'", g.gr_name) - check_output(cmd, stderr=STDOUT) - return True diff --git a/lib/charms/operator_libs_linux/v1/snap.py b/lib/charms/operator_libs_linux/v1/snap.py new file mode 100644 index 00000000..71cdee39 --- /dev/null +++ b/lib/charms/operator_libs_linux/v1/snap.py @@ -0,0 +1,1065 @@ +# Copyright 2021 Canonical Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Representations of the system's Snaps, and abstractions around managing them. + +The `snap` module provides convenience methods for listing, installing, refreshing, and removing +Snap packages, in addition to setting and getting configuration options for them. + +In the `snap` module, `SnapCache` creates a dict-like mapping of `Snap` objects at when +instantiated. Installed snaps are fully populated, and available snaps are lazily-loaded upon +request. This module relies on an installed and running `snapd` daemon to perform operations over +the `snapd` HTTP API. + +`SnapCache` objects can be used to install or modify Snap packages by name in a manner similar to +using the `snap` command from the commandline. + +An example of adding Juju to the system with `SnapCache` and setting a config value: + +```python +try: + cache = snap.SnapCache() + juju = cache["juju"] + + if not juju.present: + juju.ensure(snap.SnapState.Latest, channel="beta") + juju.set({"some.key": "value", "some.key2": "value2"}) +except snap.SnapError as e: + logger.error("An exception occurred when installing charmcraft. Reason: %s", e.message) +``` + +In addition, the `snap` module provides "bare" methods which can act on Snap packages as +simple function calls. :meth:`add`, :meth:`remove`, and :meth:`ensure` are provided, as +well as :meth:`add_local` for installing directly from a local `.snap` file. These return +`Snap` objects. + +As an example of installing several Snaps and checking details: + +```python +try: + nextcloud, charmcraft = snap.add(["nextcloud", "charmcraft"]) + if nextcloud.get("mode") != "production": + nextcloud.set({"mode": "production"}) +except snap.SnapError as e: + logger.error("An exception occurred when installing snaps. Reason: %s" % e.message) +``` +""" + +import http.client +import json +import logging +import os +import re +import socket +import subprocess +import sys +import urllib.error +import urllib.parse +import urllib.request +from collections.abc import Mapping +from datetime import datetime, timedelta, timezone +from enum import Enum +from subprocess import CalledProcessError, CompletedProcess +from typing import Any, Dict, Iterable, List, Optional, Union + +logger = logging.getLogger(__name__) + +# The unique Charmhub library identifier, never change it +LIBID = "05394e5893f94f2d90feb7cbe6b633cd" + +# Increment this major API version when introducing breaking changes +LIBAPI = 1 + +# Increment this PATCH version before using `charmcraft publish-lib` or reset +# to 0 if you are raising the major API version +LIBPATCH = 12 + + +# Regex to locate 7-bit C1 ANSI sequences +ansi_filter = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") + + +def _cache_init(func): + def inner(*args, **kwargs): + if _Cache.cache is None: + _Cache.cache = SnapCache() + return func(*args, **kwargs) + + return inner + + +# recursive hints seems to error out pytest +JSONType = Union[Dict[str, Any], List[Any], str, int, float] + + +class SnapService: + """Data wrapper for snap services.""" + + def __init__( + self, + daemon: Optional[str] = None, + daemon_scope: Optional[str] = None, + enabled: bool = False, + active: bool = False, + activators: List[str] = [], + **kwargs, + ): + self.daemon = daemon + self.daemon_scope = kwargs.get("daemon-scope", None) or daemon_scope + self.enabled = enabled + self.active = active + self.activators = activators + + def as_dict(self) -> Dict: + """Return instance representation as dict.""" + return { + "daemon": self.daemon, + "daemon_scope": self.daemon_scope, + "enabled": self.enabled, + "active": self.active, + "activators": self.activators, + } + + +class MetaCache(type): + """MetaCache class used for initialising the snap cache.""" + + @property + def cache(cls) -> "SnapCache": + """Property for returning the snap cache.""" + return cls._cache + + @cache.setter + def cache(cls, cache: "SnapCache") -> None: + """Setter for the snap cache.""" + cls._cache = cache + + def __getitem__(cls, name) -> "Snap": + """Snap cache getter.""" + return cls._cache[name] + + +class _Cache(object, metaclass=MetaCache): + _cache = None + + +class Error(Exception): + """Base class of most errors raised by this library.""" + + def __repr__(self): + """Represent the Error class.""" + return "<{}.{} {}>".format(type(self).__module__, type(self).__name__, self.args) + + @property + def name(self): + """Return a string representation of the model plus class.""" + return "<{}.{}>".format(type(self).__module__, type(self).__name__) + + @property + def message(self): + """Return the message passed as an argument.""" + return self.args[0] + + +class SnapAPIError(Error): + """Raised when an HTTP API error occurs talking to the Snapd server.""" + + def __init__(self, body: Dict, code: int, status: str, message: str): + super().__init__(message) # Makes str(e) return message + self.body = body + self.code = code + self.status = status + self._message = message + + def __repr__(self): + """Represent the SnapAPIError class.""" + return "APIError({!r}, {!r}, {!r}, {!r})".format( + self.body, self.code, self.status, self._message + ) + + +class SnapState(Enum): + """The state of a snap on the system or in the cache.""" + + Present = "present" + Absent = "absent" + Latest = "latest" + Available = "available" + + +class SnapError(Error): + """Raised when there's an error running snap control commands.""" + + +class SnapNotFoundError(Error): + """Raised when a requested snap is not known to the system.""" + + +class Snap(object): + """Represents a snap package and its properties. + + `Snap` exposes the following properties about a snap: + - name: the name of the snap + - state: a `SnapState` representation of its install status + - channel: "stable", "candidate", "beta", and "edge" are common + - revision: a string representing the snap's revision + - confinement: "classic" or "strict" + """ + + def __init__( + self, + name, + state: SnapState, + channel: str, + revision: int, + confinement: str, + apps: Optional[List[Dict[str, str]]] = None, + cohort: Optional[str] = "", + ) -> None: + self._name = name + self._state = state + self._channel = channel + self._revision = revision + self._confinement = confinement + self._cohort = cohort + self._apps = apps or [] + self._snap_client = SnapClient() + + def __eq__(self, other) -> bool: + """Equality for comparison.""" + return isinstance(other, self.__class__) and ( + self._name, + self._revision, + ) == (other._name, other._revision) + + def __hash__(self): + """Calculate a hash for this snap.""" + return hash((self._name, self._revision)) + + def __repr__(self): + """Represent the object such that it can be reconstructed.""" + return "<{}.{}: {}>".format(self.__module__, self.__class__.__name__, self.__dict__) + + def __str__(self): + """Represent the snap object as a string.""" + return "<{}: {}-{}.{} -- {}>".format( + self.__class__.__name__, + self._name, + self._revision, + self._channel, + str(self._state), + ) + + def _snap(self, command: str, optargs: Optional[Iterable[str]] = None) -> str: + """Perform a snap operation. + + Args: + command: the snap command to execute + optargs: an (optional) list of additional arguments to pass, + commonly confinement or channel + + Raises: + SnapError if there is a problem encountered + """ + optargs = optargs or [] + _cmd = ["snap", command, self._name, *optargs] + try: + return subprocess.check_output(_cmd, universal_newlines=True) + except CalledProcessError as e: + raise SnapError( + "Snap: {!r}; command {!r} failed with output = {!r}".format( + self._name, _cmd, e.output + ) + ) + + def _snap_daemons( + self, + command: List[str], + services: Optional[List[str]] = None, + ) -> CompletedProcess: + """Perform snap app commands. + + Args: + command: the snap command to execute + services: the snap service to execute command on + + Raises: + SnapError if there is a problem encountered + """ + if services: + # an attempt to keep the command constrained to the snap instance's services + services = ["{}.{}".format(self._name, service) for service in services] + else: + services = [self._name] + + _cmd = ["snap", *command, *services] + + try: + return subprocess.run(_cmd, universal_newlines=True, check=True, capture_output=True) + except CalledProcessError as e: + raise SnapError("Could not {} for snap [{}]: {}".format(_cmd, self._name, e.stderr)) + + def get(self, key) -> str: + """Fetch a snap configuration value. + + Args: + key: the key to retrieve + """ + return self._snap("get", [key]).strip() + + def set(self, config: Dict) -> str: + """Set a snap configuration value. + + Args: + config: a dictionary containing keys and values specifying the config to set. + """ + args = ['{}="{}"'.format(key, val) for key, val in config.items()] + + return self._snap("set", [*args]) + + def unset(self, key) -> str: + """Unset a snap configuration value. + + Args: + key: the key to unset + """ + return self._snap("unset", [key]) + + def start(self, services: Optional[List[str]] = None, enable: Optional[bool] = False) -> None: + """Start a snap's services. + + Args: + services (list): (optional) list of individual snap services to start (otherwise all) + enable (bool): (optional) flag to enable snap services on start. Default `false` + """ + args = ["start", "--enable"] if enable else ["start"] + self._snap_daemons(args, services) + + def stop(self, services: Optional[List[str]] = None, disable: Optional[bool] = False) -> None: + """Stop a snap's services. + + Args: + services (list): (optional) list of individual snap services to stop (otherwise all) + disable (bool): (optional) flag to disable snap services on stop. Default `False` + """ + args = ["stop", "--disable"] if disable else ["stop"] + self._snap_daemons(args, services) + + def logs(self, services: Optional[List[str]] = None, num_lines: Optional[int] = 10) -> str: + """Fetch a snap services' logs. + + Args: + services (list): (optional) list of individual snap services to show logs from + (otherwise all) + num_lines (int): (optional) integer number of log lines to return. Default `10` + """ + args = ["logs", "-n={}".format(num_lines)] if num_lines else ["logs"] + return self._snap_daemons(args, services).stdout + + def connect( + self, plug: str, service: Optional[str] = None, slot: Optional[str] = None + ) -> None: + """Connect a plug to a slot. + + Args: + plug (str): the plug to connect + service (str): (optional) the snap service name to plug into + slot (str): (optional) the snap service slot to plug in to + + Raises: + SnapError if there is a problem encountered + """ + command = ["connect", "{}:{}".format(self._name, plug)] + + if service and slot: + command = command + ["{}:{}".format(service, slot)] + elif slot: + command = command + [slot] + + _cmd = ["snap", *command] + try: + subprocess.run(_cmd, universal_newlines=True, check=True, capture_output=True) + except CalledProcessError as e: + raise SnapError("Could not {} for snap [{}]: {}".format(_cmd, self._name, e.stderr)) + + def hold(self, duration: Optional[timedelta] = None) -> None: + """Add a refresh hold to a snap. + + Args: + duration: duration for the hold, or None (the default) to hold this snap indefinitely. + """ + hold_str = "forever" + if duration is not None: + seconds = round(duration.total_seconds()) + hold_str = f"{seconds}s" + self._snap("refresh", [f"--hold={hold_str}"]) + + def unhold(self) -> None: + """Remove the refresh hold of a snap.""" + self._snap("refresh", ["--unhold"]) + + def restart( + self, services: Optional[List[str]] = None, reload: Optional[bool] = False + ) -> None: + """Restarts a snap's services. + + Args: + services (list): (optional) list of individual snap services to show logs from. + (otherwise all) + reload (bool): (optional) flag to use the service reload command, if available. + Default `False` + """ + args = ["restart", "--reload"] if reload else ["restart"] + self._snap_daemons(args, services) + + def _install( + self, + channel: Optional[str] = "", + cohort: Optional[str] = "", + revision: Optional[int] = None, + ) -> None: + """Add a snap to the system. + + Args: + channel: the channel to install from + cohort: optional, the key of a cohort that this snap belongs to + revision: optional, the revision of the snap to install + """ + cohort = cohort or self._cohort + + args = [] + if self.confinement == "classic": + args.append("--classic") + if channel: + args.append('--channel="{}"'.format(channel)) + if revision: + args.append('--revision="{}"'.format(revision)) + if cohort: + args.append('--cohort="{}"'.format(cohort)) + + self._snap("install", args) + + def _refresh( + self, + channel: Optional[str] = "", + cohort: Optional[str] = "", + revision: Optional[int] = None, + leave_cohort: Optional[bool] = False, + ) -> None: + """Refresh a snap. + + Args: + channel: the channel to install from + cohort: optionally, specify a cohort. + revision: optionally, specify the revision of the snap to refresh + leave_cohort: leave the current cohort. + """ + args = [] + if channel: + args.append('--channel="{}"'.format(channel)) + + if revision: + args.append('--revision="{}"'.format(revision)) + + if not cohort: + cohort = self._cohort + + if leave_cohort: + self._cohort = "" + args.append("--leave-cohort") + elif cohort: + args.append('--cohort="{}"'.format(cohort)) + + self._snap("refresh", args) + + def _remove(self) -> str: + """Remove a snap from the system.""" + return self._snap("remove") + + @property + def name(self) -> str: + """Returns the name of the snap.""" + return self._name + + def ensure( + self, + state: SnapState, + classic: Optional[bool] = False, + channel: Optional[str] = "", + cohort: Optional[str] = "", + revision: Optional[int] = None, + ): + """Ensure that a snap is in a given state. + + Args: + state: a `SnapState` to reconcile to. + classic: an (Optional) boolean indicating whether classic confinement should be used + channel: the channel to install from + cohort: optional. Specify the key of a snap cohort. + revision: optional. the revision of the snap to install/refresh + + While both channel and revision could be specified, the underlying snap install/refresh + command will determine which one takes precedence (revision at this time) + + Raises: + SnapError if an error is encountered + """ + self._confinement = "classic" if classic or self._confinement == "classic" else "" + + if state not in (SnapState.Present, SnapState.Latest): + # We are attempting to remove this snap. + if self._state in (SnapState.Present, SnapState.Latest): + # The snap is installed, so we run _remove. + self._remove() + else: + # The snap is not installed -- no need to do anything. + pass + else: + # We are installing or refreshing a snap. + if self._state not in (SnapState.Present, SnapState.Latest): + # The snap is not installed, so we install it. + self._install(channel, cohort, revision) + else: + # The snap is installed, but we are changing it (e.g., switching channels). + self._refresh(channel, cohort, revision) + + self._update_snap_apps() + self._state = state + + def _update_snap_apps(self) -> None: + """Update a snap's apps after snap changes state.""" + try: + self._apps = self._snap_client.get_installed_snap_apps(self._name) + except SnapAPIError: + logger.debug("Unable to retrieve snap apps for {}".format(self._name)) + self._apps = [] + + @property + def present(self) -> bool: + """Report whether or not a snap is present.""" + return self._state in (SnapState.Present, SnapState.Latest) + + @property + def latest(self) -> bool: + """Report whether the snap is the most recent version.""" + return self._state is SnapState.Latest + + @property + def state(self) -> SnapState: + """Report the current snap state.""" + return self._state + + @state.setter + def state(self, state: SnapState) -> None: + """Set the snap state to a given value. + + Args: + state: a `SnapState` to reconcile the snap to. + + Raises: + SnapError if an error is encountered + """ + if self._state is not state: + self.ensure(state) + self._state = state + + @property + def revision(self) -> int: + """Returns the revision for a snap.""" + return self._revision + + @property + def channel(self) -> str: + """Returns the channel for a snap.""" + return self._channel + + @property + def confinement(self) -> str: + """Returns the confinement for a snap.""" + return self._confinement + + @property + def apps(self) -> List: + """Returns (if any) the installed apps of the snap.""" + self._update_snap_apps() + return self._apps + + @property + def services(self) -> Dict: + """Returns (if any) the installed services of the snap.""" + self._update_snap_apps() + services = {} + for app in self._apps: + if "daemon" in app: + services[app["name"]] = SnapService(**app).as_dict() + + return services + + @property + def held(self) -> bool: + """Report whether the snap has a hold.""" + info = self._snap("info") + return "hold:" in info + + +class _UnixSocketConnection(http.client.HTTPConnection): + """Implementation of HTTPConnection that connects to a named Unix socket.""" + + def __init__(self, host, timeout=None, socket_path=None): + if timeout is None: + super().__init__(host) + else: + super().__init__(host, timeout=timeout) + self.socket_path = socket_path + + def connect(self): + """Override connect to use Unix socket (instead of TCP socket).""" + if not hasattr(socket, "AF_UNIX"): + raise NotImplementedError("Unix sockets not supported on {}".format(sys.platform)) + self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.sock.connect(self.socket_path) + if self.timeout is not None: + self.sock.settimeout(self.timeout) + + +class _UnixSocketHandler(urllib.request.AbstractHTTPHandler): + """Implementation of HTTPHandler that uses a named Unix socket.""" + + def __init__(self, socket_path: str): + super().__init__() + self.socket_path = socket_path + + def http_open(self, req) -> http.client.HTTPResponse: + """Override http_open to use a Unix socket connection (instead of TCP).""" + return self.do_open(_UnixSocketConnection, req, socket_path=self.socket_path) + + +class SnapClient: + """Snapd API client to talk to HTTP over UNIX sockets. + + In order to avoid shelling out and/or involving sudo in calling the snapd API, + use a wrapper based on the Pebble Client, trimmed down to only the utility methods + needed for talking to snapd. + """ + + def __init__( + self, + socket_path: str = "/run/snapd.socket", + opener: Optional[urllib.request.OpenerDirector] = None, + base_url: str = "http://localhost/v2/", + timeout: float = 5.0, + ): + """Initialize a client instance. + + Args: + socket_path: a path to the socket on the filesystem. Defaults to /run/snap/snapd.socket + opener: specifies an opener for unix socket, if unspecified a default is used + base_url: base url for making requests to the snap client. Defaults to + http://localhost/v2/ + timeout: timeout in seconds to use when making requests to the API. Default is 5.0s. + """ + if opener is None: + opener = self._get_default_opener(socket_path) + self.opener = opener + self.base_url = base_url + self.timeout = timeout + + @classmethod + def _get_default_opener(cls, socket_path): + """Build the default opener to use for requests (HTTP over Unix socket).""" + opener = urllib.request.OpenerDirector() + opener.add_handler(_UnixSocketHandler(socket_path)) + opener.add_handler(urllib.request.HTTPDefaultErrorHandler()) + opener.add_handler(urllib.request.HTTPRedirectHandler()) + opener.add_handler(urllib.request.HTTPErrorProcessor()) + return opener + + def _request( + self, + method: str, + path: str, + query: Dict = None, + body: Dict = None, + ) -> JSONType: + """Make a JSON request to the Snapd server with the given HTTP method and path. + + If query dict is provided, it is encoded and appended as a query string + to the URL. If body dict is provided, it is serialied as JSON and used + as the HTTP body (with Content-Type: "application/json"). The resulting + body is decoded from JSON. + """ + headers = {"Accept": "application/json"} + data = None + if body is not None: + data = json.dumps(body).encode("utf-8") + headers["Content-Type"] = "application/json" + + response = self._request_raw(method, path, query, headers, data) + return json.loads(response.read().decode())["result"] + + def _request_raw( + self, + method: str, + path: str, + query: Dict = None, + headers: Dict = None, + data: bytes = None, + ) -> http.client.HTTPResponse: + """Make a request to the Snapd server; return the raw HTTPResponse object.""" + url = self.base_url + path + if query: + url = url + "?" + urllib.parse.urlencode(query) + + if headers is None: + headers = {} + request = urllib.request.Request(url, method=method, data=data, headers=headers) + + try: + response = self.opener.open(request, timeout=self.timeout) + except urllib.error.HTTPError as e: + code = e.code + status = e.reason + message = "" + try: + body = json.loads(e.read().decode())["result"] + except (IOError, ValueError, KeyError) as e2: + # Will only happen on read error or if Pebble sends invalid JSON. + body = {} + message = "{} - {}".format(type(e2).__name__, e2) + raise SnapAPIError(body, code, status, message) + except urllib.error.URLError as e: + raise SnapAPIError({}, 500, "Not found", e.reason) + return response + + def get_installed_snaps(self) -> Dict: + """Get information about currently installed snaps.""" + return self._request("GET", "snaps") + + def get_snap_information(self, name: str) -> Dict: + """Query the snap server for information about single snap.""" + return self._request("GET", "find", {"name": name})[0] + + def get_installed_snap_apps(self, name: str) -> List: + """Query the snap server for apps belonging to a named, currently installed snap.""" + return self._request("GET", "apps", {"names": name, "select": "service"}) + + +class SnapCache(Mapping): + """An abstraction to represent installed/available packages. + + When instantiated, `SnapCache` iterates through the list of installed + snaps using the `snapd` HTTP API, and a list of available snaps by reading + the filesystem to populate the cache. Information about available snaps is lazily-loaded + from the `snapd` API when requested. + """ + + def __init__(self): + if not self.snapd_installed: + raise SnapError("snapd is not installed or not in /usr/bin") from None + self._snap_client = SnapClient() + self._snap_map = {} + if self.snapd_installed: + self._load_available_snaps() + self._load_installed_snaps() + + def __contains__(self, key: str) -> bool: + """Check if a given snap is in the cache.""" + return key in self._snap_map + + def __len__(self) -> int: + """Report number of items in the snap cache.""" + return len(self._snap_map) + + def __iter__(self) -> Iterable["Snap"]: + """Provide iterator for the snap cache.""" + return iter(self._snap_map.values()) + + def __getitem__(self, snap_name: str) -> Snap: + """Return either the installed version or latest version for a given snap.""" + snap = self._snap_map.get(snap_name, None) + if snap is None: + # The snapd cache file may not have existed when _snap_map was + # populated. This is normal. + try: + self._snap_map[snap_name] = self._load_info(snap_name) + except SnapAPIError: + raise SnapNotFoundError("Snap '{}' not found!".format(snap_name)) + + return self._snap_map[snap_name] + + @property + def snapd_installed(self) -> bool: + """Check whether snapd has been installled on the system.""" + return os.path.isfile("/usr/bin/snap") + + def _load_available_snaps(self) -> None: + """Load the list of available snaps from disk. + + Leave them empty and lazily load later if asked for. + """ + if not os.path.isfile("/var/cache/snapd/names"): + # The snap catalog may not be populated yet; this is normal. + # snapd updates the cache infrequently and the cache file may not + # currently exist. + return + + with open("/var/cache/snapd/names", "r") as f: + for line in f: + if line.strip(): + self._snap_map[line.strip()] = None + + def _load_installed_snaps(self) -> None: + """Load the installed snaps into the dict.""" + installed = self._snap_client.get_installed_snaps() + + for i in installed: + snap = Snap( + name=i["name"], + state=SnapState.Latest, + channel=i["channel"], + revision=int(i["revision"]), + confinement=i["confinement"], + apps=i.get("apps", None), + ) + self._snap_map[snap.name] = snap + + def _load_info(self, name) -> Snap: + """Load info for snaps which are not installed if requested. + + Args: + name: a string representing the name of the snap + """ + info = self._snap_client.get_snap_information(name) + + return Snap( + name=info["name"], + state=SnapState.Available, + channel=info["channel"], + revision=int(info["revision"]), + confinement=info["confinement"], + apps=None, + ) + + +@_cache_init +def add( + snap_names: Union[str, List[str]], + state: Union[str, SnapState] = SnapState.Latest, + channel: Optional[str] = "", + classic: Optional[bool] = False, + cohort: Optional[str] = "", + revision: Optional[int] = None, +) -> Union[Snap, List[Snap]]: + """Add a snap to the system. + + Args: + snap_names: the name or names of the snaps to install + state: a string or `SnapState` representation of the desired state, one of + [`Present` or `Latest`] + channel: an (Optional) channel as a string. Defaults to 'latest' + classic: an (Optional) boolean specifying whether it should be added with classic + confinement. Default `False` + cohort: an (Optional) string specifying the snap cohort to use + revision: an (Optional) integer specifying the snap revision to use + + Raises: + SnapError if some snaps failed to install or were not found. + """ + if not channel and not revision: + channel = "latest" + + snap_names = [snap_names] if type(snap_names) is str else snap_names + if not snap_names: + raise TypeError("Expected at least one snap to add, received zero!") + + if type(state) is str: + state = SnapState(state) + + return _wrap_snap_operations(snap_names, state, channel, classic, cohort, revision) + + +@_cache_init +def remove(snap_names: Union[str, List[str]]) -> Union[Snap, List[Snap]]: + """Remove specified snap(s) from the system. + + Args: + snap_names: the name or names of the snaps to install + + Raises: + SnapError if some snaps failed to install. + """ + snap_names = [snap_names] if type(snap_names) is str else snap_names + if not snap_names: + raise TypeError("Expected at least one snap to add, received zero!") + + return _wrap_snap_operations(snap_names, SnapState.Absent, "", False) + + +@_cache_init +def ensure( + snap_names: Union[str, List[str]], + state: str, + channel: Optional[str] = "", + classic: Optional[bool] = False, + cohort: Optional[str] = "", + revision: Optional[int] = None, +) -> Union[Snap, List[Snap]]: + """Ensure specified snaps are in a given state on the system. + + Args: + snap_names: the name(s) of the snaps to operate on + state: a string representation of the desired state, from `SnapState` + channel: an (Optional) channel as a string. Defaults to 'latest' + classic: an (Optional) boolean specifying whether it should be added with classic + confinement. Default `False` + cohort: an (Optional) string specifying the snap cohort to use + revision: an (Optional) integer specifying the snap revision to use + + When both channel and revision are specified, the underlying snap install/refresh + command will determine the precedence (revision at the time of adding this) + + Raises: + SnapError if the snap is not in the cache. + """ + if not revision and not channel: + channel = "latest" + + if state in ("present", "latest") or revision: + return add(snap_names, SnapState(state), channel, classic, cohort, revision) + else: + return remove(snap_names) + + +def _wrap_snap_operations( + snap_names: List[str], + state: SnapState, + channel: str, + classic: bool, + cohort: Optional[str] = "", + revision: Optional[int] = None, +) -> Union[Snap, List[Snap]]: + """Wrap common operations for bare commands.""" + snaps = {"success": [], "failed": []} + + op = "remove" if state is SnapState.Absent else "install or refresh" + + for s in snap_names: + try: + snap = _Cache[s] + if state is SnapState.Absent: + snap.ensure(state=SnapState.Absent) + else: + snap.ensure( + state=state, classic=classic, channel=channel, cohort=cohort, revision=revision + ) + snaps["success"].append(snap) + except SnapError as e: + logger.warning("Failed to {} snap {}: {}!".format(op, s, e.message)) + snaps["failed"].append(s) + except SnapNotFoundError: + logger.warning("Snap '{}' not found in cache!".format(s)) + snaps["failed"].append(s) + + if len(snaps["failed"]): + raise SnapError( + "Failed to install or refresh snap(s): {}".format(", ".join(list(snaps["failed"]))) + ) + + return snaps["success"] if len(snaps["success"]) > 1 else snaps["success"][0] + + +def install_local( + filename: str, classic: Optional[bool] = False, dangerous: Optional[bool] = False +) -> Snap: + """Perform a snap operation. + + Args: + filename: the path to a local .snap file to install + classic: whether to use classic confinement + dangerous: whether --dangerous should be passed to install snaps without a signature + + Raises: + SnapError if there is a problem encountered + """ + _cmd = [ + "snap", + "install", + filename, + ] + if classic: + _cmd.append("--classic") + if dangerous: + _cmd.append("--dangerous") + try: + result = subprocess.check_output(_cmd, universal_newlines=True).splitlines()[-1] + snap_name, _ = result.split(" ", 1) + snap_name = ansi_filter.sub("", snap_name) + + c = SnapCache() + + try: + return c[snap_name] + except SnapAPIError as e: + logger.error( + "Could not find snap {} when querying Snapd socket: {}".format(snap_name, e.body) + ) + raise SnapError("Failed to find snap {} in Snap cache".format(snap_name)) + except CalledProcessError as e: + raise SnapError("Could not install snap {}: {}".format(filename, e.output)) + + +def _system_set(config_item: str, value: str) -> None: + """Set system snapd config values. + + Args: + config_item: name of snap system setting. E.g. 'refresh.hold' + value: value to assign + """ + _cmd = ["snap", "set", "system", "{}={}".format(config_item, value)] + try: + subprocess.check_call(_cmd, universal_newlines=True) + except CalledProcessError: + raise SnapError("Failed setting system config '{}' to '{}'".format(config_item, value)) + + +def hold_refresh(days: int = 90, forever: bool = False) -> bool: + """Set the system-wide snap refresh hold. + + Args: + days: number of days to hold system refreshes for. Maximum 90. Set to zero to remove hold. + forever: if True, will set a hold forever. + """ + if not isinstance(forever, bool): + raise TypeError("forever must be a bool") + if not isinstance(days, int): + raise TypeError("days must be an int") + if forever: + _system_set("refresh.hold", "forever") + logger.info("Set system-wide snap refresh hold to: forever") + elif days == 0: + _system_set("refresh.hold", "") + logger.info("Removed system-wide snap refresh hold") + else: + # Currently the snap daemon can only hold for a maximum of 90 days + if not 1 <= days <= 90: + raise ValueError("days must be between 1 and 90") + # Add the number of days to current time + target_date = datetime.now(timezone.utc).astimezone() + timedelta(days=days) + # Format for the correct datetime format + hold_date = target_date.strftime("%Y-%m-%dT%H:%M:%S%z") + # Python dumps the offset in format '+0100', we need '+01:00' + hold_date = "{0}:{1}".format(hold_date[:-2], hold_date[-2:]) + # Actually set the hold date + _system_set("refresh.hold", hold_date) + logger.info("Set system-wide snap refresh hold to: %s", hold_date) diff --git a/lib/charms/operator_libs_linux/v1/systemd.py b/lib/charms/operator_libs_linux/v1/systemd.py deleted file mode 100644 index 5be34c17..00000000 --- a/lib/charms/operator_libs_linux/v1/systemd.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright 2021 Canonical Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -"""Abstractions for stopping, starting and managing system services via systemd. - -This library assumes that your charm is running on a platform that uses systemd. E.g., -Centos 7 or later, Ubuntu Xenial (16.04) or later. - -For the most part, we transparently provide an interface to a commonly used selection of -systemd commands, with a few shortcuts baked in. For example, service_pause and -service_resume with run the mask/unmask and enable/disable invocations. - -Example usage: -```python -from charms.operator_libs_linux.v0.systemd import service_running, service_reload - -# Start a service -if not service_running("mysql"): - success = service_start("mysql") - -# Attempt to reload a service, restarting if necessary -success = service_reload("nginx", restart_on_failure=True) -``` - -""" - -import logging -import subprocess - -__all__ = [ # Don't export `_systemctl`. (It's not the intended way of using this lib.) - "service_pause", - "service_reload", - "service_restart", - "service_resume", - "service_running", - "service_start", - "service_stop", - "daemon_reload", -] - -logger = logging.getLogger(__name__) - -# The unique Charmhub library identifier, never change it -LIBID = "045b0d179f6b4514a8bb9b48aee9ebaf" - -# Increment this major API version when introducing breaking changes -LIBAPI = 1 - -# Increment this PATCH version before using `charmcraft publish-lib` or reset -# to 0 if you are raising the major API version -LIBPATCH = 0 - - -class SystemdError(Exception): - pass - - -def _popen_kwargs(): - return dict( - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - bufsize=1, - universal_newlines=True, - encoding="utf-8", - ) - - -def _systemctl( - sub_cmd: str, service_name: str = None, now: bool = None, quiet: bool = None -) -> bool: - """Control a system service. - - Args: - sub_cmd: the systemctl subcommand to issue - service_name: the name of the service to perform the action on - now: passes the --now flag to the shell invocation. - quiet: passes the --quiet flag to the shell invocation. - """ - cmd = ["systemctl", sub_cmd] - - if service_name is not None: - cmd.append(service_name) - if now is not None: - cmd.append("--now") - if quiet is not None: - cmd.append("--quiet") - if sub_cmd != "is-active": - logger.debug("Attempting to {} '{}' with command {}.".format(cmd, service_name, cmd)) - else: - logger.debug("Checking if '{}' is active".format(service_name)) - - proc = subprocess.Popen(cmd, **_popen_kwargs()) - last_line = "" - for line in iter(proc.stdout.readline, ""): - last_line = line - logger.debug(line) - - proc.wait() - - if sub_cmd == "is-active": - # If we are just checking whether a service is running, return True/False, rather - # than raising an error. - if proc.returncode < 1: - return True - if proc.returncode == 3: # Code returned when service is not active. - return False - - if proc.returncode < 1: - return True - - raise SystemdError( - "Could not {}{}: systemd output: {}".format( - sub_cmd, " {}".format(service_name) if service_name else "", last_line - ) - ) - - -def service_running(service_name: str) -> bool: - """Determine whether a system service is running. - - Args: - service_name: the name of the service - """ - return _systemctl("is-active", service_name, quiet=True) - - -def service_start(service_name: str) -> bool: - """Start a system service. - - Args: - service_name: the name of the service to stop - """ - return _systemctl("start", service_name) - - -def service_stop(service_name: str) -> bool: - """Stop a system service. - - Args: - service_name: the name of the service to stop - """ - return _systemctl("stop", service_name) - - -def service_restart(service_name: str) -> bool: - """Restart a system service. - - Args: - service_name: the name of the service to restart - """ - return _systemctl("restart", service_name) - - -def service_reload(service_name: str, restart_on_failure: bool = False) -> bool: - """Reload a system service, optionally falling back to restart if reload fails. - - Args: - service_name: the name of the service to reload - restart_on_failure: boolean indicating whether to fallback to a restart if the - reload fails. - """ - try: - return _systemctl("reload", service_name) - except SystemdError: - if restart_on_failure: - return _systemctl("restart", service_name) - else: - raise - - -def service_pause(service_name: str) -> bool: - """Pause a system service. - - Stop it, and prevent it from starting again at boot. - - Args: - service_name: the name of the service to pause - """ - _systemctl("disable", service_name, now=True) - _systemctl("mask", service_name) - - if not service_running(service_name): - return True - - raise SystemdError("Attempted to pause '{}', but it is still running.".format(service_name)) - - -def service_resume(service_name: str) -> bool: - """Resume a system service. - - Re-enable starting again at boot. Start the service. - - Args: - service_name: the name of the service to resume - """ - _systemctl("unmask", service_name) - _systemctl("enable", service_name, now=True) - - if service_running(service_name): - return True - - raise SystemdError("Attempted to resume '{}', but it is not running.".format(service_name)) - - -def daemon_reload() -> bool: - """Reload systemd manager configuration.""" - return _systemctl("daemon-reload") diff --git a/src/charm.py b/src/charm.py index d5420d3f..d52d1ef9 100755 --- a/src/charm.py +++ b/src/charm.py @@ -11,12 +11,14 @@ import subprocess from typing import Optional -from charms.operator_libs_linux.v1 import systemd +from charms.operator_libs_linux.v1 import snap from ops.charm import CharmBase, RelationChangedEvent from ops.main import main from ops.model import ActiveStatus, BlockedStatus, MaintenanceStatus, WaitingStatus from constants import ( + CHARMED_MYSQL_ROUTER_SERVICE, + CHARMED_MYSQL_SNAP, LEGACY_SHARED_DB, MYSQL_ROUTER_LEADER_BOOTSTRAPED, MYSQL_ROUTER_REQUIRES_DATA, @@ -25,7 +27,7 @@ from mysql_router_helpers import ( MySQLRouter, MySQLRouterBootstrapError, - MySQLRouterInstallAndConfigureError, + MySQLRouterInstallCharmedMySQLError, ) from relations.database_provides import DatabaseProvidesRelation from relations.database_requires import DatabaseRequiresRelation @@ -110,8 +112,8 @@ def _on_install(self, _) -> None: self.unit.status = MaintenanceStatus("Installing packages") try: - MySQLRouter.install_and_configure_mysql_router() - except MySQLRouterInstallAndConfigureError: + MySQLRouter.install_charmed_mysql() + except MySQLRouterInstallCharmedMySQLError: self.unit.status = BlockedStatus("Failed to install mysqlrouter") return @@ -129,17 +131,11 @@ def _on_upgrade_charm(self, _) -> None: self.unit.status = MaintenanceStatus("Upgrading charm") requires_data = json.loads(self.app_peer_data.get(MYSQL_ROUTER_REQUIRES_DATA)) - related_app_name = ( - self.shared_db_relation._get_related_app_name() - if self.shared_db_relation._shared_db_relation_exists() - else self.database_provides_relation._get_related_app_name() - ) try: MySQLRouter.bootstrap_and_start_mysql_router( requires_data["username"], self._get_secret("app", "database-password"), - related_app_name, requires_data["endpoints"].split(",")[0].split(":")[0], "3306", force=True, @@ -160,9 +156,14 @@ def _on_peer_relation_changed(self, event: RelationChangedEvent) -> None: MYSQL_ROUTER_LEADER_BOOTSTRAPED ): try: - mysqlrouter_running = MySQLRouter.is_mysqlrouter_running() - except systemd.SystemdError as e: - logger.exception("Failed to check if mysqlrouter with systemd", exc_info=e) + cache = snap.SnapCache() + charmed_mysql = cache[CHARMED_MYSQL_SNAP] + + mysqlrouter_running = charmed_mysql.services[CHARMED_MYSQL_ROUTER_SERVICE][ + "active" + ] + except snap.SnapError: + logger.exception("Failed to check if mysqlrouter service is running") self.unit.status = BlockedStatus("Failed to bootstrap mysqlrouter") return @@ -176,17 +177,11 @@ def _on_peer_relation_changed(self, event: RelationChangedEvent) -> None: return requires_data = json.loads(self.app_peer_data.get(MYSQL_ROUTER_REQUIRES_DATA)) - related_app_name = ( - self.shared_db_relation._get_related_app_name() - if self.shared_db_relation._shared_db_relation_exists() - else self.database_provides_relation._get_related_app_name() - ) try: MySQLRouter.bootstrap_and_start_mysql_router( requires_data["username"], self._get_secret("app", "database-password"), - related_app_name, requires_data["endpoints"].split(",")[0].split(":")[0], "3306", ) diff --git a/src/constants.py b/src/constants.py index fe0ddd99..80fb4804 100644 --- a/src/constants.py +++ b/src/constants.py @@ -3,17 +3,17 @@ """File containing constants to be used in the charm.""" -MYSQL_ROUTER_APT_PACKAGE = "mysql-router" -MYSQL_ROUTER_GROUP = "mysql" -MYSQL_ROUTER_USER = "mysql" -MYSQL_HOME_DIRECTORY = "/var/lib/mysql" +CHARMED_MYSQL_SNAP = "charmed-mysql" +CHARMED_MYSQL_SNAP_REVISION = 48 +SNAP_DAEMON_USER = "snap_daemon" +CHARMED_MYSQL_DATA_DIRECTORY = "/var/snap/charmed-mysql/current" +CHARMED_MYSQL_COMMON_DIRECTORY = "/var/snap/charmed-mysql/common" +CHARMED_MYSQL_ROUTER = "charmed-mysql.mysqlrouter" +CHARMED_MYSQL_ROUTER_SERVICE = "mysqlrouter-service" PEER = "mysql-router-peers" DATABASE_REQUIRES_RELATION = "backend-database" DATABASE_PROVIDES_RELATION = "database" MYSQL_ROUTER_LEADER_BOOTSTRAPED = "mysql-router-leader-bootstraped" -MYSQL_ROUTER_UNIT_TEMPLATE = "templates/mysqlrouter.service.j2" -MYSQL_ROUTER_SERVICE_NAME = "mysqlrouter.service" -MYSQL_ROUTER_SYSTEMD_DIRECTORY = "/etc/systemd/system" MYSQL_ROUTER_REQUIRES_DATA = "requires-database-data" MYSQL_ROUTER_PROVIDES_DATA = "provides-database-data" PASSWORD_LENGTH = 24 diff --git a/src/mysql_router_helpers.py b/src/mysql_router_helpers.py index 9ba39d73..49dc481a 100644 --- a/src/mysql_router_helpers.py +++ b/src/mysql_router_helpers.py @@ -3,25 +3,20 @@ """Helper class to manage the MySQL Router lifecycle.""" -import grp import logging -import os -import pwd import subprocess -import jinja2 import mysql.connector -from charms.operator_libs_linux.v0 import apt, passwd -from charms.operator_libs_linux.v1 import systemd +from charms.operator_libs_linux.v1 import snap from constants import ( - MYSQL_HOME_DIRECTORY, - MYSQL_ROUTER_APT_PACKAGE, - MYSQL_ROUTER_GROUP, - MYSQL_ROUTER_SERVICE_NAME, - MYSQL_ROUTER_SYSTEMD_DIRECTORY, - MYSQL_ROUTER_UNIT_TEMPLATE, - MYSQL_ROUTER_USER, + CHARMED_MYSQL_COMMON_DIRECTORY, + CHARMED_MYSQL_DATA_DIRECTORY, + CHARMED_MYSQL_ROUTER, + CHARMED_MYSQL_ROUTER_SERVICE, + CHARMED_MYSQL_SNAP, + CHARMED_MYSQL_SNAP_REVISION, + SNAP_DAEMON_USER, ) logger = logging.getLogger(__name__) @@ -45,8 +40,8 @@ def message(self): return self.args[0] -class MySQLRouterInstallAndConfigureError(Error): - """Exception raised when there is an issue installing MySQLRouter.""" +class MySQLRouterInstallCharmedMySQLError(Error): + """Exception raised when there is an issue installing charmed-mysql snap.""" class MySQLRouterBootstrapError(Error): @@ -61,55 +56,24 @@ class MySQLRouter: """Class to encapsulate all operations related to MySQLRouter.""" @staticmethod - def install_and_configure_mysql_router() -> None: - """Install and configure MySQLRouter.""" + def install_charmed_mysql() -> None: + """Install charmed-mysql snap and configure MySQLRouter.""" try: - apt.update() - apt.add_package(MYSQL_ROUTER_APT_PACKAGE) + logger.debug("Retrieving snap cache") + cache = snap.SnapCache() + charmed_mysql = cache[CHARMED_MYSQL_SNAP] - if not passwd.group_exists(MYSQL_ROUTER_GROUP): - passwd.add_group(MYSQL_ROUTER_GROUP, system_group=True) - - if not passwd.user_exists(MYSQL_ROUTER_USER): - passwd.add_user( - MYSQL_ROUTER_USER, - shell="/usr/sbin/nologin", - system_user=True, - primary_group=MYSQL_ROUTER_GROUP, - home_dir=MYSQL_HOME_DIRECTORY, - ) - - if not os.path.exists(MYSQL_HOME_DIRECTORY): - os.makedirs(MYSQL_HOME_DIRECTORY, mode=0o755, exist_ok=True) - - user_id = pwd.getpwnam(MYSQL_ROUTER_USER).pw_uid - group_id = grp.getgrnam(MYSQL_ROUTER_GROUP).gr_gid - - os.chown(MYSQL_HOME_DIRECTORY, user_id, group_id) + if not charmed_mysql.present: + logger.debug("Install charmed-mysql snap") + charmed_mysql.ensure(snap.SnapState.Latest, revision=CHARMED_MYSQL_SNAP_REVISION) except Exception as e: - logger.exception(f"Failed to install the {MYSQL_ROUTER_APT_PACKAGE} apt package.") - raise MySQLRouterInstallAndConfigureError(e.stderr) - - @staticmethod - def _render_and_copy_mysqlrouter_systemd_unit_file(app_name): - with open(MYSQL_ROUTER_UNIT_TEMPLATE, "r") as file: - template = jinja2.Template(file.read()) - - rendered_template = template.render(charm_app_name=app_name) - systemd_file_path = f"{MYSQL_ROUTER_SYSTEMD_DIRECTORY}/mysqlrouter.service" - - with open(systemd_file_path, "w+") as file: - file.write(rendered_template) - - os.chmod(systemd_file_path, 0o644) - mysql_user = pwd.getpwnam(MYSQL_ROUTER_USER) - os.chown(systemd_file_path, uid=mysql_user.pw_uid, gid=mysql_user.pw_gid) + logger.exception(f"Failed to install the {CHARMED_MYSQL_SNAP} snap.") + raise MySQLRouterInstallCharmedMySQLError(e.stderr) @staticmethod def bootstrap_and_start_mysql_router( user, password, - name, db_host, port, force=False, @@ -119,7 +83,6 @@ def bootstrap_and_start_mysql_router( Args: user: The user to connect to the database with password: The password to connect to the database with - name: The name of application that will use mysqlrouter db_host: The hostname of the database to connect to port: The port at which to bootstrap mysqlrouter to force: Overwrite existing config if any @@ -132,15 +95,11 @@ def bootstrap_and_start_mysql_router( # https://dev.mysql.com/doc/refman/8.0/en/caching-sha2-pluggable-authentication.html) bootstrap_mysqlrouter_command = [ "sudo", - "/usr/bin/mysqlrouter", + CHARMED_MYSQL_ROUTER, "--user", - MYSQL_ROUTER_USER, - "--name", - name, + SNAP_DAEMON_USER, "--bootstrap", f"{user}:{password}@{db_host}", - "--directory", - f"{MYSQL_HOME_DIRECTORY}/{name}", "--conf-use-sockets", "--conf-bind-address", "127.0.0.1", @@ -157,35 +116,34 @@ def bootstrap_and_start_mysql_router( bootstrap_mysqlrouter_command.append("--force") try: - subprocess.run(bootstrap_mysqlrouter_command) + subprocess.run(bootstrap_mysqlrouter_command, check=True) - subprocess.run(f"sudo chmod 755 {MYSQL_HOME_DIRECTORY}/{name}".split()) + replace_socket_location_command = [ + "sudo", + "sed", + "-Ei", + f"s:/tmp/(.+).sock:{CHARMED_MYSQL_COMMON_DIRECTORY}/var/run/mysqlrouter/\\1.sock:g", + f"{CHARMED_MYSQL_DATA_DIRECTORY}/etc/mysqlrouter/mysqlrouter.conf", + ] + subprocess.run(replace_socket_location_command, check=True) - MySQLRouter._render_and_copy_mysqlrouter_systemd_unit_file(name) + cache = snap.SnapCache() + charmed_mysql = cache[CHARMED_MYSQL_SNAP] - if not systemd.daemon_reload(): - error_message = "Failed to load the mysqlrouter systemd service" - logger.exception(error_message) - raise MySQLRouterBootstrapError(error_message) + charmed_mysql.start(services=[CHARMED_MYSQL_ROUTER_SERVICE]) - systemd.service_start(MYSQL_ROUTER_SERVICE_NAME) - if not MySQLRouter.is_mysqlrouter_running(): - error_message = "Failed to start the mysqlrouter systemd service" + if not charmed_mysql.services[CHARMED_MYSQL_ROUTER_SERVICE]["active"]: + error_message = "Failed to start the mysqlrouter snap service" logger.exception(error_message) raise MySQLRouterBootstrapError(error_message) except subprocess.CalledProcessError as e: - logger.exception("Failed to bootstrap mysqlrouter") + logger.exception("Failed to bootstrap and start mysqlrouter") raise MySQLRouterBootstrapError(e.stderr) - except systemd.SystemdError: - error_message = "Failed to set up mysqlrouter as a systemd service" + except snap.SnapError: + error_message = f"Failed to start snap service {CHARMED_MYSQL_ROUTER_SERVICE}" logger.exception(error_message) raise MySQLRouterBootstrapError(error_message) - @staticmethod - def is_mysqlrouter_running() -> bool: - """Indicates whether MySQLRouter is running as a systemd service.""" - return systemd.service_running(MYSQL_ROUTER_SERVICE_NAME) - @staticmethod def create_user_with_database_privileges( username, password, hostname, database, db_username, db_password, db_host, db_port diff --git a/src/relations/database_provides.py b/src/relations/database_provides.py index 13b2b43f..ef45a4aa 100644 --- a/src/relations/database_provides.py +++ b/src/relations/database_provides.py @@ -11,9 +11,10 @@ DatabaseRequestedEvent, ) from ops.framework import Object -from ops.model import Application, BlockedStatus +from ops.model import BlockedStatus from constants import ( + CHARMED_MYSQL_COMMON_DIRECTORY, DATABASE_PROVIDES_RELATION, MYSQL_ROUTER_LEADER_BOOTSTRAPED, MYSQL_ROUTER_PROVIDES_DATA, @@ -54,17 +55,6 @@ def _database_provides_relation_exists(self) -> bool: database_provides_relations = self.charm.model.relations.get(DATABASE_PROVIDES_RELATION) return bool(database_provides_relations) - def _get_related_app_name(self) -> str: - """Helper to get the name of the related `database-provides` application.""" - if not self._database_provides_relation_exists(): - return None - - for key in self.charm.model.relations[DATABASE_PROVIDES_RELATION][0].data: - if type(key) == Application and key.name != self.charm.app.name: - return key.name - - return None - # ======================= # Handlers # ======================= @@ -106,13 +96,11 @@ def _on_peer_relation_changed(self, _) -> None: db_host = parsed_database_requires_data["endpoints"].split(",")[0].split(":")[0] mysqlrouter_username = parsed_database_requires_data["username"] mysqlrouter_user_password = self.charm._get_secret("app", "database-password") - related_app_name = self._get_related_app_name() try: MySQLRouter.bootstrap_and_start_mysql_router( mysqlrouter_username, mysqlrouter_user_password, - related_app_name, db_host, "3306", ) @@ -147,10 +135,12 @@ def _on_peer_relation_changed(self, _) -> None: provides_relation_id, application_username, application_password ) self.database.set_endpoints( - provides_relation_id, f"file:///var/lib/mysql/{related_app_name}/mysql.sock" + provides_relation_id, + f"file://{CHARMED_MYSQL_COMMON_DIRECTORY}/var/run/mysqlrouter/mysql.sock", ) self.database.set_read_only_endpoints( - provides_relation_id, f"file:///var/lib/mysql/{related_app_name}/mysqlro.sock" + provides_relation_id, + f"file://{CHARMED_MYSQL_COMMON_DIRECTORY}/var/run/mysqlrouter/mysqlro.sock", ) self.charm.app_peer_data[MYSQL_ROUTER_LEADER_BOOTSTRAPED] = "true" diff --git a/src/relations/shared_db.py b/src/relations/shared_db.py index b1ea0995..e65fd9ae 100644 --- a/src/relations/shared_db.py +++ b/src/relations/shared_db.py @@ -8,7 +8,7 @@ from ops.charm import RelationChangedEvent from ops.framework import Object -from ops.model import Application, BlockedStatus, Unit +from ops.model import BlockedStatus, Unit from constants import ( LEGACY_SHARED_DB, @@ -53,17 +53,6 @@ def _shared_db_relation_exists(self) -> bool: shared_db_relations = self.charm.model.relations.get(LEGACY_SHARED_DB) return bool(shared_db_relations) - def _get_related_app_name(self) -> str: - """Helper to get the name of the related `shared-db` application.""" - if not self._shared_db_relation_exists(): - return None - - for key in self.charm.model.relations[LEGACY_SHARED_DB][0].data: - if type(key) == Application and key.name != self.charm.app.name: - return key.name - - return None - def _get_related_unit_name(self) -> str: """Helper to get the name of the related `shared-db` unit.""" if not self._shared_db_relation_exists(): @@ -136,14 +125,12 @@ def _on_peer_relation_changed(self, _) -> None: parsed_shared_db_data = json.loads(self.charm.app_peer_data[LEGACY_SHARED_DB_DATA]) db_host = parsed_requires_data["endpoints"].split(",")[0].split(":")[0] - related_app_name = self._get_related_app_name() application_password = generate_random_password(PASSWORD_LENGTH) try: MySQLRouter.bootstrap_and_start_mysql_router( parsed_requires_data["username"], database_password, - related_app_name, db_host, "3306", ) diff --git a/templates/mysqlrouter.service.j2 b/templates/mysqlrouter.service.j2 deleted file mode 100644 index 850b1a8c..00000000 --- a/templates/mysqlrouter.service.j2 +++ /dev/null @@ -1,19 +0,0 @@ -# MySQL Router systemd service file - -[Unit] -Description=MySQL Router -After=network.target - -[Service] -Type=forking -User=mysql -Group=mysql -RuntimeDirectory=mysql -ExecStart=/var/lib/mysql/{{ charm_app_name }}/start.sh -ExecStop=/var/lib/mysql/{{ charm_app_name }}/stop.sh -RemainAfterExit=yes -Restart=on-failure -LimitNOFILE=65535 - -[Install] -WantedBy=multi-user.target diff --git a/tests/unit/test_charm.py b/tests/unit/test_charm.py index a95660d5..d1478b1b 100644 --- a/tests/unit/test_charm.py +++ b/tests/unit/test_charm.py @@ -12,7 +12,7 @@ from constants import MYSQL_ROUTER_REQUIRES_DATA, PEER from mysql_router_helpers import ( MySQLRouterBootstrapError, - MySQLRouterInstallAndConfigureError, + MySQLRouterInstallCharmedMySQLError, ) @@ -67,30 +67,26 @@ def test_set_secret(self): ) @patch("subprocess.check_call") - @patch("mysql_router_helpers.MySQLRouter.install_and_configure_mysql_router") - def test_on_install(self, _install_and_configure_mysql_router, _check_call): + @patch("mysql_router_helpers.MySQLRouter.install_charmed_mysql") + def test_on_install(self, _install_charmed_mysql, _check_call): self.charm.on.install.emit() self.assertTrue(isinstance(self.harness.model.unit.status, WaitingStatus)) @patch( - "mysql_router_helpers.MySQLRouter.install_and_configure_mysql_router", - side_effect=MySQLRouterInstallAndConfigureError(), + "mysql_router_helpers.MySQLRouter.install_charmed_mysql", + side_effect=MySQLRouterInstallCharmedMySQLError(), ) - def test_on_install_exception(self, _install_and_configure_mysql_router): + def test_on_install_exception(self, _install_charmed_mysql): self.charm.on.install.emit() self.assertTrue(isinstance(self.harness.model.unit.status, BlockedStatus)) - @patch("charm.DatabaseProvidesRelation._get_related_app_name") @patch("charm.MySQLRouterOperatorCharm._get_secret") @patch("mysql_router_helpers.MySQLRouter.bootstrap_and_start_mysql_router") - def test_on_upgrade_charm( - self, bootstrap_and_start_mysql_router, get_secret, get_related_app_name - ): + def test_on_upgrade_charm(self, bootstrap_and_start_mysql_router, get_secret): self.charm.unit.status = ActiveStatus() get_secret.return_value = "s3kr1t" - get_related_app_name.return_value = "testapp" self.charm.app_peer_data[MYSQL_ROUTER_REQUIRES_DATA] = json.dumps( { "username": "test_user", @@ -101,7 +97,7 @@ def test_on_upgrade_charm( self.assertTrue(isinstance(self.harness.model.unit.status, ActiveStatus)) bootstrap_and_start_mysql_router.assert_called_with( - "test_user", "s3kr1t", "testapp", "10.10.0.1", "3306", force=True + "test_user", "s3kr1t", "10.10.0.1", "3306", force=True ) @patch("mysql_router_helpers.MySQLRouter.bootstrap_and_start_mysql_router") @@ -112,15 +108,11 @@ def test_on_upgrade_charm_waiting(self, bootstrap_and_start_mysql_router): self.assertTrue(isinstance(self.harness.model.unit.status, WaitingStatus)) bootstrap_and_start_mysql_router.assert_not_called() - @patch("charm.DatabaseProvidesRelation._get_related_app_name") @patch("charm.MySQLRouterOperatorCharm._get_secret") @patch("mysql_router_helpers.MySQLRouter.bootstrap_and_start_mysql_router") - def test_on_upgrade_charm_error( - self, bootstrap_and_start_mysql_router, get_secret, get_related_app_name - ): + def test_on_upgrade_charm_error(self, bootstrap_and_start_mysql_router, get_secret): bootstrap_and_start_mysql_router.side_effect = MySQLRouterBootstrapError() get_secret.return_value = "s3kr1t" - get_related_app_name.return_value = "testapp" self.charm.unit.status = ActiveStatus() self.charm.app_peer_data[MYSQL_ROUTER_REQUIRES_DATA] = json.dumps( { @@ -132,5 +124,5 @@ def test_on_upgrade_charm_error( self.assertTrue(isinstance(self.harness.model.unit.status, BlockedStatus)) bootstrap_and_start_mysql_router.assert_called_with( - "test_user", "s3kr1t", "testapp", "10.10.0.1", "3306", force=True + "test_user", "s3kr1t", "10.10.0.1", "3306", force=True ) diff --git a/tests/unit/test_mysql_router_helpers.py b/tests/unit/test_mysql_router_helpers.py index 9d08cd2c..0e4af34c 100644 --- a/tests/unit/test_mysql_router_helpers.py +++ b/tests/unit/test_mysql_router_helpers.py @@ -3,24 +3,20 @@ import unittest from subprocess import CalledProcessError -from unittest.mock import call, patch +from unittest.mock import MagicMock, call, patch -from charms.operator_libs_linux.v1.systemd import SystemdError +from charms.operator_libs_linux.v1 import snap -from constants import MYSQL_ROUTER_SERVICE_NAME +from constants import CHARMED_MYSQL_ROUTER_SERVICE, CHARMED_MYSQL_SNAP from mysql_router_helpers import MySQLRouter, MySQLRouterBootstrapError bootstrap_cmd = [ "sudo", - "/usr/bin/mysqlrouter", + "charmed-mysql.mysqlrouter", "--user", - "mysql", - "--name", - "testapp", + "snap_daemon", "--bootstrap", "test_user:qweqwe@10.10.0.1", - "--directory", - "/var/lib/mysql/testapp", "--conf-use-sockets", "--conf-bind-address", "127.0.0.1", @@ -32,158 +28,119 @@ "http_server.bind_address=127.0.0.1", "--conf-use-gr-notifications", ] -chmod_cmd = [ +replace_socket_location_cmd = [ "sudo", - "chmod", - "755", - "/var/lib/mysql/testapp", + "sed", + "-Ei", + "s:/tmp/(.+).sock:/var/snap/charmed-mysql/common/var/run/mysqlrouter/\\1.sock:g", + "/var/snap/charmed-mysql/current/etc/mysqlrouter/mysqlrouter.conf", ] class TestMysqlRouterHelpers(unittest.TestCase): - @patch("mysql_router_helpers.MySQLRouter._render_and_copy_mysqlrouter_systemd_unit_file") - @patch("mysql_router_helpers.systemd") @patch("mysql_router_helpers.subprocess.run") - def test_bootstrap_and_start_mysql_router(self, run, systemd, render_and_copy): - MySQLRouter.bootstrap_and_start_mysql_router( - "test_user", "qweqwe", "testapp", "10.10.0.1", "3306" - ) + @patch("mysql_router_helpers.snap.SnapCache") + def test_bootstrap_and_start_mysql_router(self, _snap_cache, _run): + _charmed_mysql_mock = MagicMock() + _cache = {CHARMED_MYSQL_SNAP: _charmed_mysql_mock} + _snap_cache.return_value.__getitem__.side_effect = _cache.__getitem__ + + MySQLRouter.bootstrap_and_start_mysql_router("test_user", "qweqwe", "10.10.0.1", "3306") self.assertEqual( - sorted(run.mock_calls), + sorted(_run.mock_calls), sorted( [ - call(bootstrap_cmd), - call(chmod_cmd), + call(bootstrap_cmd, check=True), + call(replace_socket_location_cmd, check=True), ] ), ) - render_and_copy.assert_called_with("testapp") - systemd.daemon_reload.assert_called_with() - systemd.service_start.assert_called_with(MYSQL_ROUTER_SERVICE_NAME) + _charmed_mysql_mock.start.assert_called_once() - @patch("mysql_router_helpers.MySQLRouter._render_and_copy_mysqlrouter_systemd_unit_file") - @patch("mysql_router_helpers.systemd") @patch("mysql_router_helpers.subprocess.run") - def test_bootstrap_and_start_mysql_router_force(self, run, systemd, render_and_copy): + @patch("mysql_router_helpers.snap.SnapCache") + def test_bootstrap_and_start_mysql_router_force(self, _snap_cache, _run): + _charmed_mysql_mock = MagicMock() + _cache = {CHARMED_MYSQL_SNAP: _charmed_mysql_mock} + _snap_cache.return_value.__getitem__.side_effect = _cache.__getitem__ + MySQLRouter.bootstrap_and_start_mysql_router( - "test_user", "qweqwe", "testapp", "10.10.0.1", "3306", force=True + "test_user", "qweqwe", "10.10.0.1", "3306", force=True ) self.assertEqual( - sorted(run.mock_calls), + sorted(_run.mock_calls), sorted( [ - call(bootstrap_cmd + ["--force"]), - call(chmod_cmd), + call(bootstrap_cmd + ["--force"], check=True), + call(replace_socket_location_cmd, check=True), ] ), ) - render_and_copy.assert_called_with("testapp") - systemd.daemon_reload.assert_called_with() - systemd.service_start.assert_called_with(MYSQL_ROUTER_SERVICE_NAME) + _charmed_mysql_mock.start.assert_called_once() @patch("mysql_router_helpers.logger") - @patch("mysql_router_helpers.MySQLRouter._render_and_copy_mysqlrouter_systemd_unit_file") - @patch("mysql_router_helpers.systemd") @patch("mysql_router_helpers.subprocess.run") - def test_bootstrap_and_start_mysql_router_subprocess_error( - self, run, systemd, render_and_copy, logger - ): + @patch("mysql_router_helpers.snap.SnapCache") + def test_bootstrap_and_start_mysql_router_subprocess_error(self, _snap_cache, _run, _logger): e = CalledProcessError(1, bootstrap_cmd) - run.side_effect = e + _run.side_effect = e with self.assertRaises(MySQLRouterBootstrapError): MySQLRouter.bootstrap_and_start_mysql_router( - "test_user", "qweqwe", "testapp", "10.10.0.1", "3306" + "test_user", "qweqwe", "10.10.0.1", "3306" ) - run.assert_called_with(bootstrap_cmd) - render_and_copy.assert_not_called() - systemd.daemon_reload.assert_not_called() - systemd.service_start.assert_not_called() - logger.exception.assert_called_with("Failed to bootstrap mysqlrouter") + _run.assert_called_once_with(bootstrap_cmd, check=True) + _logger.exception.assert_called_with("Failed to bootstrap and start mysqlrouter") @patch("mysql_router_helpers.logger") - @patch("mysql_router_helpers.MySQLRouter._render_and_copy_mysqlrouter_systemd_unit_file") - @patch("mysql_router_helpers.systemd.service_start") - @patch("mysql_router_helpers.systemd.daemon_reload") @patch("mysql_router_helpers.subprocess.run") - def test_bootstrap_and_start_mysql_router_systemd_error( - self, run, daemon_reload, service_start, render_and_copy, logger - ): - e = SystemdError() - daemon_reload.side_effect = e + @patch("mysql_router_helpers.snap.SnapCache") + def test_bootstrap_and_start_mysql_router_snap_error(self, _snap_cache, _run, _logger): + e = snap.SnapError() + _snap_cache.return_value.__getitem__.side_effect = e with self.assertRaises(MySQLRouterBootstrapError): MySQLRouter.bootstrap_and_start_mysql_router( - "test_user", "qweqwe", "testapp", "10.10.0.1", "3306" + "test_user", "qweqwe", "10.10.0.1", "3306" ) self.assertEqual( - sorted(run.mock_calls), + sorted(_run.mock_calls), sorted( [ - call(bootstrap_cmd), - call(chmod_cmd), + call(bootstrap_cmd, check=True), + call(replace_socket_location_cmd, check=True), ] ), ) - render_and_copy.assert_called_with("testapp") - daemon_reload.assert_called_with() - service_start.assert_not_called() - logger.exception.assert_called_with("Failed to set up mysqlrouter as a systemd service") + _logger.exception.assert_called_with( + f"Failed to start snap service {CHARMED_MYSQL_ROUTER_SERVICE}" + ) @patch("mysql_router_helpers.logger") - @patch("mysql_router_helpers.MySQLRouter._render_and_copy_mysqlrouter_systemd_unit_file") - @patch("mysql_router_helpers.systemd.service_start") - @patch("mysql_router_helpers.systemd.daemon_reload") @patch("mysql_router_helpers.subprocess.run") - def test_bootstrap_and_start_mysql_router_no_daemon_reload( - self, run, daemon_reload, service_start, render_and_copy, logger - ): - daemon_reload.return_value = False - with self.assertRaises(MySQLRouterBootstrapError): - MySQLRouter.bootstrap_and_start_mysql_router( - "test_user", "qweqwe", "testapp", "10.10.0.1", "3306" - ) + @patch("mysql_router_helpers.snap.SnapCache") + def test_bootstrap_and_start_mysql_router_no_service_start(self, _snap_cache, _run, _logger): + _charmed_mysql_mock = MagicMock() + _cache = {CHARMED_MYSQL_SNAP: _charmed_mysql_mock} + _snap_cache.return_value.__getitem__.side_effect = _cache.__getitem__ - self.assertEqual( - sorted(run.mock_calls), - sorted( - [ - call(bootstrap_cmd), - call(chmod_cmd), - ] - ), - ) - render_and_copy.assert_called_with("testapp") - daemon_reload.assert_called_with() - service_start.assert_not_called() - logger.exception.assert_called_with("Failed to load the mysqlrouter systemd service") + _services = {CHARMED_MYSQL_ROUTER_SERVICE: {"active": False}} + _charmed_mysql_mock.services.__getitem__.side_effect = _services.__getitem__ - @patch("mysql_router_helpers.logger") - @patch("mysql_router_helpers.MySQLRouter._render_and_copy_mysqlrouter_systemd_unit_file") - @patch("mysql_router_helpers.systemd.service_start") - @patch("mysql_router_helpers.systemd.daemon_reload") - @patch("mysql_router_helpers.subprocess.run") - def test_bootstrap_and_start_mysql_router_no_service_start( - self, run, daemon_reload, service_start, render_and_copy, logger - ): - service_start.return_value = False with self.assertRaises(MySQLRouterBootstrapError): MySQLRouter.bootstrap_and_start_mysql_router( - "test_user", "qweqwe", "testapp", "10.10.0.1", "3306" + "test_user", "qweqwe", "10.10.0.1", "3306" ) self.assertEqual( - sorted(run.mock_calls), + sorted(_run.mock_calls), sorted( [ - call(bootstrap_cmd), - call(chmod_cmd), + call(bootstrap_cmd, check=True), + call(replace_socket_location_cmd, check=True), ] ), ) - render_and_copy.assert_called_with("testapp") - daemon_reload.assert_called_with() - service_start.assert_called_with(MYSQL_ROUTER_SERVICE_NAME) - logger.exception.assert_called_with("Failed to start the mysqlrouter systemd service") + _logger.exception.assert_called_with("Failed to start the mysqlrouter snap service")