Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #32: Initial implementation of the non-blocking tagging task allocator. #47

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ graphlib-backport; python_version < '3.9'
psutil
pycapnp
pyzmq
sortedcontainers
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not really add a new dependency, as pycapnp already has it as a dependency.

tblib
2 changes: 1 addition & 1 deletion scaler/about.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.9.1"
__version__ = "1.9.2"
10 changes: 9 additions & 1 deletion scaler/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def get(
metadata=b"",
func_object_id=b"",
function_args=[],
tags=set(),
),
is_delayed=not block,
group_task_id=None,
Expand Down Expand Up @@ -382,6 +383,7 @@ def __submit(self, function_object_id: bytes, args: Tuple[Any, ...], delayed: bo
metadata=task_flags_bytes,
func_object_id=function_object_id,
function_args=arguments,
tags=set(),
)

future = self._future_factory(task=task, is_delayed=delayed, group_task_id=None)
Expand Down Expand Up @@ -498,6 +500,7 @@ def __construct_graph(
metadata=task_flags_bytes,
func_object_id=function_cache.object_id,
function_args=arguments,
tags=set(),
)

result_task_ids = [node_name_to_task_id[key] for key in keys if key in call_graph]
Expand All @@ -515,7 +518,12 @@ def __construct_graph(
argument, data = node_name_to_arguments[key]
future: ScalerFuture = self._future_factory(
task=Task.new_msg(
task_id=argument.data, source=self._identity, metadata=b"", func_object_id=b"", function_args=[]
task_id=argument.data,
source=self._identity,
metadata=b"",
func_object_id=b"",
function_args=[],
tags=set(),
),
is_delayed=False,
group_task_id=graph_task_id,
Expand Down
2 changes: 2 additions & 0 deletions scaler/protocol/capnp/message.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ struct Task {
metadata @2 :Data;
funcObjectId @3 :Data;
functionArgs @4 :List(Argument);
tags @5 :List(Text);

struct Argument {
type @0 :ArgumentType;
Expand Down Expand Up @@ -64,6 +65,7 @@ struct WorkerHeartbeat {
latencyUS @3 :UInt32;
taskLock @4 :Bool;
processors @5 :List(Status.ProcessorStatus);
tags @6 :List(Text);
}

struct WorkerHeartbeatEcho {
Expand Down
18 changes: 17 additions & 1 deletion scaler/protocol/python/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,18 @@ def function_args(self) -> List[Argument]:
for arg in self._msg.functionArgs
]

@property
def tags(self) -> Set[str]:
return set(self._msg.tags)

@staticmethod
def new_msg(
task_id: bytes, source: bytes, metadata: bytes, func_object_id: bytes, function_args: List[Argument]
task_id: bytes,
source: bytes,
metadata: bytes,
func_object_id: bytes,
function_args: List[Argument],
tags: Set[str],
) -> "Task":
return Task(
_message.Task(
Expand All @@ -75,6 +84,7 @@ def new_msg(
metadata=metadata,
funcObjectId=func_object_id,
functionArgs=[_message.Task.Argument(type=arg.type.value, data=arg.data) for arg in function_args],
tags=list(tags),
)
)

Expand Down Expand Up @@ -256,6 +266,10 @@ def task_lock(self) -> bool:
def processors(self) -> List[ProcessorStatus]:
return [ProcessorStatus(p) for p in self._msg.processors]

@property
def tags(self) -> Set[str]:
return set(self._msg.tags)

@staticmethod
def new_msg(
agent: Resource,
Expand All @@ -264,6 +278,7 @@ def new_msg(
latency_us: int,
task_lock: bool,
processors: List[ProcessorStatus],
tags: Set[str],
) -> "WorkerHeartbeat":
return WorkerHeartbeat(
_message.WorkerHeartbeat(
Expand All @@ -273,6 +288,7 @@ def new_msg(
latencyUS=latency_us,
taskLock=task_lock,
processors=[p.get_message() for p in processors],
tags=list(tags),
)
)

Expand Down
242 changes: 242 additions & 0 deletions scaler/scheduler/allocators/tagged_allocator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
import dataclasses
import typing
from collections import OrderedDict, defaultdict
from itertools import takewhile
from sortedcontainers import SortedList
from typing import Dict, Iterable, List, Optional, Set

from scaler.protocol.python.message import Task


@dataclasses.dataclass(frozen=True)
class _TaskHolder:
task_id: bytes = dataclasses.field()
tags: Set[str] = dataclasses.field()


@dataclasses.dataclass(frozen=True)
class _WorkerHolder:
worker_id: bytes = dataclasses.field()
tags: Set[str] = dataclasses.field()

# Queued tasks, ordered from oldest to youngest tasks.
task_id_to_task: typing.OrderedDict[bytes, _TaskHolder] = dataclasses.field(default_factory=OrderedDict)

def n_tasks(self) -> int:
return len(self.task_id_to_task)

def copy(self) -> "_WorkerHolder":
return _WorkerHolder(self.worker_id, self.tags, self.task_id_to_task.copy())


class TaggedAllocator: # FIXME: remove async. methods from the TaskAllocator mixin to make it a derivative.
def __init__(self, max_tasks_per_worker: int):
self._max_tasks_per_worker = max_tasks_per_worker

self._worker_id_to_worker: Dict[bytes, _WorkerHolder] = {}

self._task_id_to_worker_id: Dict[bytes, bytes] = {}
self._tag_to_worker_ids: Dict[str, Set[bytes]] = {}

def add_worker(self, worker_id: bytes, tags: Set[str]) -> bool:
if worker_id in self._worker_id_to_worker:
return False

worker = _WorkerHolder(worker_id=worker_id, tags=tags)
self._worker_id_to_worker[worker_id] = worker

for tag in tags:
if tag not in self._tag_to_worker_ids:
self._tag_to_worker_ids[tag] = set()

self._tag_to_worker_ids[tag].add(worker.worker_id)

return True

def remove_worker(self, worker_id: bytes) -> List[bytes]:
worker = self._worker_id_to_worker.pop(worker_id, None)

if worker is None:
return []

for tag in worker.tags:
self._tag_to_worker_ids[tag].discard(worker.worker_id)
if len(self._tag_to_worker_ids[tag]) == 0:
self._tag_to_worker_ids.pop(tag)

task_ids = list(worker.task_id_to_task.keys())
for task_id in task_ids:
self._task_id_to_worker_id.pop(task_id)

return task_ids

def get_worker_ids(self) -> Set[bytes]:
return set(self._worker_id_to_worker.keys())

def assign_task(self, task: Task) -> Optional[bytes]:
available_workers = self.__get_available_workers_for_tags(task.tags)

if len(available_workers) <= 0:
return None

min_loaded_worker = min(available_workers, key=lambda worker: worker.n_tasks())
min_loaded_worker.task_id_to_task[task.task_id] = _TaskHolder(task.task_id, task.tags)

self._task_id_to_worker_id[task.task_id] = min_loaded_worker.worker_id

return min_loaded_worker.worker_id

def remove_task(self, task_id: bytes) -> Optional[bytes]:
worker_id = self._task_id_to_worker_id.pop(task_id, None)

if worker_id is None:
return None

worker = self._worker_id_to_worker[worker_id]
worker.task_id_to_task.pop(task_id)

return worker_id

def balance(self) -> Dict[bytes, List[bytes]]:
"""Returns, for every worker id, the list of task ids to balance out."""

has_idle_workers = any(worker.n_tasks() == 0 for worker in self._worker_id_to_worker.values())

if not has_idle_workers:
return {}

# The balancing algorithm works by trying to move tasks from workers that have more queued tasks than the
# average (high-load workers) to workers that have less tasks than the average (low-load workers).
#
# Because of the tag constraints, this might result in less than optimal balancing. However, it will greatly
# limit the number of messages transmitted to workers, and reduce the algorithmic worst-case of the balancing
# process.
#
# The overall worst-case time complexity of the balancing algorithm is:
#
# O(n_workers * log(n_workers) + n_tasks * n_workers * n_tags)
#
# However, if the cluster does not use any tag, time complexity is always:
#
# O(n_workers * log(n_workers) + n_tasks * log(n_workers))
#
# See <https://github.com/Citi/scaler/issues/32#issuecomment-2541897645> for more details.

n_tasks = sum(worker.n_tasks() for worker in self._worker_id_to_worker.values())
avg_tasks_per_worker = n_tasks / len(self._worker_id_to_worker)

def is_balanced(worker: _WorkerHolder) -> bool:
return abs(worker.n_tasks() - avg_tasks_per_worker) < 1

# First, we create a copy of the current workers objects so that we can modify their respective task queues.
# We also filter out workers that are already balanced as we will not touch these.
#
# Time complexity is O(n_workers)

workers = [worker.copy() for worker in self._worker_id_to_worker.values() if not is_balanced(worker)]

# Then, we sort the remaining workers by the number of queued tasks.
#
# Time complexity is O(n_workers * log(n_workers))

sorted_workers: SortedList[_WorkerHolder] = SortedList(workers, key=lambda worker: worker.n_tasks())

# Finally, we repeatedly remove one task from the most loaded worker until either:
#
# - all workers are balanced;
# - we cannot find a low-load worker than can accept tasks from a high-load worker.
#
# Worst-case time complexity is O(n_tasks * n_workers * n_tags). If no tag is used in the cluster, complexity is
# always O(n_tasks * log(n_workers))

balancing_advice: Dict[bytes, List[bytes]] = defaultdict(list)
unbalanceable_tasks: Set[bytes] = set()

while len(sorted_workers) >= 2:
most_loaded_worker: _WorkerHolder = sorted_workers.pop(-1)

if most_loaded_worker.n_tasks() - avg_tasks_per_worker < 1:
# Most loaded worker is not high-load, stop
break

# Go through all of the most loaded worker's tasks, trying to find a low-load worker that can accept it.

receiving_worker: Optional[_WorkerHolder] = None
moved_task: Optional[_TaskHolder] = None

for task in reversed(most_loaded_worker.task_id_to_task.values()): # Try to balance youngest tasks first.
if task.task_id in unbalanceable_tasks:
continue

worker_candidates = takewhile(lambda worker: worker.n_tasks() < avg_tasks_per_worker, sorted_workers)
receiving_worker_index = self.__balance_try_reassign_task(task, worker_candidates)

if receiving_worker_index is not None:
receiving_worker = sorted_workers.pop(receiving_worker_index)
moved_task = task
break
else:
# We could not find a receiving worker for this task, remember the task as unbalanceable in case the
# worker pops-up again. This greatly reduces the worst-case big-O complexity of the algorithm.
unbalanceable_tasks.add(task.task_id)

# Re-inserts the workers in the sorted list if these can be balanced more.

if moved_task is not None:
assert receiving_worker is not None

balancing_advice[most_loaded_worker.worker_id].append(moved_task.task_id)

most_loaded_worker.task_id_to_task.pop(moved_task.task_id)
receiving_worker.task_id_to_task[moved_task.task_id] = moved_task

if not is_balanced(most_loaded_worker):
sorted_workers.add(most_loaded_worker)

if not is_balanced(receiving_worker):
sorted_workers.add(receiving_worker)

return balancing_advice

@staticmethod
def __balance_try_reassign_task(task: _TaskHolder, worker_candidates: Iterable[_WorkerHolder]) -> Optional[int]:
"""Returns the index of the first worker that can accept the task."""

# Time complexity is O(n_worker * n_tags)

for worker_index, worker in enumerate(worker_candidates):
if task.tags.issubset(worker.tags):
return worker_index

return None

def get_assigned_worker(self, task_id: bytes) -> Optional[bytes]:
if task_id not in self._task_id_to_worker_id:
return None

return self._task_id_to_worker_id[task_id]

def has_available_worker(self, tags: Optional[Set[str]] = None) -> bool:
if tags is None:
tags = set()

return len(self.__get_available_workers_for_tags(tags)) > 0

def statistics(self) -> Dict:
return {
worker.worker_id: {"free": self._max_tasks_per_worker - worker.n_tasks(), "sent": worker.n_tasks()}
for worker in self._worker_id_to_worker.values()
}

def __get_available_workers_for_tags(self, tags: Set[str]) -> List[_WorkerHolder]:
if any(tag not in self._tag_to_worker_ids for tag in tags):
return []

matching_worker_ids = set(self._worker_id_to_worker.keys())

for tag in tags:
matching_worker_ids.intersection_update(self._tag_to_worker_ids[tag])

matching_workers = [self._worker_id_to_worker[worker_id] for worker_id in matching_worker_ids]

return [worker for worker in matching_workers if worker.n_tasks() < self._max_tasks_per_worker]
1 change: 1 addition & 0 deletions scaler/scheduler/graph_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ async def __check_one_graph(self, graph_task_id: bytes):
metadata=task_info.task.metadata,
func_object_id=task_info.task.func_object_id,
function_args=[self.__get_argument(graph_task_id, arg) for arg in task_info.task.function_args],
tags=set(),
)

await self._task_manager.on_task_new(graph_info.client, task)
Expand Down
1 change: 1 addition & 0 deletions scaler/worker/agent/heartbeat_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ async def routine(self):
self._latency_us,
self._processor_manager.can_accept_task(),
[self.__get_processor_status_from_holder(processor) for processor in processors],
set(),
)
)
self._start_timestamp_ns = time.time_ns()
Expand Down
1 change: 1 addition & 0 deletions scaler/worker/symphony/heartbeat_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ async def routine(self):
self._latency_us,
self._worker_task_manager.can_accept_task(),
[],
set(),
)
)
self._start_timestamp_ns = time.time_ns()
Loading
Loading