Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DPE-4179, DPE-4219] Add integration test for upgrades #135

Merged
merged 16 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions actions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

resume-upgrade:
description: Upgrade remaining units (after you manually verified that upgraded units are healthy).

force-upgrade:
description: |
Potential of *data loss* and *downtime*
Expand All @@ -12,6 +13,7 @@ force-upgrade:
Use to
- force incompatible upgrade and/or
- continue upgrade if 1+ upgraded units have non-active status

set-tls-private-key:
description:
Set the private key, which will be used for certificate signing requests (CSR). Run
Expand Down
2 changes: 1 addition & 1 deletion src/machine_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class AuthenticatedMachineWorkload(workload.AuthenticatedWorkload):
def _get_bootstrap_command(
self, *, event, connection_info: "relations.database_requires.ConnectionInformation"
) -> typing.List[str]:
command = super()._get_bootstrap_command(connection_info)
command = super()._get_bootstrap_command(event=event, connection_info=connection_info)
if self._charm.is_externally_accessible(event=event):
command.extend(
[
Expand Down
2 changes: 1 addition & 1 deletion src/snap.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
logger = logging.getLogger(__name__)

_SNAP_NAME = "charmed-mysql"
REVISION = "102" # Keep in sync with `workload_version` file
REVISION = "103" # Keep in sync with `workload_version` file
_snap = snap_lib.SnapCache()[_SNAP_NAME]
_UNIX_USERNAME = "snap_daemon"

Expand Down
2 changes: 1 addition & 1 deletion src/upgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def authorized(self) -> bool:
"""

@abc.abstractmethod
def upgrade_unit(self, *, workload_: workload.Workload, tls: bool) -> None:
def upgrade_unit(self, *, event, workload_: workload.Workload, tls: bool) -> None:
"""Upgrade this unit.

Only applies to machine charm
Expand Down
4 changes: 2 additions & 2 deletions src/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def _cleanup_after_upgrade_or_potential_container_restart(self) -> None:

# TODO python3.10 min version: Use `list` instead of `typing.List`
def _get_bootstrap_command(
self, connection_info: "relations.database_requires.ConnectionInformation"
self, *, event, connection_info: "relations.database_requires.ConnectionInformation"
) -> typing.List[str]:
return [
"--bootstrap",
Expand Down Expand Up @@ -430,7 +430,7 @@ def upgrade(
if enabled:
logger.debug("Disabling MySQL Router service before upgrade")
self._disable_router()
super().upgrade(unit=unit, tls=tls, exporter_config=exporter_config)
super().upgrade(event=event, unit=unit, tls=tls, exporter_config=exporter_config)
if enabled:
logger.debug("Re-enabling MySQL Router service after upgrade")
self._enable_router(event=event, tls=tls, unit_name=unit.name)
Expand Down
26 changes: 26 additions & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2022 Canonical Ltd.
# See LICENSE file for licensing details.

import logging
import os
import pathlib
from unittest.mock import PropertyMock
Expand All @@ -9,6 +10,12 @@
import pytest_operator.plugin
from ops import JujuVersion
from pytest_mock import MockerFixture
from pytest_operator.plugin import OpsTest

from . import juju_
from .helpers import APPLICATION_DEFAULT_APP_NAME, get_application_name

logger = logging.getLogger(__name__)


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -54,3 +61,22 @@ def juju_has_secrets(mocker: MockerFixture, request):
JujuVersion, "has_secrets", new_callable=PropertyMock
).return_value = True
return True


@pytest.fixture
async def continuous_writes(ops_test: OpsTest):
"""Starts continuous writes to the MySQL cluster for a test and clear the writes at the end."""
application_name = get_application_name(ops_test, APPLICATION_DEFAULT_APP_NAME)

application_unit = ops_test.model.applications[application_name].units[0]

logger.info("Clearing continuous writes")
await juju_.run_action(application_unit, "clear-continuous-writes")

logger.info("Starting continuous writes")
await juju_.run_action(application_unit, "start-continuous-writes")

yield

logger.info("Clearing continuous writes")
await juju_.run_action(application_unit, "clear-continuous-writes")
232 changes: 221 additions & 11 deletions tests/integration/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,26 @@
# See LICENSE file for licensing details.

import itertools
import logging
import subprocess
import tempfile
from typing import Dict, List, Optional

import tenacity
from juju.model import Model
from juju.unit import Unit
from pytest_operator.plugin import OpsTest
from tenacity import Retrying, stop_after_attempt, wait_fixed

from .connector import MySQLConnector
from .juju_ import run_action

logger = logging.getLogger(__name__)

CONTINUOUS_WRITES_DATABASE_NAME = "continuous_writes_database"
CONTINUOUS_WRITES_TABLE_NAME = "data"

MYSQL_DEFAULT_APP_NAME = "mysql"
APPLICATION_DEFAULT_APP_NAME = "mysql-test-app"


async def get_server_config_credentials(unit: Unit) -> Dict:
Expand All @@ -23,13 +35,10 @@ async def get_server_config_credentials(unit: Unit) -> Dict:
Returns:
A dictionary with the server config username and password
"""
action = await unit.run_action(action_name="get-password", username="serverconfig")
result = await action.wait()
return await run_action(unit, "get-password", username="serverconfig")

return result.results


async def get_inserted_data_by_application(unit: Unit) -> str:
async def get_inserted_data_by_application(unit: Unit) -> Optional[str]:
"""Helper to run an action to retrieve inserted data by the application.

Args:
Expand All @@ -38,10 +47,7 @@ async def get_inserted_data_by_application(unit: Unit) -> str:
Returns:
A string representing the inserted data
"""
action = await unit.run_action("get-inserted-data")
result = await action.wait()

return result.results.get("data")
return (await run_action(unit, "get-inserted-data")).get("data")
shayancanonical marked this conversation as resolved.
Show resolved Hide resolved


async def execute_queries_against_unit(
Expand Down Expand Up @@ -220,7 +226,9 @@ async def stop_running_flush_mysqlrouter_cronjobs(ops_test: OpsTest, unit_name:
)

# hold execution until process is stopped
for attempt in Retrying(reraise=True, stop=stop_after_attempt(45), wait=wait_fixed(2)):
for attempt in tenacity.Retrying(
reraise=True, stop=tenacity.stop_after_attempt(45), wait=tenacity.wait_fixed(2)
):
with attempt:
if await get_process_pid(ops_test, unit_name, "logrotate"):
raise Exception("Failed to stop the flush_mysql_logs logrotate process")
Expand All @@ -242,3 +250,205 @@ async def get_tls_certificate_issuer(
return_code, issuer, _ = await ops_test.juju(*get_tls_certificate_issuer_commands)
assert return_code == 0, f"failed to get TLS certificate issuer on {unit_name=}"
return issuer


def get_application_name(ops_test: OpsTest, application_name_substring: str) -> str:
"""Returns the name of the application with the provided application name.

This enables us to retrieve the name of the deployed application in an existing model.

Note: if multiple applications with the application name exist,
the first one found will be returned.
"""
for application in ops_test.model.applications:
if application_name_substring == application:
return application

return ""
Comment on lines +255 to +267
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the purpose of this?

to check if application is deployed? if so, can you check application in ops_test.model.applications?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is a helper function that is imported from mysql vm. i'd like to keep as many helpers the same as we can, so we can easily refactor them into common code in the future if necessary



@tenacity.retry(stop=tenacity.stop_after_attempt(30), wait=tenacity.wait_fixed(5), reraise=True)
async def get_primary_unit(
ops_test: OpsTest,
unit: Unit,
app_name: str,
) -> Unit:
"""Helper to retrieve the primary unit.

Args:
ops_test: The ops test object passed into every test case
unit: A unit on which to run dba.get_cluster().status() on
app_name: The name of the test application
cluster_name: The name of the test cluster

Returns:
A juju unit that is a MySQL primary
"""
units = ops_test.model.applications[app_name].units
results = await run_action(unit, "get-cluster-status")

primary_unit = None
for k, v in results["status"]["defaultreplicaset"]["topology"].items():
if v["memberrole"] == "primary":
unit_name = f"{app_name}/{k.split('-')[-1]}"
primary_unit = [unit for unit in units if unit.name == unit_name][0]
break

if not primary_unit:
raise ValueError("Unable to find primary unit")
return primary_unit


async def get_primary_unit_wrapper(ops_test: OpsTest, app_name: str, unit_excluded=None) -> Unit:
"""Wrapper for getting primary.

Args:
ops_test: The ops test object passed into every test case
app_name: The name of the application
unit_excluded: excluded unit to run command on
Returns:
The primary Unit object
"""
logger.info("Retrieving primary unit")
units = ops_test.model.applications[app_name].units
if unit_excluded:
# if defined, exclude unit from available unit to run command on
# useful when the workload is stopped on unit
unit = ({unit for unit in units if unit.name != unit_excluded.name}).pop()
else:
unit = units[0]

primary_unit = await get_primary_unit(ops_test, unit, app_name)

return primary_unit


async def get_max_written_value_in_database(
ops_test: OpsTest, unit: Unit, credentials: dict
) -> int:
"""Retrieve the max written value in the MySQL database.

Args:
ops_test: The ops test framework
unit: The MySQL unit on which to execute queries on
credentials: Database credentials to use
"""
unit_address = await unit.get_public_address()

select_max_written_value_sql = [
f"SELECT MAX(number) FROM `{CONTINUOUS_WRITES_DATABASE_NAME}`.`{CONTINUOUS_WRITES_TABLE_NAME}`;"
]

output = await execute_queries_against_unit(
unit_address,
credentials["username"],
credentials["password"],
select_max_written_value_sql,
)

return output[0]


async def ensure_all_units_continuous_writes_incrementing(
ops_test: OpsTest, mysql_units: Optional[List[Unit]] = None
) -> None:
"""Ensure that continuous writes is incrementing on all units.

Also, ensure that all continuous writes up to the max written value is available
on all units (ensure that no committed data is lost).
"""
logger.info("Ensure continuous writes are incrementing")

mysql_application_name = get_application_name(ops_test, MYSQL_DEFAULT_APP_NAME)

if not mysql_units:
mysql_units = ops_test.model.applications[mysql_application_name].units

primary = await get_primary_unit_wrapper(ops_test, mysql_application_name)

server_config_credentials = await get_server_config_credentials(mysql_units[0])

last_max_written_value = await get_max_written_value_in_database(
ops_test, primary, server_config_credentials
)

select_all_continuous_writes_sql = [
f"SELECT * FROM `{CONTINUOUS_WRITES_DATABASE_NAME}`.`{CONTINUOUS_WRITES_TABLE_NAME}`"
]

async with ops_test.fast_forward():
for unit in mysql_units:
for attempt in tenacity.Retrying(
reraise=True, stop=tenacity.stop_after_delay(5 * 60), wait=tenacity.wait_fixed(10)
):
with attempt:
# ensure that all units are up to date (including the previous primary)
unit_address = await unit.get_public_address()

# ensure the max written value is incrementing (continuous writes is active)
max_written_value = await get_max_written_value_in_database(
ops_test, unit, server_config_credentials
)
assert (
max_written_value > last_max_written_value
), "Continuous writes not incrementing"

# ensure that the unit contains all values up to the max written value
all_written_values = set(
await execute_queries_against_unit(
unit_address,
server_config_credentials["username"],
server_config_credentials["password"],
select_all_continuous_writes_sql,
)
)
numbers = {n for n in range(1, max_written_value)}
assert (
numbers <= all_written_values
), f"Missing numbers in database for unit {unit.name}"

last_max_written_value = max_written_value


async def get_workload_version(ops_test: OpsTest, unit_name: str) -> None:
"""Get the workload version of the deployed router charm."""
return_code, output, _ = await ops_test.juju(
"ssh",
unit_name,
"sudo",
"cat",
f"/var/lib/juju/agents/unit-{unit_name.replace('/', '-')}/charm/workload_version",
)

assert return_code == 0
return output.strip()


async def get_leader_unit(
ops_test: Optional[OpsTest], app_name: str, model: Optional[Model] = None
) -> Optional[Unit]:
"""Get the leader unit of a given application.

Args:
ops_test: The ops test framework instance
app_name: The name of the application
model: The model to use (overrides ops_test.model)
"""
leader_unit = None
if not model:
model = ops_test.model
for unit in model.applications[app_name].units:
if await unit.is_leader_from_status():
leader_unit = unit
break

return leader_unit


def get_juju_status(model_name: str) -> str:
"""Return the juju status output.

Args:
model_name: The model for which to retrieve juju status for
"""
return subprocess.check_output(["juju", "status", "--model", model_name]).decode("utf-8")
Loading
Loading