Skip to content

Commit

Permalink
formatted with black
Browse files Browse the repository at this point in the history
  • Loading branch information
benvanwerkhoven committed Jan 14, 2025
1 parent 5c4ba1a commit 68f7ce6
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 160 deletions.
83 changes: 29 additions & 54 deletions kernel_tuner/backends/hip.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@

hipSuccess = 0


def hip_check(call_result):
"""helper function to check return values of hip calls"""
err = call_result[0]
result = call_result[1:]
if len(result) == 1:
Expand All @@ -41,6 +43,7 @@ def hip_check(call_result):
raise RuntimeError(str(err))
return result


class HipFunctions(GPUBackend):
"""Class that groups the HIP functions on maintains state about the device."""

Expand All @@ -59,7 +62,9 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
:type iterations: int
"""
if not hip or not hiprtc:
raise ImportError("Unable to import HIP Python, check https://kerneltuner.github.io/kernel_tuner/stable/install.html#hip-and-hip-python.")
raise ImportError(
"Unable to import HIP Python, check https://kerneltuner.github.io/kernel_tuner/stable/install.html#hip-and-hip-python."
)

# embedded in try block to be able to generate documentation
# and run tests without HIP Python installed
Expand All @@ -69,7 +74,7 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
props = hip.hipDeviceProp_t()
hip_check(hip.hipGetDeviceProperties(props, device))

self.name = props.name.decode('utf-8')
self.name = props.name.decode("utf-8")
self.max_threads = props.maxThreadsPerBlock
self.device = device
self.compiler_options = compiler_options or []
Expand All @@ -81,7 +86,7 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
env["compiler_options"] = compiler_options
self.env = env

# Create stream and events
# Create stream and events
self.stream = hip_check(hip.hipStreamCreate())
self.start = hip_check(hip.hipEventCreate())
self.end = hip_check(hip.hipEventCreate())
Expand All @@ -108,40 +113,34 @@ def ready_argument_list(self, arguments):
"""
logging.debug("HipFunction ready_argument_list called")
prepared_args = []

for arg in arguments:
dtype_str = str(arg.dtype)

# Handle numpy arrays
if isinstance(arg, np.ndarray):
if dtype_str in dtype_map.keys():
# Allocate device memory
device_ptr = hip_check(hip.hipMalloc(arg.nbytes))

# Copy data to device using hipMemcpy
hip_check(hip.hipMemcpy(
device_ptr,
arg,
arg.nbytes,
hip.hipMemcpyKind.hipMemcpyHostToDevice
))

hip_check(hip.hipMemcpy(device_ptr, arg, arg.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice))

prepared_args.append(device_ptr)
else:
raise TypeError(f"Unknown dtype {dtype_str} for ndarray")

# Handle numpy scalar types
elif isinstance(arg, np.generic):
# Convert numpy scalar to corresponding ctypes
ctype_arg = dtype_map[dtype_str](arg)
prepared_args.append(ctype_arg)

else:
raise ValueError(f"Invalid argument type {type(arg)}, {arg}")

return prepared_args


