diff --git a/README.md b/README.md index c244be6..91fcc2b 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,8 @@ Note that all images needed to run the specified stages are pulled in parallel d #### Run ``` -run(self, environ={}): +def run(self, environ={}) +async def run_async(self, environ={}) ``` The `Chainlink` run function takes a base environment (`environ`) and executes each container specified by `stages` during construction in sequence. If a stage fails, then no subsequent stages will be run. @@ -86,6 +87,8 @@ The run function returns a list of object, an example of which is annotated belo Note that the returned list will have the same number of elements as there are stages, with element corresponding to the stage with the same index. +`run_async` is an async version of `run`. + ### Cross-Stage Communication A single directory is mounted at `/job` in each container before it is run, and contents in this `/job` directory are persisted across stages. diff --git a/chainlink/__init__.py b/chainlink/__init__.py index b8254c3..1e7d971 100644 --- a/chainlink/__init__.py +++ b/chainlink/__init__.py @@ -26,14 +26,18 @@ def __init__(self, stages, workdir="/tmp"): self._pull_status = {} self._pull_images() + # sync version def run(self, environ={}): + return asyncio.get_event_loop().run_until_complete(self.run_async(environ)) + + async def run_async(self, environ={}): results = [] with tempfile.TemporaryDirectory(dir=self.workdir) as mount: logger.info("using {} for temporary job directory".format(mount)) for (idx, stage) in enumerate(self.stages): logger.info("running stage {}".format(idx + 1)) - results.append(self._run_stage(stage, mount, environ)) + results.append(await self._run_stage(stage, mount, environ)) if not results[-1]["success"]: logger.error("stage {} was unsuccessful".format(idx + 1)) break @@ -74,7 +78,7 @@ def _pull_image(client, image, status): except docker.errors.ImageNotFound: logger.error("image '{}' not found remotely or locally".format(image)) - def _run_stage(self, stage, mount, environ): + async def _run_stage(self, stage, mount, environ): environ = {**environ, **stage.get("env", {})} volumes = {mount: {"bind": "/job", "mode": "rw"}} @@ -94,7 +98,7 @@ def _run_stage(self, stage, mount, environ): "tty": True, } - container, killed = self._wait_for_stage(stage, options) + container, killed = await self._wait_for_stage(stage, options) result = { "data": self.client.api.inspect_container(container.id)["State"], "killed": killed, @@ -108,27 +112,23 @@ def _run_stage(self, stage, mount, environ): return result - def _wait_for_stage(self, stage, options): + async def _wait_for_stage(self, stage, options): timeout = stage.get("timeout", 30) container = self.client.containers.run(stage["image"], **options) + event_loop = asyncio.get_event_loop() - # anonymous async runner for executing and waiting on container - async def __run(loop, executor): - try: - await asyncio.wait_for( - loop.run_in_executor(executor, container.wait), timeout=timeout - ) - except asyncio.TimeoutError: - logger.error("killing container after {} seconds".format(timeout)) - container.kill() - return True - return False + # execute and wait + try: + await asyncio.wait_for( + event_loop.run_in_executor(self._executor, container.wait), + timeout=timeout, + ) + except asyncio.TimeoutError: + logger.error("killing container after {} seconds".format(timeout)) + container.kill() + return container, True - event_loop = asyncio.get_event_loop() - killed = event_loop.run_until_complete( - asyncio.gather(__run(event_loop, self._executor)) - )[0] - return container, killed + return container, False def __del__(self): self.client.close() diff --git a/tests/integration/basic.py b/tests/integration/basic.py index 4fc9e96..3eccce5 100644 --- a/tests/integration/basic.py +++ b/tests/integration/basic.py @@ -2,7 +2,7 @@ from chainlink import Chainlink -stages = [ +stages_1 = [ { "image": "alpine:3.5", "entrypoint": ["env"], @@ -10,15 +10,22 @@ }, {"image": "alpine:3.5", "entrypoint": ["sleep", "2"]}, ] + +stages_2 = [{"image": "no-such-image:3.1415926535", "entrypoint": ["env"]}] + env = {"TEST": "testing", "SEMESTER": "sp18", "ASSIGNMENT": "mp1"} class TestBasicChaining(unittest.TestCase): def test_basic_chain(self): - chain = Chainlink(stages) + chain = Chainlink(stages_1) results = chain.run(env) self.assertFalse(results[0]["killed"]) self.assertTrue("TEST=testing" in results[0]["logs"]["stdout"].decode("utf-8")) self.assertFalse(results[0]["killed"]) self.assertEqual(results[1]["data"]["ExitCode"], 0) + + def test_no_such_image(self): + with self.assertRaises(Exception): + Chainlink(stages_2)