-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Issue #32: Initial implementation of the non-blocking tagging task allocator. #47
Open
rafa-be
wants to merge
1
commit into
Citi:main
Choose a base branch
from
rafa-be:task_tagging
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,4 +4,5 @@ graphlib-backport; python_version < '3.9' | |
psutil | ||
pycapnp | ||
pyzmq | ||
sortedcontainers | ||
tblib |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = "1.9.1" | ||
__version__ = "1.9.2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,242 @@ | ||
import dataclasses | ||
import typing | ||
from collections import OrderedDict, defaultdict | ||
from itertools import takewhile | ||
from sortedcontainers import SortedList | ||
from typing import Dict, Iterable, List, Optional, Set | ||
|
||
from scaler.protocol.python.message import Task | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class _TaskHolder: | ||
task_id: bytes = dataclasses.field() | ||
tags: Set[str] = dataclasses.field() | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class _WorkerHolder: | ||
worker_id: bytes = dataclasses.field() | ||
tags: Set[str] = dataclasses.field() | ||
|
||
# Queued tasks, ordered from oldest to youngest tasks. | ||
task_id_to_task: typing.OrderedDict[bytes, _TaskHolder] = dataclasses.field(default_factory=OrderedDict) | ||
|
||
def n_tasks(self) -> int: | ||
return len(self.task_id_to_task) | ||
|
||
def copy(self) -> "_WorkerHolder": | ||
return _WorkerHolder(self.worker_id, self.tags, self.task_id_to_task.copy()) | ||
|
||
|
||
class TaggedAllocator: # FIXME: remove async. methods from the TaskAllocator mixin to make it a derivative. | ||
def __init__(self, max_tasks_per_worker: int): | ||
self._max_tasks_per_worker = max_tasks_per_worker | ||
|
||
self._worker_id_to_worker: Dict[bytes, _WorkerHolder] = {} | ||
|
||
self._task_id_to_worker_id: Dict[bytes, bytes] = {} | ||
self._tag_to_worker_ids: Dict[str, Set[bytes]] = {} | ||
|
||
def add_worker(self, worker_id: bytes, tags: Set[str]) -> bool: | ||
if worker_id in self._worker_id_to_worker: | ||
return False | ||
|
||
worker = _WorkerHolder(worker_id=worker_id, tags=tags) | ||
self._worker_id_to_worker[worker_id] = worker | ||
|
||
for tag in tags: | ||
if tag not in self._tag_to_worker_ids: | ||
self._tag_to_worker_ids[tag] = set() | ||
|
||
self._tag_to_worker_ids[tag].add(worker.worker_id) | ||
|
||
return True | ||
|
||
def remove_worker(self, worker_id: bytes) -> List[bytes]: | ||
worker = self._worker_id_to_worker.pop(worker_id, None) | ||
|
||
if worker is None: | ||
return [] | ||
|
||
for tag in worker.tags: | ||
self._tag_to_worker_ids[tag].discard(worker.worker_id) | ||
if len(self._tag_to_worker_ids[tag]) == 0: | ||
self._tag_to_worker_ids.pop(tag) | ||
|
||
task_ids = list(worker.task_id_to_task.keys()) | ||
for task_id in task_ids: | ||
self._task_id_to_worker_id.pop(task_id) | ||
|
||
return task_ids | ||
|
||
def get_worker_ids(self) -> Set[bytes]: | ||
return set(self._worker_id_to_worker.keys()) | ||
|
||
def assign_task(self, task: Task) -> Optional[bytes]: | ||
available_workers = self.__get_available_workers_for_tags(task.tags) | ||
|
||
if len(available_workers) <= 0: | ||
return None | ||
|
||
min_loaded_worker = min(available_workers, key=lambda worker: worker.n_tasks()) | ||
min_loaded_worker.task_id_to_task[task.task_id] = _TaskHolder(task.task_id, task.tags) | ||
|
||
self._task_id_to_worker_id[task.task_id] = min_loaded_worker.worker_id | ||
|
||
return min_loaded_worker.worker_id | ||
|
||
def remove_task(self, task_id: bytes) -> Optional[bytes]: | ||
worker_id = self._task_id_to_worker_id.pop(task_id, None) | ||
|
||
if worker_id is None: | ||
return None | ||
|
||
worker = self._worker_id_to_worker[worker_id] | ||
worker.task_id_to_task.pop(task_id) | ||
|
||
return worker_id | ||
|
||
def balance(self) -> Dict[bytes, List[bytes]]: | ||
"""Returns, for every worker id, the list of task ids to balance out.""" | ||
|
||
has_idle_workers = any(worker.n_tasks() == 0 for worker in self._worker_id_to_worker.values()) | ||
|
||
if not has_idle_workers: | ||
return {} | ||
|
||
# The balancing algorithm works by trying to move tasks from workers that have more queued tasks than the | ||
# average (high-load workers) to workers that have less tasks than the average (low-load workers). | ||
# | ||
# Because of the tag constraints, this might result in less than optimal balancing. However, it will greatly | ||
# limit the number of messages transmitted to workers, and reduce the algorithmic worst-case of the balancing | ||
# process. | ||
# | ||
# The overall worst-case time complexity of the balancing algorithm is: | ||
# | ||
# O(n_workers * log(n_workers) + n_tasks * n_workers * n_tags) | ||
# | ||
# However, if the cluster does not use any tag, time complexity is always: | ||
# | ||
# O(n_workers * log(n_workers) + n_tasks * log(n_workers)) | ||
# | ||
# See <https://github.com/Citi/scaler/issues/32#issuecomment-2541897645> for more details. | ||
|
||
n_tasks = sum(worker.n_tasks() for worker in self._worker_id_to_worker.values()) | ||
avg_tasks_per_worker = n_tasks / len(self._worker_id_to_worker) | ||
|
||
def is_balanced(worker: _WorkerHolder) -> bool: | ||
return abs(worker.n_tasks() - avg_tasks_per_worker) < 1 | ||
|
||
# First, we create a copy of the current workers objects so that we can modify their respective task queues. | ||
# We also filter out workers that are already balanced as we will not touch these. | ||
# | ||
# Time complexity is O(n_workers) | ||
|
||
workers = [worker.copy() for worker in self._worker_id_to_worker.values() if not is_balanced(worker)] | ||
|
||
# Then, we sort the remaining workers by the number of queued tasks. | ||
# | ||
# Time complexity is O(n_workers * log(n_workers)) | ||
|
||
sorted_workers: SortedList[_WorkerHolder] = SortedList(workers, key=lambda worker: worker.n_tasks()) | ||
|
||
# Finally, we repeatedly remove one task from the most loaded worker until either: | ||
# | ||
# - all workers are balanced; | ||
# - we cannot find a low-load worker than can accept tasks from a high-load worker. | ||
# | ||
# Worst-case time complexity is O(n_tasks * n_workers * n_tags). If no tag is used in the cluster, complexity is | ||
# always O(n_tasks * log(n_workers)) | ||
|
||
balancing_advice: Dict[bytes, List[bytes]] = defaultdict(list) | ||
unbalanceable_tasks: Set[bytes] = set() | ||
|
||
while len(sorted_workers) >= 2: | ||
most_loaded_worker: _WorkerHolder = sorted_workers.pop(-1) | ||
|
||
if most_loaded_worker.n_tasks() - avg_tasks_per_worker < 1: | ||
# Most loaded worker is not high-load, stop | ||
break | ||
|
||
# Go through all of the most loaded worker's tasks, trying to find a low-load worker that can accept it. | ||
|
||
receiving_worker: Optional[_WorkerHolder] = None | ||
moved_task: Optional[_TaskHolder] = None | ||
|
||
for task in reversed(most_loaded_worker.task_id_to_task.values()): # Try to balance youngest tasks first. | ||
if task.task_id in unbalanceable_tasks: | ||
continue | ||
|
||
worker_candidates = takewhile(lambda worker: worker.n_tasks() < avg_tasks_per_worker, sorted_workers) | ||
receiving_worker_index = self.__balance_try_reassign_task(task, worker_candidates) | ||
|
||
if receiving_worker_index is not None: | ||
receiving_worker = sorted_workers.pop(receiving_worker_index) | ||
moved_task = task | ||
break | ||
else: | ||
# We could not find a receiving worker for this task, remember the task as unbalanceable in case the | ||
# worker pops-up again. This greatly reduces the worst-case big-O complexity of the algorithm. | ||
unbalanceable_tasks.add(task.task_id) | ||
|
||
# Re-inserts the workers in the sorted list if these can be balanced more. | ||
|
||
if moved_task is not None: | ||
assert receiving_worker is not None | ||
|
||
balancing_advice[most_loaded_worker.worker_id].append(moved_task.task_id) | ||
|
||
most_loaded_worker.task_id_to_task.pop(moved_task.task_id) | ||
receiving_worker.task_id_to_task[moved_task.task_id] = moved_task | ||
|
||
if not is_balanced(most_loaded_worker): | ||
sorted_workers.add(most_loaded_worker) | ||
|
||
if not is_balanced(receiving_worker): | ||
sorted_workers.add(receiving_worker) | ||
|
||
return balancing_advice | ||
|
||
@staticmethod | ||
def __balance_try_reassign_task(task: _TaskHolder, worker_candidates: Iterable[_WorkerHolder]) -> Optional[int]: | ||
"""Returns the index of the first worker that can accept the task.""" | ||
|
||
# Time complexity is O(n_worker * n_tags) | ||
|
||
for worker_index, worker in enumerate(worker_candidates): | ||
if task.tags.issubset(worker.tags): | ||
return worker_index | ||
|
||
return None | ||
|
||
def get_assigned_worker(self, task_id: bytes) -> Optional[bytes]: | ||
if task_id not in self._task_id_to_worker_id: | ||
return None | ||
|
||
return self._task_id_to_worker_id[task_id] | ||
|
||
def has_available_worker(self, tags: Optional[Set[str]] = None) -> bool: | ||
if tags is None: | ||
tags = set() | ||
|
||
return len(self.__get_available_workers_for_tags(tags)) > 0 | ||
|
||
def statistics(self) -> Dict: | ||
return { | ||
worker.worker_id: {"free": self._max_tasks_per_worker - worker.n_tasks(), "sent": worker.n_tasks()} | ||
for worker in self._worker_id_to_worker.values() | ||
} | ||
|
||
def __get_available_workers_for_tags(self, tags: Set[str]) -> List[_WorkerHolder]: | ||
if any(tag not in self._tag_to_worker_ids for tag in tags): | ||
return [] | ||
|
||
matching_worker_ids = set(self._worker_id_to_worker.keys()) | ||
|
||
for tag in tags: | ||
matching_worker_ids.intersection_update(self._tag_to_worker_ids[tag]) | ||
|
||
matching_workers = [self._worker_id_to_worker[worker_id] for worker_id in matching_worker_ids] | ||
|
||
return [worker for worker in matching_workers if worker.n_tasks() < self._max_tasks_per_worker] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not really add a new dependency, as pycapnp already has it as a dependency.