def compile(self, kernel_instance):
"""Call the HIP compiler to compile the kernel, return the function.
Expand All @@ -159,28 +158,22 @@ def compile(self, kernel_instance):
kernel_name = kernel_instance.name
if 'extern "C"' not in kernel_string:
kernel_string = 'extern "C" {\n' + kernel_string + "\n}"

# Create program
prog = hip_check(hiprtc.hiprtcCreateProgram(
kernel_string.encode(),
kernel_name.encode(),
0,
[],
[]
))
prog = hip_check(hiprtc.hiprtcCreateProgram(kernel_string.encode(), kernel_name.encode(), 0, [], []))

try:
# Get device properties
props = hip.hipDeviceProp_t()
hip_check(hip.hipGetDeviceProperties(props, 0))

# Setup compilation options
arch = props.gcnArchName
cflags = [b"--offload-arch=" + arch]
cflags.extend([opt.encode() if isinstance(opt, str) else opt for opt in self.compiler_options])

# Compile program
err, = hiprtc.hiprtcCompileProgram(prog, len(cflags), cflags)
(err,) = hiprtc.hiprtcCompileProgram(prog, len(cflags), cflags)
if err != hiprtc.hiprtcResult.HIPRTC_SUCCESS:
# Get compilation log if there's an error
log_size = hip_check(hiprtc.hiprtcGetProgramLogSize(prog))
Expand Down Expand Up @@ -208,19 +201,19 @@ def compile(self, kernel_instance):
def start_event(self):
"""Records the event that marks the start of a measurement."""
logging.debug("HipFunction start_event called")

hip_check(hip.hipEventRecord(self.start, self.stream))

def stop_event(self):
"""Records the event that marks the end of a measurement."""
logging.debug("HipFunction stop_event called")

hip_check(hip.hipEventRecord(self.end, self.stream))

def kernel_finished(self):
"""Returns True if the kernel has finished, False otherwise."""
logging.debug("HipFunction kernel_finished called")

# ROCm HIP returns (hipError_t, bool) for hipEventQuery
status = hip.hipEventQuery(self.end)
if status[0] == hip.hipError_t.hipSuccess:
Expand All @@ -233,7 +226,7 @@ def kernel_finished(self):
def synchronize(self):
"""Halts execution until device has finished its tasks."""
logging.debug("HipFunction synchronize called")

hip_check(hip.hipDeviceSynchronize())

def run_kernel(self, func, gpu_args, threads, grid, stream=None):
Expand All @@ -242,7 +235,7 @@ def run_kernel(self, func, gpu_args, threads, grid, stream=None):
:param func: A HIP kernel compiled for this specific kernel configuration
:type func: hipFunction_t
:param gpu_args: List of arguments to pass to the kernel. Can be DeviceArray
:param gpu_args: List of arguments to pass to the kernel. Can be DeviceArray
objects or ctypes values
:type gpu_args: list
Expand Down Expand Up @@ -272,7 +265,7 @@ def run_kernel(self, func, gpu_args, threads, grid, stream=None):
sharedMemBytes=self.smem_size,
stream=stream,
kernelParams=None,
extra=tuple(gpu_args)
extra=tuple(gpu_args),
)
)

Expand Down Expand Up @@ -303,12 +296,7 @@ def memcpy_dtoh(self, dest, src):
"""
logging.debug("HipFunction memcpy_dtoh called")

hip_check(hip.hipMemcpy(
dest,
src,
dest.nbytes,
hip.hipMemcpyKind.hipMemcpyDeviceToHost
))
hip_check(hip.hipMemcpy(dest, src, dest.nbytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost))

def memcpy_htod(self, dest, src):
"""Perform a host to device memory copy.
Expand All @@ -321,12 +309,7 @@ def memcpy_htod(self, dest, src):
"""
logging.debug("HipFunction memcpy_htod called")

hip_check(hip.hipMemcpy(
dest,
src,
src.nbytes,
hip.hipMemcpyKind.hipMemcpyHostToDevice
))
hip_check(hip.hipMemcpy(dest, src, src.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice))

def copy_constant_memory_args(self, cmem_args):
"""Adds constant memory arguments to the most recently compiled module.
Expand All @@ -343,18 +326,10 @@ def copy_constant_memory_args(self, cmem_args):
# Iterate over dictionary
for symbol_name, data in cmem_args.items():
# Get symbol pointer and size using hipModuleGetGlobal
dptr, _ = hip_check(hip.hipModuleGetGlobal(
self.current_module,
symbol_name.encode()
))
dptr, _ = hip_check(hip.hipModuleGetGlobal(self.current_module, symbol_name.encode()))

# Copy data to the global memory location
hip_check(hip.hipMemcpy(
dptr,
data,
data.nbytes,
hip.hipMemcpyKind.hipMemcpyHostToDevice
))
hip_check(hip.hipMemcpy(dptr, data, data.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice))

def copy_shared_memory_args(self, smem_args):
"""Add shared memory arguments to the kernel."""
Expand Down
Loading

0 comments on commit 68f7ce6

Please sign in to comment.