Skip to content

Commit

Permalink
Add the BATS executor
Browse files Browse the repository at this point in the history
  • Loading branch information
anfimovdm committed Sep 9, 2023
1 parent a09b5d7 commit 11d4683
Show file tree
Hide file tree
Showing 14 changed files with 474 additions and 133 deletions.
48 changes: 48 additions & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
name: pytest
on:
pull_request:
branches:
- "**"
jobs:
build:
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- uses: actions/checkout@v3
name: Check out repository
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Prepare ssh
run: |
whoami
ssh-keygen -t ed25519 -f ~/.ssh/whatever -N ''
cat > ~/.ssh/config <<EOF
Host localhost
User runner
HostName 127.0.0.1
IdentityFile ~/.ssh/whatever
ssh localhost 'bash --version'
- name: Prepare python env
run: |
python -m venv env
env/bin/python -m pip install -U pip
env/bin/python -m pip install -r requirements/celery.txt
- name: Run unit tests (pytest)
env:
CELERY_CONFIG_PATH: tests/tests_config.yaml
SSH_USERNAME: runner
run: env/bin/python -m pytest -v --cov-report term-missing:skip-covered
--cov-report xml:/tmp/coverage.xml --junitxml=/tmp/pytest.xml --cov=alts tests/ | tee /tmp/pytest-coverage.txt
- name: Pytest coverage comment
uses: MishaKav/pytest-coverage-comment@main
with:
pytest-coverage-path: /tmp/pytest-coverage.txt
pytest-xml-coverage-path: /tmp/coverage.xml
title: Coverage report for changed files
badge-title: Total coverage
hide-badge: false
hide-report: false
report-only-changed-files: true
hide-comment: false
remove-link-from-badge: false
junitxml-path: /tmp/pytest.xml
24 changes: 12 additions & 12 deletions alts/shared/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,36 +134,36 @@ def broker_url(self) -> str:


class AzureResultsConfig(BaseResultsConfig):
azureblockblob_container_name: typing.Optional[str]
azureblockblob_container_name: str
azureblockblob_base_path: str = 'celery_result_backend/'
azure_connection_string: typing.Optional[str]
azure_connection_string: str


class FilesystemResultsConfig(BaseResultsConfig):
path: typing.Optional[str]
path: str


class RedisResultsConfig(BaseResultsConfig, RedisBrokerConfig):
pass


class S3ResultsConfig(BaseResultsConfig):
s3_access_key_id: typing.Optional[str]
s3_secret_access_key: typing.Optional[str]
s3_bucket: typing.Optional[str]
s3_access_key_id: str
s3_secret_access_key: str
s3_bucket: str
s3_base_path: str = 'celery_result_backend/'
s3_region: typing.Optional[str]
s3_endpoint_url: typing.Optional[str] = None
s3_region: str
s3_endpoint_url: str


class AzureLogsConfig(BaseLogsConfig, AzureResultsConfig):
azure_logs_container: typing.Optional[str]
azure_logs_container: str


class PulpLogsConfig(BaseLogsConfig):
pulp_host: typing.Optional[str]
pulp_user: typing.Optional[str]
pulp_password: typing.Optional[str]
pulp_host: str
pulp_user: str
pulp_password: str


class CeleryConfig(BaseModel):
Expand Down
131 changes: 91 additions & 40 deletions alts/shared/utils/asyncssh.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,12 @@
import asyncio
import logging
import typing
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Optional
from traceback import format_exc
from typing import Any, Dict, List, Literal, Optional, Tuple

import asyncssh


class AsyncSSHClientSession(asyncssh.SSHClientSession):
def data_received(self, data: str, datatype: asyncssh.DataType):
if datatype == asyncssh.EXTENDED_DATA_STDERR:
logging.error(
'SSH command stderr:\n%s',
data,
)
else:
logging.info(
'SSH command stdout:\n%s',
data,
)

