Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix for out of order dependency locking #147

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 62 additions & 32 deletions parla/task_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from parla.dataflow import Dataflow

# Logger configuration (uncomment and adjust level if needed)
#logging.basicConfig(level=logging.DEBUG)
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

__all__ = ["Task", "SchedulerContext", "DeviceSetRequirements", "OptionsRequirements",
Expand Down Expand Up @@ -52,12 +52,10 @@ def __init__(self):
self._dependencies = defaultdict(list)

def add(self, dependency: 'TaskID', dependent: 'TaskID'):
if dependent.task is None:
raise ValueError(
"dependent task %s should have been spawned", repr(dependent)
)
print("Adding dependency: {} -> {}".format(dependency, dependent), flush=True)
with self._mutex:
self._dependencies[dependency].append(dependent)
print("Added dependency: {} -> {}".format(dependency, dependent), flush=True)

def get_dependents(self, tid: 'TaskID') -> List['TaskID']:
with self._mutex:
Expand Down Expand Up @@ -85,6 +83,8 @@ def is_terminal(self):

# TODO(lhc): Why do we need dependency information at here?
# It is not exploited/managed correctly.


class TaskRunning(TaskState):
__slots__ = ["func", "args", "dependencies"]

Expand Down Expand Up @@ -125,6 +125,7 @@ def __repr__(self):
else:
return "Functionless task"


class TaskCompleted(TaskState):
__slots__ = ["ret"]

Expand Down Expand Up @@ -353,7 +354,6 @@ def _finish(self, ctx: 'SchedulerContext'):
"""Cleanup works after executing the task."""
raise NotImplementedError()


def run(self):
assert self._assigned, "Task was not assigned before running."
assert isinstance(self.req, EnvironmentRequirements), \
Expand All @@ -364,7 +364,8 @@ def run(self):
task_state = TaskException(RuntimeError("Unknown fatal error"))
# Run the task and assign the new task state
try:
assert isinstance(self._state, TaskRunning), " Task is not running state: {} on task: {}".format(self._state, self.taskid)
assert isinstance(self._state, TaskRunning), " Task is not running state: {} on task: {}".format(
self._state, self.taskid)
# TODO(lhc): This assumes Parla only has two devices.
# The reason why I am trying to do is importing
# Parla's cuda.py is expensive.
Expand Down Expand Up @@ -437,9 +438,9 @@ def dependencies(self, dependencies: Collection['Task']):

def is_blocked_by_dependencies(self) -> bool:
if self._num_blocking_dependencies == 0:
return False
return False
else:
return True
return True

def is_blocked_by_dependencies_mutex(self) -> bool:
with self._mutex:
Expand All @@ -454,8 +455,8 @@ def _is_schedulable(self) -> bool:
# is scheduable.
# TODO(hc): Task state should be defined better.
return isinstance(self._state, TaskRunning) and \
not self.is_blocked_by_dependencies() and\
self._assigned
not self.is_blocked_by_dependencies() and\
self._assigned

def _enqueue_to_scheduler(self):
get_scheduler_context().enqueue_task(self)
Expand Down Expand Up @@ -560,6 +561,8 @@ def __init__(self, func, args, dependencies: Collection["Task"], taskid: 'TaskID
super(ComputeTask, self).__init__(
dependencies, taskid, req, name, init_state=TaskWaiting()
)

print("Creating task")
with self._mutex:
# This task could be spawend when it is ready.
# To set its state Running when it is running later,
Expand All @@ -575,11 +578,14 @@ def __init__(self, func, args, dependencies: Collection["Task"], taskid: 'TaskID
self.num_unspawned_dependencies = num_unspawned_dependencies

# Enable tasks who have waited for this task to map.
print("Notifying spawned dependents")
self.__notify_spawned_dependents()
print("Notified spawned dependents")

# If this task is not waiting for any dependent tasks,
# enqueue onto the spawned queue.
if self.num_unspawned_dependencies == 0:
print("Ready to map", flush=True)
self._ready_to_map()
get_scheduler_context().scheduler.incr_active_compute_tasks()
logger.debug("Task %r: Creating", self)
Expand All @@ -589,11 +595,17 @@ def __notify_spawned_dependents(self):
PRIVATE USE ONLY. Not thread-safe and should be called WITH ITS MUTEX.
"""
# Get the list of all waiting dependents from the global collection.
print("Getting dependents", flush=True)
dependents = unspawned_dependencies.get_dependents(self.taskid)
print("Got dependents", flush=True)
for d_tid in dependents:
print("Getting dependent: ", d_tid, flush=True)
dt = d_tid.task
assert isinstance(dt, ComputeTask), type(dt)
print("Handling dependent spawn")
self._dependents.append(dt)
dt._handle_dependency_spawn(self)
print("Handled dependent spawn")

def _ready_to_map(self):
assert self.num_unspawned_dependencies == 0
Expand All @@ -606,10 +618,16 @@ def _ready_to_map(self):

def _handle_dependency_spawn(self, dependency: "Task"):
with self._mutex:
print("Inside handle dependency spawn", flush=True)
self.num_unspawned_dependencies -= 1
self._add_dependency(dependency)
self._dependencies.append(dependency)
# self._add_dependency(dependency)
print("Added dependency: ",
self.num_unspawned_dependencies, flush=True)
if self.num_unspawned_dependencies == 0:
print("Ready to map", flush=True)
self._ready_to_map()
print("Finished ready to map", flush=True)

def _execute_task(self):
return self._state.func(self, *self._state.args)
Expand All @@ -630,7 +648,6 @@ def _finish(self, ctx):
ctx.scheduler.update_mapped_task_count_mutex(self, d, -1)
ctx.scheduler.update_launched_task_count_mutex(self, d, -1)


# _finish() can be called more than once on global task.
if (self.dataflow != None):
# Update OUT parrays which may have changed size from 0 to something
Expand All @@ -655,7 +672,8 @@ def __init__(self, computation_task: ComputeTask, taskid,
super(DataMovementTask, self).__init__([], taskid, req, name,
# TODO(lhc): temporary task running state.
# This would be a data movement kernel.
init_state=TaskRunning(None, None, None),
init_state=TaskRunning(
None, None, None),
init_assigned=True
)
with self._mutex:
Expand Down Expand Up @@ -820,15 +838,21 @@ def spawn_task(
# dependent tasks are spawned.
spawned_dependencies = []
for dependency in list(dependencies):
print("Looking at dependency: ", dependency, flush=True)
if isinstance(dependency, TaskID):
# If the dependency is not yet spawned, temporarily removes it from
# a task's dependency list.
print("Dependency is a TaskID", flush=True)
unspawned_dependencies.add(dependency, taskid)
else:
spawned_dependencies.append(dependency)
num_unspawned_dependencies = len(
dependencies) - len(spawned_dependencies)

print("Spawned dependencies: ", spawned_dependencies, flush=True)
print("num_unspawned_dependencies: ",
num_unspawned_dependencies, flush=True)

return ComputeTask(
function, args, spawned_dependencies, taskid, req, dataflow, name,
num_unspawned_dependencies
Expand Down Expand Up @@ -1020,6 +1044,8 @@ def run(self) -> None:
f"[WorkerThread %d] Starting: %s", self.index, self.task.name)
self._status = "Running Task {}".format(self.task)
self.task.run()

print("Finished task: ", self.task.name, flush=True)
self.scheduler.decr_running_tasks()
# Free self back to worker pool
self._remove_task()
Expand Down Expand Up @@ -1620,7 +1646,6 @@ def enqueue_dev_queue(self, dev, task: Task):
else:
self._datamove_task_dev_queues[dev].append(task)


def enqueue_dev_queue_mutex(self, dev, task: Task):
"""Enqueue a task on the device queue.
Note that this enqueue has no data race.
Expand Down Expand Up @@ -1826,8 +1851,8 @@ def fill_curr_spawned_task_queue(self):
with self._spawned_queue_monitor:
if (len(self._new_spawned_task_queue) > 0):
new_q = self._new_spawned_task_queue
# Only map tasks whose dependencies are all mapped
# to avoid resource deadlock.
# Only map tasks whose dependencies are all mapped
# to avoid resource deadlock.
new_tasks = []
failed_tasks = []
for _ in range(len(new_q)):
Expand Down Expand Up @@ -1953,7 +1978,7 @@ def _construct_datamove_task(self, target_data, compute_task: ComputeTask, opera
# create more data movement tasks and make additional dependencies.
# The computation task should not be run until all the data movement
# tasks are created.
#
#
if not datamove_task.is_blocked_by_dependencies_mutex():
return datamove_task
return None
Expand Down Expand Up @@ -1993,19 +2018,19 @@ def _map_tasks(self):
mappable_datamove_tasks = []
for data in task.dataflow.input:
dtask = self._construct_datamove_task(
data, task, OperandType.IN)
data, task, OperandType.IN)
if dtask is not None:
mappable_datamove_tasks.append(dtask)
for data in task.dataflow.output:
dtask = self._construct_datamove_task(
data, task, OperandType.OUT)
data, task, OperandType.OUT)
if dtask is not None:
mappable_datamove_tasks.append(dtask)
mappable_datamove_tasks.append(dtask)
for data in task.dataflow.inout:
dtask = self._construct_datamove_task(
data, task, OperandType.INOUT)
data, task, OperandType.INOUT)
if dtask is not None:
mappable_datamove_tasks.append(dtask)
mappable_datamove_tasks.append(dtask)

# Update parray tracking and task count on the device
for parray in (task.dataflow.input + task.dataflow.inout + task.dataflow.output):
Expand Down Expand Up @@ -2088,7 +2113,7 @@ def _launch_task(self, queue, dev: Device, is_cpu: bool, num_launched_tasks):
while len(queue):
task = queue.pop()
worker = self._free_worker_threads.pop() # grab a worker
#print("Worker thread:", str(worker.index))
# print("Worker thread:", str(worker.index))
logger.info(f"[Scheduler] Launching %s task, %r on %r",
dev.architecture.id, task, worker)
# XXX(lhc): The error that tried to launch a completed task
Expand All @@ -2111,7 +2136,7 @@ def _launch_task(self, queue, dev: Device, is_cpu: bool, num_launched_tasks):
if isinstance(task._state, TaskCompleted):
logger.info(f"This should not be passed.")
continue
self.scheduler.incr_running_tasks()
self.scheduler.incr_running_tasks()
worker.assign_task(task)
logger.debug(f"[Scheduler] Launched %r", task)

Expand All @@ -2136,13 +2161,17 @@ def _launch_tasks(self):
compute_queue = self._compute_task_dev_queues[dev]
datamove_queue = self._datamove_task_dev_queues[dev]
if len(compute_queue) > 0:
num_launched_compute_task_count = self.get_launched_compute_task_count(dev)
num_launched_compute_task_count = self.get_launched_compute_task_count(
dev)
if is_cpu or num_launched_compute_task_count < (self._num_colocatable_tasks + 1):
self._launch_task(compute_queue, dev, is_cpu, num_launched_compute_task_count)
self._launch_task(
compute_queue, dev, is_cpu, num_launched_compute_task_count)
if len(datamove_queue) > 0:
num_launched_datamove_task_count = self.get_launched_datamove_task_count(dev)
num_launched_datamove_task_count = self.get_launched_datamove_task_count(
dev)
if is_cpu or num_launched_datamove_task_count < (self._num_colocatable_tasks + 1):
self._launch_task(datamove_queue, dev, is_cpu, num_launched_datamove_task_count)
self._launch_task(
datamove_queue, dev, is_cpu, num_launched_datamove_task_count)

def start_scheduler_callbacks(self):
map_succeed = True
Expand All @@ -2152,8 +2181,8 @@ def start_scheduler_callbacks(self):
schedule_succeed = self.schedule_tasks_callback()
launch_succeed = self.launch_tasks_callback()
while self.no_running_tasks():
if not map_succeed and not schedule_succeed and not launch_succeed:
break
# if not map_succeed and not schedule_succeed and not launch_succeed:
# break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't need these, we can also remove the below assignment.

map_succeed = self.map_tasks_callback()
schedule_succeed = self.schedule_tasks_callback()
launch_succeed = self.launch_tasks_callback()
Expand Down Expand Up @@ -2233,7 +2262,8 @@ def launch_tasks_callback(self):
# Check runtime conditions
# Are there any tasks to launch?
# Are there any free worker threads?
condition = len(self._free_worker_threads) > 0 and self.num_active_tasks() != 0
condition = len(
self._free_worker_threads) > 0 and self.num_active_tasks() != 0
"""
dev_condition = False
if condition:
Expand Down
15 changes: 15 additions & 0 deletions parla/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,17 +407,28 @@ def decorator(body) -> ComputeTask:

processed_dependencies = tasks(*dependencies)._flat_tasks

print("spawn: ", taskid, processed_dependencies,
req, body, separated_body, flush=True)

# gather input/output/inout, which is hint for data from or to the this task
# TODO (ses): I gathered these into lists so I could perform concatentation later. This may be inefficient.
dataflow = Dataflow(list(input), list(output), list(inout))

print("After dataflow", flush=True)

# Get handle to current scheduler
scheduler = task_runtime.get_scheduler_context()

if isinstance(scheduler, WorkerThread):
# If we are in a worker thread, get the real scheduler object instead.
scheduler = scheduler.scheduler

print("After scheduler", flush=True)

logger.debug("Creating: %s %r", taskid, body)

print("Before spawn_task", flush=True)

# Spawn the task via the Parla runtime API
task = scheduler.spawn_task(
function=_task_callback,
Expand All @@ -428,6 +439,8 @@ def decorator(body) -> ComputeTask:
dataflow=dataflow,
name=getattr(body, "__name__", None))

print("After spawn_task", flush=True)

logger.debug("Created: %s %r", taskid, body)

for scope in task_locals.task_scopes:
Expand All @@ -436,6 +449,8 @@ def decorator(body) -> ComputeTask:
# Activate scheduler
scheduler.start_scheduler_callbacks()

print("After start_scheduler_callbacks", flush=True)

# Return the task object to user code
return task

Expand Down