Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: Allow for more thread blocks than the X dimension of the block grid #41

Merged
merged 2 commits into from
Apr 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/backend/cuda/genpup.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,11 @@ def generate_kernels(b, darray):
display("}\n\n")

# generate the host function
OUTFILE.write("void yaksuri_cudai_%s(const void *inbuf, void *outbuf, uintptr_t count, yaksuri_cudai_md_s *md, int n_threads, int n_blocks, int device)\n" % funcprefix)
OUTFILE.write("void yaksuri_cudai_%s(const void *inbuf, void *outbuf, uintptr_t count, yaksuri_cudai_md_s *md, int n_threads, int n_blocks_x, int n_blocks_y, int n_blocks_z, int device)\n" % funcprefix)
OUTFILE.write("{\n")
OUTFILE.write(" void *args[] = { &inbuf, &outbuf, &count, &md };\n")
OUTFILE.write(" cudaError_t cerr = cudaLaunchKernel((const void *) yaksuri_cudai_kernel_%s,\n" % funcprefix)
OUTFILE.write(" dim3(n_blocks), dim3(n_threads), args, 0, yaksuri_cudai_global.stream[device]);\n")
OUTFILE.write(" dim3(n_blocks_x, n_blocks_y, n_blocks_z), dim3(n_threads), args, 0, yaksuri_cudai_global.stream[device]);\n")
OUTFILE.write(" YAKSURI_CUDAI_CUDA_ERR_CHECK(cerr);\n")
OUTFILE.write("}\n\n")

Expand Down Expand Up @@ -427,7 +427,9 @@ def switcher(typelist, pupstr, nests):
OUTFILE.write("uintptr_t count, ")
OUTFILE.write("yaksuri_cudai_md_s *md, ")
OUTFILE.write("int n_threads, ")
OUTFILE.write("int n_blocks, ")
OUTFILE.write("int n_blocks_x, ")
OUTFILE.write("int n_blocks_y, ")
OUTFILE.write("int n_blocks_z, ")
OUTFILE.write("int device);\n")

OUTFILE.write("\n")
Expand Down
6 changes: 2 additions & 4 deletions src/backend/cuda/include/yaksuri_cudai.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
#define CUDA_P2P_DISABLED (2)
#define CUDA_P2P_CLIQUES (3)

#define YAKSURI_CUDAI_THREAD_BLOCK_SIZE (256)

