Skip to content

Commit

Permalink
Merge branch 'main' into fix/datamove_dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
wlruys authored Jan 18, 2024
2 parents f7da831 + dfe4c2b commit 54667db
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 22 deletions.
4 changes: 2 additions & 2 deletions src/c/backend/include/device_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class DeviceManager {

// TODO(hc): use a customized type for device id.

const DevID_t globalid_to_parrayid(DevID_t global_dev_id) const {
const int globalid_to_parrayid(unsigned int global_dev_id) const {
Device *dev = all_devices_[global_dev_id];
if (dev->get_type() == DeviceType::CPU) {
return -1;
Expand All @@ -107,7 +107,7 @@ class DeviceManager {
}
}

const int parrayid_to_globalid(DevID_t parray_dev_id) const {
const unsigned int parrayid_to_globalid(int parray_dev_id) const {
if (parray_dev_id == -1) {
// XXX: This assumes that a CPU device is always single and
// is added at first.
Expand Down
4 changes: 2 additions & 2 deletions src/python/parla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ def __enter__(self):
self.original_handler = signal.getsignal(self.sig)

def handler(signum, frame):
print("YOU PRESSED CTRL+C, INTERRUPTING ALL TASKS", flush=True)
print("Attempting to interurpt all running tasks...", flush=True)
self._sched.stop()
self.release()
self.interrupted = True
self.interuppted = True

signal.signal(self.sig, handler)
except ValueError:
Expand Down
14 changes: 9 additions & 5 deletions src/python/parla/common/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@
USE_PYTHON_RUNAHEAD = os.getenv("PARLA_ENABLE_PYTHON_RUNAHEAD", "1") == "1"
PREINIT_THREADS = os.getenv("PARLA_PREINIT_THREADS", "1") == "1"

print("USE_PYTHON_RUNAHEAD: ", USE_PYTHON_RUNAHEAD)
print("CUPY_ENABLED: ", CUPY_ENABLED)
print("PREINIT_THREADS: ", PREINIT_THREADS)

_global_data_tasks = {}


Expand All @@ -65,8 +61,16 @@ class SynchronizationType(IntEnum):
else:
default_sync = SynchronizationType.NON_BLOCKING

print("DEFAULT SYNC: ", default_sync)

def print_config():
print("Parla Configuration", flush=True)
print("-------------------", flush=True)
print("Cupy Found: ", CUPY_ENABLED, flush=True)
print("Crosspy Found: ", CROSSPY_ENABLED, flush=True)
print("Preinitialize Cupy + Handles in Threads: ", PREINIT_THREADS, flush=True)
print("Runahead Scheduling Backend: ", USE_PYTHON_RUNAHEAD, flush=True)
print("Default Runahead Behavior: ", default_sync, flush=True)
print("VCU Precision: ", VCU_BASELINE, flush=True)

class DeviceType(IntEnum):
"""
Expand Down
5 changes: 4 additions & 1 deletion src/python/parla/common/parray/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,10 @@ def _current_device_index(self) -> int:

# Public API:

def get(self, device: Optional[PyDevice] = None) -> "np.ndarray" | "cp.ndarray":
def set_name(self, name: str):
self._name = name

def get(self, device: Optional[PyDevice] = None) -> 'np.ndarray' | 'cp.ndarray':
if device is None:
return self.array
else:
Expand Down
14 changes: 8 additions & 6 deletions src/python/parla/common/parray/from_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def asarray(a, dtype=None, order=None, like=None, on_gpu=False, name: str = "NA"
)


def asarray_batch(*args):
def asarray_batch(*args, base="array"):
"""Converts numpy/cupy ndarray to Parla array without creating additional copy.
Args:
Expand All @@ -168,27 +168,29 @@ def asarray_batch(*args):
a, b = asarray_batch(a, b) # a and b are now parla array
"""

def get_parray(object): # recursively process Sequence or Dictionary
def get_parray(object, count=0): # recursively process Sequence or Dictionary
if isinstance(object, (numpy.ndarray, cupy.ndarray)):
return asarray(object)
return asarray(object, name=f"{base}::{count}")
elif isinstance(object, PArray):
return object
elif isinstance(object, dict):
accumulator = {}
for key, value in object.items():
accumulator[key] = get_parray(value)
accumulator[key] = get_parray(value, count+1)
return accumulator
elif isinstance(object, (list, tuple, set)):
accumulator = []
for item in object:
accumulator.append(get_parray(item))
accumulator.append(get_parray(item, count+1))
return type(object)(accumulator)
else:
raise TypeError(f"Unsupported Type: {type(object)}")

parla_arrays = []
i = 0
for arg in args:
parla_arrays.append(get_parray(arg))
parla_arrays.append(get_parray(arg, i))
i += 1

if len(parla_arrays) == 1:
return parla_arrays[0]
Expand Down
5 changes: 2 additions & 3 deletions src/python/parla/cython/device_manager.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@ cdef extern from "include/device_manager.hpp" nogil:
DeviceManager() except +
void register_device(Device*) except +
void print_registered_devices() except +
int globalid_to_parrayid(int) except +
int parrayid_to_globalid(int) except +

int globalid_to_parrayid(unsigned int) except +
unsigned int parrayid_to_globalid(int) except +

cdef class CyDeviceManager:
cdef DeviceManager* cpp_device_manager_
Expand Down
10 changes: 10 additions & 0 deletions src/python/parla/cython/scheduler.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ class WorkerThread(ControllableThread, SchedulerContext):
with self.scheduler.start_monitor:
self.scheduler.start_monitor.notify_all()

device_manager = self.scheduler.device_manager

while self._should_run:
self.status = "Waiting"

Expand Down Expand Up @@ -246,6 +248,14 @@ class WorkerThread(ControllableThread, SchedulerContext):
Locals.push_task(active_task)

with device_context as env:

if isinstance(active_task, ComputeTask):
# Perform write invalidations
for parray, target_idx in active_task.dataflow.inout:
target_device = parla_devices[target_idx]
global_target_id = target_device.get_global_id()
parray_target_id = device_manager.globalid_to_parrayid(global_target_id)
parray._auto_move(parray_target_id, True)

core.binlog_2("Worker", "Running task: ", active_task.inner_task, " on worker: ", self.inner_worker)
# Run the task body (this may complete the task or return a continuation)
Expand Down
7 changes: 4 additions & 3 deletions src/python/parla/cython/tasks.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -713,15 +713,16 @@ class DataMovementTask(Task):
@brief Run the data movement task. Calls the PArray interface to move the data to the assigned devices.
Devices are given by the local relative device id within the TaskEnvironment.
"""
write_flag = True if self.access_mode != AccessMode.IN else False

# write_flag = True if self.access_mode != AccessMode.IN else False
# Data movement tasks should only perform read operations
write_flag = False

# TODO: Get device manager from task environment instead of scheduler at creation time
device_manager = self.scheduler.device_manager
target_dev = self.assigned_devices[0]
global_id = target_dev.get_global_id()
parray_id = device_manager.globalid_to_parrayid(global_id)

print(self.name, " moves parray:", parray_id, " on device:", target_dev)
self.parray._auto_move(parray_id, write_flag)
return TaskRunahead(0)

Expand Down

0 comments on commit 54667db

Please sign in to comment.