Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Very simple LRU-based garbage collector (GC) #144

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions examples/cholesky/blocked_cholesky_automatic.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def cholesky_blocked_inplace(a, block_size):

@spawn(gemm1[j, k], [solve[j, k], gemm1[j, 0:k]], input=[a[j][k]], inout=[a[j][j]], placement=loc_syrk, memory=mem)
def t1():
#print(f"+SYRK: ({j}, {k}) - Requires rw({j},{j}) r({j}, {k})", flush=True)
print(f"+SYRK: ({j}, {k}) - Requires rw({j},{j}) r({j}, {k})", flush=True)
out = a[j][j].array
rhs = a[j][k].array
out = update(rhs, rhs, out)
Expand All @@ -217,7 +217,7 @@ def t1():
stream.synchronize()
a[j][j].update(out)
stream.synchronize()
#print(f"-SYRK: ({j}, {k}) - Requires rw({j},{j}) r({j}, {k})", flush=True)
print(f"-SYRK: ({j}, {k}) - Requires rw({j},{j}) r({j}, {k})", flush=True)

# Cholesky on block
mem = 8*block_size**2
Expand All @@ -229,7 +229,7 @@ def t1():

@spawn(subcholesky[j], [gemm1[j, 0:j]], inout=[a[j][j]], placement=loc_potrf, memory=mem)
def t2():
#print(f"+POTRF: ({j}) - Requires rw({j},{j})", flush=True)
print(f"+POTRF: ({j}) - Requires rw({j},{j})", flush=True)
dblock = a[j][j].array

log_memory()
Expand All @@ -239,7 +239,7 @@ def t2():
stream.synchronize()
a[j][j].update(dblock)
stream.synchronize()
#print(f"-POTRF: ({j}) - Requires rw({j},{j})", flush=True)
print(f"-POTRF: ({j}) - Requires rw({j},{j})", flush=True)
for i in range(j+1, len(a)):
for k in range(j):
# Inter-block GEMM
Expand All @@ -251,7 +251,7 @@ def t2():

@spawn(gemm2[i, j, k], [solve[j, k], solve[i, k], gemm2[i, j, 0:k]], inout=[a[i][j]], input=[a[i][k], a[j][k]], placement=loc_gemm, memory=mem)
def t3():
#print(f"+GEMM: ({i}, {j}, {k}) - Requires rw({i},{j}), r({i}, {k}), r({j}, {k})", flush=True)
print(f"+GEMM: ({i}, {j}, {k}) - Requires rw({i},{j}), r({i}, {k}), r({j}, {k})", flush=True)
out = a[i][j].array
rhs1 = a[i][k].array
rhs2 = a[j][k].array
Expand All @@ -263,7 +263,7 @@ def t3():
stream.synchronize()
a[i][j].update(out)
stream.synchronize()
#print(f"-GEMM: ({i}, {j}, {k}) - Requires rw({i},{j}), r({i}, {k}), r({j}, {k})", flush=True)
print(f"-GEMM: ({i}, {j}, {k}) - Requires rw({i},{j}), r({i}, {k}), r({j}, {k})", flush=True)

# Triangular solve
mem = 8*2*block_size**2
Expand All @@ -275,7 +275,7 @@ def t3():

@spawn(solve[i, j], [gemm2[i, j, 0:j], subcholesky[j]], inout=[a[i][j]], input=[a[j][j]], placement=loc_trsm, memory=mem)
def t4():
#print(f"+TRSM: ({i}, {j}) - Requires rw({i},{j}), r({j}, {j})", flush=True)
print(f"+TRSM: ({i}, {j}) - Requires rw({i},{j}), r({j}, {j})", flush=True)
factor = a[j][j].array
panel = a[i][j].array

Expand All @@ -285,7 +285,7 @@ def t4():
stream.synchronize()
a[i][j].update(out)
stream.synchronize()
#print(f"-TRSM: ({i}, {j}) - Requires rw({i},{j}), r({j}, {j})", flush=True)
print(f"-TRSM: ({i}, {j}) - Requires rw({i},{j}), r({j}, {j})", flush=True)

