Skip to content

Commit

Permalink
chore: cublas device handle
Browse files Browse the repository at this point in the history
  • Loading branch information
brodeynewman committed Nov 12, 2024
1 parent 982644d commit 695a028
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 17 deletions.
8 changes: 5 additions & 3 deletions codegen/manual_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -865,17 +865,19 @@ extern "C"
cublasStatus_t cublasCreate_v2(cublasHandle_t* handle)
{
cublasStatus_t return_value;

std::cout << "cublas handle: " << handle << std::endl;
cublasHandle_t h;

if (rpc_start_request(0, RPC_cublasCreate_v2) < 0 ||
rpc_write(0, handle, sizeof(cublasHandle_t)) < 0 ||
rpc_wait_for_response(0) < 0 ||
rpc_read(0, &h, sizeof(cublasHandle_t)) ||
rpc_end_response(0, &return_value) < 0)
{
return CUBLAS_STATUS_INTERNAL_ERROR;
}

// it's important to create the cublas handle on the device and update the pointer on the client
*handle = h;

return return_value;
}

Expand Down
19 changes: 5 additions & 14 deletions codegen/manual_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ int handle_cudaMemcpy(void *conn)

result = cudaMemcpy(host_data, src, count, cudaMemcpyDeviceToHost);

std::cout << "DONE COPYING " << result << std::endl;
std::cout << "DONE COPYING COUNT" << count << std::endl;

if (rpc_start_response(conn, request_id) < 0 ||
rpc_write(conn, host_data, count) < 0)
goto ERROR_1;
Expand Down Expand Up @@ -599,10 +602,7 @@ int handle___cudaRegisterVar(void *conn)

int handle_cublasCreate_v2(void *conn)
{
cublasHandle_t* handle;

if (rpc_read(conn, handle, sizeof(cublasHandle_t)) < 0)
return -1;
cublasHandle_t *handle;

int request_id = rpc_end_request(conn);
if (request_id < 0)
Expand All @@ -611,6 +611,7 @@ int handle_cublasCreate_v2(void *conn)
cublasStatus_t result = cublasCreate(handle);

if (rpc_start_response(conn, request_id) < 0 ||
rpc_write(conn, handle, sizeof(cublasHandle_t)) < 0 ||
rpc_write(conn, &result, sizeof(cublasStatus_t)) < 0 ||
rpc_end_response(conn, &result) < 0)
return -1;
Expand Down Expand Up @@ -690,16 +691,6 @@ int handle_cublasSgemm_v2(void *conn)

std::cout << "Calling cublasSgemm with handle: " << handle << std::endl;

printf("Calling cublasSgemm with the following parameters:\n");
printf(" Handle: %p\n", handle);
printf(" transa: %d, transb: %d\n", transa, transb);
printf(" m: %d, n: %d, k: %d\n", m, n, k);
printf(" alpha: %f\n", alpha);
printf(" A: %p, lda: %d\n", A, lda);
printf(" B: %p, ldb: %d\n", B, ldb);
printf(" beta: %f\n", beta);
printf(" C: %p, ldc: %d\n", C, ldc);

// Perform cublasSgemm
cublasStatus_t result = cublasSgemm(handle, transa, transb, m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc);

Expand Down

0 comments on commit 695a028

Please sign in to comment.