Skip to content

Commit

Permalink
nvector test: CUDA/HIP agnostic
Browse files Browse the repository at this point in the history
  • Loading branch information
jsdomine committed Jul 6, 2023
1 parent 921873b commit 3c38974
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions examples/nvector/parhyp/test_nvector_parhyp.c
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,9 @@ int check_ans(realtype ans, N_Vector X, sunindextype local_length)
Xvec = N_VGetVector_ParHyp(X);
Xdata = hypre_VectorData(hypre_ParVectorLocalVector(Xvec));
// if CUDA, malloc host -> cudamemcpy -> check -> free
#if defined(SUNDIALS_HYPRE_BACKENDS_CUDA)
#if defined(SUNDIALS_HYPRE_BACKENDS_CUDA_OR_HIP)
realtype *host_data = (realtype*)malloc(sizeof(realtype)*local_length);
cudaMemcpy(host_data,Xdata,sizeof(realtype)*local_length,cudaMemcpyDeviceToHost);
NV_ADD_LANG_PREFIX_PH(Memcpy)(host_data,Xdata,sizeof(realtype)*local_length,NV_ADD_LANG_PREFIX_PH(MemcpyDeviceToHost));
Xdata = host_data;
#endif

Expand All @@ -297,7 +297,7 @@ int check_ans(realtype ans, N_Vector X, sunindextype local_length)
failure += SUNRCompare(Xdata[i], ans);
}

#if defined(SUNDIALS_HYPRE_BACKENDS_CUDA)
#if defined(SUNDIALS_HYPRE_BACKENDS_CUDA_OR_HIP)
free(host_data);
#endif

Expand Down Expand Up @@ -327,11 +327,11 @@ void set_element_range(N_Vector X, sunindextype is, sunindextype ie,
Xvec = N_VGetVector_ParHyp(X);
Xdata = hypre_VectorData(hypre_ParVectorLocalVector(Xvec));

#if defined(SUNDIALS_HYPRE_BACKENDS_CUDA)
#if defined(SUNDIALS_HYPRE_BACKENDS_CUDA_OR_HIP)
int sub_len = ie-is+1;
realtype *host_data = (realtype*)malloc(sizeof(realtype)*sub_len);
for(i = 0; i < sub_len; i++) host_data[i] = val;
cudaMemcpy(Xdata+is,host_data,sizeof(realtype)*sub_len,cudaMemcpyHostToDevice);
NV_ADD_LANG_PREFIX_PH(Memcpy)(Xdata+is,host_data,sizeof(realtype)*sub_len,NV_ADD_LANG_PREFIX_PH(MemcpyHostToDevice));
free(host_data);
#else
for(i = is; i <= ie; i++) Xdata[i] = val;
Expand All @@ -347,9 +347,9 @@ realtype get_element(N_Vector X, sunindextype i)
Xvec = N_VGetVector_ParHyp(X);
Xdata = hypre_VectorData(hypre_ParVectorLocalVector(Xvec));

#if defined(SUNDIALS_HYPRE_BACKENDS_CUDA)
#if defined(SUNDIALS_HYPRE_BACKENDS_CUDA_OR_HIP)
realtype host_data;
cudaMemcpy(&host_data,Xdata,sizeof(realtype),cudaMemcpyDeviceToHost);
NV_ADD_LANG_PREFIX_PH(Memcpy)(&host_data,Xdata,sizeof(realtype),NV_ADD_LANG_PREFIX_PH(MemcpyDeviceToHost));
return host_data;
#else
return Xdata[i];
Expand All @@ -371,8 +371,8 @@ double max_time(N_Vector X, double time)
void sync_device(N_Vector x)
{
/* not running on GPU, just return */
#if defined(SUNDIALS_HYPRE_BACKENDS_CUDA)
cudaDeviceSynchronize();
#if defined(SUNDIALS_HYPRE_BACKENDS_CUDA_OR_HIP)
NV_ADD_LANG_PREFIX_PH(DeviceSynchronize)();
#endif
return;
}

0 comments on commit 3c38974

Please sign in to comment.