return subcholesky[len(a) - 1]

Expand Down
50 changes: 23 additions & 27 deletions parla/parray/coherence.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ def evict(self, device_id: int, keep_one_copy: bool = True) -> List[MemoryOperat
"""
device_local_state = self._local_states[device_id]
operations = []
evict_last_copy = False

if device_local_state == self.INVALID: # already evicted, do nothing
operations.append(MemoryOperation.noop())
Expand All @@ -547,40 +548,35 @@ def evict(self, device_id: int, keep_one_copy: bool = True) -> List[MemoryOperat
new_owner = device
break

# this device owns the last copy
if new_owner is None:
if keep_one_copy:
if device_id == CPU_INDEX:
# the last copy is already at CPU,
# do nothing and skip the rest of the code
return [MemoryOperation.noop()]
else:
# write back the last copy to CPU
operations.append(MemoryOperation.load(CPU_INDEX, device_id))

# now CPU has exclusive access to the data
self._global_state = self.MODIFIED
self._local_states[CPU_INDEX] = self.MODIFIED

new_owner = CPU_INDEX
else:
self._global_state = self.INVALID # the system lose the last copy
self.owner = new_owner
if new_owner is None:
evict_last_copy = True
else:
# update states
self._local_states[device_id] = self.INVALID
operations.append(MemoryOperation.evict(device_id))
self._versions[device_id] = -1
self._is_complete[device_id] = None

# update states
self._local_states[device_id] = self.INVALID
operations.append(MemoryOperation.evict(device_id))
else: # Modified, this device owns the last copy
evict_last_copy = True

if evict_last_copy:
if keep_one_copy: # write back to CPU
self.owner = CPU_INDEX
self._local_states[CPU_INDEX] = self.MODIFIED
if device_id != CPU_INDEX:
operations.extend(self._write_back_to(CPU_INDEX, self.MODIFIED, on_different_device=True, this_device_id=device_id)[0])

operations.append(MemoryOperation.load(CPU_INDEX, device_id))
self.owner = CPU_INDEX
self._local_states[CPU_INDEX] = self.MODIFIED
self._is_complete[device_id] = True
else:
return [MemoryOperation.noop()]
else:
self._global_state = self.INVALID # the system lose the last copy
self.owner = None
self._versions[device_id] = -1
self._is_complete[device_id] = None

self._local_states[device_id] = self.INVALID
operations.append(MemoryOperation.evict(device_id))
self._local_states[device_id] = self.INVALID
operations.append(MemoryOperation.evict(device_id))

return operations
10 changes: 5 additions & 5 deletions parla/parray/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def update(self, array) -> None:

# update shape
self._array.shape = array.shape

cupy.cuda.stream.get_current_stream().synchronize()
if num_gpu > 0:
cupy.cuda.stream.get_current_stream().synchronize()

# slicing/indexing

Expand Down Expand Up @@ -282,8 +282,7 @@ def evict(self, device_id: int = None, keep_one_copy: bool = True) -> None:

with self._coherence_cv[device_id]:
operations = self._coherence.evict(device_id, keep_one_copy)
for op in operations:
self._process_operation(op)
self._process_operations(operations)


# Coherence update operations:
Expand Down Expand Up @@ -362,7 +361,8 @@ def _process_operations(self, operations: List[MemoryOperation], slices: SlicesT
self._array.copy_data_between_device(op.dst, op.src, dst_is_current_device)

# sync stream before set it as ready, so asyc call is ensured to be done
cupy.cuda.stream.get_current_stream().synchronize()
if num_gpu > 0:
cupy.cuda.stream.get_current_stream().synchronize()

# data is ready now
if MemoryOperation.LOAD_SUBARRAY in op.flag:
Expand Down
17 changes: 17 additions & 0 deletions parla/parray/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

import numpy

import sys
#import gc

#TODO: Fix this to be more stable and less of a hack.
try:
import cupy
Expand Down Expand Up @@ -514,5 +517,19 @@ def clear(self, device_id) -> None:
"""
Clear data in device_id
"""
import psutil
import os
mempool = cupy.get_default_memory_pool()
pinned_mempool = cupy.get_default_pinned_memory_pool()
#proc = psutil.Process(os.getpid())
#mem0 = proc.memory_info().rss
#print("Before:\n\t Used bytes:", mempool.used_bytes(), " total bytes: ", mempool.total_bytes(), " free blocks:", pinned_mempool.n_free_blocks(), flush=True)
#del self._buffer[device_id]
import psutil
self._indices_map[device_id] = None
self._buffer[device_id] = None
#mem1 = proc.memory_info().rss
#gc.collect()
mempool.free_all_blocks()
#print("After:\n\t Used bytes:", mempool.used_bytes(), " total bytes: ", mempool.total_bytes(), " free blocks:", pinned_mempool.n_free_blocks(), flush=True)
#print("\t Before deallcation: ", mem0, " after: ", mem1, flush=True)
76 changes: 67 additions & 9 deletions parla/task_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from parla.cpu_impl import cpu
from parla.dataflow import Dataflow

from parla.tracking import LRUManager

# Logger configuration (uncomment and adjust level if needed)
#logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -353,6 +355,11 @@ def _finish(self, ctx: 'SchedulerContext'):
"""Cleanup works after executing the task."""
raise NotImplementedError()

@abstractmethod
def _invoke_garbage_collector(self):
"""Invoke a garbage collector; for now, invoke it for
each task execution."""
raise NotImplementedError()

def run(self):
assert self._assigned, "Task was not assigned before running."
Expand All @@ -362,6 +369,7 @@ def run(self):
with self._mutex:
# A default state to avoid exceptions during catch
task_state = TaskException(RuntimeError("Unknown fatal error"))
event_exists = False
# Run the task and assign the new task state
try:
assert isinstance(self._state, TaskRunning), " Task is not running state: {} on task: {}".format(self._state, self.taskid)
Expand All @@ -383,10 +391,12 @@ def run(self):
# Events could be multiple for multiple devices task.
env.record_events()
if len(events) > 0:
event_exists = True
# If any event created by the current task exist,
# notify dependents and make them wait for that event,
# not Parla task completion.
if not isinstance(task_state, TaskRunning):
self._invoke_garbage_collector()
self._notify_dependents(events)
env.sync_events()
task_state = task_state or TaskCompleted(None)
Expand All @@ -404,10 +414,12 @@ def run(self):
# new dependents could be added after the above
# notifications, while other devices are running
# their kernels asynchronously.
if not isinstance(task_state, TaskRunning):
if event_exists == False and not isinstance(task_state, TaskRunning):
self._invoke_garbage_collector()
self._notify_dependents()
self._set_state(task_state)
self._finish(ctx)
if isinstance(self._state, TaskCompleted):
self._finish(ctx)
except Exception as e:
logger.exception("Task %r: Exception in task handling", self)
raise e
Expand Down Expand Up @@ -611,8 +623,30 @@ def _handle_dependency_spawn(self, dependency: "Task"):
if self.num_unspawned_dependencies == 0:
self._ready_to_map()

def acquire_parray(self):
ctx = get_scheduler_context()
if self.dataflow is not None:
for parray in (self.dataflow.input + \
self.dataflow.inout + \
self.dataflow.output):
for d in self.req.devices:
ctx.scheduler.lrum._acquire_data(parray, d, str(self.taskid))

def _invoke_garbage_collector(self):
print(f"Garbage collector is activated", flush=True)
ctx = get_scheduler_context()
if self.dataflow is not None:
for parray in (self.dataflow.input + \
self.dataflow.inout + \
self.dataflow.output):
for d in self.req.devices:
ctx.scheduler.lrum._release_data(parray, d, str(self.taskid))
ctx.scheduler.lrum._evict()

def _execute_task(self):
return self._state.func(self, *self._state.args)
self.acquire_parray()
result = self._state.func(self, *self._state.args)
return result

def cleanup(self):
self._func = None
Expand All @@ -630,14 +664,26 @@ def _finish(self, ctx):
ctx.scheduler.update_mapped_task_count_mutex(self, d, -1)
ctx.scheduler.update_launched_task_count_mutex(self, d, -1)


# _finish() can be called more than once on global task.
if (self.dataflow != None):
"""
self.release_parray()
for parray in self.dataflow.input:
for d in self.req.devices:
ctx.scheduler.lrum._release_data(parray, d, str(self.taskid))
"""
# Update OUT parrays which may have changed size from 0 to something
# We assume all IN and INOUT params don't change size
for parray in (self.dataflow.output):
for parray in (self.dataflow.output + self.dataflow.inout):
"""
for d in self.req.devices:
ctx.scheduler.lrum._release_data(parray, d, str(self.taskid))
"""
ctx.scheduler._available_resources.update_parray_nbytes(
parray, self.req.devices)
"""
ctx.scheduler.lrum._evict()
"""
ctx.scheduler.decr_active_compute_tasks()
self.cleanup()

Expand Down Expand Up @@ -669,13 +715,19 @@ def _execute_task(self):
if (self._operand_type == OperandType.IN):
write_flag = False
# Move data to current device
dev_type = get_current_devices()[0]
self.dev_type = get_current_devices()[0]
dev_no = -1
if (dev_type.architecture is not cpu):
dev_no = dev_type.index
if (self.dev_type.architecture is not cpu):
dev_no = self.dev_type.index
ctx = get_scheduler_context()
ctx.scheduler.lrum._start_prefetch_data(self._target_data, self.dev_type, str(self.taskid))
self._target_data._auto_move(device_id=dev_no, do_write=write_flag)
ctx.scheduler.lrum._stop_prefetch_data(self._target_data, self.dev_type, str(self.taskid))
return TaskCompleted(None)

def _invoke_garbage_collector(self):
pass

def cleanup(self):
self._target_data = None

Expand All @@ -688,7 +740,7 @@ def _finish(self, ctx):
# Don't update parray tracking information either
# The scheduler already registered the new location
# If size changes, the ComputeTask will take care of that

ctx = get_scheduler_context()
# Decrease the number of running tasks on the device d.
for d in self.req.devices:
ctx.scheduler.update_mapped_task_count_mutex(self, d, -1)
Expand Down Expand Up @@ -1479,6 +1531,8 @@ def __init__(self, environments: Collection[TaskEnvironment], n_threads: Optiona
self._device_launched_datamove_task_counts = {
dev: 0 for dev in self._available_resources.get_resources()}

self._lrum = LRUManager()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just looking through this bc we'll be looking at it again soon and been thinking about memory counting.
It should be a separate LRUManager for each device object (memory space), shouldn't it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It depends on the design. My original idea was to have a single GC (not necessarily now but I thouhgt a daemon like a scheduler thread) and manage all memory regardless of its location. It might be simpler than a GC per a device? Could you please let me know your idea and its pros?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason we would call into the EvictionPolicy is when a specific device is full and we want to bring more data onto it. This would look at only data on that device to start evicting (which means it needs a device specific hash-list structure for LRU). I am not sure how this would be handled by a global one for all devices without evicting data from other devices unnecessarily.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also (just terminology), specifically I'm using "GC" for cleaning up data that will not be used in the future and "Eviction" for kicking out any data (possibly data that will be used in the future) due to space limitations on a device.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am on board with you. But I think there is a different way to implement what you said. What I meant was the global memory manager can still manage the whole devices. For example, as you said, we maintain a list of arrays for each device and the global one can traverse and evicts/dellocates subset of them to get more memory. Or we can have a dedicated memory manager thread for each device and let each of them manage the owned device parrays. I think this is implementation detail?


# Dictionary mapping data block to task lists.
self._datablock_dict = defaultdict(list)

Expand Down Expand Up @@ -1506,6 +1560,10 @@ def components(self) -> List[EnvironmentComponentInstance]:
def scheduler(self):
return self

@ property
def lrum(self):
return self._lrum

def __enter__(self):
if self._active_task_count != 1:
raise InvalidSchedulerAccessException(
Expand Down
Loading