diff --git a/codegen/manual_client.cpp b/codegen/manual_client.cpp index 91c4c62..26992dc 100755 --- a/codegen/manual_client.cpp +++ b/codegen/manual_client.cpp @@ -865,17 +865,19 @@ extern "C" cublasStatus_t cublasCreate_v2(cublasHandle_t* handle) { cublasStatus_t return_value; - - std::cout << "cublas handle: " << handle << std::endl; + cublasHandle_t h; if (rpc_start_request(0, RPC_cublasCreate_v2) < 0 || - rpc_write(0, handle, sizeof(cublasHandle_t)) < 0 || rpc_wait_for_response(0) < 0 || + rpc_read(0, &h, sizeof(cublasHandle_t)) || rpc_end_response(0, &return_value) < 0) { return CUBLAS_STATUS_INTERNAL_ERROR; } + // it's important to create the cublas handle on the device and update the pointer on the client + *handle = h; + return return_value; } diff --git a/codegen/manual_server.cpp b/codegen/manual_server.cpp index 5540e78..07c0c8f 100755 --- a/codegen/manual_server.cpp +++ b/codegen/manual_server.cpp @@ -53,6 +53,9 @@ int handle_cudaMemcpy(void *conn) result = cudaMemcpy(host_data, src, count, cudaMemcpyDeviceToHost); + std::cout << "DONE COPYING " << result << std::endl; + std::cout << "DONE COPYING COUNT" << count << std::endl; + if (rpc_start_response(conn, request_id) < 0 || rpc_write(conn, host_data, count) < 0) goto ERROR_1; @@ -599,10 +602,7 @@ int handle___cudaRegisterVar(void *conn) int handle_cublasCreate_v2(void *conn) { - cublasHandle_t* handle; - - if (rpc_read(conn, handle, sizeof(cublasHandle_t)) < 0) - return -1; + cublasHandle_t *handle; int request_id = rpc_end_request(conn); if (request_id < 0) @@ -611,6 +611,7 @@ int handle_cublasCreate_v2(void *conn) cublasStatus_t result = cublasCreate(handle); if (rpc_start_response(conn, request_id) < 0 || + rpc_write(conn, handle, sizeof(cublasHandle_t)) < 0 || rpc_write(conn, &result, sizeof(cublasStatus_t)) < 0 || rpc_end_response(conn, &result) < 0) return -1; @@ -690,16 +691,6 @@ int handle_cublasSgemm_v2(void *conn) std::cout << "Calling cublasSgemm with handle: " << handle << std::endl; - printf("Calling cublasSgemm with the following parameters:\n"); - printf(" Handle: %p\n", handle); - printf(" transa: %d, transb: %d\n", transa, transb); - printf(" m: %d, n: %d, k: %d\n", m, n, k); - printf(" alpha: %f\n", alpha); - printf(" A: %p, lda: %d\n", A, lda); - printf(" B: %p, ldb: %d\n", B, ldb); - printf(" beta: %f\n", beta); - printf(" C: %p, ldc: %d\n", C, ldc); - // Perform cublasSgemm cublasStatus_t result = cublasSgemm(handle, transa, transb, m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc);