diff --git a/backends/hip-ref/ceed-hip-ref-operator.c b/backends/hip-ref/ceed-hip-ref-operator.c index bb5d09816d..6045d0ad96 100644 --- a/backends/hip-ref/ceed-hip-ref-operator.c +++ b/backends/hip-ref/ceed-hip-ref-operator.c @@ -27,6 +27,8 @@ static int CeedOperatorDestroy_Hip(CeedOperator op) { // Apply data CeedCallBackend(CeedFree(&impl->skip_rstr_in)); + CeedCallBackend(CeedFree(&impl->skip_rstr_out)); + CeedCallBackend(CeedFree(&impl->apply_add_basis_out)); for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) { CeedCallBackend(CeedVectorDestroy(&impl->e_vecs[i])); } @@ -96,8 +98,8 @@ static int CeedOperatorDestroy_Hip(CeedOperator op) { //------------------------------------------------------------------------------ // Setup infields or outfields //------------------------------------------------------------------------------ -static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool is_input, bool is_at_points, bool *skip_rstr, CeedVector *e_vecs, - CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q, CeedInt num_elem) { +static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool is_input, bool is_at_points, bool *skip_rstr, bool *apply_add_basis, + CeedVector *e_vecs, CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q, CeedInt num_elem) { Ceed ceed; CeedQFunctionField *qf_fields; CeedOperatorField *op_fields; @@ -183,7 +185,7 @@ static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool i break; } } - // Drop duplicate input restrictions + // Drop duplicate restrictions if (is_input) { for (CeedInt i = 0; i < num_fields; i++) { CeedVector vec_i; @@ -198,11 +200,31 @@ static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool i CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); if (vec_i == vec_j && rstr_i == rstr_j) { - CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); + CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i + start_e], &e_vecs[j + start_e])); skip_rstr[j] = true; } } } + } else { + for (CeedInt i = num_fields - 1; i >= 0; i--) { + CeedVector vec_i; + CeedElemRestriction rstr_i; + + CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i)); + CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i)); + for (CeedInt j = i - 1; j >= 0; j--) { + CeedVector vec_j; + CeedElemRestriction rstr_j; + + CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); + CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); + if (vec_i == vec_j && rstr_i == rstr_j) { + CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i + start_e], &e_vecs[j + start_e])); + skip_rstr[j] = true; + apply_add_basis[i] = true; + } + } + } } return CEED_ERROR_SUCCESS; } @@ -233,6 +255,8 @@ static int CeedOperatorSetup_Hip(CeedOperator op) { // Allocate CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in)); + CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_out)); + CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->apply_add_basis_out)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); @@ -242,10 +266,10 @@ static int CeedOperatorSetup_Hip(CeedOperator op) { // Set up infield and outfield e_vecs and q_vecs // Infields CeedCallBackend( - CeedOperatorSetupFields_Hip(qf, op, true, false, impl->skip_rstr_in, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, Q, num_elem)); + CeedOperatorSetupFields_Hip(qf, op, true, false, impl->skip_rstr_in, NULL, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, Q, num_elem)); // Outfields - CeedCallBackend( - CeedOperatorSetupFields_Hip(qf, op, false, false, NULL, impl->e_vecs, impl->q_vecs_out, num_input_fields, num_output_fields, Q, num_elem)); + CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, false, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs, impl->q_vecs_out, + num_input_fields, num_output_fields, Q, num_elem)); CeedCallBackend(CeedOperatorSetSetupDone(op)); return CEED_ERROR_SUCCESS; @@ -430,7 +454,11 @@ static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector in_vec, CeedVect case CEED_EVAL_DIV: case CEED_EVAL_CURL: CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); - CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs])); + if (impl->apply_add_basis_out[i]) { + CeedCallBackend(CeedBasisApplyAdd(basis, num_elem, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs])); + } else { + CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs])); + } break; // LCOV_EXCL_START case CEED_EVAL_WEIGHT: { @@ -451,6 +479,7 @@ static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector in_vec, CeedVect if (eval_mode == CEED_EVAL_NONE) { CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields])); } + if (impl->skip_rstr_out[i]) continue; // Get output vector CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); // Restrict @@ -498,6 +527,8 @@ static int CeedOperatorSetupAtPoints_Hip(CeedOperator op) { // Allocate CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in)); + CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_out)); + CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->apply_add_basis_out)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out)); @@ -506,11 +537,11 @@ static int CeedOperatorSetupAtPoints_Hip(CeedOperator op) { // Set up infield and outfield e_vecs and q_vecs // Infields - CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, true, true, impl->skip_rstr_in, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, + CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, true, true, impl->skip_rstr_in, NULL, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, max_num_points, num_elem)); // Outfields - CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, true, NULL, impl->e_vecs, impl->q_vecs_out, num_input_fields, num_output_fields, - max_num_points, num_elem)); + CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, true, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs, impl->q_vecs_out, + num_input_fields, num_output_fields, max_num_points, num_elem)); CeedCallBackend(CeedOperatorSetSetupDone(op)); return CEED_ERROR_SUCCESS; @@ -634,8 +665,13 @@ static int CeedOperatorApplyAddAtPoints_Hip(CeedOperator op, CeedVector in_vec, case CEED_EVAL_DIV: case CEED_EVAL_CURL: CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); - CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, impl->q_vecs_out[i], - impl->e_vecs[i + impl->num_inputs])); + if (impl->apply_add_basis_out[i]) { + CeedCallBackend(CeedBasisApplyAddAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, + impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs])); + } else { + CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, impl->q_vecs_out[i], + impl->e_vecs[i + impl->num_inputs])); + } break; // LCOV_EXCL_START case CEED_EVAL_WEIGHT: { @@ -656,6 +692,7 @@ static int CeedOperatorApplyAddAtPoints_Hip(CeedOperator op, CeedVector in_vec, if (eval_mode == CEED_EVAL_NONE) { CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields])); } + if (impl->skip_rstr_out[i]) continue; // Get output vector CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); // Restrict diff --git a/backends/hip-ref/ceed-hip-ref.h b/backends/hip-ref/ceed-hip-ref.h index 5199ce8767..59f8d809c2 100644 --- a/backends/hip-ref/ceed-hip-ref.h +++ b/backends/hip-ref/ceed-hip-ref.h @@ -132,7 +132,7 @@ typedef struct { } CeedOperatorAssemble_Hip; typedef struct { - bool *skip_rstr_in; + bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out; uint64_t *input_states; // State tracking for passive inputs CeedVector *e_vecs; // E-vectors, inputs followed by outputs CeedVector *q_vecs_in; // Input Q-vectors needed to apply operator