Skip to content

Commit

Permalink
Merge pull request #1678 from CEED/jeremy/vec-fix
Browse files Browse the repository at this point in the history
vec - fix poinwisemult length check
  • Loading branch information
jeremylt authored Oct 3, 2024
2 parents bdd4742 + 54404f0 commit e704478
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
15 changes: 14 additions & 1 deletion backends/cuda-ref/ceed-cuda-ref-operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -1721,6 +1721,19 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
// Work vector
CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec_in));
CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec_out));
{
CeedSize length_in, length_out;

CeedCallBackend(CeedVectorGetLength(active_e_vec_in, &length_in));
CeedCallBackend(CeedVectorGetLength(active_e_vec_out, &length_out));
// Need input e_vec to be longer
if (length_in < length_out) {
CeedVector temp = active_e_vec_in;

active_e_vec_in = active_e_vec_out;
active_e_vec_out = temp;
}
}

// Get point coordinates
if (!impl->point_coords_elem) {
Expand Down Expand Up @@ -1804,7 +1817,7 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
const CeedScalar *e_vec_array;

CeedCallBackend(CeedVectorGetArrayRead(active_e_vec_in, CEED_MEM_DEVICE, &e_vec_array));
CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, (CeedScalar *)e_vec_array));
CeedCallBackend(CeedVectorSetArray(q_vec, CEED_MEM_DEVICE, CEED_USE_POINTER, (CeedScalar *)e_vec_array));
break;
}
case CEED_EVAL_INTERP:
Expand Down
15 changes: 14 additions & 1 deletion backends/hip-ref/ceed-hip-ref-operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -1718,6 +1718,19 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip(CeedOperator op, Ce
// Work vector
CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec_in));
CeedCallBackend(CeedGetWorkVector(ceed, impl->max_active_e_vec_len, &active_e_vec_out));
{
CeedSize length_in, length_out;

CeedCallBackend(CeedVectorGetLength(active_e_vec_in, &length_in));
CeedCallBackend(CeedVectorGetLength(active_e_vec_out, &length_out));
// Need input e_vec to be longer
if (length_in < length_out) {
CeedVector temp = active_e_vec_in;

active_e_vec_in = active_e_vec_out;
active_e_vec_out = temp;
}
}

// Get point coordinates
if (!impl->point_coords_elem) {
Expand Down Expand Up @@ -1801,7 +1814,7 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip(CeedOperator op, Ce
const CeedScalar *e_vec_array;

CeedCallBackend(CeedVectorGetArrayRead(active_e_vec_in, CEED_MEM_DEVICE, &e_vec_array));
CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, (CeedScalar *)e_vec_array));
CeedCallBackend(CeedVectorSetArray(q_vec, CEED_MEM_DEVICE, CEED_USE_POINTER, (CeedScalar *)e_vec_array));
break;
}
case CEED_EVAL_INTERP:
Expand Down
8 changes: 4 additions & 4 deletions interface/ceed-vector.c
Original file line number Diff line number Diff line change
Expand Up @@ -862,10 +862,10 @@ int CeedVectorPointwiseMult(CeedVector w, CeedVector x, CeedVector y) {
CeedCall(CeedVectorGetLength(w, &length_w));
CeedCall(CeedVectorGetLength(x, &length_x));
CeedCall(CeedVectorGetLength(y, &length_y));
CeedCheck(length_x >= length_x && length_y >= length_w, ceed, CEED_ERROR_UNSUPPORTED,
"Cannot multiply vectors of different lengths."
" x length: %" CeedSize_FMT " y length: %" CeedSize_FMT,
length_x, length_y);
CeedCheck(length_x >= length_w && length_y >= length_w, ceed, CEED_ERROR_UNSUPPORTED,
"Cannot pointwise multiply vectors of incompatible lengths."
" w length: %" CeedSize_FMT " x length: %" CeedSize_FMT " y length: %" CeedSize_FMT,
length_w, length_x, length_y);

CeedCall(CeedGetParent(w->ceed, &ceed_parent_w));
CeedCall(CeedGetParent(x->ceed, &ceed_parent_x));
Expand Down

0 comments on commit e704478

Please sign in to comment.