Skip to content

Commit

Permalink
cpu - skip duplicate output rstr
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremylt committed Aug 13, 2024
1 parent 311be37 commit a3f234e
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 57 deletions.
70 changes: 54 additions & 16 deletions backends/blocked/ceed-blocked-operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
//------------------------------------------------------------------------------
// Setup Input/Output Fields
//------------------------------------------------------------------------------
static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bool is_input, bool *skip_rstr, const CeedInt block_size,
CeedElemRestriction *block_rstr, CeedVector *e_vecs_full, CeedVector *e_vecs, CeedVector *q_vecs,
CeedInt start_e, CeedInt num_fields, CeedInt Q) {
static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bool is_input, bool *skip_rstr, CeedInt *e_data_out_indices,
bool *apply_add_basis, const CeedInt block_size, CeedElemRestriction *block_rstr, CeedVector *e_vecs_full,
CeedVector *e_vecs, CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q) {
Ceed ceed;
CeedSize e_size, q_size;
CeedInt num_comp, size, P;
Expand Down Expand Up @@ -135,7 +135,7 @@ static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bo
break;
}
}
// Drop duplicate input restrictions
// Drop duplicate restrictions
if (is_input) {
for (CeedInt i = 0; i < num_fields; i++) {
CeedVector vec_i;
Expand All @@ -151,11 +151,33 @@ static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bo
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j));
if (vec_i == vec_j && rstr_i == rstr_j) {
CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j]));
CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i], &e_vecs_full[j]));
CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i + start_e], &e_vecs_full[j + start_e]));
skip_rstr[j] = true;
}
}
}
} else {
for (CeedInt i = num_fields - 1; i >= 0; i--) {
CeedVector vec_i;
CeedElemRestriction rstr_i;

CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i));
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i));
for (CeedInt j = i - 1; j >= 0; j--) {
CeedVector vec_j;
CeedElemRestriction rstr_j;

CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j));
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j));
if (vec_i == vec_j && rstr_i == rstr_j) {
CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j]));
CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i + start_e], &e_vecs_full[j + start_e]));
skip_rstr[j] = true;
apply_add_basis[i] = true;
e_data_out_indices[j] = i;
}
}
}
}
return CEED_ERROR_SUCCESS;
}
Expand Down Expand Up @@ -189,6 +211,9 @@ static int CeedOperatorSetup_Blocked(CeedOperator op) {
CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full));

CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_out));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_data_out_indices));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->apply_add_basis_out));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out));
Expand All @@ -200,11 +225,12 @@ static int CeedOperatorSetup_Blocked(CeedOperator op) {

// Set up infield and outfield pointer arrays
// Infields
CeedCallBackend(CeedOperatorSetupFields_Blocked(qf, op, true, impl->skip_rstr_in, block_size, impl->block_rstr, impl->e_vecs_full, impl->e_vecs_in,
impl->q_vecs_in, 0, num_input_fields, Q));
CeedCallBackend(CeedOperatorSetupFields_Blocked(qf, op, true, impl->skip_rstr_in, NULL, NULL, block_size, impl->block_rstr, impl->e_vecs_full,
impl->e_vecs_in, impl->q_vecs_in, 0, num_input_fields, Q));
// Outfields
CeedCallBackend(CeedOperatorSetupFields_Blocked(qf, op, false, NULL, block_size, impl->block_rstr, impl->e_vecs_full, impl->e_vecs_out,
impl->q_vecs_out, num_input_fields, num_output_fields, Q));
CeedCallBackend(CeedOperatorSetupFields_Blocked(qf, op, false, impl->skip_rstr_out, impl->e_data_out_indices, impl->apply_add_basis_out, block_size,
impl->block_rstr, impl->e_vecs_full, impl->e_vecs_out, impl->q_vecs_out, num_input_fields,
num_output_fields, Q));

// Identity QFunctions
if (impl->is_identity_qf) {
Expand Down Expand Up @@ -310,8 +336,8 @@ static inline int CeedOperatorInputBasis_Blocked(CeedInt e, CeedInt Q, CeedQFunc
// Output Basis Action
//------------------------------------------------------------------------------
static inline int CeedOperatorOutputBasis_Blocked(CeedInt e, CeedInt Q, CeedQFunctionField *qf_output_fields, CeedOperatorField *op_output_fields,
CeedInt block_size, CeedInt num_input_fields, CeedInt num_output_fields, CeedOperator op,
CeedScalar *e_data_full[2 * CEED_FIELD_MAX], CeedOperator_Blocked *impl) {
CeedInt block_size, CeedInt num_input_fields, CeedInt num_output_fields, bool *apply_add_basis,
CeedOperator op, CeedScalar *e_data_full[2 * CEED_FIELD_MAX], CeedOperator_Blocked *impl) {
for (CeedInt i = 0; i < num_output_fields; i++) {
CeedInt elem_size, num_comp;
CeedEvalMode eval_mode;
Expand All @@ -334,7 +360,11 @@ static inline int CeedOperatorOutputBasis_Blocked(CeedInt e, CeedInt Q, CeedQFun
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
CeedCallBackend(CeedVectorSetArray(impl->e_vecs_out[i], CEED_MEM_HOST, CEED_USE_POINTER,
&e_data_full[i + num_input_fields][(CeedSize)e * elem_size * num_comp]));
CeedCallBackend(CeedBasisApply(basis, block_size, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i]));
if (apply_add_basis[i]) {
CeedCallBackend(CeedBasisApplyAdd(basis, block_size, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i]));
} else {
CeedCallBackend(CeedBasisApply(basis, block_size, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i]));
}
break;
// LCOV_EXCL_START
case CEED_EVAL_WEIGHT: {
Expand Down Expand Up @@ -405,8 +435,12 @@ static int CeedOperatorApplyAdd_Blocked(CeedOperator op, CeedVector in_vec, Ceed
CeedCallBackend(CeedOperatorSetupInputs_Blocked(num_input_fields, qf_input_fields, op_input_fields, in_vec, false, e_data_full, impl, request));

// Output Evecs
for (CeedInt i = 0; i < num_output_fields; i++) {
CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_full[i + impl->num_inputs], CEED_MEM_HOST, &e_data_full[i + num_input_fields]));
for (CeedInt i = num_output_fields - 1; i >= 0; i--) {
if (impl->skip_rstr_out[i]) {
e_data_full[i + num_input_fields] = e_data_full[impl->e_data_out_indices[i] + num_input_fields];
} else {
CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_full[i + impl->num_inputs], CEED_MEM_HOST, &e_data_full[i + num_input_fields]));
}
}

// Loop through elements
Expand All @@ -430,14 +464,15 @@ static int CeedOperatorApplyAdd_Blocked(CeedOperator op, CeedVector in_vec, Ceed
}

// Output basis apply
CeedCallBackend(CeedOperatorOutputBasis_Blocked(e, Q, qf_output_fields, op_output_fields, block_size, num_input_fields, num_output_fields, op,
e_data_full, impl));
CeedCallBackend(CeedOperatorOutputBasis_Blocked(e, Q, qf_output_fields, op_output_fields, block_size, num_input_fields, num_output_fields,
impl->apply_add_basis_out, op, e_data_full, impl));
}

// Output restriction
for (CeedInt i = 0; i < num_output_fields; i++) {
CeedVector vec;

if (impl->skip_rstr_out[i]) continue;
// Restore evec
CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_full[i + impl->num_inputs], &e_data_full[i + num_input_fields]));
// Get output vector
Expand Down Expand Up @@ -671,6 +706,9 @@ static int CeedOperatorDestroy_Blocked(CeedOperator op) {
CeedCallBackend(CeedOperatorGetData(op, &impl));

CeedCallBackend(CeedFree(&impl->skip_rstr_in));
CeedCallBackend(CeedFree(&impl->skip_rstr_out));
CeedCallBackend(CeedFree(&impl->e_data_out_indices));
CeedCallBackend(CeedFree(&impl->apply_add_basis_out));
for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) {
CeedCallBackend(CeedElemRestrictionDestroy(&impl->block_rstr[i]));
CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_full[i]));
Expand Down
3 changes: 2 additions & 1 deletion backends/blocked/ceed-blocked.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ typedef struct {

typedef struct {
bool is_identity_qf, is_identity_rstr_op;
bool *skip_rstr_in;
bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out;
CeedInt *e_data_out_indices;
uint64_t *input_states; /* State counter of inputs */
CeedVector *e_vecs_full; /* Full E-vectors, inputs followed by outputs */
CeedVector *e_vecs_in; /* Element block input E-vectors */
Expand Down
80 changes: 67 additions & 13 deletions backends/opt/ceed-opt-operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
//------------------------------------------------------------------------------
// Setup Input/Output Fields
//------------------------------------------------------------------------------
static int CeedOperatorSetupFields_Opt(CeedQFunction qf, CeedOperator op, bool is_input, const CeedInt block_size, CeedElemRestriction *block_rstr,
CeedVector *e_vecs_full, CeedVector *e_vecs, CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields,
CeedInt Q) {
static int CeedOperatorSetupFields_Opt(CeedQFunction qf, CeedOperator op, bool is_input, bool *skip_rstr, bool *apply_add_basis,
const CeedInt block_size, CeedElemRestriction *block_rstr, CeedVector *e_vecs_full, CeedVector *e_vecs,
CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q) {
Ceed ceed;
CeedSize e_size, q_size;
CeedInt num_comp, size, P;
Expand Down Expand Up @@ -161,6 +161,49 @@ static int CeedOperatorSetupFields_Opt(CeedQFunction qf, CeedOperator op, bool i
}
}
}
// Drop duplicate restrictions
if (is_input) {
for (CeedInt i = 0; i < num_fields; i++) {
CeedVector vec_i;
CeedElemRestriction rstr_i;

CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i));
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i));
for (CeedInt j = i + 1; j < num_fields; j++) {
CeedVector vec_j;
CeedElemRestriction rstr_j;

CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j));
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j));
if (vec_i == vec_j && rstr_i == rstr_j) {
CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j]));
CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i + start_e], &e_vecs_full[j + start_e]));
skip_rstr[j] = true;
}
}
}
} else {
for (CeedInt i = num_fields - 1; i >= 0; i--) {
CeedVector vec_i;
CeedElemRestriction rstr_i;

CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i));
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i));
for (CeedInt j = i - 1; j >= 0; j--) {
CeedVector vec_j;
CeedElemRestriction rstr_j;

CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j));
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j));
if (vec_i == vec_j && rstr_i == rstr_j) {
CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j]));
CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i + start_e], &e_vecs_full[j + start_e]));
skip_rstr[j] = true;
apply_add_basis[i] = true;
}
}
}
}
return CEED_ERROR_SUCCESS;
}

