Skip to content

Commit

Permalink
wip - delete after OperatorFieldGet*
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremylt committed Aug 20, 2024
1 parent 49a690b commit 02da2a3
Show file tree
Hide file tree
Showing 9 changed files with 297 additions and 104 deletions.
56 changes: 43 additions & 13 deletions backends/blocked/ceed-blocked-operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bo
// Empty case - won't occur
break;
}
CeedCallBackend(CeedElemRestrictionDestroy(&rstr));
CeedCallBackend(CeedElemRestrictionCreateVector(block_rstr[i + start_e], NULL, &e_vecs_full[i + start_e]));
}

Expand All @@ -122,6 +123,7 @@ static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bo
CeedCallBackend(CeedQFunctionFieldGetSize(qf_fields[i], &size));
CeedCallBackend(CeedBasisGetNumNodes(basis, &P));
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
CeedCallBackend(CeedBasisDestroy(&basis));
e_size = (CeedSize)P * num_comp * block_size;
CeedCallBackend(CeedVectorCreate(ceed, e_size, &e_vecs[i]));
q_size = (CeedSize)Q * size * block_size;
Expand All @@ -132,6 +134,7 @@ static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bo
q_size = (CeedSize)Q * block_size;
CeedCallBackend(CeedVectorCreate(ceed, q_size, &q_vecs[i]));
CeedCallBackend(CeedBasisApply(basis, block_size, CEED_NOTRANSPOSE, CEED_EVAL_WEIGHT, CEED_VECTOR_NONE, q_vecs[i]));
CeedCallBackend(CeedBasisDestroy(&basis));
break;
}
}
Expand All @@ -154,7 +157,11 @@ static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bo
CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i + start_e], &e_vecs_full[j + start_e]));
skip_rstr[j] = true;
}
CeedCallBackend(CeedVectorDestroy(&vec_j));
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_j));
}
CeedCallBackend(CeedVectorDestroy(&vec_i));
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i));
}
} else {
for (CeedInt i = num_fields - 1; i >= 0; i--) {
Expand All @@ -176,7 +183,11 @@ static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bo
apply_add_basis[i] = true;
e_data_out_indices[j] = i;
}
CeedCallBackend(CeedVectorDestroy(&vec_j));
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_j));
}
CeedCallBackend(CeedVectorDestroy(&vec_i));
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_i));
}
}
return CEED_ERROR_SUCCESS;
Expand Down Expand Up @@ -259,13 +270,15 @@ static inline int CeedOperatorSetupInputs_Blocked(CeedInt num_input_fields, Ceed
CeedVector in_vec, bool skip_active, CeedScalar *e_data_full[2 * CEED_FIELD_MAX],
CeedOperator_Blocked *impl, CeedRequest *request) {
for (CeedInt i = 0; i < num_input_fields; i++) {
bool is_active;
uint64_t state;
CeedEvalMode eval_mode;
CeedVector vec;

// Get input vector
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
if (vec == CEED_VECTOR_ACTIVE) {
is_active = vec == CEED_VECTOR_ACTIVE;
if (is_active) {
if (skip_active) continue;
else vec = in_vec;
}
Expand All @@ -282,6 +295,7 @@ static inline int CeedOperatorSetupInputs_Blocked(CeedInt num_input_fields, Ceed
// Get evec
CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs_full[i], CEED_MEM_HOST, (const CeedScalar **)&e_data_full[i]));
}
if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec));
}
return CEED_ERROR_SUCCESS;
}
Expand All @@ -300,15 +314,19 @@ static inline int CeedOperatorInputBasis_Blocked(CeedInt e, CeedInt Q, CeedQFunc

// Skip active input
if (skip_active) {
bool is_active;
CeedVector vec;

CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
if (vec == CEED_VECTOR_ACTIVE) continue;
is_active = vec == CEED_VECTOR_ACTIVE;
CeedCallBackend(CeedVectorDestroy(&vec));
if (is_active) continue;
}

// Get elem_size, eval_mode, size
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size));
// Basis action
Expand All @@ -324,6 +342,7 @@ static inline int CeedOperatorInputBasis_Blocked(CeedInt e, CeedInt Q, CeedQFunc
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
CeedCallBackend(CeedVectorSetArray(impl->e_vecs_in[i], CEED_MEM_HOST, CEED_USE_POINTER, &e_data_full[i][(CeedSize)e * elem_size * num_comp]));
CeedCallBackend(CeedBasisApply(basis, block_size, CEED_NOTRANSPOSE, eval_mode, impl->e_vecs_in[i], impl->q_vecs_in[i]));
CeedCallBackend(CeedBasisDestroy(&basis));
break;
case CEED_EVAL_WEIGHT:
break; // No action
Expand All @@ -347,6 +366,7 @@ static inline int CeedOperatorOutputBasis_Blocked(CeedInt e, CeedInt Q, CeedQFun
// Get elem_size, eval_mode, size
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
// Basis action
switch (eval_mode) {
Expand All @@ -365,6 +385,7 @@ static inline int CeedOperatorOutputBasis_Blocked(CeedInt e, CeedInt Q, CeedQFun
} else {
CeedCallBackend(CeedBasisApply(basis, block_size, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i]));
}
CeedCallBackend(CeedBasisDestroy(&basis));
break;
// LCOV_EXCL_START
case CEED_EVAL_WEIGHT: {
Expand All @@ -386,10 +407,13 @@ static inline int CeedOperatorRestoreInputs_Blocked(CeedInt num_input_fields, Ce

// Skip active inputs
if (skip_active) {
bool is_active;
CeedVector vec;

CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
if (vec == CEED_VECTOR_ACTIVE) continue;
is_active = vec == CEED_VECTOR_ACTIVE;
CeedCallBackend(CeedVectorDestroy(&vec));
if (is_active) continue;
}
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
Expand Down Expand Up @@ -470,18 +494,21 @@ static int CeedOperatorApplyAdd_Blocked(CeedOperator op, CeedVector in_vec, Ceed

// Output restriction
for (CeedInt i = 0; i < num_output_fields; i++) {
bool is_active;
CeedVector vec;

if (impl->skip_rstr_out[i]) continue;
// Restore evec
CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_full[i + impl->num_inputs], &e_data_full[i + num_input_fields]));
// Get output vector
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
is_active = vec == CEED_VECTOR_ACTIVE;
// Active
if (vec == CEED_VECTOR_ACTIVE) vec = out_vec;
if (is_active) vec = out_vec;
// Restrict
CeedCallBackend(
CeedElemRestrictionApply(impl->block_rstr[i + impl->num_inputs], CEED_TRANSPOSE, impl->e_vecs_full[i + impl->num_inputs], vec, request));
if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec));
}

// Restore input arrays
Expand Down Expand Up @@ -533,14 +560,14 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Blocked(CeedOperator o
CeedInt field_size;
CeedVector vec;

// Get input vector
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
// Check if active input
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
if (vec == CEED_VECTOR_ACTIVE) {
CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &field_size));
CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0));
qf_size_in += field_size;
}
CeedCallBackend(CeedVectorDestroy(&vec));
}
CeedCheck(qf_size_in > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs");
impl->qf_size_in = qf_size_in;
Expand All @@ -552,13 +579,13 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Blocked(CeedOperator o
CeedInt field_size;
CeedVector vec;

// Get output vector
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
// Check if active output
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
if (vec == CEED_VECTOR_ACTIVE) {
CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &field_size));
qf_size_out += field_size;
}
CeedCallBackend(CeedVectorDestroy(&vec));
}
CeedCheck(qf_size_out > 0, ceed, CEED_ERROR_BACKEND, "Cannot assemble QFunction without active inputs and outputs");
impl->qf_size_out = qf_size_out;
Expand Down Expand Up @@ -601,13 +628,15 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Blocked(CeedOperator o

// Assemble QFunction
for (CeedInt i = 0; i < num_input_fields; i++) {
bool is_active;
CeedInt field_size;
CeedVector vec;

// Get input vector
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
// Check if active input
if (vec != CEED_VECTOR_ACTIVE) continue;
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
is_active = vec == CEED_VECTOR_ACTIVE;
CeedCallBackend(CeedVectorDestroy(&vec));
if (!is_active) continue;
CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &field_size));
for (CeedInt field = 0; field < field_size; field++) {
// Set current portion of input to 1.0
Expand All @@ -633,6 +662,7 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Blocked(CeedOperator o
CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[out], &field_size));
l_vec_array += field_size * Q * block_size; // Advance the pointer by the size of the output
}
CeedCallBackend(CeedVectorDestroy(&vec));
}
// Apply QFunction
CeedCallBackend(CeedQFunctionApply(qf, Q * block_size, impl->q_vecs_in, impl->q_vecs_out));
Expand Down Expand Up @@ -664,12 +694,12 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Blocked(CeedOperator o
for (CeedInt out = 0; out < num_output_fields; out++) {
CeedVector vec;

// Get output vector
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec));
// Check if active output
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[out], &vec));
if (vec == CEED_VECTOR_ACTIVE) {
CeedCallBackend(CeedVectorTakeArray(impl->q_vecs_out[out], CEED_MEM_HOST, NULL));
}
CeedCallBackend(CeedVectorDestroy(&vec));
}
}

Expand Down
Loading

0 comments on commit 02da2a3

Please sign in to comment.