diff --git a/tests/test_aws.py b/tests/test_aws.py deleted file mode 100644 index cd8f017..0000000 --- a/tests/test_aws.py +++ /dev/null @@ -1,271 +0,0 @@ -from botocore.exceptions import WaiterError -import pytest -from moto import mock_aws -import os -from gha_runner.clouddeployment import AWS - - -@pytest.fixture(scope="function") -def aws_credentials(): - """Mocked AWS Credentials for moto.""" - os.environ["AWS_ACCESS_KEY_ID"] = "testing" - os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" - os.environ["AWS_SECURITY_TOKEN"] = "testing" - os.environ["AWS_SESSION_TOKEN"] = "testing" - os.environ["AWS_DEFAULT_REGION"] = "us-east-1" - - -@pytest.fixture(scope="function") -def aws(aws_credentials): - with mock_aws(): - params = { - "image_id": "ami-0772db4c976d21e9b", - "instance_type": "t2.micro", - # "tags": {}, - "region_name": "us-east-1", - # "gh_runner_token": "testing", - "gh_runner_tokens": ["testing"], - "home_dir": "/home/ec2-user", - "runner_release": "testing", - "repo": "omsf-eco-infra/awsinfratesting", - } - yield AWS(**params) - - -def test_build_aws_params(): - params = { - "image_id": "ami-0772db4c976d21e9b", - "instance_type": "t2.micro", - "tags": [ - {"Key": "Name", "Value": "test"}, - {"Key": "Owner", "Value": "test"}, - ], - "region_name": "us-east-1", - "gh_runner_tokens": ["testing"], - "home_dir": "/home/ec2-user", - "runner_release": "", - "repo": "omsf-eco-infra/awsinfratesting", - "subnet_id": "test", - "security_group_id": "test", - "iam_role": "test", - } - user_data_params = { - "token": "test", - "repo": "omsf-eco-infra/awsinfratesting", - "homedir": "/home/ec2-user", - "script": "echo 'Hello, World!'", - "runner_release": "test.tar.gz", - "labels": "label", - } - aws = AWS(**params) - params = aws._build_aws_params(user_data_params) - assert params == { - "ImageId": "ami-0772db4c976d21e9b", - "InstanceType": "t2.micro", - "MinCount": 1, - "MaxCount": 1, - "SubnetId": "test", - "SecurityGroupIds": ["test"], - "IamInstanceProfile": {"Name": "test"}, - "UserData": """#!/bin/bash -cd "/home/ec2-user" -echo "echo 'Hello, World!'" > pre-runner-script.sh -source pre-runner-script.sh -export RUNNER_ALLOW_RUNASROOT=1 -# We will get the latest release from the GitHub API -curl -L test.tar.gz -o runner.tar.gz -tar xzf runner.tar.gz -./config.sh --url https://github.com/omsf-eco-infra/awsinfratesting --token test --labels label --ephemeral -./run.sh -""", - "TagSpecifications": [ - { - "ResourceType": "instance", - "Tags": [ - {"Key": "Name", "Value": "test"}, - {"Key": "Owner", "Value": "test"}, - ], - } - ], - } - - -def test_create_instance_with_labels(aws): - aws.labels = "test" - ids = aws.create_instances() - assert len(ids) == 1 - - -def test_create_instances(aws): - ids = aws.create_instances() - assert len(ids) == 1 - - -def test_create_instances_missing_release(aws): - aws.runner_release = "" - with pytest.raises( - ValueError, match="No runner release provided, cannot create instances." - ): - aws.create_instances() - - -def test_create_instances_missing_home_dir(aws): - aws.home_dir = "" - with pytest.raises( - ValueError, match="No home directory provided, cannot create instances." - ): - aws.create_instances() - - -def test_create_instances_missing_tokens(aws): - aws.gh_runner_tokens = [] - with pytest.raises( - ValueError, - match="No GitHub runner tokens provided, cannot create instances.", - ): - aws.create_instances() - - -def test_create_instances_missing_image_id(aws): - aws.image_id = "" - with pytest.raises( - ValueError, match="No image ID provided, cannot create instances." - ): - aws.create_instances() - - -def test_create_instances_missing_instance_type(aws): - aws.instance_type = "" - with pytest.raises( - ValueError, match="No instance type provided, cannot create instances." - ): - aws.create_instances() - - -def test_instance_running(aws): - ids = aws.create_instances() - assert len(ids) == 1 - ids = list(ids) - assert aws.instance_running(ids[0]) - - -def test_instance_running_dne(aws): - # This is a fake instance id - ids = ["i-xxxxxxxxxxxxxxxxx"] - with pytest.raises(Exception): - aws.instance_running(ids[0]) - - -def test_instance_running_terminated(aws): - ids = aws.create_instances() - assert len(ids) == 1 - ids = list(ids) - aws.remove_instances(ids) - assert not aws.instance_running(ids[0]) - - -def test_wait_until_ready(aws): - ids = aws.create_instances() - params = { - "MaxAttempts": 1, - "Delay": 5, - } - ids = list(ids) - aws.wait_until_ready(ids, **params) - assert aws.instance_running(ids[0]) - - -def test_wait_until_ready_dne(aws): - # This is a fake instance id - ids = ["i-xxxxxxxxxxxxxxxxx"] - params = { - "MaxAttempts": 1, - "Delay": 5, - } - with pytest.raises(WaiterError): - aws.wait_until_ready(ids, **params) - - -@pytest.mark.slow -def test_wait_until_ready_dne_long(aws): - # This is a fake instance id - ids = ["i-xxxxxxxxxxxxxxxxx"] - with pytest.raises(WaiterError): - aws.wait_until_ready(ids) - - -def test_remove_instances(aws): - ids = aws.create_instances() - assert len(ids) == 1 - ids = list(ids) - aws.remove_instances(ids) - assert not aws.instance_running(ids[0]) - - -def test_wait_until_removed(aws): - ids = aws.create_instances() - assert len(ids) == 1 - ids = list(ids) - aws.remove_instances(ids) - params = { - "MaxAttempts": 1, - "Delay": 5, - } - aws.wait_until_removed(ids, **params) - assert not aws.instance_running(ids[0]) - - -def test_wait_until_removed_dne(aws): - # This is a fake instance id - ids = ["i-xxxxxxxxxxxxxxxxx"] - params = { - "MaxAttempts": 1, - "Delay": 5, - } - with pytest.raises(WaiterError): - aws.wait_until_removed(ids, **params) - - -@pytest.mark.slow -def test_wait_until_removed_dne_long(aws): - # This is a fake instance id - ids = ["i-xxxxxxxxxxxxxxxxx"] - with pytest.raises(WaiterError): - aws.wait_until_removed(ids) - - -def test_build_user_data(aws): - params = { - "homedir": "/home/ec2-user", - "script": "echo 'Hello, World!'", - "repo": "omsf-eco-infra/awsinfratesting", - "token": "test", - "labels": "label", - "runner_release": "test.tar.gz", - } - # We strip this to ensure that we don't have any extra whitespace to fail our test - user_data = aws._build_user_data(**params).strip() - # We also strip here - file = """#!/bin/bash -cd "/home/ec2-user" -echo "echo 'Hello, World!'" > pre-runner-script.sh -source pre-runner-script.sh -export RUNNER_ALLOW_RUNASROOT=1 -# We will get the latest release from the GitHub API -curl -L test.tar.gz -o runner.tar.gz -tar xzf runner.tar.gz -./config.sh --url https://github.com/omsf-eco-infra/awsinfratesting --token test --labels label --ephemeral -./run.sh - """.strip() - assert user_data == file - - -def test_build_user_data_missing_params(aws): - params = { - "homedir": "/home/ec2-user", - "script": "echo 'Hello, World!'", - "repo": "omsf-eco-infra/awsinfratesting", - "token": "test", - } - with pytest.raises(Exception): - aws._build_user_data(**params) diff --git a/tests/test_main.py b/tests/test_main.py deleted file mode 100644 index 7f41b3d..0000000 --- a/tests/test_main.py +++ /dev/null @@ -1,237 +0,0 @@ -import pytest -import os -from gha_runner.__main__ import ( - parse_aws_params, - _env_parse_helper, - start_runner_instances, - stop_runner_instances, -) -from unittest.mock import Mock, patch, mock_open - -pytestmark = pytest.mark.main - - -def test_env_parse_helper(): - base: dict[str, str] = {} - assert _env_parse_helper(base, "TEST_ENV_VAR_GHA_RUNNER", "test") == {} - os.environ["TEST_ENV_VAR_GHA_RUNNER"] = "test" - assert _env_parse_helper(base, "TEST_ENV_VAR_GHA_RUNNER", "test") == { - "test": "test" - } - - -@pytest.mark.parametrize( - "env_vars, expected_output", - [ - ({}, [{}]), - ( - {"INPUT_AWS_IMAGE_ID": "ami-1234567890"}, - [{"image_id": "ami-1234567890"}], - ), - ( - {"INPUT_AWS_INSTANCE_TYPE": "t2.micro"}, - [{"instance_type": "t2.micro"}], - ), - ( - {"INPUT_AWS_SUBNET_ID": "subnet-1234567890"}, - [{"subnet_id": "subnet-1234567890"}], - ), - ( - {"INPUT_AWS_SECURITY_GROUP_ID": "sg-1234567890"}, - [{"security_group_id": "sg-1234567890"}], - ), - ( - {"INPUT_AWS_IAM_ROLE": "role-1234567890"}, - [{"iam_role": "role-1234567890"}], - ), - ( - {"INPUT_AWS_TAGS": '[{"Key": "Name", "Value": "test"}]'}, - [{"tags": [{"Key": "Name", "Value": "test"}]}], - ), - ( - {"INPUT_AWS_REGION_NAME": "us-east-1"}, - [{"region_name": "us-east-1"}], - ), - ( - {"INPUT_AWS_HOME_DIR": "/home/ec2-user"}, - [{"home_dir": "/home/ec2-user"}], - ), - ({"INPUT_EXTRA_GH_LABELS": "test"}, [{"labels": "test"}]), - ( - {"INPUT_AWS_IMAGE_ID": "ami-1234567890"}, - [{"image_id": "ami-1234567890"}], - ), - ( - { - "INPUT_AWS_IMAGE_ID": "ami-1234567890", - "INPUT_AWS_INSTANCE_TYPE": "t2.micro", - }, - [ - {"image_id": "ami-1234567890"}, - {"image_id": "ami-1234567890", "instance_type": "t2.micro"}, - ], - ), - ], -) -def test_parse_aws_params(env_vars, expected_output): - idx = 0 - for key, value in env_vars.items(): - if key == "INPUT_AWS_TAGS": - os.environ[key] = value - else: - os.environ[key] = str(value) - assert parse_aws_params() == expected_output[idx] - idx += 1 - for key in env_vars.keys(): - del os.environ[key] - - -def test_parse_aws_params_empty(): - os.environ["INPUT_AWS_IMAGE_ID"] = "" - os.environ["INPUT_AWS_INSTANCE_TYPE"] = "" - os.environ["INPUT_AWS_SUBNET_ID"] = "" - os.environ["INPUT_AWS_SECURITY_GROUP_ID"] = "" - os.environ["INPUT_AWS_IAM_ROLE"] = "" - os.environ["INPUT_AWS_TAGS"] = "" - os.environ["INPUT_AWS_REGION_NAME"] = "" - os.environ["INPUT_AWS_HOME_DIR"] = "" - os.environ["INPUT_AWS_LABELS"] = "" - assert parse_aws_params() == { - "image_id": "", - "instance_type": "", - "home_dir": "", - "region_name": "", - } - - -# Define a mock CloudProvider class (e.g., MockAWS) -class MockCloudProvider: - def __init__(self, **kwargs): - pass - - def create_instances(self): - return mock_get_instance_mapping() - - def wait_until_ready(self, instance_ids): - pass - - def remove_instances(self, instance_ids): - pass - - def wait_until_removed(self, instance_ids): - pass - - -class FailedCloudProvider: - def __init__(self, **kwargs): - pass - - def create_instances(self): - pass - - def wait_until_ready(self, instance_ids): - pass - - def remove_instances(self, instance_ids): - pass - - def wait_until_removed(self, instance_ids): - raise Exception("Test") - - -def mock_get_instance_mapping(): - return {"instance_id_1": "label_1"} - - -@pytest.fixture -def mock_cloud_deployment_factory(): - with patch.dict( - "gha_runner.clouddeployment.CloudDeploymentFactory.providers", - { - "mock_provider": MockCloudProvider, - "failed_provider": FailedCloudProvider, - }, - ): - yield - - -@pytest.fixture -def mock_gh_output(): - with patch.dict("os.environ", {"GITHUB_OUTPUT": "mock_output"}): - with patch("builtins.open", mock_open()): - yield - - -@pytest.fixture -def mock_gh(): - gh_mock = Mock() - gh_mock.get_latest_runner_release.return_value = "mock_release" - gh_mock.create_runner_tokens.return_value = ["mock_token"] - gh_mock.remove_runner.return_value = None - - return gh_mock - - -def test_start_runner_instances_smoke( - mock_cloud_deployment_factory, mock_gh_output, mock_gh -): - try: - start_runner_instances( - provider="mock_provider", - gh=mock_gh, - count=1, - cloud_params={}, - timeout=0, - ) - except Exception as e: - pytest.fail(f"Exception raised: {e}") - - -def test_stop_runner_instances_smoke( - mock_cloud_deployment_factory, mock_gh_output, mock_gh -): - with patch( - "gha_runner.__main__.get_instance_mapping", - new=mock_get_instance_mapping, - ): - try: - stop_runner_instances( - provider="mock_provider", cloud_params={}, gh=mock_gh - ) - except Exception as e: - pytest.fail(f"stop_runner_instances raised an exception: {e}") - - -def test_stop_runner_instances_failure( - mock_cloud_deployment_factory, mock_gh_output, mock_gh, capsys -): - with patch( - "gha_runner.__main__.get_instance_mapping", - side_effect=Exception("Test"), - ): - with pytest.raises(SystemExit) as fail: - stop_runner_instances( - provider="mock_provider", cloud_params={}, gh=mock_gh - ) - assert fail.type == SystemExit - assert fail.value.code == 1 - - captured = capsys.readouterr() - assert ( - captured.out == "::error title=Malformed instance mapping::Test" - ) - - -def test_stop_runner_instances_aws( - mock_cloud_deployment_factory, mock_gh_output, mock_gh, capsys -): - with patch( - "gha_runner.__main__.get_instance_mapping", - new=mock_get_instance_mapping, - ): - with pytest.raises(SystemExit) as fail: - stop_runner_instances( - provider="failed_provider", cloud_params={}, gh=mock_gh - ) - assert fail.type == SystemExit - assert fail.value.code == 1