diff --git a/.gitignore b/.gitignore index 0c8b4dec..2ab17a65 100644 --- a/.gitignore +++ b/.gitignore @@ -18,7 +18,6 @@ dist/ downloads/ eggs/ .eggs/ -lib/ lib64/ parts/ sdist/ diff --git a/mite/__init__.py b/mite/__init__.py index 87567607..25a4d878 100755 --- a/mite/__init__.py +++ b/mite/__init__.py @@ -9,6 +9,7 @@ import random from pkg_resources import get_distribution, DistributionNotFound +from .scenario import time_function # noqa: F401 from .exceptions import MiteError # noqa: F401 from .context import Context import mite.utils @@ -29,7 +30,7 @@ def test_context(extensions=('http',), **config): return c -class ensure_separation_from_callable: +class _ensure_separation_from_callable: def __init__(self, sep_callable, loop=None): self._sep_callable = sep_callable self._loop = loop @@ -65,7 +66,7 @@ def ensure_fixed_separation(separation, loop=None): def fixed_separation(): return separation - return ensure_separation_from_callable(fixed_separation, loop=loop) + return _ensure_separation_from_callable(fixed_separation, loop=loop) def ensure_average_separation(mean_separation, plus_minus=None, loop=None): @@ -88,4 +89,4 @@ def ensure_average_separation(mean_separation, plus_minus=None, loop=None): def average_separation(): return mean_separation + (random.random() * plus_minus * 2) - plus_minus - return ensure_separation_from_callable(average_separation, loop=loop) + return _ensure_separation_from_callable(average_separation, loop=loop) diff --git a/mite/__main__.py b/mite/__main__.py index bf968127..8a4cc111 100755 --- a/mite/__main__.py +++ b/mite/__main__.py @@ -62,25 +62,21 @@ import os import sys import threading -from urllib.request import Request as UrlLibRequest -from urllib.request import urlopen import docopt import msgpack -import ujson import uvloop -from .cli.common import (_create_config_manager, _create_runner, - _create_scenario_manager) +from .cli.common import _create_runner, _create_sender +from .cli.controller import controller from .cli.duplicator import duplicator from .cli.stats import stats from .cli.test import journey_cmd, scenario_cmd from .collector import Collector -from .controller import Controller from .har_to_mite import har_convert_to_mite from .recorder import Recorder -from .utils import _msg_backend_module, spec_import +from .utils import _msg_backend_module from .web import app, prometheus_metrics @@ -98,13 +94,6 @@ def _recorder_receiver(opts): return receiver -def _create_sender(opts): - socket = opts['--message-socket'] - sender = _msg_backend_module(opts).Sender() - sender.connect(socket) - return sender - - def _create_prometheus_exporter_receiver(opts): socket = opts['--stats-out-socket'] receiver = _msg_backend_module(opts).Receiver() @@ -117,11 +106,6 @@ def _create_runner_transport(opts): return _msg_backend_module(opts).RunnerTransport(socket) -def _create_controller_server(opts): - socket = opts['--controller-socket'] - return _msg_backend_module(opts).ControllerServer(socket) - - logger = logging.getLogger(__name__) @@ -147,103 +131,6 @@ def _start_web_in_thread(opts): t.start() -def _controller_log_start(scenario_spec, logging_url): - if not logging_url.endswith("/"): - logging_url += "/" - - # The design decision has been made to do this logging synchronously - # rather than using the usual mite data pipeline, because we want to make - # sure the log is nailed down before we start doing any test activity. - url = logging_url + "start" - logger.info(f"Logging test start to {url}") - resp = urlopen( - UrlLibRequest( - url, - data=ujson.dumps( - { - 'testname': scenario_spec, - # TODO: log other properties as well, - # like the endpoint URLs we are - # hitting. - } - ).encode(), - method="POST", - ) - ) - logger.debug("Logging test start complete") - if resp.status == 200: - return ujson.loads(resp.read())['newid'] - else: - logger.warning( - f"Could not complete test start logging; status was {resp.status_code}" - ) - - -def _controller_log_end(logging_id, logging_url): - if logging_id is None: - return - - if not logging_url.endswith("/"): - logging_url += "/" - - url = logging_url + "end" - logger.info(f"Logging test end to {url}") - resp = urlopen(UrlLibRequest(url, data=ujson.dumps({'id': logging_id}).encode())) - if resp.status != 204: - logger.warning( - f"Could not complete test end logging; status was {resp.status_code}" - ) - logger.debug("Logging test end complete") - - -def controller(opts): - config_manager = _create_config_manager(opts) - scenario_spec = opts['SCENARIO_SPEC'] - scenarios_fn = spec_import(scenario_spec) - scenario_manager = _create_scenario_manager(opts) - try: - scenarios = scenarios_fn(config_manager) - except TypeError: - scenarios = scenarios_fn() - for journey_spec, datapool, volumemodel in scenarios: - scenario_manager.add_scenario(journey_spec, datapool, volumemodel) - controller = Controller(scenario_spec, scenario_manager, config_manager) - server = _create_controller_server(opts) - sender = _create_sender(opts) - loop = asyncio.get_event_loop() - logging_id = None - logging_url = opts["--logging-webhook"] - if logging_url is None: - try: - logging_url = os.environ["MITE_LOGGING_URL"] - except KeyError: - pass - if logging_url is not None: - logging_id = _controller_log_start(scenario_spec, logging_url) - - async def controller_report(): - while True: - if controller.should_stop(): - return - await asyncio.sleep(1) - controller.report(sender.send) - - try: - loop.run_until_complete( - asyncio.gather( - controller_report(), server.run(controller, controller.should_stop) - ) - ) - except KeyboardInterrupt: - # TODO: kill runners, do other shutdown tasks - logging.info("Received interrupt signal, shutting down") - finally: - _controller_log_end(logging_id, logging_url) - # TODO: cancel all loop tasks? Something must be done to stop this - # from hanging - loop.close() - - def runner(opts): transport = _create_runner_transport(opts) sender = _create_sender(opts) diff --git a/mite/cli/common.py b/mite/cli/common.py index a01c3a80..665619fb 100644 --- a/mite/cli/common.py +++ b/mite/cli/common.py @@ -1,7 +1,7 @@ from ..config import ConfigManager from ..runner import Runner from ..scenario import ScenarioManager -from ..utils import spec_import +from ..utils import spec_import, _msg_backend_module def _create_config_manager(opts): @@ -15,8 +15,9 @@ def _create_config_manager(opts): return config_manager -def _create_scenario_manager(opts): +def _create_scenario_manager(spec, opts): return ScenarioManager( + spec=spec, start_delay=float(opts['--delay-start-seconds']), period=float(opts['--max-loop-delay']), spawn_rate=int(opts['--spawn-rate']), @@ -37,3 +38,10 @@ def _create_runner(opts, transport, msg_sender): max_work=max_work, debug=opts['--debugging'], ) + + +def _create_sender(opts): + socket = opts['--message-socket'] + sender = _msg_backend_module(opts).Sender() + sender.connect(socket) + return sender diff --git a/mite/cli/controller.py b/mite/cli/controller.py new file mode 100644 index 00000000..ec5d96cb --- /dev/null +++ b/mite/cli/controller.py @@ -0,0 +1,97 @@ +import asyncio +import logging +from functools import partial + +from ..controller import Controller +from ..utils import _msg_backend_module, spec_import +from .common import _create_config_manager, _create_scenario_manager, _create_sender + + +def _create_controller_server(opts): + socket = opts['--controller-socket'] + return _msg_backend_module(opts).ControllerServer(socket) + + +async def _run_time_function(time_fn, start_event, end_event): + if time_fn is not None: + await time_fn(start_event, end_event) + if not end_event.is_set(): + logging.error( + "The time function exited before the scenario ended, which seems like a bug" + ) + else: + start_event.set() + await end_event.wait() + + +async def _send_controller_report(sender, controller_obj): + while True: + if controller_obj.should_stop(): + return + await asyncio.sleep(1) + controller_obj.report(sender.send) + + +async def _run_controller(server, controller_obj, start_event, end_event): + await start_event.wait() + await server.run(controller_obj, controller_obj.should_stop) + end_event.set() + + +def _run( + scenario_spec, + opts, + scenarios_fn, + server, + sender, + get_controller=Controller, + extra_tasks=(), +): + config_manager = _create_config_manager(opts) + scenario_manager = _create_scenario_manager(scenario_spec, opts) + try: + scenarios = scenarios_fn(config_manager) + except TypeError: + scenarios = scenarios_fn() + for journey_spec, datapool, volumemodel in scenarios: + scenario_manager.add_scenario(journey_spec, datapool, volumemodel) + controller_object = get_controller(scenario_manager, config_manager) + + time_fn = getattr(scenarios_fn, "_mite_time_function", None) + if time_fn is not None: + time_fn = partial(time_fn, scenario_spec, config_manager) + + loop = asyncio.get_event_loop() + start_event = asyncio.Event() + end_event = asyncio.Event() + + time_fn_task = loop.create_task(_run_time_function(time_fn, start_event, end_event)) + + try: + loop.run_until_complete( + asyncio.gather( + _send_controller_report(sender, controller_object), + _run_controller(server, controller_object, start_event, end_event), + time_fn_task, + *extra_tasks, + ) + ) + except KeyboardInterrupt: + # TODO: kill runners, do other shutdown tasks + logging.info("Received interrupt signal, shutting down") + if not end_event.is_set(): + time_fn_task.cancel() + loop.run_until_complete(time_fn_task) + finally: + # TODO: cancel all loop tasks? Something must be done to stop this + # from hanging + loop.close() + + +def controller(opts): + server = _create_controller_server(opts) + sender = _create_sender(opts) + scenario_spec = opts['SCENARIO_SPEC'] + scenarios_fn = spec_import(scenario_spec) + + _run(scenario_spec, opts, scenarios_fn, server, sender) diff --git a/mite/cli/test.py b/mite/cli/test.py index 2c33116f..28b921dc 100644 --- a/mite/cli/test.py +++ b/mite/cli/test.py @@ -5,14 +5,11 @@ from ..controller import Controller from ..recorder import Recorder from ..utils import pack_msg, spec_import -from .common import (_create_config_manager, _create_runner, - _create_scenario_manager) +from .common import _create_runner +from .controller import _run as _run_controller class DirectRunnerTransport: - def __init__(self, controller): - self._controller = controller - async def hello(self): return self._controller.hello() @@ -25,7 +22,7 @@ async def bye(self, runner_id): return self._controller.bye(runner_id) -class DirectReciever: +class DirectReceiver: def __init__(self): self._listeners = [] self._raw_listeners = [] @@ -36,13 +33,22 @@ def add_listener(self, listener): def add_raw_listener(self, raw_listener): self._raw_listeners.append(raw_listener) - def recieve(self, msg): + def receive(self, msg): for listener in self._listeners: listener(msg) packed_msg = pack_msg(msg) for raw_listener in self._raw_listeners: raw_listener(packed_msg) + def send(self, msg): + self.receive(msg) + + +class DummyServer: + async def run(self, controller, stop_func): + while stop_func is None or not stop_func(): + await asyncio.sleep(5) + def _setup_msg_processors(receiver, opts): collector = Collector(opts['--collector-dir'], int(opts['--collector-roll'])) @@ -50,50 +56,44 @@ def _setup_msg_processors(receiver, opts): receiver.add_listener(recorder.process_message) receiver.add_raw_listener(collector.process_raw_message) - extra_processors = [ - spec_import(x)() - for x in opts["--message-processors"].split(",") - ] + extra_processors = [spec_import(x)() for x in opts["--message-processors"].split(",")] for processor in extra_processors: if hasattr(processor, "process_message"): receiver.add_listener(processor.process_message) elif hasattr(processor, "process_raw_message"): receiver.add_raw_listener(processor.process_raw_message) else: - logging.error(f"Class {processor.__name__} does not have a process(_raw)_message method!") - - -def test_scenarios(test_name, opts, scenarios, config_manager): - scenario_manager = _create_scenario_manager(opts) - for journey_spec, datapool, volumemodel in scenarios: - scenario_manager.add_scenario(journey_spec, datapool, volumemodel) - controller = Controller(test_name, scenario_manager, config_manager) - transport = DirectRunnerTransport(controller) - receiver = DirectReciever() - _setup_msg_processors(receiver, opts) - loop = asyncio.get_event_loop() - - async def controller_report(): - while True: - await asyncio.sleep(1) - controller.report(receiver.recieve) - - loop.run_until_complete( - asyncio.gather( - controller_report(), _create_runner(opts, transport, receiver.recieve).run() - ) + logging.error( + f"Class {processor.__name__} does not have a process(_raw)_message method!" + ) + + +def test_scenarios(scenario_spec, opts, scenario_fn): + server = DummyServer() + sender = DirectReceiver() + transport = DirectRunnerTransport() + + def get_controller(*args, **kwargs): + c = Controller(*args, **kwargs) + transport._controller = c + return c + + _setup_msg_processors(sender, opts) + _run_controller( + scenario_spec, + opts, + scenario_fn, + server, + sender, + get_controller=get_controller, + extra_tasks=(_create_runner(opts, transport, sender.receive).run(),), ) def scenario_test_cmd(opts): scenario_spec = opts['SCENARIO_SPEC'] scenarios_fn = spec_import(scenario_spec) - config_manager = _create_config_manager(opts) - try: - scenarios = scenarios_fn(config_manager) - except TypeError: - scenarios = scenarios_fn() - test_scenarios(scenario_spec, opts, scenarios, config_manager) + test_scenarios(scenario_spec, opts, scenarios_fn) def journey_test_cmd(opts): @@ -104,11 +104,12 @@ def journey_test_cmd(opts): else: datapool = None volumemodel = lambda start, end: int(opts['--volume']) + + def dummy_scenario(): + return [(journey_spec, datapool, volumemodel)] + test_scenarios( - journey_spec, - opts, - [(journey_spec, datapool, volumemodel)], - _create_config_manager(opts), + journey_spec, opts, dummy_scenario, ) diff --git a/mite/controller.py b/mite/controller.py index 5687567d..4e870c7f 100644 --- a/mite/controller.py +++ b/mite/controller.py @@ -71,8 +71,7 @@ def get_active_count(self): class Controller: - def __init__(self, testname, scenario_manager, config_manager): - self._testname = testname + def __init__(self, scenario_manager, config_manager): self._scenario_manager = scenario_manager self._runner_id_gen = count(1) self._work_tracker = WorkTracker() @@ -83,7 +82,7 @@ def hello(self): runner_id = next(self._runner_id_gen) return ( runner_id, - self._testname, + self._scenario_manager.spec, self._config_manager.get_changes_for_runner(runner_id), ) @@ -125,7 +124,7 @@ def report(self, sender): { 'type': 'controller_report', 'time': time.time(), - 'test': self._testname, + 'test': self._scenario_manager.spec, 'required': required, 'actual': dict(actual), 'num_runners': len(active_runner_ids), diff --git a/mite/lib/logging_webhook.py b/mite/lib/logging_webhook.py new file mode 100644 index 00000000..8a520594 --- /dev/null +++ b/mite/lib/logging_webhook.py @@ -0,0 +1,61 @@ +import logging +from urllib.request import Request as UrlLibRequest +from urllib.request import urlopen + +import ujson + + +def _controller_log_start(scenario_spec, logging_url): + if not logging_url.endswith("/"): + logging_url += "/" + + url = logging_url + "start" + logging.info(f"Logging test start to {url}") + resp = urlopen( + UrlLibRequest( + url, + data=ujson.dumps( + { + 'testname': scenario_spec, + # TODO: log other properties as well, + # like the endpoint URLs we are + # hitting. + } + ).encode(), + method="POST", + ) + ) + logging.debug("Logging test start complete") + if resp.status == 200: + return ujson.loads(resp.read())['newid'] + else: + logging.warning( + f"Could not complete test start logging; status was {resp.status_code}" + ) + + +def _controller_log_end(logging_id, logging_url): + if logging_id is None: + return + + if not logging_url.endswith("/"): + logging_url += "/" + + url = logging_url + "end" + logging.info(f"Logging test end to {url}") + resp = urlopen(UrlLibRequest(url, data=ujson.dumps({'id': logging_id}).encode())) + if resp.status != 204: + logging.warning( + f"Could not complete test end logging; status was {resp.status_code}" + ) + logging.debug("Logging test end complete") + + +async def log(scenario_spec, config_manager, start_event, end_event): + url = config_manager.get("logging_webhook_url") + logging_id = _controller_log_start(scenario_spec, url) + start_event.set() + try: + await end_event.wait() + finally: + _controller_log_end(logging_id, url) diff --git a/mite/scenario.py b/mite/scenario.py index f3f338a0..779dc166 100644 --- a/mite/scenario.py +++ b/mite/scenario.py @@ -27,7 +27,8 @@ def _volume_dicts_remove_a_from_b(a, b): class ScenarioManager: - def __init__(self, start_delay=0, period=1, spawn_rate=None): + def __init__(self, spec, start_delay=0, period=1, spawn_rate=None): + self._spec = spec self._period = period self._scenario_id_gen = count(1) self._in_start = start_delay > 0 @@ -38,6 +39,10 @@ def __init__(self, start_delay=0, period=1, spawn_rate=None): self._required = {} self._scenarios = {} + @property + def spec(self): + return self._spec + def _now(self): return time.time() - self._start_time @@ -167,3 +172,11 @@ async def checkin_data(self, ids): for scenario_id, scenario_data_id in ids: if scenario_id in self._scenarios: await self._scenarios[scenario_id].datapool.checkin(scenario_data_id) + + +def time_function(tf_name): + def decorator_inner(fn): + fn._mite_time_function = tf_name + return fn + + return decorator_inner diff --git a/test/test_cli_common.py b/test/test_cli_common.py new file mode 100644 index 00000000..1227a2ff --- /dev/null +++ b/test/test_cli_common.py @@ -0,0 +1,20 @@ +from mite.cli.common import _create_config_manager + + +def dummy_config(): + return {"foo": "bar"} + + +def test_create_config_manager(): + cm = _create_config_manager( + {"--config": "test_cli_common:dummy_config", "--add-to-config": ()} + ) + assert cm.get("foo") == "bar" + + +def test_create_config_manager_add_to_config(): + cm = _create_config_manager( + {"--config": "test_cli_common:dummy_config", "--add-to-config": ("abc:xyz",)} + ) + assert cm.get("foo") == "bar" + assert cm.get("abc") == "xyz" diff --git a/test/test_cli_controller.py b/test/test_cli_controller.py new file mode 100644 index 00000000..6781b5b4 --- /dev/null +++ b/test/test_cli_controller.py @@ -0,0 +1,59 @@ +import asyncio +from unittest import mock +from unittest.mock import call + +import pytest +from mite.cli.controller import _run_time_function + + +@pytest.mark.asyncio +class TestRunTimeFunction: + async def test_none(self): + s = asyncio.Event() + e = asyncio.Event() + + async def co(): + await asyncio.sleep(0.01) + e.set() + + await asyncio.gather(co(), _run_time_function(None, s, e)) + assert s.is_set() + + async def test_time_fn(self): + s = asyncio.Event() + e = asyncio.Event() + + async def co(): + await asyncio.sleep(0.01) + e.set() + + async def tf(s, e): + assert not s.is_set() + s.set() + await e.wait() + + await asyncio.gather(co(), _run_time_function(tf, s, e)) + assert s.is_set() + assert e.is_set() + + async def test_early_return(self): + s = asyncio.Event() + e = asyncio.Event() + + async def co(): + await asyncio.sleep(0.01) + e.set() + + async def tf(s, e): + assert not s.is_set() + s.set() + + with mock.patch("logging.error") as m: + await asyncio.gather(co(), _run_time_function(tf, s, e)) + assert s.is_set() + assert e.is_set() + assert m.call_args_list == [ + call( + 'The time function exited before the scenario ended, which seems like a bug' + ), + ] diff --git a/test/test_controller.py b/test/test_controller.py index 19ad5752..d8d2c161 100644 --- a/test/test_controller.py +++ b/test/test_controller.py @@ -3,21 +3,18 @@ from mite.config import ConfigManager from werkzeug.wrappers import Response import ujson -from mite.__main__ import _controller_log_start, _controller_log_end +from mite.lib.logging_webhook import _controller_log_start, _controller_log_end from pytest_httpserver.httpserver import HandlerType import json -TESTNAME = "unit_test_name" - - def test_controller_hello(): - scenario_manager = ScenarioManager() + scenario_manager = ScenarioManager("test_123") config_manager = ConfigManager() - controller = Controller(TESTNAME, scenario_manager, config_manager) + controller = Controller(scenario_manager, config_manager) for i in range(5): runner_id, test_name, config_list = controller.hello() - assert test_name == controller._testname + assert test_name == "test_123" assert runner_id == i + 1 diff --git a/test/test_scenario.py b/test/test_scenario.py index f1678062..02b12a2c 100644 --- a/test/test_scenario.py +++ b/test/test_scenario.py @@ -57,7 +57,7 @@ def volumemodel(start, end): def test_add_scenario(): - scenario_manager = ScenarioManager() + scenario_manager = ScenarioManager("test") for row in test_scenarios: scenario_manager.add_scenario(row[0], row[1], row[2]) for i in scenario_manager._scenarios: @@ -65,7 +65,7 @@ def test_add_scenario(): def test_upadate_required_work(): - scenario_manager = ScenarioManager() + scenario_manager = ScenarioManager("test") for row in test_scenarios: scenario_manager.add_scenario(row[0], row[1], row[2]) scenario_manager._update_required_and_period(5, 10) @@ -76,7 +76,7 @@ def test_upadate_required_work(): @pytest.mark.asyncio async def test_get_work(): - scenario_manager = ScenarioManager() + scenario_manager = ScenarioManager("test") for row in test_scenarios: scenario_manager.add_scenario(row[0], row[1], row[2]) scenario_manager._update_required_and_period(5, 10)