diff --git a/backends/blocked/ceed-blocked-operator.c b/backends/blocked/ceed-blocked-operator.c index 788533cbff..80b1f44865 100644 --- a/backends/blocked/ceed-blocked-operator.c +++ b/backends/blocked/ceed-blocked-operator.c @@ -16,9 +16,9 @@ //------------------------------------------------------------------------------ // Setup Input/Output Fields //------------------------------------------------------------------------------ -static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bool is_input, bool *skip_rstr, const CeedInt block_size, - CeedElemRestriction *block_rstr, CeedVector *e_vecs_full, CeedVector *e_vecs, CeedVector *q_vecs, - CeedInt start_e, CeedInt num_fields, CeedInt Q) { +static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bool is_input, bool *skip_rstr, CeedInt *e_data_out_indices, + bool *apply_add_basis, const CeedInt block_size, CeedElemRestriction *block_rstr, CeedVector *e_vecs_full, + CeedVector *e_vecs, CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q) { Ceed ceed; CeedSize e_size, q_size; CeedInt num_comp, size, P; @@ -135,7 +135,7 @@ static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bo break; } } - // Drop duplicate input restrictions + // Drop duplicate restrictions if (is_input) { for (CeedInt i = 0; i < num_fields; i++) { CeedVector vec_i; @@ -151,11 +151,33 @@ static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bo CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); if (vec_i == vec_j && rstr_i == rstr_j) { CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); - CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i], &e_vecs_full[j])); + CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i + start_e], &e_vecs_full[j + start_e])); skip_rstr[j] = true; } } } + } else { + for (CeedInt i = num_fields - 1; i >= 0; i--) { + CeedVector vec_i; + CeedElemRestriction rstr_i; + + CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i)); + CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i)); + for (CeedInt j = i - 1; j >= 0; j--) { + CeedVector vec_j; + CeedElemRestriction rstr_j; + + CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); + CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); + if (vec_i == vec_j && rstr_i == rstr_j) { + CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); + CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i + start_e], &e_vecs_full[j + start_e])); + skip_rstr[j] = true; + apply_add_basis[i] = true; + e_data_out_indices[j] = i; + } + } + } } return CEED_ERROR_SUCCESS; } @@ -189,6 +211,9 @@ static int CeedOperatorSetup_Blocked(CeedOperator op) { CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in)); + CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_out)); + CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_data_out_indices)); + CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->apply_add_basis_out)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out)); @@ -200,11 +225,12 @@ static int CeedOperatorSetup_Blocked(CeedOperator op) { // Set up infield and outfield pointer arrays // Infields - CeedCallBackend(CeedOperatorSetupFields_Blocked(qf, op, true, impl->skip_rstr_in, block_size, impl->block_rstr, impl->e_vecs_full, impl->e_vecs_in, - impl->q_vecs_in, 0, num_input_fields, Q)); + CeedCallBackend(CeedOperatorSetupFields_Blocked(qf, op, true, impl->skip_rstr_in, NULL, NULL, block_size, impl->block_rstr, impl->e_vecs_full, + impl->e_vecs_in, impl->q_vecs_in, 0, num_input_fields, Q)); // Outfields - CeedCallBackend(CeedOperatorSetupFields_Blocked(qf, op, false, NULL, block_size, impl->block_rstr, impl->e_vecs_full, impl->e_vecs_out, - impl->q_vecs_out, num_input_fields, num_output_fields, Q)); + CeedCallBackend(CeedOperatorSetupFields_Blocked(qf, op, false, impl->skip_rstr_out, impl->e_data_out_indices, impl->apply_add_basis_out, block_size, + impl->block_rstr, impl->e_vecs_full, impl->e_vecs_out, impl->q_vecs_out, num_input_fields, + num_output_fields, Q)); // Identity QFunctions if (impl->is_identity_qf) { @@ -310,8 +336,8 @@ static inline int CeedOperatorInputBasis_Blocked(CeedInt e, CeedInt Q, CeedQFunc // Output Basis Action //------------------------------------------------------------------------------ static inline int CeedOperatorOutputBasis_Blocked(CeedInt e, CeedInt Q, CeedQFunctionField *qf_output_fields, CeedOperatorField *op_output_fields, - CeedInt block_size, CeedInt num_input_fields, CeedInt num_output_fields, CeedOperator op, - CeedScalar *e_data_full[2 * CEED_FIELD_MAX], CeedOperator_Blocked *impl) { + CeedInt block_size, CeedInt num_input_fields, CeedInt num_output_fields, bool *apply_add_basis, + CeedOperator op, CeedScalar *e_data_full[2 * CEED_FIELD_MAX], CeedOperator_Blocked *impl) { for (CeedInt i = 0; i < num_output_fields; i++) { CeedInt elem_size, num_comp; CeedEvalMode eval_mode; @@ -334,7 +360,11 @@ static inline int CeedOperatorOutputBasis_Blocked(CeedInt e, CeedInt Q, CeedQFun CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); CeedCallBackend(CeedVectorSetArray(impl->e_vecs_out[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i + num_input_fields][(CeedSize)e * elem_size * num_comp])); - CeedCallBackend(CeedBasisApply(basis, block_size, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i])); + if (apply_add_basis[i]) { + CeedCallBackend(CeedBasisApplyAdd(basis, block_size, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i])); + } else { + CeedCallBackend(CeedBasisApply(basis, block_size, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i])); + } break; // LCOV_EXCL_START case CEED_EVAL_WEIGHT: { @@ -405,8 +435,12 @@ static int CeedOperatorApplyAdd_Blocked(CeedOperator op, CeedVector in_vec, Ceed CeedCallBackend(CeedOperatorSetupInputs_Blocked(num_input_fields, qf_input_fields, op_input_fields, in_vec, false, e_data_full, impl, request)); // Output Evecs - for (CeedInt i = 0; i < num_output_fields; i++) { - CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_full[i + impl->num_inputs], CEED_MEM_HOST, &e_data_full[i + num_input_fields])); + for (CeedInt i = num_output_fields - 1; i >= 0; i--) { + if (impl->skip_rstr_out[i]) { + e_data_full[i + num_input_fields] = e_data_full[impl->e_data_out_indices[i] + num_input_fields]; + } else { + CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_full[i + impl->num_inputs], CEED_MEM_HOST, &e_data_full[i + num_input_fields])); + } } // Loop through elements @@ -430,14 +464,15 @@ static int CeedOperatorApplyAdd_Blocked(CeedOperator op, CeedVector in_vec, Ceed } // Output basis apply - CeedCallBackend(CeedOperatorOutputBasis_Blocked(e, Q, qf_output_fields, op_output_fields, block_size, num_input_fields, num_output_fields, op, - e_data_full, impl)); + CeedCallBackend(CeedOperatorOutputBasis_Blocked(e, Q, qf_output_fields, op_output_fields, block_size, num_input_fields, num_output_fields, + impl->apply_add_basis_out, op, e_data_full, impl)); } // Output restriction for (CeedInt i = 0; i < num_output_fields; i++) { CeedVector vec; + if (impl->skip_rstr_out[i]) continue; // Restore evec CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_full[i + impl->num_inputs], &e_data_full[i + num_input_fields])); // Get output vector @@ -671,6 +706,9 @@ static int CeedOperatorDestroy_Blocked(CeedOperator op) { CeedCallBackend(CeedOperatorGetData(op, &impl)); CeedCallBackend(CeedFree(&impl->skip_rstr_in)); + CeedCallBackend(CeedFree(&impl->skip_rstr_out)); + CeedCallBackend(CeedFree(&impl->e_data_out_indices)); + CeedCallBackend(CeedFree(&impl->apply_add_basis_out)); for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) { CeedCallBackend(CeedElemRestrictionDestroy(&impl->block_rstr[i])); CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_full[i])); diff --git a/backends/blocked/ceed-blocked.h b/backends/blocked/ceed-blocked.h index 5876e969b7..f04307abdc 100644 --- a/backends/blocked/ceed-blocked.h +++ b/backends/blocked/ceed-blocked.h @@ -17,7 +17,8 @@ typedef struct { typedef struct { bool is_identity_qf, is_identity_rstr_op; - bool *skip_rstr_in; + bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out; + CeedInt *e_data_out_indices; uint64_t *input_states; /* State counter of inputs */ CeedVector *e_vecs_full; /* Full E-vectors, inputs followed by outputs */ CeedVector *e_vecs_in; /* Element block input E-vectors */ diff --git a/backends/opt/ceed-opt-operator.c b/backends/opt/ceed-opt-operator.c index eaacaedc12..c006399c5b 100644 --- a/backends/opt/ceed-opt-operator.c +++ b/backends/opt/ceed-opt-operator.c @@ -16,9 +16,9 @@ //------------------------------------------------------------------------------ // Setup Input/Output Fields //------------------------------------------------------------------------------ -static int CeedOperatorSetupFields_Opt(CeedQFunction qf, CeedOperator op, bool is_input, const CeedInt block_size, CeedElemRestriction *block_rstr, - CeedVector *e_vecs_full, CeedVector *e_vecs, CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, - CeedInt Q) { +static int CeedOperatorSetupFields_Opt(CeedQFunction qf, CeedOperator op, bool is_input, bool *skip_rstr, bool *apply_add_basis, + const CeedInt block_size, CeedElemRestriction *block_rstr, CeedVector *e_vecs_full, CeedVector *e_vecs, + CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q) { Ceed ceed; CeedSize e_size, q_size; CeedInt num_comp, size, P; @@ -161,6 +161,49 @@ static int CeedOperatorSetupFields_Opt(CeedQFunction qf, CeedOperator op, bool i } } } + // Drop duplicate restrictions + if (is_input) { + for (CeedInt i = 0; i < num_fields; i++) { + CeedVector vec_i; + CeedElemRestriction rstr_i; + + CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i)); + CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i)); + for (CeedInt j = i + 1; j < num_fields; j++) { + CeedVector vec_j; + CeedElemRestriction rstr_j; + + CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); + CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); + if (vec_i == vec_j && rstr_i == rstr_j) { + CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); + CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i + start_e], &e_vecs_full[j + start_e])); + skip_rstr[j] = true; + } + } + } + } else { + for (CeedInt i = num_fields - 1; i >= 0; i--) { + CeedVector vec_i; + CeedElemRestriction rstr_i; + + CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i)); + CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i)); + for (CeedInt j = i - 1; j >= 0; j--) { + CeedVector vec_j; + CeedElemRestriction rstr_j; + + CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); + CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); + if (vec_i == vec_j && rstr_i == rstr_j) { + CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); + CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i + start_e], &e_vecs_full[j + start_e])); + skip_rstr[j] = true; + apply_add_basis[i] = true; + } + } + } + } return CEED_ERROR_SUCCESS; } @@ -194,6 +237,9 @@ static int CeedOperatorSetup_Opt(CeedOperator op) { CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->block_rstr)); CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full)); + CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in)); + CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_out)); + CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->apply_add_basis_out)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out)); @@ -205,11 +251,11 @@ static int CeedOperatorSetup_Opt(CeedOperator op) { // Set up infield and outfield pointer arrays // Infields - CeedCallBackend(CeedOperatorSetupFields_Opt(qf, op, true, block_size, impl->block_rstr, impl->e_vecs_full, impl->e_vecs_in, impl->q_vecs_in, 0, - num_input_fields, Q)); + CeedCallBackend(CeedOperatorSetupFields_Opt(qf, op, true, impl->skip_rstr_in, NULL, block_size, impl->block_rstr, impl->e_vecs_full, + impl->e_vecs_in, impl->q_vecs_in, 0, num_input_fields, Q)); // Outfields - CeedCallBackend(CeedOperatorSetupFields_Opt(qf, op, false, block_size, impl->block_rstr, impl->e_vecs_full, impl->e_vecs_out, impl->q_vecs_out, - num_input_fields, num_output_fields, Q)); + CeedCallBackend(CeedOperatorSetupFields_Opt(qf, op, false, impl->skip_rstr_out, impl->apply_add_basis_out, block_size, impl->block_rstr, + impl->e_vecs_full, impl->e_vecs_out, impl->q_vecs_out, num_input_fields, num_output_fields, Q)); // Identity QFunctions if (impl->is_identity_qf) { @@ -251,7 +297,7 @@ static inline int CeedOperatorSetupInputs_Opt(CeedInt num_input_fields, CeedQFun if (vec != CEED_VECTOR_ACTIVE) { // Restrict CeedCallBackend(CeedVectorGetState(vec, &state)); - if (state != impl->input_states[i] && impl->block_rstr[i]) { + if (state != impl->input_states[i] && impl->block_rstr[i] && !impl->skip_rstr_in[i]) { CeedCallBackend(CeedElemRestrictionApply(impl->block_rstr[i], CEED_NOTRANSPOSE, vec, impl->e_vecs_full[i], request)); } impl->input_states[i] = state; @@ -327,8 +373,8 @@ static inline int CeedOperatorInputBasis_Opt(CeedInt e, CeedInt Q, CeedQFunction // Output Basis Action //------------------------------------------------------------------------------ static inline int CeedOperatorOutputBasis_Opt(CeedInt e, CeedInt Q, CeedQFunctionField *qf_output_fields, CeedOperatorField *op_output_fields, - CeedInt block_size, CeedInt num_input_fields, CeedInt num_output_fields, CeedOperator op, - CeedVector out_vec, CeedOperator_Opt *impl, CeedRequest *request) { + CeedInt block_size, CeedInt num_input_fields, CeedInt num_output_fields, bool *apply_add_basis, + bool *skip_rstr, CeedOperator op, CeedVector out_vec, CeedOperator_Opt *impl, CeedRequest *request) { for (CeedInt i = 0; i < num_output_fields; i++) { CeedEvalMode eval_mode; CeedVector vec; @@ -347,7 +393,11 @@ static inline int CeedOperatorOutputBasis_Opt(CeedInt e, CeedInt Q, CeedQFunctio case CEED_EVAL_DIV: case CEED_EVAL_CURL: CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); - CeedCallBackend(CeedBasisApply(basis, block_size, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i])); + if (apply_add_basis[i]) { + CeedCallBackend(CeedBasisApplyAdd(basis, block_size, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i])); + } else { + CeedCallBackend(CeedBasisApply(basis, block_size, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i])); + } break; // LCOV_EXCL_START case CEED_EVAL_WEIGHT: { @@ -356,6 +406,7 @@ static inline int CeedOperatorOutputBasis_Opt(CeedInt e, CeedInt Q, CeedQFunctio } } // Restrict output block + if (skip_rstr[i]) continue; // Get output vector CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); if (vec == CEED_VECTOR_ACTIVE) vec = out_vec; @@ -448,8 +499,8 @@ static int CeedOperatorApplyAdd_Opt(CeedOperator op, CeedVector in_vec, CeedVect } // Output basis apply and restriction - CeedCallBackend(CeedOperatorOutputBasis_Opt(e, Q, qf_output_fields, op_output_fields, block_size, num_input_fields, num_output_fields, op, - out_vec, impl, request)); + CeedCallBackend(CeedOperatorOutputBasis_Opt(e, Q, qf_output_fields, op_output_fields, block_size, num_input_fields, num_output_fields, + impl->apply_add_basis_out, impl->skip_rstr_out, op, out_vec, impl, request)); } // Restore input arrays @@ -694,6 +745,9 @@ static int CeedOperatorDestroy_Opt(CeedOperator op) { CeedCallBackend(CeedFree(&impl->block_rstr)); CeedCallBackend(CeedFree(&impl->e_vecs_full)); CeedCallBackend(CeedFree(&impl->input_states)); + CeedCallBackend(CeedFree(&impl->skip_rstr_in)); + CeedCallBackend(CeedFree(&impl->skip_rstr_out)); + CeedCallBackend(CeedFree(&impl->apply_add_basis_out)); for (CeedInt i = 0; i < impl->num_inputs; i++) { CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_in[i])); diff --git a/backends/opt/ceed-opt.h b/backends/opt/ceed-opt.h index b40124fb99..d5f7399a89 100644 --- a/backends/opt/ceed-opt.h +++ b/backends/opt/ceed-opt.h @@ -21,6 +21,7 @@ typedef struct { typedef struct { bool is_identity_qf, is_identity_rstr_op; + bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out; CeedElemRestriction *block_rstr; /* Blocked versions of restrictions */ CeedVector *e_vecs_full; /* Full E-vectors, inputs followed by outputs */ uint64_t *input_states; /* State counter of inputs */ diff --git a/backends/ref/ceed-ref-operator.c b/backends/ref/ceed-ref-operator.c index d605c943e7..de79e96d5b 100644 --- a/backends/ref/ceed-ref-operator.c +++ b/backends/ref/ceed-ref-operator.c @@ -16,8 +16,9 @@ //------------------------------------------------------------------------------ // Setup Input/Output Fields //------------------------------------------------------------------------------ -static int CeedOperatorSetupFields_Ref(CeedQFunction qf, CeedOperator op, bool is_input, bool *skip_rstr, CeedVector *e_vecs_full, CeedVector *e_vecs, - CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q) { +static int CeedOperatorSetupFields_Ref(CeedQFunction qf, CeedOperator op, bool is_input, bool *skip_rstr, CeedInt *e_data_out_indices, + bool *apply_add_basis, CeedVector *e_vecs_full, CeedVector *e_vecs, CeedVector *q_vecs, CeedInt start_e, + CeedInt num_fields, CeedInt Q) { Ceed ceed; CeedSize e_size, q_size; CeedInt num_comp, size, P; @@ -78,7 +79,7 @@ static int CeedOperatorSetupFields_Ref(CeedQFunction qf, CeedOperator op, bool i break; } } - // Drop duplicate input restrictions + // Drop duplicate restrictions if (is_input) { for (CeedInt i = 0; i < num_fields; i++) { CeedVector vec_i; @@ -94,11 +95,33 @@ static int CeedOperatorSetupFields_Ref(CeedQFunction qf, CeedOperator op, bool i CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); if (vec_i == vec_j && rstr_i == rstr_j) { CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); - CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i], &e_vecs_full[j])); + CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i + start_e], &e_vecs_full[j + start_e])); skip_rstr[j] = true; } } } + } else { + for (CeedInt i = num_fields - 1; i >= 0; i--) { + CeedVector vec_i; + CeedElemRestriction rstr_i; + + CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i)); + CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i)); + for (CeedInt j = i - 1; j >= 0; j--) { + CeedVector vec_j; + CeedElemRestriction rstr_j; + + CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); + CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); + if (vec_i == vec_j && rstr_i == rstr_j) { + CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); + CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i + start_e], &e_vecs_full[j + start_e])); + skip_rstr[j] = true; + apply_add_basis[i] = true; + e_data_out_indices[j] = i; + } + } + } } return CEED_ERROR_SUCCESS; } @@ -128,6 +151,9 @@ static int CeedOperatorSetup_Ref(CeedOperator op) { CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in)); + CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_out)); + CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_data_out_indices)); + CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->apply_add_basis_out)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out)); @@ -139,11 +165,11 @@ static int CeedOperatorSetup_Ref(CeedOperator op) { // Set up infield and outfield e_vecs and q_vecs // Infields - CeedCallBackend( - CeedOperatorSetupFields_Ref(qf, op, true, impl->skip_rstr_in, impl->e_vecs_full, impl->e_vecs_in, impl->q_vecs_in, 0, num_input_fields, Q)); + CeedCallBackend(CeedOperatorSetupFields_Ref(qf, op, true, impl->skip_rstr_in, NULL, NULL, impl->e_vecs_full, impl->e_vecs_in, impl->q_vecs_in, 0, + num_input_fields, Q)); // Outfields - CeedCallBackend(CeedOperatorSetupFields_Ref(qf, op, false, NULL, impl->e_vecs_full, impl->e_vecs_out, impl->q_vecs_out, num_input_fields, - num_output_fields, Q)); + CeedCallBackend(CeedOperatorSetupFields_Ref(qf, op, false, impl->skip_rstr_out, impl->e_data_out_indices, impl->apply_add_basis_out, + impl->e_vecs_full, impl->e_vecs_out, impl->q_vecs_out, num_input_fields, num_output_fields, Q)); // Identity QFunctions if (impl->is_identity_qf) { @@ -252,7 +278,7 @@ static inline int CeedOperatorInputBasis_Ref(CeedInt e, CeedInt Q, CeedQFunction // Output Basis Action //------------------------------------------------------------------------------ static inline int CeedOperatorOutputBasis_Ref(CeedInt e, CeedInt Q, CeedQFunctionField *qf_output_fields, CeedOperatorField *op_output_fields, - CeedInt num_input_fields, CeedInt num_output_fields, CeedOperator op, + CeedInt num_input_fields, CeedInt num_output_fields, bool *apply_add_basis, CeedOperator op, CeedScalar *e_data_full[2 * CEED_FIELD_MAX], CeedOperator_Ref *impl) { for (CeedInt i = 0; i < num_output_fields; i++) { CeedInt elem_size, num_comp; @@ -276,7 +302,11 @@ static inline int CeedOperatorOutputBasis_Ref(CeedInt e, CeedInt Q, CeedQFunctio CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); CeedCallBackend(CeedVectorSetArray(impl->e_vecs_out[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i + num_input_fields][(CeedSize)e * elem_size * num_comp])); - CeedCallBackend(CeedBasisApply(basis, 1, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i])); + if (apply_add_basis[i]) { + CeedCallBackend(CeedBasisApplyAdd(basis, 1, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i])); + } else { + CeedCallBackend(CeedBasisApply(basis, 1, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i])); + } break; // LCOV_EXCL_START case CEED_EVAL_WEIGHT: { @@ -350,8 +380,12 @@ static int CeedOperatorApplyAdd_Ref(CeedOperator op, CeedVector in_vec, CeedVect CeedCallBackend(CeedOperatorSetupInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, in_vec, false, e_data_full, impl, request)); // Output Evecs - for (CeedInt i = 0; i < num_output_fields; i++) { - CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_full[i + impl->num_inputs], CEED_MEM_HOST, &e_data_full[i + num_input_fields])); + for (CeedInt i = num_output_fields - 1; i >= 0; i--) { + if (impl->skip_rstr_out[i]) { + e_data_full[i + num_input_fields] = e_data_full[impl->e_data_out_indices[i] + num_input_fields]; + } else { + CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_full[i + impl->num_inputs], CEED_MEM_HOST, &e_data_full[i + num_input_fields])); + } } // Loop through elements @@ -375,8 +409,8 @@ static int CeedOperatorApplyAdd_Ref(CeedOperator op, CeedVector in_vec, CeedVect } // Output basis apply - CeedCallBackend( - CeedOperatorOutputBasis_Ref(e, Q, qf_output_fields, op_output_fields, num_input_fields, num_output_fields, op, e_data_full, impl)); + CeedCallBackend(CeedOperatorOutputBasis_Ref(e, Q, qf_output_fields, op_output_fields, num_input_fields, num_output_fields, + impl->apply_add_basis_out, op, e_data_full, impl)); } // Output restriction @@ -384,6 +418,7 @@ static int CeedOperatorApplyAdd_Ref(CeedOperator op, CeedVector in_vec, CeedVect CeedVector vec; CeedElemRestriction elem_rstr; + if (impl->skip_rstr_out[i]) continue; // Restore Evec CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_full[i + impl->num_inputs], &e_data_full[i + num_input_fields])); // Get output vector @@ -590,8 +625,9 @@ static int CeedOperatorLinearAssembleQFunctionUpdate_Ref(CeedOperator op, CeedVe //------------------------------------------------------------------------------ // Setup Input/Output Fields //------------------------------------------------------------------------------ -static int CeedOperatorSetupFieldsAtPoints_Ref(CeedQFunction qf, CeedOperator op, bool is_input, bool *skip_rstr, CeedVector *e_vecs_full, - CeedVector *e_vecs, CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q) { +static int CeedOperatorSetupFieldsAtPoints_Ref(CeedQFunction qf, CeedOperator op, bool is_input, bool *skip_rstr, bool *apply_add_basis, + CeedVector *e_vecs_full, CeedVector *e_vecs, CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, + CeedInt Q) { Ceed ceed; CeedSize e_size, q_size; CeedInt max_num_points, num_comp, size, P; @@ -685,7 +721,7 @@ static int CeedOperatorSetupFieldsAtPoints_Ref(CeedQFunction qf, CeedOperator op if (e_vecs[i]) CeedCallBackend(CeedVectorSetValue(e_vecs[i], 0.0)); if (eval_mode != CEED_EVAL_WEIGHT) CeedCallBackend(CeedVectorSetValue(q_vecs[i], 0.0)); } - // Drop duplicate input restrictions + // Drop duplicate restrictions if (is_input) { for (CeedInt i = 0; i < num_fields; i++) { CeedVector vec_i; @@ -701,10 +737,32 @@ static int CeedOperatorSetupFieldsAtPoints_Ref(CeedQFunction qf, CeedOperator op CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); if (vec_i == vec_j && rstr_i == rstr_j) { CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); + CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i + start_e], &e_vecs_full[j + start_e])); skip_rstr[j] = true; } } } + } else { + for (CeedInt i = num_fields - 1; i >= 0; i--) { + CeedVector vec_i; + CeedElemRestriction rstr_i; + + CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i)); + CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i)); + for (CeedInt j = i - 1; j >= 0; j--) { + CeedVector vec_j; + CeedElemRestriction rstr_j; + + CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j)); + CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j)); + if (vec_i == vec_j && rstr_i == rstr_j) { + CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j])); + CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i + start_e], &e_vecs_full[j + start_e])); + skip_rstr[j] = true; + apply_add_basis[i] = true; + } + } + } } return CEED_ERROR_SUCCESS; } @@ -734,6 +792,8 @@ static int CeedOperatorSetupAtPoints_Ref(CeedOperator op) { CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in)); + CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_out)); + CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->apply_add_basis_out)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in)); CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out)); @@ -745,11 +805,11 @@ static int CeedOperatorSetupAtPoints_Ref(CeedOperator op) { // Set up infield and outfield pointer arrays // Infields - CeedCallBackend(CeedOperatorSetupFieldsAtPoints_Ref(qf, op, true, impl->skip_rstr_in, impl->e_vecs_full, impl->e_vecs_in, impl->q_vecs_in, 0, + CeedCallBackend(CeedOperatorSetupFieldsAtPoints_Ref(qf, op, true, impl->skip_rstr_in, NULL, impl->e_vecs_full, impl->e_vecs_in, impl->q_vecs_in, 0, num_input_fields, Q)); // Outfields - CeedCallBackend(CeedOperatorSetupFieldsAtPoints_Ref(qf, op, false, NULL, impl->e_vecs_full, impl->e_vecs_out, impl->q_vecs_out, num_input_fields, - num_output_fields, Q)); + CeedCallBackend(CeedOperatorSetupFieldsAtPoints_Ref(qf, op, false, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs_full, + impl->e_vecs_out, impl->q_vecs_out, num_input_fields, num_output_fields, Q)); // Identity QFunctions if (impl->is_identity_qf) { @@ -828,8 +888,8 @@ static inline int CeedOperatorInputBasisAtPoints_Ref(CeedInt e, CeedInt num_poin //------------------------------------------------------------------------------ static inline int CeedOperatorOutputBasisAtPoints_Ref(CeedInt e, CeedInt num_points_offset, CeedInt num_points, CeedQFunctionField *qf_output_fields, CeedOperatorField *op_output_fields, CeedInt num_input_fields, CeedInt num_output_fields, - CeedOperator op, CeedVector out_vec, CeedVector point_coords_elem, CeedOperator_Ref *impl, - CeedRequest *request) { + bool *apply_add_basis, bool *skip_rstr, CeedOperator op, CeedVector out_vec, + CeedVector point_coords_elem, CeedOperator_Ref *impl, CeedRequest *request) { for (CeedInt i = 0; i < num_output_fields; i++) { CeedRestrictionType rstr_type; CeedEvalMode eval_mode; @@ -849,8 +909,13 @@ static inline int CeedOperatorOutputBasisAtPoints_Ref(CeedInt e, CeedInt num_poi case CEED_EVAL_DIV: case CEED_EVAL_CURL: CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); - CeedCallBackend( - CeedBasisApplyAtPoints(basis, 1, &num_points, CEED_TRANSPOSE, eval_mode, point_coords_elem, impl->q_vecs_out[i], impl->e_vecs_out[i])); + if (apply_add_basis[i]) { + CeedCallBackend(CeedBasisApplyAddAtPoints(basis, 1, &num_points, CEED_TRANSPOSE, eval_mode, point_coords_elem, impl->q_vecs_out[i], + impl->e_vecs_out[i])); + } else { + CeedCallBackend( + CeedBasisApplyAtPoints(basis, 1, &num_points, CEED_TRANSPOSE, eval_mode, point_coords_elem, impl->q_vecs_out[i], impl->e_vecs_out[i])); + } break; // LCOV_EXCL_START case CEED_EVAL_WEIGHT: { @@ -859,6 +924,7 @@ static inline int CeedOperatorOutputBasisAtPoints_Ref(CeedInt e, CeedInt num_poi } } // Restrict output block + if (skip_rstr[i]) continue; // Get output vector CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type)); CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); @@ -920,7 +986,8 @@ static int CeedOperatorApplyAddAtPoints_Ref(CeedOperator op, CeedVector in_vec, // Output basis apply and restriction CeedCallBackend(CeedOperatorOutputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_output_fields, op_output_fields, num_input_fields, - num_output_fields, op, out_vec, impl->point_coords_elem, impl, request)); + num_output_fields, impl->apply_add_basis_out, impl->skip_rstr_out, op, out_vec, + impl->point_coords_elem, impl, request)); num_points_offset += num_points; } @@ -1292,7 +1359,8 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Ref(CeedOperator op, Ce // -- Output basis apply and restriction CeedCallBackend(CeedOperatorOutputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_output_fields, op_output_fields, num_input_fields, - num_output_fields, op, out_vec, impl->point_coords_elem, impl, request)); + num_output_fields, impl->apply_add_basis_out, impl->skip_rstr_out, op, out_vec, + impl->point_coords_elem, impl, request)); // -- Grab diagonal value for (CeedInt j = 0; j < num_output_fields; j++) { @@ -1389,6 +1457,9 @@ static int CeedOperatorDestroy_Ref(CeedOperator op) { CeedCallBackend(CeedOperatorGetData(op, &impl)); CeedCallBackend(CeedFree(&impl->skip_rstr_in)); + CeedCallBackend(CeedFree(&impl->skip_rstr_out)); + CeedCallBackend(CeedFree(&impl->e_data_out_indices)); + CeedCallBackend(CeedFree(&impl->apply_add_basis_out)); for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) { CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_full[i])); } diff --git a/backends/ref/ceed-ref.h b/backends/ref/ceed-ref.h index ff8e9fa773..880b4f89af 100644 --- a/backends/ref/ceed-ref.h +++ b/backends/ref/ceed-ref.h @@ -49,7 +49,8 @@ typedef struct { typedef struct { bool is_identity_qf, is_identity_rstr_op; - bool *skip_rstr_in; + bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out; + CeedInt *e_data_out_indices; uint64_t *input_states; /* State counter of inputs */ CeedVector *e_vecs_full; /* Full E-vectors, inputs followed by outputs */ CeedVector *e_vecs_in; /* Single element input E-vectors */