Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hotfix: revert streamclass to old implementation #151

Merged
merged 1 commit into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion example/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import argparse

from parla import Parla, spawn, TaskSpace, sleep_nogil, sleep_gil
from parla.devices import cpu, gpu
import nvtx

free_sleep = sleep_nogil
Expand Down Expand Up @@ -39,7 +40,7 @@ async def task1():
domain="application")

for i in range(n):
@spawn(T[i], vcus=cost)
@spawn(T[i], vcus=cost, placement=[gpu])
def task1():
nvtx.push_range(message="TASK", color="blue", domain="application")
if args.empty:
Expand Down
2 changes: 1 addition & 1 deletion src/python/parla/cython/device_manager.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class PyDeviceManager:
# self.register_devices_to_cpp()

# Initialize Device Hardware Queues
self.stream_pool = stream_pool.CyStreamPool(self.get_devices(DeviceType.GPU))
self.stream_pool = stream_pool.StreamPool(self.get_devices(DeviceType.GPU))

def __dealloc__(self):
for arch in self.py_registered_archs:
Expand Down
1 change: 1 addition & 0 deletions src/python/parla/cython/stream_pool.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ cdef extern from "include/device_contexts.hpp":

cdef class CyStreamPool:
cdef InnerStreamPool* _c_pool
cdef object StreamClass
cdef dict _pool
cdef int _per_device
cdef list _device_list
179 changes: 176 additions & 3 deletions src/python/parla/cython/stream_pool.pyx
Original file line number Diff line number Diff line change
@@ -1,15 +1,188 @@
from typing import Type
from ..common.globals import cupy
from parla.common.globals import _Locals as Locals

"""
cdef class Stream:
cdef object _stream
cdef int _device
cdef object _device

def __init__(self, device: int = 0, stream=None):
def __init__(self, device = None , stream=None):
self._stream = stream
self._device = device

def __repr__(self):
return f"Stream({self._device})"

def __str__(self):
return self.__repr__()

def __enter__(self):
# print("Entering Stream: ", self, flush=True)
pass

def __exit__(self, exc_type, exc_val, exc_tb):
# print("Exiting Stream: ", self, flush=True)
pass

@property
def device(self):
return self._device

@property
def stream(self):
return self._stream

def synchronize(self):
pass

def create_event(self):
return None

def wait_event(self):
pass

@property
def ptr(self):
return None


cdef class CupyStream(Stream):

def __init__(self, device: int = 0, stream = None):
def __init__(self, device = None , stream = None):
super().__init__(device=device, stream=stream)
"""
class Stream:
def __init__(self, device=None, stream=None, non_blocking=True):
self._device = device
self._device_id = device.device.id
self._stream = stream

def __repr__(self):
return f"Stream({self._device})"

def __str__(self):
return self.__repr__()

def __enter__(self):
# print("Entering Stream: ", self, flush=True)
pass

def __exit__(self, exc_type, exc_val, exc_tb):
# print("Exiting Stream: ", self, flush=True)
pass

@property
def device(self):
return self._device

@property
def stream(self):
return self._stream

def synchronize(self):
pass

def create_event(self):
return None

def wait_event(self):
pass

@property
def ptr(self):
return None


class CupyStream(Stream):

def __init__(self, device=None, stream=None, non_blocking=True):
"""
Initialize a Parla Stream object.
Assumes device and stream are cupy objects.
"""

if device is None and stream is not None:
raise ValueError("Device must be specified if stream is specified.")

if device is None:
self._device = cupy.cuda.Device()
self._device_id = self._device.id
else:
self._device = device
self._device_id = device.device.id

with cupy.cuda.Device(self._device_id) as d:
if stream is None:
self._stream = cupy.cuda.Stream(non_blocking=non_blocking)
else:
self._stream = stream

def __repr__(self):
return f"Stream({self._device}, {self._stream})"

def __str__(self):
return self.__repr__()

def __enter__(self):
# Set the device to the stream's device.
self.active_device = cupy.cuda.Device(self._device_id)

self.active_device.__enter__()
# self._device.__enter__()

# Set the stream to the current stream.
self._stream.__enter__()

Locals.push_stream(self)

return self

def __exit__(self, exc_type, exc_value, traceback):

ret_stream = False
ret_device = False

# Restore the stream to the previous stream.
ret_stream = self._stream.__exit__(exc_type, exc_value, traceback)

# Restore the device to the previous device.
ret_device = self.active_device.__exit__(exc_type, exc_value, traceback)

Locals.pop_stream()
return ret_stream and ret_device

@property
def device(self):
return self._device

@property
def stream(self):
return self._stream

def synchronize(self):
# print("Synchronizing stream", flush=True)
self._stream.synchronize()

def create_event(self):
active_device = cupy.cuda.Device(self._device_id)
with active_device:
new_event = cupy.cuda.Event(block=True, disable_timing=True, interprocess=False)
return new_event

def wait_event(self, event):
self._stream.wait_event(event)

@property
def ptr(self):
return self._stream.ptr

# TODO(wlr): What is the performance impact of this?
def __getatrr__(self, name):
if hasattr(self, name):
return getattr(self, name)
return getattr(self._stream, name)


cdef class CyStreamPool:

Expand Down
2 changes: 1 addition & 1 deletion src/python/parla/cython/tasks.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1273,7 +1273,7 @@ class GPUEnvironment(TerminalEnvironment):
super(GPUEnvironment, self).__init__(device, blocking=blocking)

stream_pool = get_stream_pool()
stream = stream_pool.get_stream(dev=device)
stream = stream_pool.get_stream(device=device)
self.stream_list.append(stream)

self.event_dict['default'] = stream.create_event()
Expand Down
Loading