Skip to content

Commit

Permalink
gpu - reuse evecs where able
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremylt committed Aug 23, 2024
1 parent 229d7ba commit 3ceddb9
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 10 deletions.
48 changes: 43 additions & 5 deletions backends/cuda-ref/ceed-cuda-ref-operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,52 @@ static int CeedOperatorSetup_Cuda(CeedOperator op) {
impl->num_inputs = num_input_fields;
impl->num_outputs = num_output_fields;

// Set up infield and outfield e_vecs and q_vecs
// Set up infield and outfield e-vecs and q-vecs
// Infields
CeedCallBackend(
CeedOperatorSetupFields_Cuda(qf, op, true, false, impl->skip_rstr_in, NULL, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, Q, num_elem));
// Outfields
CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, false, false, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs, impl->q_vecs_out,
num_input_fields, num_output_fields, Q, num_elem));

// Reuse active e-vecs where able
{
CeedInt num_used = 0;
CeedElemRestriction *rstr_used = NULL;

for (CeedInt i = 0; i < num_input_fields; i++) {
bool is_used = false;
CeedVector vec_i;
CeedElemRestriction rstr_i;

CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec_i));
if (vec_i != CEED_VECTOR_ACTIVE) continue;
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &rstr_i));
for (CeedInt j = 0; j < num_used; j++) {
if (rstr_i == rstr_used[i]) is_used = true;
}
if (is_used) continue;
num_used++;
if (num_used == 1) CeedCallBackend(CeedCalloc(num_used, &rstr_used));
else CeedCallBackend(CeedRealloc(num_used, &rstr_used));
rstr_used[num_used - 1] = rstr_i;
for (CeedInt j = num_output_fields - 1; j >= 0; j--) {
CeedEvalMode eval_mode;
CeedVector vec_j;
CeedElemRestriction rstr_j;

CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[j], &vec_j));
if (vec_j != CEED_VECTOR_ACTIVE) continue;
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[j], &eval_mode));
if (eval_mode == CEED_EVAL_NONE) continue;
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &rstr_j));
if (rstr_i == rstr_j) {
CeedCallBackend(CeedVectorReferenceCopy(impl->e_vecs[i], &impl->e_vecs[j + impl->num_inputs]));
}
}
}
CeedCallBackend(CeedFree(&rstr_used));
}
CeedCallBackend(CeedOperatorSetSetupDone(op));
return CEED_ERROR_SUCCESS;
}
Expand Down Expand Up @@ -310,7 +348,7 @@ static inline int CeedOperatorSetupInputs_Cuda(CeedInt num_input_fields, CeedQFu
uint64_t state;

CeedCallBackend(CeedVectorGetState(vec, &state));
if (state != impl->input_states[i] && !impl->skip_rstr_in[i]) {
if ((state != impl->input_states[i] || vec == in_vec) && !impl->skip_rstr_in[i]) {
CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_NOTRANSPOSE, vec, impl->e_vecs[i], request));
}
impl->input_states[i] = state;
Expand Down Expand Up @@ -435,6 +473,9 @@ static int CeedOperatorApplyAdd_Cuda(CeedOperator op, CeedVector in_vec, CeedVec
// Q function
CeedCallBackend(CeedQFunctionApply(qf, num_elem * Q, impl->q_vecs_in, impl->q_vecs_out));

// Restore input arrays
CeedCallBackend(CeedOperatorRestoreInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, false, e_data, impl));

// Output basis apply if needed
for (CeedInt i = 0; i < num_output_fields; i++) {
CeedEvalMode eval_mode;
Expand Down Expand Up @@ -490,9 +531,6 @@ static int CeedOperatorApplyAdd_Cuda(CeedOperator op, CeedVector in_vec, CeedVec

CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs[i + impl->num_inputs], vec, request));
}

// Restore input arrays
CeedCallBackend(CeedOperatorRestoreInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, false, e_data, impl));
return CEED_ERROR_SUCCESS;
}

Expand Down
48 changes: 43 additions & 5 deletions backends/hip-ref/ceed-hip-ref-operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -263,14 +263,52 @@ static int CeedOperatorSetup_Hip(CeedOperator op) {
impl->num_inputs = num_input_fields;
impl->num_outputs = num_output_fields;

// Set up infield and outfield e_vecs and q_vecs
// Set up infield and outfield e-vecs and q-vecs
// Infields
CeedCallBackend(
CeedOperatorSetupFields_Hip(qf, op, true, false, impl->skip_rstr_in, NULL, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, Q, num_elem));
// Outfields
CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, false, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs, impl->q_vecs_out,
num_input_fields, num_output_fields, Q, num_elem));

// Reuse active e-vecs where able
{
CeedInt num_used = 0;
CeedElemRestriction *rstr_used = NULL;

for (CeedInt i = 0; i < num_input_fields; i++) {
bool is_used = false;
CeedVector vec_i;
CeedElemRestriction rstr_i;

CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec_i));
if (vec_i != CEED_VECTOR_ACTIVE) continue;
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &rstr_i));
for (CeedInt j = 0; j < num_used; j++) {
if (rstr_i == rstr_used[i]) is_used = true;
}
if (is_used) continue;
num_used++;
if (num_used == 1) CeedCallBackend(CeedCalloc(num_used, &rstr_used));
else CeedCallBackend(CeedRealloc(num_used, &rstr_used));
rstr_used[num_used - 1] = rstr_i;
for (CeedInt j = num_output_fields - 1; j >= 0; j--) {
CeedEvalMode eval_mode;
CeedVector vec_j;
CeedElemRestriction rstr_j;

CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[j], &vec_j));
if (vec_j != CEED_VECTOR_ACTIVE) continue;
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[j], &eval_mode));
if (eval_mode == CEED_EVAL_NONE) continue;
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &rstr_j));
if (rstr_i == rstr_j) {
CeedCallBackend(CeedVectorReferenceCopy(impl->e_vecs[i], &impl->e_vecs[j + impl->num_inputs]));
}
}
}
CeedCallBackend(CeedFree(&rstr_used));
}
CeedCallBackend(CeedOperatorSetSetupDone(op));
return CEED_ERROR_SUCCESS;
}
Expand Down Expand Up @@ -309,7 +347,7 @@ static inline int CeedOperatorSetupInputs_Hip(CeedInt num_input_fields, CeedQFun
uint64_t state;

CeedCallBackend(CeedVectorGetState(vec, &state));
if (state != impl->input_states[i] && !impl->skip_rstr_in[i]) {
if ((state != impl->input_states[i] || vec == in_vec) && !impl->skip_rstr_in[i]) {
CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_NOTRANSPOSE, vec, impl->e_vecs[i], request));
}
impl->input_states[i] = state;
Expand Down Expand Up @@ -434,6 +472,9 @@ static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector in_vec, CeedVect
// Q function
CeedCallBackend(CeedQFunctionApply(qf, num_elem * Q, impl->q_vecs_in, impl->q_vecs_out));

// Restore input arrays
CeedCallBackend(CeedOperatorRestoreInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, false, e_data, impl));

// Output basis apply if needed
for (CeedInt i = 0; i < num_output_fields; i++) {
CeedEvalMode eval_mode;
Expand Down Expand Up @@ -489,9 +530,6 @@ static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector in_vec, CeedVect

CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs[i + impl->num_inputs], vec, request));
}

// Restore input arrays
CeedCallBackend(CeedOperatorRestoreInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, false, e_data, impl));
return CEED_ERROR_SUCCESS;
}

Expand Down

0 comments on commit 3ceddb9

Please sign in to comment.