/* *INDENT-OFF* */
#ifdef __cplusplus
extern "C" {
Expand Down Expand Up @@ -85,9 +83,9 @@ typedef struct yaksuri_cudai_md_s {

typedef struct yaksuri_cudai_type_s {
void (*pack) (const void *inbuf, void *outbuf, uintptr_t count, yaksuri_cudai_md_s * md,
int n_threads, int n_blocks, int device);
int n_threads, int n_blocks_x, int n_blocks_y, int n_blocks_z, int device);
void (*unpack) (const void *inbuf, void *outbuf, uintptr_t count, yaksuri_cudai_md_s * md,
int n_threads, int n_blocks, int device);
int n_threads, int n_blocks_x, int n_blocks_y, int n_blocks_z, int device);
yaksuri_cudai_md_s *md;
pthread_mutex_t mdmutex;
uintptr_t num_elements;
Expand Down
70 changes: 56 additions & 14 deletions src/backend/cuda/pup/yaksuri_cudai_pup.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,42 @@
#include "yaksi.h"
#include "yaksuri_cudai.h"

#define THREAD_BLOCK_SIZE (256)
#define MAX_GRIDSZ_X ((1ULL << 31) - 1)
#define MAX_GRIDSZ_Y (65535)
#define MAX_GRIDSZ_Z (65535)

static int get_thread_block_dims(uint64_t count, yaksi_type_s * type, int *n_threads,
int *n_blocks_x, int *n_blocks_y, int *n_blocks_z)
{
int rc = YAKSA_SUCCESS;
yaksuri_cudai_type_s *cuda_type = (yaksuri_cudai_type_s *) type->backend.cuda.priv;

*n_threads = THREAD_BLOCK_SIZE;
uint64_t n_blocks = count * cuda_type->num_elements / THREAD_BLOCK_SIZE;
n_blocks += ! !(count * cuda_type->num_elements % THREAD_BLOCK_SIZE);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For correctness, should this return an error code if the number of blocks exceeds the max allowed size? Or simply assert?

Copy link
Contributor

@gcongiu gcongiu Apr 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I think I have commented too late :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be more than the size of int64_t. At that point, we'd need to change a whole lot of code in yaksa to make it work, and an assert would not be sufficient.

if (n_blocks <= MAX_GRIDSZ_X) {
*n_blocks_x = (int) n_blocks;
*n_blocks_y = 1;
*n_blocks_z = 1;
} else if (n_blocks <= MAX_GRIDSZ_X * MAX_GRIDSZ_Y) {
*n_blocks_x = YAKSU_CEIL(n_blocks, MAX_GRIDSZ_Y);
*n_blocks_y = YAKSU_CEIL(n_blocks, (*n_blocks_x));
*n_blocks_z = 1;
} else {
int n_blocks_xy = YAKSU_CEIL(n_blocks, MAX_GRIDSZ_Z);
*n_blocks_x = YAKSU_CEIL(n_blocks_xy, MAX_GRIDSZ_Y);
*n_blocks_y = YAKSU_CEIL(n_blocks_xy, (*n_blocks_x));
*n_blocks_z = YAKSU_CEIL(n_blocks, (uintptr_t) (*n_blocks_x) * (*n_blocks_y));
}

fn_exit:
return rc;
fn_fail:
goto fn_exit;
}

int yaksuri_cudai_pup_is_supported(yaksi_type_s * type, bool * is_supported)
{
int rc = YAKSA_SUCCESS;
Expand Down Expand Up @@ -68,9 +104,10 @@ int yaksuri_cudai_ipack(const void *inbuf, void *outbuf, uintptr_t count, yaksi_
rc = yaksuri_cudai_md_alloc(type);
YAKSU_ERR_CHECK(rc, fn_fail);

int n_threads = YAKSURI_CUDAI_THREAD_BLOCK_SIZE;
int n_blocks = count * cuda_type->num_elements / YAKSURI_CUDAI_THREAD_BLOCK_SIZE;
n_blocks += ! !(count * cuda_type->num_elements % YAKSURI_CUDAI_THREAD_BLOCK_SIZE);
int n_threads;
int n_blocks_x, n_blocks_y, n_blocks_z;
rc = get_thread_block_dims(count, type, &n_threads, &n_blocks_x, &n_blocks_y, &n_blocks_z);
YAKSU_ERR_CHECK(rc, fn_fail);

if ((inattr.type == cudaMemoryTypeManaged && outattr.type == cudaMemoryTypeManaged) ||
(inattr.type == cudaMemoryTypeDevice && outattr.type == cudaMemoryTypeManaged) ||
Expand All @@ -86,7 +123,8 @@ int yaksuri_cudai_ipack(const void *inbuf, void *outbuf, uintptr_t count, yaksi_
YAKSURI_CUDAI_CUDA_ERR_CHKANDJUMP(cerr, rc, fn_fail);
}

cuda_type->pack(inbuf, outbuf, count, cuda_type->md, n_threads, n_blocks, target);
cuda_type->pack(inbuf, outbuf, count, cuda_type->md, n_threads, n_blocks_x, n_blocks_y,
n_blocks_z, target);
} else if (inattr.type == cudaMemoryTypeManaged && outattr.type == cudaMemoryTypeDevice) {
target = outattr.device;
cerr = cudaSetDevice(target);
Expand All @@ -98,7 +136,8 @@ int yaksuri_cudai_ipack(const void *inbuf, void *outbuf, uintptr_t count, yaksi_
YAKSURI_CUDAI_CUDA_ERR_CHKANDJUMP(cerr, rc, fn_fail);
}

cuda_type->pack(inbuf, outbuf, count, cuda_type->md, n_threads, n_blocks, target);
cuda_type->pack(inbuf, outbuf, count, cuda_type->md, n_threads, n_blocks_x, n_blocks_y,
n_blocks_z, target);
} else if ((outattr.type == cudaMemoryTypeDevice && inattr.device != outattr.device) ||
(outattr.type == cudaMemoryTypeHost)) {
assert(inattr.type == cudaMemoryTypeDevice);
Expand All @@ -113,8 +152,8 @@ int yaksuri_cudai_ipack(const void *inbuf, void *outbuf, uintptr_t count, yaksi_
YAKSURI_CUDAI_CUDA_ERR_CHKANDJUMP(cerr, rc, fn_fail);
}

cuda_type->pack(inbuf, device_tmpbuf, count, cuda_type->md, n_threads, n_blocks,
target);
cuda_type->pack(inbuf, device_tmpbuf, count, cuda_type->md, n_threads, n_blocks_x,
n_blocks_y, n_blocks_z, target);
cerr = cudaMemcpyAsync(outbuf, device_tmpbuf, count * type->size, cudaMemcpyDefault,
yaksuri_cudai_global.stream[target]);
YAKSURI_CUDAI_CUDA_ERR_CHKANDJUMP(cerr, rc, fn_fail);
Expand Down Expand Up @@ -184,9 +223,10 @@ int yaksuri_cudai_iunpack(const void *inbuf, void *outbuf, uintptr_t count, yaks
rc = yaksuri_cudai_md_alloc(type);
YAKSU_ERR_CHECK(rc, fn_fail);

int n_threads = YAKSURI_CUDAI_THREAD_BLOCK_SIZE;
int n_blocks = count * cuda_type->num_elements / YAKSURI_CUDAI_THREAD_BLOCK_SIZE;
n_blocks += ! !(count * cuda_type->num_elements % YAKSURI_CUDAI_THREAD_BLOCK_SIZE);
int n_threads;
int n_blocks_x, n_blocks_y, n_blocks_z;
rc = get_thread_block_dims(count, type, &n_threads, &n_blocks_x, &n_blocks_y, &n_blocks_z);
YAKSU_ERR_CHECK(rc, fn_fail);

if ((inattr.type == cudaMemoryTypeManaged && outattr.type == cudaMemoryTypeManaged) ||
(inattr.type == cudaMemoryTypeManaged && outattr.type == cudaMemoryTypeDevice) ||
Expand All @@ -202,7 +242,8 @@ int yaksuri_cudai_iunpack(const void *inbuf, void *outbuf, uintptr_t count, yaks
YAKSURI_CUDAI_CUDA_ERR_CHKANDJUMP(cerr, rc, fn_fail);
}

cuda_type->unpack(inbuf, outbuf, count, cuda_type->md, n_threads, n_blocks, target);
cuda_type->unpack(inbuf, outbuf, count, cuda_type->md, n_threads, n_blocks_x,
n_blocks_y, n_blocks_z, target);
} else if (inattr.type == cudaMemoryTypeDevice && outattr.type == cudaMemoryTypeManaged) {
target = inattr.device;
cerr = cudaSetDevice(target);
Expand All @@ -214,7 +255,8 @@ int yaksuri_cudai_iunpack(const void *inbuf, void *outbuf, uintptr_t count, yaks
YAKSURI_CUDAI_CUDA_ERR_CHKANDJUMP(cerr, rc, fn_fail);
}

cuda_type->unpack(inbuf, outbuf, count, cuda_type->md, n_threads, n_blocks, target);
cuda_type->unpack(inbuf, outbuf, count, cuda_type->md, n_threads, n_blocks_x,
n_blocks_y, n_blocks_z, target);
} else if ((inattr.type == cudaMemoryTypeDevice && inattr.device != outattr.device) ||
(inattr.type == cudaMemoryTypeHost)) {
assert(outattr.type == cudaMemoryTypeDevice);
Expand All @@ -233,8 +275,8 @@ int yaksuri_cudai_iunpack(const void *inbuf, void *outbuf, uintptr_t count, yaks
yaksuri_cudai_global.stream[target]);
YAKSURI_CUDAI_CUDA_ERR_CHKANDJUMP(cerr, rc, fn_fail);

cuda_type->unpack(device_tmpbuf, outbuf, count, cuda_type->md, n_threads, n_blocks,
target);
cuda_type->unpack(device_tmpbuf, outbuf, count, cuda_type->md, n_threads, n_blocks_x,
n_blocks_y, n_blocks_z, target);
} else {
rc = YAKSA_ERR__INTERNAL;
goto fn_fail;
Expand Down
1 change: 1 addition & 0 deletions src/util/yaksu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#define YAKSU_MAX(x, y) ((x) > (y) ? (x) : (y))
#define YAKSU_MIN(x, y) ((x) < (y) ? (x) : (y))
#define YAKSU_CEIL(x, y) (((x) / (y)) + !!((x) % (y)))

#define YAKSU_ERR_CHKANDJUMP(check, rc, errcode, label) \
do { \
Expand Down