diff --git a/src/cirrus/management/cli.py b/src/cirrus/management/cli.py index b7917f4..954b40b 100644 --- a/src/cirrus/management/cli.py +++ b/src/cirrus/management/cli.py @@ -1,6 +1,8 @@ import sys + +from collections.abc import Callable from functools import wraps -from typing import Any, Callable +from typing import Any import boto3 import botocore.exceptions @@ -13,11 +15,11 @@ logger = logging.getLogger(__name__) -from cirrus.management import DESCRIPTION, NAME # noqa: E -from cirrus.management.commands.deployments import list_deployments # noqa: E -from cirrus.management.commands.manage import manage as manage_group # noqa: E -from cirrus.management.commands.payload import payload as payload_group # noqa: E -from cirrus.management.exceptions import SSOError # noqa: E +from cirrus.management import DESCRIPTION, NAME # noqa: E402 +from cirrus.management.commands.deployments import list_deployments # noqa: E402 +from cirrus.management.commands.manage import manage as manage_group # noqa: E402 +from cirrus.management.commands.payload import payload as payload_group # noqa: E402 +from cirrus.management.exceptions import SSOError # noqa: E402 def handle_sso_error(func: Callable) -> Callable: diff --git a/src/cirrus/management/commands/manage.py b/src/cirrus/management/commands/manage.py index ac15274..eacb0ec 100644 --- a/src/cirrus/management/commands/manage.py +++ b/src/cirrus/management/commands/manage.py @@ -1,14 +1,13 @@ import json import logging import sys + from functools import wraps from subprocess import CalledProcessError -from typing import Optional import boto3 import botocore.exceptions import click -from click_option_group import RequiredMutuallyExclusiveOptionGroup, optgroup from cirrus.management.deployment import WORKFLOW_POLL_INTERVAL, Deployment from cirrus.management.utils.click import ( @@ -17,6 +16,7 @@ pass_session, silence_templating_errors, ) +from click_option_group import RequiredMutuallyExclusiveOptionGroup, optgroup logger = logging.getLogger(__name__) @@ -44,7 +44,7 @@ def execution_arn(func): cls=RequiredMutuallyExclusiveOptionGroup, help="Identifer type and value to get execution", )(func) - return func + return func # noqa: RET504 def raw_option(func): @@ -79,7 +79,7 @@ def wrapper(*args, **kwargs): ) @pass_session @click.pass_context -def manage(ctx, session: boto3.Session, deployment: str, profile: Optional[str] = None): +def manage(ctx, session: boto3.Session, deployment: str, profile: str | None = None): """ Commands to run management operations against a cirrus deployment. """ @@ -152,7 +152,7 @@ def download(output_fileobj): json.dump(json.load(b), sys.stdout, indent=4) # ensure we end with a newline - print() + click.echo("") @manage.command("get-execution") @@ -188,7 +188,8 @@ def get_execution_input(deployment, arn, payload_id, raw): @raw_option @pass_deployment def get_execution_output(deployment, arn, payload_id, raw): - """Get a workflow execution's output payload using its ARN or its input payload ID""" + """Get a workflow execution's output payload using its ARN or its input + payload ID""" output = json.loads(_get_execution(deployment, arn, payload_id)["output"]) if raw: @@ -223,7 +224,7 @@ def process(deployment): def invoke_lambda(deployment, lambda_name): """Invoke lambda with event (from stdin)""" click.echo( - json.dumps(deployment.invoke_lambda(sys.stdin.read(), lambda_name), indent=4) + json.dumps(deployment.invoke_lambda(sys.stdin.read(), lambda_name), indent=4), ) @@ -283,7 +284,8 @@ def _exec(ctx, deployment, command, include_user_vars): @pass_deployment @click.pass_context def _call(ctx, deployment, command, include_user_vars): - """Run an executable, in a new process, with the deployment environment vars loaded""" + """Run an executable, in a new process, with the deployment environment + vars loaded""" if not command: return try: @@ -299,8 +301,10 @@ def list_lambdas(ctx, deployment): """List lambda functions""" click.echo( json.dumps( - {"Functions": deployment.get_lambda_functions()}, indent=4, default=str - ) + {"Functions": deployment.get_lambda_functions()}, + indent=4, + default=str, + ), ) diff --git a/src/cirrus/management/commands/payload.py b/src/cirrus/management/commands/payload.py index a71cf95..afc3c1f 100644 --- a/src/cirrus/management/commands/payload.py +++ b/src/cirrus/management/commands/payload.py @@ -47,6 +47,8 @@ def template(additional_variables, silence_templating_errors): click.echo( template_payload( - sys.stdin.read(), additional_variables, silence_templating_errors - ) + sys.stdin.read(), + additional_variables, + silence_templating_errors, + ), ) diff --git a/src/cirrus/management/deployment.py b/src/cirrus/management/deployment.py index c25f120..1ec468e 100644 --- a/src/cirrus/management/deployment.py +++ b/src/cirrus/management/deployment.py @@ -4,8 +4,9 @@ import json import logging import os + from collections.abc import Iterator -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path from subprocess import check_call from time import sleep, time @@ -13,9 +14,9 @@ import backoff import boto3 + from cirrus.lib.process_payload import ProcessPayload from cirrus.lib.utils import get_client - from cirrus.management import exceptions from cirrus.management.deployment_pointer import DeploymentPointer @@ -29,7 +30,7 @@ def now_isoformat(): - return datetime.now(timezone.utc).isoformat() + return datetime.now(UTC).isoformat() def _maybe_use_buffer(fileobj: IO): @@ -57,17 +58,6 @@ def asjson(self, *args, **kwargs) -> str: return json.dumps(self.asdict(), *args, **kwargs) -# @staticmethod -# def _get_session(profile: str = None): -# # TODO: MFA session should likely be used only with the cli, -# # so this probably needs to be parameterized by the caller -# -# def get_session(self): -# if not self._session: -# self._session = self._get_session(profile=self.profile) -# return self._session - - @dataclasses.dataclass class Deployment(DeploymentMeta): def __init__( @@ -132,20 +122,20 @@ def exec(self, command, include_user_vars=True, isolated=False): env = self.environment.copy() if include_user_vars: env.update(self.user_vars) - os.execlpe(command[0], *command, env) + os.execlpe(command[0], *command, env) # noqa: S606 self.set_env(include_user_vars=include_user_vars) - os.execlp(command[0], *command) + os.execlp(command[0], *command) # noqa: S606 def call(self, command, include_user_vars=True, isolated=False): if isolated: env = self.environment.copy() if include_user_vars: env.update(self.user_vars) - check_call(command, env=env) + check_call(command, env=env) # noqa: S603 else: self.set_env(include_user_vars=include_user_vars) - check_call(command) + check_call(command) # noqa: S603 def get_payload_state(self, payload_id): from cirrus.lib.statedb import StateDB @@ -215,8 +205,8 @@ def get_execution_by_payload_id(self, payload_id): execs = self.get_payload_state(payload_id).get("executions", []) try: exec_arn = execs[-1] - except IndexError: - raise exceptions.NoExecutionsError(payload_id) + except IndexError as e: + raise exceptions.NoExecutionsError(payload_id) from e return self.get_execution(exec_arn) @@ -224,7 +214,7 @@ def invoke_lambda(self, event, function_name): aws_lambda = get_client("lambda", session=self.session) if function_name not in self.get_lambda_functions(): raise ValueError( - f"lambda named '{function_name}' not found in deployment '{self.name}'" + f"lambda named '{function_name}' not found in deployment '{self.name}'", ) full_name = f"{self.stackname}-{function_name}" response = aws_lambda.invoke(FunctionName=full_name, Payload=event) diff --git a/src/cirrus/management/deployment_pointer.py b/src/cirrus/management/deployment_pointer.py index b00f1e8..2d12efa 100644 --- a/src/cirrus/management/deployment_pointer.py +++ b/src/cirrus/management/deployment_pointer.py @@ -32,14 +32,12 @@ class PointerObject(Protocol): # pragma: no cover region: str @classmethod - def from_string(cls: type[Self], string: str) -> Self: - ... + def from_string(cls: type[Self], string: str) -> Self: ... def fetch( self: Self, session: boto3.Session | None = None, - ) -> str: - ... + ) -> str: ... class SecretArn: @@ -98,7 +96,7 @@ class Pointer: @classmethod def from_string(cls: type[Self], string: str) -> Self: obj = json.loads(string) - obj['_type'] = obj.pop('type') + obj["_type"] = obj.pop("type") return cls(**obj) def resolve(self) -> PointerObject: diff --git a/src/cirrus/management/utils/click.py b/src/cirrus/management/utils/click.py index b4835b5..e6ce48c 100644 --- a/src/cirrus/management/utils/click.py +++ b/src/cirrus/management/utils/click.py @@ -62,7 +62,7 @@ def get_command(self, ctx, cmd_name): self.resolve_alias(cmd) for cmd in self.list_commands(ctx) + list(self._alias2cmd.keys()) if cmd.startswith(cmd_name) - } + }, ) # no matches no command @@ -71,13 +71,14 @@ def get_command(self, ctx, cmd_name): # one match then we can resolve the match # and try getting the command again - elif len(matches) == 1: + if len(matches) == 1: return super().get_command(ctx, matches[0]) # otherwise the string matched but was not unique # to a single command and we have to bail out - ctx.fail( - f"Unknown command '{cmd_name}. Did you mean any of these: {', '.join(sorted(matches))}?", + ctx.fail( # noqa: RET503 + f"Unknown command '{cmd_name}. Did you mean any of these: " + f"{', '.join(sorted(matches))}?", ) def format_commands(self, ctx, formatter): @@ -144,8 +145,6 @@ class Variable(click.ParamType): name = "key/val pair" def convert(self, value, param, ctx): - print(22, value) - print(33, param) return {value[0]: value[1]} @@ -177,7 +176,6 @@ def additional_variables(func): "additional_variables", nargs=2, multiple=True, - # type=Variable(), callback=merge_vars2, help="Additional templating variables", )(func) diff --git a/src/cirrus/management/utils/logging_classes.py b/src/cirrus/management/utils/logging_classes.py index 142ef21..6fd146d 100644 --- a/src/cirrus/management/utils/logging_classes.py +++ b/src/cirrus/management/utils/logging_classes.py @@ -1,12 +1,14 @@ import logging +from typing import Any, ClassVar + import click # Inspired from https://github.com/click-contrib/click-log class ClickFormatter(logging.Formatter): - colors = { + colors: ClassVar[dict[str, dict[str, Any]]] = { "error": {"fg": "red"}, "exception": {"fg": "red"}, "critical": {"fg": "red"}, @@ -28,5 +30,6 @@ def emit(self, record): msg = self.format(record) record.levelname.lower() click.echo(msg, err=True) - except Exception: + # not sure if we can narrow this exception down or not... + except Exception: # noqa: BLE001 self.handleError(record) diff --git a/src/cirrus/management/utils/templating.py b/src/cirrus/management/utils/templating.py index dcfaa90..1fcbe37 100644 --- a/src/cirrus/management/utils/templating.py +++ b/src/cirrus/management/utils/templating.py @@ -1,4 +1,5 @@ import logging + from string import Template logger = logging.getLogger(__name__) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py index 02dfb77..3801837 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,9 @@ import json import os import shlex -import shutil + +from collections.abc import Iterator +from copy import deepcopy from pathlib import Path from unittest.mock import patch @@ -9,26 +11,19 @@ import botocore import moto import pytest -from cirrus.cli.commands import cli -from click.testing import CliRunner -try: - # temporary measure while waiting on pending PRs - from cirrus.lib2.eventdb import EventDB -except ImportError: - EventDB = None - -from cirrus.core.project import Project -from cirrus.lib2.process_payload import ProcessPayload, ProcessPayloads -from cirrus.lib2.statedb import StateDB +from cirrus.lib.process_payload import ProcessPayload, ProcessPayloads +from cirrus.lib.statedb import StateDB +from cirrus.management.cli import cli +from click.testing import CliRunner def set_fake_creds(): """Mocked AWS Credentials for moto.""" os.environ["AWS_ACCESS_KEY_ID"] = "testing" - os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" - os.environ["AWS_SECURITY_TOKEN"] = "testing" - os.environ["AWS_SESSION_TOKEN"] = "testing" + os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" # noqa: S105 + os.environ["AWS_SECURITY_TOKEN"] = "testing" # noqa: S105 + os.environ["AWS_SESSION_TOKEN"] = "testing" # noqa: S105 os.environ["AWS_DEFAULT_REGION"] = "us-east-1" os.environ["AWS_REGION"] = "us-east-1" @@ -37,7 +32,7 @@ def set_fake_creds(): @pytest.fixture(autouse=True) -def aws_credentials(): +def _aws_credentials(): set_fake_creds() @@ -51,109 +46,86 @@ def statedb_schema(fixtures): return json.loads(fixtures.joinpath("statedb-schema.json").read_text()) -@pytest.fixture(scope="module") -def project_testdir(): - pdir = Path(__file__).parent.joinpath("output") - if pdir.is_dir(): - shutil.rmtree(pdir) - pdir.mkdir() - Project.new(pdir) - old_cwd = os.getcwd() - os.chdir(pdir) - yield pdir - os.chdir(old_cwd) - - -@pytest.fixture -def project(project_testdir): - return Project.resolve(strict=True) - - -@pytest.fixture +@pytest.fixture() def s3(aws_credentials): with moto.mock_s3(): yield boto3.client("s3", region_name="us-east-1") -@pytest.fixture +@pytest.fixture() def sqs(aws_credentials): with moto.mock_sqs(): yield boto3.client("sqs", region_name="us-east-1") -@pytest.fixture +@pytest.fixture() def dynamo(): with moto.mock_dynamodb(): yield boto3.client("dynamodb", region_name="us-east-1") -@pytest.fixture +@pytest.fixture() def stepfunctions(aws_credentials): with moto.mock_stepfunctions(): yield boto3.client("stepfunctions", region_name="us-east-1") -@pytest.fixture +@pytest.fixture() def iam(aws_credentials): with moto.mock_iam(): yield boto3.client("iam", region_name="us-east-1") @pytest.fixture(autouse=True) -def sts(aws_credentials): +def sts(aws_credentials): # noqa: PT004 with moto.mock_sts(): yield -@pytest.fixture +@pytest.fixture() def payloads(s3): name = "payloads" s3.create_bucket(Bucket=name) return name -@pytest.fixture +@pytest.fixture() def data(s3): name = "data" s3.create_bucket(Bucket=name) return name -@pytest.fixture +@pytest.fixture() def queue(sqs): q = sqs.create_queue(QueueName="test-queue") q["Arn"] = "arn:aws:sqs:us-east-1:123456789012:test-queue" return q -@pytest.fixture +@pytest.fixture() def timestream_write_client(): with moto.mock_timestreamwrite(): yield boto3.client("timestream-write", region_name="us-east-1") -if EventDB: - - @pytest.fixture - def eventdb(timestream_write_client): - timestream_write_client.create_database(DatabaseName="event-db-1") - timestream_write_client.create_table( - DatabaseName="event-db-1", TableName="event-table-1" - ) - return EventDB("event-db-1|event-table-1") +@pytest.fixture() +def _eventdb(timestream_write_client): + timestream_write_client.create_database(DatabaseName="event-db-1") + timestream_write_client.create_table( + DatabaseName="event-db-1", + TableName="event-table-1", + ) -@pytest.fixture -def statedb(dynamo, statedb_schema, eventdb=None) -> str: +@pytest.fixture() +def statedb(dynamo, statedb_schema, _eventdb) -> StateDB: dynamo.create_table(**statedb_schema) table_name = statedb_schema["TableName"] - if eventdb: - return StateDB(table_name=table_name, eventdb=eventdb) - else: - return StateDB(table_name=table_name) + return StateDB(table_name=table_name) -@pytest.fixture +@pytest.fixture() def workflow(stepfunctions, iam): defn = { "StartAt": "FirstState", @@ -173,7 +145,7 @@ def workflow(stepfunctions, iam): "Service": "states.us-east-1.amazonaws.com", }, "Action": "sts:AssumeRole", - } + }, ], } role = iam.create_role( @@ -194,7 +166,7 @@ def workflow(stepfunctions, iam): LAMBDA_ENV_VARS = {"var": "value"} -@pytest.fixture +@pytest.fixture() def lambda_env(): return LAMBDA_ENV_VARS @@ -205,8 +177,8 @@ def mock_make_api_call(self, operation_name, kwarg): return orig(self, operation_name, kwarg) -@pytest.fixture -def mock_lambda_get_conf(): +@pytest.fixture() +def mock_lambda_get_conf(): # noqa: PT004 with patch( "botocore.client.BaseClient._make_api_call", new=mock_make_api_call, @@ -228,12 +200,22 @@ def _invoke(cmd, **kwargs): return _invoke -@pytest.fixture +@pytest.fixture() def basic_payloads(fixtures): return ProcessPayloads( process_payloads=[ ProcessPayload( - json.loads(fixtures.joinpath("basic_payload.json").read_text()) - ) - ] + json.loads(fixtures.joinpath("basic_payload.json").read_text()), + ), + ], ) + + +@pytest.fixture() +def _environment() -> Iterator[None]: + current_env = deepcopy(os.environ) # stash env + try: + yield + finally: + os.environ.clear() + os.environ = current_env # noqa: B003 diff --git a/tests/test_deployments.py b/tests/test_deployments.py index 0679dc2..34afa97 100644 --- a/tests/test_deployments.py +++ b/tests/test_deployments.py @@ -3,7 +3,7 @@ DEPLYOMENT_NAME = "test-deployment" -@pytest.fixture +@pytest.fixture() def deployments(invoke): def _deployments(cmd): return invoke("deployments " + cmd) diff --git a/tests/test_manage.py b/tests/test_manage.py index 2aba873..e0bd48a 100644 --- a/tests/test_manage.py +++ b/tests/test_manage.py @@ -1,10 +1,8 @@ import json -import os -from copy import deepcopy import pytest -from cirrus.plugins.management.deployment import ( +from cirrus.management.deployment import ( CONFIG_VERSION, DEFAULT_DEPLOYMENTS_DIR_NAME, Deployment, @@ -14,7 +12,7 @@ STACK_NAME = "cirrus-test" -@pytest.fixture +@pytest.fixture() def manage(invoke): def _manage(cmd): return invoke("manage " + cmd) @@ -22,7 +20,7 @@ def _manage(cmd): return _manage -@pytest.fixture +@pytest.fixture() def deployment_meta(queue, statedb, payloads, data, workflow): return { "name": DEPLYOMENT_NAME, @@ -33,22 +31,21 @@ def deployment_meta(queue, statedb, payloads, data, workflow): "environment": { "CIRRUS_STATE_DB": statedb.table_name, "CIRRUS_BASE_WORKFLOW_ARN": workflow["stateMachineArn"].replace( - "workflow1", "" + "workflow1", + "", ), "CIRRUS_LOG_LEVEL": "DEBUG", "CIRRUS_STACK": STACK_NAME, "CIRRUS_DATA_BUCKET": data, "CIRRUS_PAYLOAD_BUCKET": payloads, "CIRRUS_PROCESS_QUEUE_URL": queue["QueueUrl"], - # "CIRRUS_INVALID_TOPIC_ARN": , - # "CIRRUS_FAILED_TOPIC_ARN": , }, "user_vars": {}, "config_version": CONFIG_VERSION, } -@pytest.fixture +@pytest.fixture() def deployment(manage, project, deployment_meta): def _manage(deployment, cmd): return manage(f"{deployment.name} {cmd}") @@ -89,8 +86,9 @@ def test_manage_get_path(deployment, project): assert result.exit_code == 0 assert result.stdout.strip() == str( project.dot_dir.joinpath( - DEFAULT_DEPLOYMENTS_DIR_NAME, f"{DEPLYOMENT_NAME}.json" - ) + DEFAULT_DEPLOYMENTS_DIR_NAME, + f"{DEPLYOMENT_NAME}.json", + ), ) @@ -101,11 +99,15 @@ def test_manage_refresh(deployment, mock_lambda_get_conf, lambda_env): assert new["environment"] == lambda_env -def test_manage_get_execution_by_payload_id(deployment, basic_payloads, statedb): - """Adds causes two workflow executions, and confirms that the second call to - get_execution_by_payload_id gets a different executionArn value from the first execution. - """ - current_env = deepcopy(os.environ) # stash env +@pytest.mark.usefixtures("_environment") +def test_manage_get_execution_by_payload_id( + deployment, + basic_payloads, + statedb, +) -> None: + """Adds causes two workflow executions, and confirms that the second call + to get_execution_by_payload_id gets a different executionArn value from the + first execution.""" deployment.set_env() basic_payloads.process() pid = basic_payloads[0]["id"] @@ -114,10 +116,15 @@ def test_manage_get_execution_by_payload_id(deployment, basic_payloads, statedb) basic_payloads.process() sfn_exe2 = deployment.get_execution_by_payload_id(pid) assert sfn_exe1["executionArn"] != sfn_exe2["executionArn"] - os.environ = current_env # pop stash -@pytest.mark.parametrize("command,expect_exit_zero", (("true", True), ("false", False))) +@pytest.mark.parametrize( + ("command", "expect_exit_zero"), + [ + ("true", True), + ("false", False), + ], +) def test_call_cli_return_values(deployment, command, expect_exit_zero): result = deployment(f"call {command}") assert result.exit_code == 0 if expect_exit_zero else result.exit_code != 0