def connection_lost(self, exc: typing.Optional[Exception]):
if exc:
logging.exception(
'SSH session error:',
)
raise exc


class AsyncSSHClient:
def __init__(
self,
Expand All @@ -38,6 +17,16 @@ def __init__(
known_hosts_files: Optional[List[str]] = None,
disable_known_hosts_check: bool = False,
env_vars: Optional[Dict[str, Any]] = None,
logger: Optional[logging.Logger] = None,
logger_name: str = 'asyncssh-client',
logging_level: Literal[
'NOTSET',
'DEBUG',
'INFO',
'WARNING',
'ERROR',
'CRITICAL',
] = 'DEBUG',
):
self.username = username
self.password = password
Expand All @@ -51,6 +40,24 @@ def __init__(
)
if disable_known_hosts_check:
self.known_hosts = None
if not logger:
self.logger = self.setup_logger(logger_name, logging_level)

def setup_logger(
self,
logger_name: str,
logging_level: str,
) -> logging.Logger:
logger = logging.getLogger(logger_name)
logger.setLevel(logging_level)
handler = logging.StreamHandler()
handler.setLevel(logging_level)
formatter = logging.Formatter(
'%(asctime)s [%(name)s:%(levelname)s] - %(message)s'
)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger

@asynccontextmanager
async def get_connection(self):
Expand All @@ -64,26 +71,70 @@ async def get_connection(self):
) as conn:
yield conn

def sync_run_command(self, command: str):
try:
asyncio.run(self.async_run_command(command))
except Exception as exc:
logging.exception('Cannot execute asyncssh command: %s', command)
raise exc
def get_process_results(
self,
result: asyncssh.SSHCompletedProcess,
) -> Tuple[int, str, str]:
return result.exit_status, result.stdout, result.stderr

async def async_run_command(self, command: str):
def print_process_results(
self,
result: asyncssh.SSHCompletedProcess,
):
self.logger.debug(
'Exit code: %s, stdout: %s, stderr: %s',
*self.get_process_results(result),
)