Expand Down Expand Up @@ -194,6 +237,9 @@ static int CeedOperatorSetup_Opt(CeedOperator op) {
CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->block_rstr));
CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full));

CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_out));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->apply_add_basis_out));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out));
Expand All @@ -205,11 +251,11 @@ static int CeedOperatorSetup_Opt(CeedOperator op) {

// Set up infield and outfield pointer arrays
// Infields
CeedCallBackend(CeedOperatorSetupFields_Opt(qf, op, true, block_size, impl->block_rstr, impl->e_vecs_full, impl->e_vecs_in, impl->q_vecs_in, 0,
num_input_fields, Q));
CeedCallBackend(CeedOperatorSetupFields_Opt(qf, op, true, impl->skip_rstr_in, NULL, block_size, impl->block_rstr, impl->e_vecs_full,
impl->e_vecs_in, impl->q_vecs_in, 0, num_input_fields, Q));
// Outfields
CeedCallBackend(CeedOperatorSetupFields_Opt(qf, op, false, block_size, impl->block_rstr, impl->e_vecs_full, impl->e_vecs_out, impl->q_vecs_out,
num_input_fields, num_output_fields, Q));
CeedCallBackend(CeedOperatorSetupFields_Opt(qf, op, false, impl->skip_rstr_out, impl->apply_add_basis_out, block_size, impl->block_rstr,
impl->e_vecs_full, impl->e_vecs_out, impl->q_vecs_out, num_input_fields, num_output_fields, Q));

// Identity QFunctions
if (impl->is_identity_qf) {
Expand Down Expand Up @@ -251,7 +297,7 @@ static inline int CeedOperatorSetupInputs_Opt(CeedInt num_input_fields, CeedQFun
if (vec != CEED_VECTOR_ACTIVE) {
// Restrict
CeedCallBackend(CeedVectorGetState(vec, &state));
if (state != impl->input_states[i] && impl->block_rstr[i]) {
if (state != impl->input_states[i] && impl->block_rstr[i] && !impl->skip_rstr_in[i]) {
CeedCallBackend(CeedElemRestrictionApply(impl->block_rstr[i], CEED_NOTRANSPOSE, vec, impl->e_vecs_full[i], request));
}
impl->input_states[i] = state;
Expand Down Expand Up @@ -327,8 +373,8 @@ static inline int CeedOperatorInputBasis_Opt(CeedInt e, CeedInt Q, CeedQFunction
// Output Basis Action
//------------------------------------------------------------------------------
static inline int CeedOperatorOutputBasis_Opt(CeedInt e, CeedInt Q, CeedQFunctionField *qf_output_fields, CeedOperatorField *op_output_fields,
CeedInt block_size, CeedInt num_input_fields, CeedInt num_output_fields, CeedOperator op,
CeedVector out_vec, CeedOperator_Opt *impl, CeedRequest *request) {
CeedInt block_size, CeedInt num_input_fields, CeedInt num_output_fields, bool *apply_add_basis,
bool *skip_rstr, CeedOperator op, CeedVector out_vec, CeedOperator_Opt *impl, CeedRequest *request) {
for (CeedInt i = 0; i < num_output_fields; i++) {
CeedEvalMode eval_mode;
CeedVector vec;
Expand All @@ -347,7 +393,11 @@ static inline int CeedOperatorOutputBasis_Opt(CeedInt e, CeedInt Q, CeedQFunctio
case CEED_EVAL_DIV:
case CEED_EVAL_CURL:
CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
CeedCallBackend(CeedBasisApply(basis, block_size, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i]));
if (apply_add_basis[i]) {
CeedCallBackend(CeedBasisApplyAdd(basis, block_size, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i]));
} else {
CeedCallBackend(CeedBasisApply(basis, block_size, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i]));
}
break;
// LCOV_EXCL_START
case CEED_EVAL_WEIGHT: {
Expand All @@ -356,6 +406,7 @@ static inline int CeedOperatorOutputBasis_Opt(CeedInt e, CeedInt Q, CeedQFunctio
}
}
// Restrict output block
if (skip_rstr[i]) continue;
// Get output vector
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
if (vec == CEED_VECTOR_ACTIVE) vec = out_vec;
Expand Down Expand Up @@ -448,8 +499,8 @@ static int CeedOperatorApplyAdd_Opt(CeedOperator op, CeedVector in_vec, CeedVect
}

// Output basis apply and restriction
CeedCallBackend(CeedOperatorOutputBasis_Opt(e, Q, qf_output_fields, op_output_fields, block_size, num_input_fields, num_output_fields, op,
out_vec, impl, request));
CeedCallBackend(CeedOperatorOutputBasis_Opt(e, Q, qf_output_fields, op_output_fields, block_size, num_input_fields, num_output_fields,
impl->apply_add_basis_out, impl->skip_rstr_out, op, out_vec, impl, request));
}

// Restore input arrays
Expand Down Expand Up @@ -694,6 +745,9 @@ static int CeedOperatorDestroy_Opt(CeedOperator op) {
CeedCallBackend(CeedFree(&impl->block_rstr));
CeedCallBackend(CeedFree(&impl->e_vecs_full));
CeedCallBackend(CeedFree(&impl->input_states));
CeedCallBackend(CeedFree(&impl->skip_rstr_in));
CeedCallBackend(CeedFree(&impl->skip_rstr_out));
CeedCallBackend(CeedFree(&impl->apply_add_basis_out));

for (CeedInt i = 0; i < impl->num_inputs; i++) {
CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_in[i]));
Expand Down
1 change: 1 addition & 0 deletions backends/opt/ceed-opt.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ typedef struct {

typedef struct {
bool is_identity_qf, is_identity_rstr_op;
bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out;
CeedElemRestriction *block_rstr; /* Blocked versions of restrictions */
CeedVector *e_vecs_full; /* Full E-vectors, inputs followed by outputs */
uint64_t *input_states; /* State counter of inputs */
Expand Down
Loading

0 comments on commit a3f234e

Please sign in to comment.