From 903f48b17d85aef61407c4e759e949f7a8ffbdce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Julian=20R=C3=BCth?= Date: Thu, 27 Jun 2024 12:45:59 +0300 Subject: [PATCH] Implement a first working version of a dask worker --- environment.yml | 2 + flatsurvey/jobs/orbit_closure.py | 1 + flatsurvey/pipeline/util.py | 31 +- flatsurvey/scheduler.py | 484 +++++++++---------------------- flatsurvey/surfaces/ngons.py | 18 +- flatsurvey/survey.py | 26 +- flatsurvey/worker/dask.py | 61 ++++ flatsurvey/worker/worker.py | 38 +-- 8 files changed, 259 insertions(+), 402 deletions(-) create mode 100644 flatsurvey/worker/dask.py diff --git a/environment.yml b/environment.yml index cc39272..ef394ee 100644 --- a/environment.yml +++ b/environment.yml @@ -7,8 +7,10 @@ channels: dependencies: - black >=22,<23 - click + - dask - humanfriendly - isort + - more-itertools - pip - psutil - pyflatsurf diff --git a/flatsurvey/jobs/orbit_closure.py b/flatsurvey/jobs/orbit_closure.py index 369455b..7fc8d5f 100644 --- a/flatsurvey/jobs/orbit_closure.py +++ b/flatsurvey/jobs/orbit_closure.py @@ -240,6 +240,7 @@ def deform(self, deformation): ), } + # TODO: Probably all command() implementations can be removed now. def command(self): command = [self.name()] if self._limit != self.DEFAULT_LIMIT: diff --git a/flatsurvey/pipeline/util.py b/flatsurvey/pipeline/util.py index fefa0e8..cb04434 100644 --- a/flatsurvey/pipeline/util.py +++ b/flatsurvey/pipeline/util.py @@ -110,14 +110,18 @@ def wrap(**kwargs): f"provide_{name}", ) provider.__module__ = "__main__" - binding = type( + + binding_type = type( f"Partial{prototype.__name__}Binding", (pinject.BindingSpec,), { f"provide_{name}": provider, "__repr__": lambda self: f"{name} binding to {prototype.__name__}", + "__reduce__": lambda self: (PartialBindingSpec_unpickle, ((prototype, name, scope), kwargs)) }, - )() + ) + + binding = binding_type() binding.name = name binding.scope = scope or "DEFAULT" return binding @@ -125,6 +129,11 @@ def wrap(**kwargs): return wrap +def PartialBindingSpec_unpickle(outer_arguments, inner_arguments): + return PartialBindingSpec(*outer_arguments)(**inner_arguments) + + +# TODO: Why are the args different between these functions? def FactoryBindingSpec(name, prototype, scope=None): r""" Return a BindingSpec that calls ``prototype`` as a provider for ``name``. @@ -148,17 +157,29 @@ def FactoryBindingSpec(name, prototype, scope=None): f"provide_{name}", ) provider.__module__ = "__main__" - binding = type( + binding_type = type( f"{name}FactoryBinding", (pinject.BindingSpec,), - {f"provide_{name}": provider, "__repr__": lambda self: f"{name}->{prototype}"}, - )() + { + f"provide_{name}": provider, + "__repr__": lambda self: f"{name}->{prototype}", + # TODO: This is not possible. + # "__reduce__": lambda self: (FactoryBindingSpec_unpickle, (name, prototype, scope)) + }, + ) + + binding = binding_type() binding.name = name binding.scope = scope or "DEFAULT" return binding +# TODO: This is not possible. +# def FactoryBindingSpec_unpickle(args): +# return FactoryBidingSpec(*args) + + def provide(name, objects): src = compile( f""" diff --git a/flatsurvey/scheduler.py b/flatsurvey/scheduler.py index 2785dc5..5b3b24c 100644 --- a/flatsurvey/scheduler.py +++ b/flatsurvey/scheduler.py @@ -1,10 +1,12 @@ r""" -Prepare surfaces for a survey and spawn processes to resolve the goals of the survey. +Runs a survey with dask on the local machine or in cluster. + +TODO: Give full examples. """ # ********************************************************************* # This file is part of flatsurvey. # -# Copyright (C) 2020-2022 Julian Rüth +# Copyright (C) 2020-2024 Julian Rüth # # flatsurvey is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -25,8 +27,8 @@ class Scheduler: r""" - A simple scheduler that splits a survey into commands that are run on the local - machine when the load admits it. + A scheduler that splits a survey into commands that are sent out to workers + via the dask protocol. >>> Scheduler(generators=[], bindings=[], goals=[], reporters=[]) Scheduler @@ -39,33 +41,47 @@ def __init__( bindings, goals, reporters, - queue=128, + scheduler=None, + queue=16, dry_run=False, - load=None, quiet=False, debug=False, ): import os - if load is None: - load = os.cpu_count() * 1.2 - self._generators = generators self._bindings = bindings self._goals = goals self._reporters = reporters + self._scheduler = scheduler self._queue_limit = queue self._dry_run = dry_run - self._load = load self._quiet = quiet - self._jobs = [] self._debug = debug + # TODO: This probably does not work at all. Probably we should ditch + # most of the progress implementation and implement something simpler + # for the survey (and something different for the standalone worker.) self._report = self._enable_shared_bindings() def __repr__(self): return "Scheduler" + async def _create_pool(self): + r""" + Return a new dask pool to schedule jobs. + + TODO: Explain how to use environment variables to configure things here. + TODO: Make configurable (also without environment variables.) and comment what does not need to be configured. + """ + import dask.config + dask.config.set({'distributed.worker.daemon': False}) + + import dask.distributed + pool = await dask.distributed.Client(scheduler_file=self._scheduler, direct_to_workers=True, connection_limit=2**16, asynchronous=True, n_workers=8, nthreads=1, preload="flatsurvey.worker.dask") + + return pool + async def start(self): r""" Run the scheduler until it has run out of jobs to schedule. @@ -75,130 +91,72 @@ async def start(self): >>> asyncio.run(scheduler.start()) """ + pool = await self._create_pool() + try: - with self._report.progress( - self, activity="Survey", count=0, what="tasks queued" - ) as scheduling_progress: - scheduling_progress(activity="running survey") - - with scheduling_progress( - "executing tasks", - activity="executing tasks", - count=0, - what="tasks running", - ) as execution_progress: - submitted_tasks = [] - - from collections import deque - - queued_commands = deque() - - surfaces = [iter(generator) for generator in self._generators] - - try: - while True: - import asyncio - - await asyncio.sleep(0) - - message = [] - - try: - if queued_commands: - # Attempt to run a task (unless the load is too high) - import os - - load = os.getloadavg()[0] - - import psutil - - psutil.cpu_percent(None) - cpu = psutil.cpu_percent(0.01) - - if self._load > 0 and load > self._load: - message.append(f"load {load:.1f} too high") - elif self._load > 0 and cpu >= 100: - message.append(f"CPU {cpu:.1f}% too high") - else: - import asyncio - - surface, command = queued_commands.popleft() - submitted_tasks.append( - asyncio.create_task( - self._run( - command, surface, execution_progress - ) - ) - ) - - continue - - if len(queued_commands) >= self._queue_limit or ( - not surfaces and queued_commands - ): - message.append("queue full") - import asyncio - - await asyncio.sleep(1) - continue - finally: - scheduling_progress( - count=len(queued_commands), - message=" and ".join(message), - ) - - if not surfaces and not queued_commands: - break - - with scheduling_progress( - source="rendering task", activity="rendering task" - ) as rendering_progress: - generator = surfaces[0] - surfaces = surfaces[1:] + surfaces[:1] - - try: - surface = next(generator) - except StopIteration: - surfaces.pop() - continue - - rendering_progress( - message="determining goals", - activity=f"rendering task for {surface}", - ) - - command = await self._render_command( - surface, scheduling_progress - ) - - if command is None: - continue - - queued_commands.append((str(surface), command)) - scheduling_progress(count=len(queued_commands)) - - except KeyboardInterrupt: - scheduling_progress( - message="stopped scheduling of new jobs as requested", - activity="waiting for pending tasks", - ) - else: - scheduling_progress( - message="all jobs have been scheduled", - activity="waiting for pending tasks", - ) + try: + with self._report.progress( + self, activity="Survey", count=0, what="tasks queued" + ) as scheduling_progress: + scheduling_progress(activity="running survey") - import asyncio + with scheduling_progress( + "executing tasks", + activity="executing tasks", + count=0, + what="tasks running", + ) as execution_progress: + from more_itertools import roundrobin + surfaces = roundrobin(*self._generators) + + pending = [] + + async def schedule_one(): + return await self._schedule(pool, pending, surfaces, self._goals, scheduling_progress) + + async def consume_one(): + return await self._consume(pool, pending) + + # Fill the job queue with a base line of queue_limit many jobs. + for i in range(self._queue_limit): + await schedule_one() + + try: + # Wait for a result. For each result, schedule a new task. + while await consume_one(): + if not await schedule_one(): + break + except KeyboardInterrupt: + print("keyboard interrupt") + scheduling_progress( + message="stopped scheduling of new jobs as requested", + activity="waiting for pending tasks", + ) + else: + scheduling_progress( + message="all jobs have been scheduled", + activity="waiting for pending tasks", + ) + + try: + # Wait for all pending tasks to finish. + while await consume_one(): + pass + except KeyboardInterrupt: + execution_progress( + message="not awaiting scheduled jobs anymore as requested", + activity="waiting for pending tasks", + ) - await asyncio.gather(*submitted_tasks) + except Exception: + if self._debug: + import pdb - except Exception: - if self._debug: - import pdb - - pdb.post_mortem() + pdb.post_mortem() - raise + raise + finally: + await pool.close(0) def _enable_shared_bindings(self): shared = [binding for binding in self._bindings if binding.scope == "SHARED"] @@ -241,32 +199,60 @@ def share(binding): return provide("report", objects) - async def _render_command(self, surface, progress=None): - r""" - Return the command to invoke a worker to compute the ``goals`` for ``surface``. + async def _schedule(self, pool, pending, surfaces, goals, scheduling_progress): + while True: + surface = next(surfaces, None) - >>> import asyncio - >>> from flatsurvey.surfaces import Ngon - >>> from flatsurvey.jobs import OrbitClosure + if surface is None: + return False - >>> scheduler = Scheduler(generators=[], bindings=[], goals=[OrbitClosure], reporters=[]) - >>> command = scheduler._render_command(Ngon([1, 1, 1])) - >>> asyncio.run(command) # doctest: +ELLIPSIS - ['orbit-closure', 'pickle', '--base64', '...'] + print(surface) - """ - if progress is None: + if await self._resolve_goals_from_cache(surface, self._goals): + # Everything could be answered from cached data. Proceed to next surface. + continue + + from flatsurvey.worker.worker import Worker + + from flatsurvey.pipeline.util import FactoryBindingSpec, ListBindingSpec + + bindings = list(self._bindings) + bindings.append(SurfaceBindingSpec(surface)) + + from flatsurvey.worker.dask import DaskTask + task = DaskTask(Worker.work, bindings=bindings, goals=self._goals, reporters=self._reporters) + + pending.append(pool.submit(task)) + return True - def progress(source, **kwargs): - return self._report.progress(source=source, **kwargs) + async def _consume(self, pool, pending): + import dask.distributed + completed, still_pending = await dask.distributed.wait(pending, return_when='FIRST_COMPLETED') + + pending.clear() + pending.extend(still_pending) + + if not completed: + return False + + for job in completed: + print(await job) + + return True + + async def _resolve_goals_from_cache(self, surface, goals): + r""" + Return whether all ``goals`` could be resolved from cached data. + """ bindings = list(self._bindings) from flatsurvey.pipeline.util import FactoryBindingSpec, ListBindingSpec bindings.append(FactoryBindingSpec("surface", lambda: surface)) - bindings.append(ListBindingSpec("goals", self._goals)) + bindings.append(ListBindingSpec("goals", goals)) bindings.append(ListBindingSpec("reporters", self._reporters)) + from random import randint bindings.append(FactoryBindingSpec("lot", lambda: randint(0, 2**64))) @@ -288,214 +274,24 @@ def progress(source, **kwargs): binding_specs=bindings, ) - commands = [] - - class Reporters: - def __init__(self, reporters): - self._reporters = reporters - - reporters = objects.provide(Reporters)._reporters - for reporter in reporters: - commands.extend(reporter.command()) - class Goals: def __init__(self, goals): self._goals = goals goals = [goal for goal in objects.provide(Goals)._goals] - with progress( - "resolving goals from cached data", - activity="resolvivg goals from cached data", - total=len(goals), - count=0, - what="goals", - ) as resolving_progress: - for goal in goals: - await goal.consume_cache() - resolving_progress(advance=1) - - goals = [goal for goal in goals if goal._resolved != goal.COMPLETED] - - if not goals: - return None - for goal in goals: - commands.extend(goal.command()) - - for binding in self._bindings: - from flatsurvey.pipeline.util import provide - - binding = provide(binding.name, objects) - if binding in reporters: - continue - if binding in goals: - continue - if binding == surface: - continue - - # We already consumed the cache above. There is no need to have the - # worker reread the cache. - from flatsurvey.cache import Cache - - if binding.name() == Cache.name(): - continue - - commands.extend(binding.command()) - - commands.extend(surface.command()) - - return commands - - async def _run(self, command, name, progress): - command = tuple(command) + await goal.consume_cache() - if self._dry_run: - if not self._quiet: - logging.info(" ".join(command)) - return - - from multiprocessing import Process, Queue - - progress_queue = Queue() - - with self._report.progress(source=command, activity=name) as worker_progress: - - def work(command, progress_queue): - try: - from click.testing import CliRunner - - from flatsurvey.worker.worker import worker - - runner = CliRunner() - - from flatsurvey.reporting.progress import RemoteProgress - - RemoteProgress._progress_queue = progress_queue - - invocation = runner.invoke( - worker, args=command, catch_exceptions=False - ) - output = invocation.output.strip() - if output: - from logging import warning - - warning("Task produced output on stdout:\n" + output) - except Exception as e: - import traceback - from logging import error - - error( - "Process crashed: " - + " ".join(command) - + "\n" - + traceback.format_exc() - ) - progress_queue.put(("crash", str(e))) - else: - progress_queue.put(("exit",)) - - progress(advance=1) - try: - worker = Process(target=work, args=(command, progress_queue)) - worker.start() - - from asyncio import Future, get_event_loop + goals = [goal for goal in goals if goal._resolved != goal.COMPLETED] - done = Future() - loop = get_event_loop() + return not goals - def consume_progress(): - tokens = {} - entered = {} - while True: - try: - report = progress_queue.get() - try: - code = report[0] - if code == "crash": - code, message = report - progress( - source=command, - activity=name, - message=f"process crashed: {message}", - ) - break - elif code == "exit": - import time +import pinject +class SurfaceBindingSpec(pinject.BindingSpec): + def __init__(self, surface): + self._surface = surface - time.sleep(2) - break - elif code == "progress": - ( - code, - identifier, - source, - count, - advance, - total, - what, - message, - parent, - activity, - ) = report - - source = tuple(command) + (source,) - - if parent is None: - parent = command - else: - parent = tuple(command) + (parent,) - - tokens[identifier] = self._report.progress( - source=source, - count=count, - advance=advance, - total=total, - what=what, - message=message, - parent=parent, - activity=activity, - ) - elif code == "enter_context": - code, identifier = report - - entered.setdefault(identifier, []) - entered[identifier].append( - ( - tokens[identifier], - tokens[identifier].__enter__(), - ) - ) - elif code == "exit_context": - code, identifier = report - - context = entered[identifier].pop()[0] - context.__exit__(None, None, None) - else: - raise NotImplementedError(code) - - except Exception: - print("Failed to process", report) - raise - except Exception: - # When anything goes wrong here, we stop to consume - # progress so this thread does not hang forever. - import traceback - - traceback.print_exc() - break - - loop.call_soon_threadsafe(done.set_result, None) - - from threading import Thread - - progress_consumer = Thread(target=consume_progress) - progress_consumer.start() - - await done - progress_consumer.join() - - finally: - progress(advance=-1) + def provide_surface(self): + return self._surface diff --git a/flatsurvey/surfaces/ngons.py b/flatsurvey/surfaces/ngons.py index e7285d0..9383c1b 100644 --- a/flatsurvey/surfaces/ngons.py +++ b/flatsurvey/surfaces/ngons.py @@ -98,17 +98,6 @@ def __init__(self, angles, length=None, polygon=None): self.length = length if polygon is not None: - if isinstance(polygon, tuple): - # At some point we pickled the lengths of the sides instead of - # the actual polygon. We are too lazy to make these pickles - # work (because there are also two different flavors of those…) - import warnings - - warnings.warn( - "ignoring legacy pickle of ngon; reported Ngon will have incorrect edge lengths" - ) - polygon = self.polygon() - self.polygon.set_cache(polygon) if any(a == sum(angles) / (len(angles) - 2) for a in angles): @@ -575,16 +564,19 @@ def to_yaml(cls, representer, self): ) def __reduce__(self): - return (Ngon, (self.angles, self.length, self.polygon())) + return (Ngon, (self.angles, self.length, self.polygon.cache)) def __hash__(self): + if self.polygon.cache is None: + raise Exception("cannot hash Ngon whose polygon() has not been determined yet") + return hash((tuple(self.angles), self.polygon())) def __eq__(self, other): return ( isinstance(other, Ngon) and self.angles == other.angles - and self.polygon() == other.polygon() + and self.polygon.cache == other.polygon.cache ) def __ne__(self, other): diff --git a/flatsurvey/survey.py b/flatsurvey/survey.py index 9809514..4c9d62c 100644 --- a/flatsurvey/survey.py +++ b/flatsurvey/survey.py @@ -19,7 +19,6 @@ --debug --help Show this message and exit. -N, --dry-run Do not spawn any workers. - -l, --load L Do not start workers until load is below L. -q, --queue INTEGER Jobs to prepare in the background for scheduling. -v, --verbose Enable verbose message, repeat for debug message. @@ -63,7 +62,7 @@ # ********************************************************************* # This file is part of flatsurvey. # -# Copyright (C) 2020-2022 Julian Rüth +# Copyright (C) 2020-2024 Julian Rüth # # flatsurvey is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -95,14 +94,6 @@ ) @click.option("--dry-run", "-N", is_flag=True, help="Do not spawn any workers.") @click.option("--debug", is_flag=True) -@click.option( - "--load", - "-l", - metavar="L", - type=float, - default=None, - help="Do not start workers until load is below L.", -) @click.option( "--queue", "-q", @@ -116,15 +107,13 @@ count=True, help="Enable verbose message, repeat for debug message.", ) -def survey(dry_run, load, debug, queue, verbose): +def survey(dry_run, debug, queue, verbose): r""" Main command, runs a survey; specific survey objects and goals are registered automatically as subcommands. """ # For technical reasons, dry_run needs to be a parameter here. It is consumed by process() below. _ = dry_run - # For technical reasons, load needs to be a parameter here. It is consumed by process() below. - _ = load # For technical reasons, debug needs to be a parameter here. It is consumed by process() below. _ = debug # For technical reasons, queue needs to be a parameter here. It is consumed by process() below. @@ -145,17 +134,16 @@ def survey(dry_run, load, debug, queue, verbose): @survey.result_callback() -def process(subcommands, dry_run=False, load=None, debug=False, queue=128, verbose=0): +def process(subcommands, dry_run=False, debug=False, queue=128, verbose=0): r""" Run the specified subcommands of ``survey``. EXAMPLES: - We start an orbit-closure computation for a single triangle without waiting - for the system load to be low:: + We start an orbit-closure computation for a single triangle:: >>> from flatsurvey.test.cli import invoke - >>> invoke(survey, "--load=0", "ngons", "-n", "3", "--limit=3", "--literature=include", "orbit-closure") + >>> invoke(survey, "ngons", "-n", "3", "--limit=3", "--literature=include", "orbit-closure") """ if debug: @@ -184,9 +172,6 @@ def process(subcommands, dry_run=False, load=None, debug=False, queue=128, verbo else: surface_generators.append(subcommand) - if dry_run: - load = 0 - import asyncio import sys @@ -201,7 +186,6 @@ def process(subcommands, dry_run=False, load=None, debug=False, queue=128, verbo reporters=reporters, queue=queue, dry_run=dry_run, - load=load, debug=debug, ).start() ) diff --git a/flatsurvey/worker/dask.py b/flatsurvey/worker/dask.py new file mode 100644 index 0000000..835035c --- /dev/null +++ b/flatsurvey/worker/dask.py @@ -0,0 +1,61 @@ +import multiprocessing + +forkserver = multiprocessing.get_context("forkserver") +multiprocessing.set_forkserver_preload(["sage.all"]) + + +class DaskTask: + def __init__(self, callable, *args, **kwargs): + from pickle import dumps + self._dump = dumps((callable, args, kwargs)) + + def __call__(self): + DaskWorker.process(self) + + def run(self): + from pickle import loads + callable, args, kwargs =loads(self._dump) + + import asyncio + result = asyncio.run(callable(*args, **kwargs)) + print(result) + return result + + +class DaskWorker: + _singleton = None + + def __init__(self): + assert DaskWorker._singleton is None + + self._work_queue = forkserver.Queue() + self._result_queue = forkserver.Queue() + self._processor = forkserver.Process(target=DaskWorker._process, args=(self,), daemon=True) + self._processor.start() + + @staticmethod + def _ensure_started(): + import sys + if 'sage' in sys.modules: + raise Exception("sage must not be loaded in dask worker") + + if DaskWorker._singleton is None: + DaskWorker._singleton = DaskWorker() + + @staticmethod + def _process(self): + while True: + try: + task = self._work_queue.get() + except ValueError: + break + print(task) + + self._result_queue.put(task.run()) + print("done.") + + @staticmethod + def process(task): + DaskWorker._ensure_started() + DaskWorker._singleton._work_queue.put(task) + return DaskWorker._singleton._result_queue.get() diff --git a/flatsurvey/worker/worker.py b/flatsurvey/worker/worker.py index 511ceae..f5bcae4 100644 --- a/flatsurvey/worker/worker.py +++ b/flatsurvey/worker/worker.py @@ -140,21 +140,8 @@ def process(commands, debug, verbose): logger.setLevel(logging.DEBUG if verbose > 1 else logging.INFO) try: - while True: - objects = Worker.make_object_graph(commands) - - try: - import asyncio - - asyncio.run(objects.provide(Worker).start()) - except Restart as restart: - commands = [ - restart.rewrite_command(command, objects=objects) - for command in commands - ] - continue - - break + import asyncio + asyncio.run(Worker.work(commands=commands)) except Exception: if debug: pdb.post_mortem() @@ -183,10 +170,22 @@ def __init__( pass @classmethod - def make_object_graph(cls, commands): - bindings = [] - goals = [] - reporters = [] + async def work(cls, /, bindings=[], goals=[], reporters=[], commands=[]): + objects = Worker.make_object_graph(bindings=bindings, goals=goals, reporters=reporters, commands=commands) + + try: + await objects.provide(Worker).start() + except Restart as restart: + await Worker.work(bindings=bindings, goals=goals, reporters=reporters, commands=[ + restart.rewrite_command(command, objects=objects) + for command in commands + ]) + + @classmethod + def make_object_graph(cls, /, bindings=[], goals=[], reporters=[], commands=[]): + bindings = list(bindings) + goals = list(goals) + reporters = list(reporters) for command in commands: bindings.extend(command.get("bindings", [])) @@ -215,6 +214,7 @@ async def start(self): r""" Run until all our goals are resolved. """ + assert self._goals try: for goal in self._goals: await goal.consume_cache()