async def async_run_command(
self,
command: str,
) -> Tuple[int, str, str]:
async with self.get_connection() as conn:
channel, session = await conn.create_session(
AsyncSSHClientSession,
command,
result = await conn.run(command)
self.print_process_results(result)
return self.get_process_results(result)

def sync_run_command(
self,
command: str,
) -> Tuple[int, str, str]:
try:
return asyncio.run(self.async_run_command(command))
except Exception as exc:
self.logger.exception(
'Cannot execute asyncssh command: %s', command
)
await channel.wait_closed()
raise exc

async def async_run_commands(self, commands: List[str]):
async def async_run_commands(
self,
commands: List[str],
) -> Dict[str, Tuple[int, str, str]]:
results = {}
async with self.get_connection() as conn:
for command in commands:
channel, session = await conn.create_session(
AsyncSSHClientSession,
command,
)
await channel.wait_closed()
try:
result = await conn.run(command)
except Exception:
self.logger.exception(
'Cannot execute asyncssh command: %s',
command,
)
results[command] = (1, '', format_exc())
continue
self.print_process_results(result)
results[command] = self.get_process_results(result)
return results

def sync_run_commands(
self,
commands: List[str],
) -> Dict[str, Tuple[int, str, str]]:
try:
return asyncio.run(self.async_run_commands(commands))
except Exception as exc:
self.logger.exception(
'Cannot execute asyncssh commands: %s', commands
)
raise exc
93 changes: 84 additions & 9 deletions alts/worker/executors/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from datetime import datetime
from functools import wraps
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

from plumbum import local

Expand Down Expand Up @@ -34,11 +34,20 @@ def wrapped(self, *args, **kwargs):
class BaseExecutor:
def __init__(
self,
binary_name: str,
env_vars: Optional[Dict[str, Any]] = None,
binary_name: Optional[str] = None,
ssh_params: Optional[Union[Dict[str, Any], AsyncSSHParams]] = None,
timeout: Optional[int] = None,
logger: Optional[logging.Logger] = None,
logger_name: str = 'base-executor',
logging_level: Literal[
'NOTSET',
'DEBUG',
'INFO',
'WARNING',
'ERROR',
'CRITICAL',
] = 'DEBUG',
) -> None:
self.ssh_client = None
self.env_vars = {}
Expand All @@ -56,22 +65,88 @@ def __init__(
self.ssh_client = AsyncSSHClient(**ssh_params.dict())
self.logger = logger
if not self.logger:
self.logger = logging.getLogger('executor')
self.logger = self.setup_logger(logger_name, logging_level)
self.check_binary_existence()

@measure_stage('run_local_command')
def run_local_command(self, cmd_args: List[str]) -> Tuple[int, str, str]:
if self.binary_name not in local:
def setup_logger(
self,
logger_name: str,
logging_level: str,
) -> logging.Logger:
logger = logging.getLogger(logger_name)
logger.setLevel(logging_level)
handler = logging.StreamHandler()
handler.setLevel(logging_level)
formatter = logging.Formatter(
'%(asctime)s [%(name)s:%(levelname)s] - %(message)s'
)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger

def check_binary_existence(self):
cmd_args = ['--version']
func = self.run_local_command
if self.ssh_client:
func = self.ssh_client.sync_run_command
cmd_args = f'{self.binary_name} --version'
try:
exit_code, *_ = func(cmd_args)
except Exception as exc:
self.logger.exception('Cannot check binary existence:')
raise exc
if exit_code != 0:
raise FileNotFoundError(
f'Binary {self.binary_name} is not found in PATH on the machine',
f'Binary "{self.binary_name}" is not found in PATH on the machine',
)

@measure_stage('run_local_command')
def run_local_command(self, cmd_args: List[str]) -> Tuple[int, str, str]:
with local.env(**self.env_vars):
return local[self.binary_name].run(
args=cmd_args,
timeout=self.timeout,
)

@measure_stage('run_ssh_command')
def run_ssh_command(self, cmd: str):
def run_ssh_command(self, cmd: str) -> Tuple[int, str, str]:
if not self.ssh_client:
raise ValueError('SSH params are missing')
return self.ssh_client.sync_run_command(cmd)
return self.ssh_client.sync_run_command(f'{self.binary_name} {cmd}')


class BatsExecutor(BaseExecutor):
def __init__(
self,
binary_name: str = 'bats',
env_vars: Optional[Dict[str, Any]] = None,
ssh_params: Optional[Union[Dict[str, Any], AsyncSSHParams]] = None,
timeout: Optional[int] = None,
logger: Optional[logging.Logger] = None,
logger_name: str = 'bats-executor',
logging_level: Literal[
'NOTSET',
'DEBUG',
'INFO',
'WARNING',
'ERROR',
'CRITICAL',
] = 'DEBUG',
):
super().__init__(
binary_name=binary_name,
env_vars=env_vars,
ssh_params=ssh_params,
timeout=timeout,
logger=logger,
logger_name=logger_name,
logging_level=logging_level,
)

@measure_stage('run_local_bats')
def run_local_command(self, cmd_args: List[str]) -> Tuple[int, str, str]:
return super().run_local_command(['--tap'] + cmd_args)

@measure_stage('run_ssh_bats')
def run_ssh_command(self, cmd: str) -> Tuple[int, str, str]:
return super().run_ssh_command(f'{self.binary_name} --tap {cmd}')
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ ruamel.yaml==0.17.30
cryptography==41.0.2
azure-storage-blob==12.16.0
tap.py==3.1
librabbitmq==2.0.0
# librabbitmq==2.0.0
requests>=2.25.1
filesplit==3.0.2
pulpcore-client==3.17.3
Expand Down
Loading

0 comments on commit 11d4683

Please sign in to comment.