Skip to content

Commit

Permalink
Skip duplicate transpose restrictions (#1645)
Browse files Browse the repository at this point in the history
* cpu - skip duplicate output rstr

* cuda - skip duplicate output rstr

* hip - skip duplicate output rstr
  • Loading branch information
jeremylt authored Aug 21, 2024
1 parent 4b3e95d commit f8a0df5
Show file tree
Hide file tree
Showing 10 changed files with 306 additions and 88 deletions.
70 changes: 54 additions & 16 deletions backends/blocked/ceed-blocked-operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
//------------------------------------------------------------------------------
// Setup Input/Output Fields
//------------------------------------------------------------------------------
static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bool is_input, bool *skip_rstr, const CeedInt block_size,
CeedElemRestriction *block_rstr, CeedVector *e_vecs_full, CeedVector *e_vecs, CeedVector *q_vecs,
CeedInt start_e, CeedInt num_fields, CeedInt Q) {
static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bool is_input, bool *skip_rstr, CeedInt *e_data_out_indices,
bool *apply_add_basis, const CeedInt block_size, CeedElemRestriction *block_rstr, CeedVector *e_vecs_full,
CeedVector *e_vecs, CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q) {
Ceed ceed;
CeedSize e_size, q_size;
CeedInt num_comp, size, P;
Expand Down Expand Up @@ -135,7 +135,7 @@ static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bo
break;
}
}
// Drop duplicate input restrictions
// Drop duplicate restrictions
if (is_input) {
for (CeedInt i = 0; i < num_fields; i++) {
CeedVector vec_i;
Expand All @@ -151,11 +151,33 @@ static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op, bo
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j));
if (vec_i == vec_j && rstr_i == rstr_j) {
CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j]));
CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i], &e_vecs_full[j]));
CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i + start_e], &e_vecs_full[j + start_e]));
skip_rstr[j] = true;
}
}
}
} else {
for (CeedInt i = num_fields - 1; i >= 0; i--) {
CeedVector vec_i;
CeedElemRestriction rstr_i;

CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i));
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i));
for (CeedInt j = i - 1; j >= 0; j--) {
CeedVector vec_j;
CeedElemRestriction rstr_j;

CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j));
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j));
if (vec_i == vec_j && rstr_i == rstr_j) {
CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j]));
CeedCallBackend(CeedVectorReferenceCopy(e_vecs_full[i + start_e], &e_vecs_full[j + start_e]));
skip_rstr[j] = true;
apply_add_basis[i] = true;
e_data_out_indices[j] = i;
}
}
}
}
return CEED_ERROR_SUCCESS;
}
Expand Down Expand Up @@ -189,6 +211,9 @@ static int CeedOperatorSetup_Blocked(CeedOperator op) {
CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs_full));

CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_out));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_data_out_indices));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->apply_add_basis_out));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_in));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->e_vecs_out));
Expand All @@ -200,11 +225,12 @@ static int CeedOperatorSetup_Blocked(CeedOperator op) {

// Set up infield and outfield pointer arrays
// Infields
CeedCallBackend(CeedOperatorSetupFields_Blocked(qf, op, true, impl->skip_rstr_in, block_size, impl->block_rstr, impl->e_vecs_full, impl->e_vecs_in,
impl->q_vecs_in, 0, num_input_fields, Q));
CeedCallBackend(CeedOperatorSetupFields_Blocked(qf, op, true, impl->skip_rstr_in, NULL, NULL, block_size, impl->block_rstr, impl->e_vecs_full,
impl->e_vecs_in, impl->q_vecs_in, 0, num_input_fields, Q));
// Outfields
CeedCallBackend(CeedOperatorSetupFields_Blocked(qf, op, false, NULL, block_size, impl->block_rstr, impl->e_vecs_full, impl->e_vecs_out,
impl->q_vecs_out, num_input_fields, num_output_fields, Q));
CeedCallBackend(CeedOperatorSetupFields_Blocked(qf, op, false, impl->skip_rstr_out, impl->e_data_out_indices, impl->apply_add_basis_out, block_size,
impl->block_rstr, impl->e_vecs_full, impl->e_vecs_out, impl->q_vecs_out, num_input_fields,
num_output_fields, Q));

// Identity QFunctions
if (impl->is_identity_qf) {
Expand Down Expand Up @@ -310,8 +336,8 @@ static inline int CeedOperatorInputBasis_Blocked(CeedInt e, CeedInt Q, CeedQFunc
// Output Basis Action
//------------------------------------------------------------------------------
static inline int CeedOperatorOutputBasis_Blocked(CeedInt e, CeedInt Q, CeedQFunctionField *qf_output_fields, CeedOperatorField *op_output_fields,
CeedInt block_size, CeedInt num_input_fields, CeedInt num_output_fields, CeedOperator op,
CeedScalar *e_data_full[2 * CEED_FIELD_MAX], CeedOperator_Blocked *impl) {
CeedInt block_size, CeedInt num_input_fields, CeedInt num_output_fields, bool *apply_add_basis,
CeedOperator op, CeedScalar *e_data_full[2 * CEED_FIELD_MAX], CeedOperator_Blocked *impl) {
for (CeedInt i = 0; i < num_output_fields; i++) {
CeedInt elem_size, num_comp;
CeedEvalMode eval_mode;
Expand All @@ -334,7 +360,11 @@ static inline int CeedOperatorOutputBasis_Blocked(CeedInt e, CeedInt Q, CeedQFun
CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
CeedCallBackend(CeedVectorSetArray(impl->e_vecs_out[i], CEED_MEM_HOST, CEED_USE_POINTER,
&e_data_full[i + num_input_fields][(CeedSize)e * elem_size * num_comp]));
CeedCallBackend(CeedBasisApply(basis, block_size, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i]));
if (apply_add_basis[i]) {
CeedCallBackend(CeedBasisApplyAdd(basis, block_size, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i]));
} else {
CeedCallBackend(CeedBasisApply(basis, block_size, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs_out[i]));
}
break;
// LCOV_EXCL_START
case CEED_EVAL_WEIGHT: {
Expand Down Expand Up @@ -405,8 +435,12 @@ static int CeedOperatorApplyAdd_Blocked(CeedOperator op, CeedVector in_vec, Ceed
CeedCallBackend(CeedOperatorSetupInputs_Blocked(num_input_fields, qf_input_fields, op_input_fields, in_vec, false, e_data_full, impl, request));

// Output Evecs
for (CeedInt i = 0; i < num_output_fields; i++) {
CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_full[i + impl->num_inputs], CEED_MEM_HOST, &e_data_full[i + num_input_fields]));
for (CeedInt i = num_output_fields - 1; i >= 0; i--) {
if (impl->skip_rstr_out[i]) {
e_data_full[i + num_input_fields] = e_data_full[impl->e_data_out_indices[i] + num_input_fields];
} else {
CeedCallBackend(CeedVectorGetArrayWrite(impl->e_vecs_full[i + impl->num_inputs], CEED_MEM_HOST, &e_data_full[i + num_input_fields]));
}
}

// Loop through elements
Expand All @@ -430,14 +464,15 @@ static int CeedOperatorApplyAdd_Blocked(CeedOperator op, CeedVector in_vec, Ceed
}

// Output basis apply
CeedCallBackend(CeedOperatorOutputBasis_Blocked(e, Q, qf_output_fields, op_output_fields, block_size, num_input_fields, num_output_fields, op,
e_data_full, impl));
CeedCallBackend(CeedOperatorOutputBasis_Blocked(e, Q, qf_output_fields, op_output_fields, block_size, num_input_fields, num_output_fields,
impl->apply_add_basis_out, op, e_data_full, impl));
}

// Output restriction
for (CeedInt i = 0; i < num_output_fields; i++) {
CeedVector vec;

if (impl->skip_rstr_out[i]) continue;
// Restore evec
CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_full[i + impl->num_inputs], &e_data_full[i + num_input_fields]));
// Get output vector
Expand Down Expand Up @@ -671,6 +706,9 @@ static int CeedOperatorDestroy_Blocked(CeedOperator op) {
CeedCallBackend(CeedOperatorGetData(op, &impl));

CeedCallBackend(CeedFree(&impl->skip_rstr_in));
CeedCallBackend(CeedFree(&impl->skip_rstr_out));
CeedCallBackend(CeedFree(&impl->e_data_out_indices));
CeedCallBackend(CeedFree(&impl->apply_add_basis_out));
for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) {
CeedCallBackend(CeedElemRestrictionDestroy(&impl->block_rstr[i]));
CeedCallBackend(CeedVectorDestroy(&impl->e_vecs_full[i]));
Expand Down
3 changes: 2 additions & 1 deletion backends/blocked/ceed-blocked.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ typedef struct {

typedef struct {
bool is_identity_qf, is_identity_rstr_op;
bool *skip_rstr_in;
bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out;
CeedInt *e_data_out_indices;
uint64_t *input_states; /* State counter of inputs */
CeedVector *e_vecs_full; /* Full E-vectors, inputs followed by outputs */
CeedVector *e_vecs_in; /* Element block input E-vectors */
Expand Down
63 changes: 50 additions & 13 deletions backends/cuda-ref/ceed-cuda-ref-operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ static int CeedOperatorDestroy_Cuda(CeedOperator op) {

// Apply data
CeedCallBackend(CeedFree(&impl->skip_rstr_in));
CeedCallBackend(CeedFree(&impl->skip_rstr_out));
CeedCallBackend(CeedFree(&impl->apply_add_basis_out));
for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) {
CeedCallBackend(CeedVectorDestroy(&impl->e_vecs[i]));
}
Expand Down Expand Up @@ -97,8 +99,8 @@ static int CeedOperatorDestroy_Cuda(CeedOperator op) {
//------------------------------------------------------------------------------
// Setup infields or outfields
//------------------------------------------------------------------------------
static int CeedOperatorSetupFields_Cuda(CeedQFunction qf, CeedOperator op, bool is_input, bool is_at_points, bool *skip_rstr, CeedVector *e_vecs,
CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q, CeedInt num_elem) {
static int CeedOperatorSetupFields_Cuda(CeedQFunction qf, CeedOperator op, bool is_input, bool is_at_points, bool *skip_rstr, bool *apply_add_basis,
CeedVector *e_vecs, CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q, CeedInt num_elem) {
Ceed ceed;
CeedQFunctionField *qf_fields;
CeedOperatorField *op_fields;
Expand Down Expand Up @@ -184,7 +186,7 @@ static int CeedOperatorSetupFields_Cuda(CeedQFunction qf, CeedOperator op, bool
break;
}
}
// Drop duplicate input restrictions
// Drop duplicate restrictions
if (is_input) {
for (CeedInt i = 0; i < num_fields; i++) {
CeedVector vec_i;
Expand All @@ -199,11 +201,31 @@ static int CeedOperatorSetupFields_Cuda(CeedQFunction qf, CeedOperator op, bool
CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j));
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j));
if (vec_i == vec_j && rstr_i == rstr_j) {
CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j]));
CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i + start_e], &e_vecs[j + start_e]));
skip_rstr[j] = true;
}
}
}
} else {
for (CeedInt i = num_fields - 1; i >= 0; i--) {
CeedVector vec_i;
CeedElemRestriction rstr_i;

CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i));
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i));
for (CeedInt j = i - 1; j >= 0; j--) {
CeedVector vec_j;
CeedElemRestriction rstr_j;

CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j));
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j));
if (vec_i == vec_j && rstr_i == rstr_j) {
CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i + start_e], &e_vecs[j + start_e]));
skip_rstr[j] = true;
apply_add_basis[i] = true;
}
}
}
}
return CEED_ERROR_SUCCESS;
}
Expand Down Expand Up @@ -234,6 +256,8 @@ static int CeedOperatorSetup_Cuda(CeedOperator op) {
// Allocate
CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_out));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->apply_add_basis_out));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out));
Expand All @@ -243,10 +267,10 @@ static int CeedOperatorSetup_Cuda(CeedOperator op) {
// Set up infield and outfield e_vecs and q_vecs
// Infields
CeedCallBackend(
CeedOperatorSetupFields_Cuda(qf, op, true, false, impl->skip_rstr_in, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, Q, num_elem));
CeedOperatorSetupFields_Cuda(qf, op, true, false, impl->skip_rstr_in, NULL, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, Q, num_elem));
// Outfields
CeedCallBackend(
CeedOperatorSetupFields_Cuda(qf, op, false, false, NULL, impl->e_vecs, impl->q_vecs_out, num_input_fields, num_output_fields, Q, num_elem));
CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, false, false, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs, impl->q_vecs_out,
num_input_fields, num_output_fields, Q, num_elem));

CeedCallBackend(CeedOperatorSetSetupDone(op));
return CEED_ERROR_SUCCESS;
Expand Down Expand Up @@ -431,7 +455,11 @@ static int CeedOperatorApplyAdd_Cuda(CeedOperator op, CeedVector in_vec, CeedVec
case CEED_EVAL_DIV:
case CEED_EVAL_CURL:
CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs]));
if (impl->apply_add_basis_out[i]) {
CeedCallBackend(CeedBasisApplyAdd(basis, num_elem, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs]));
} else {
CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs]));
}
break;
// LCOV_EXCL_START
case CEED_EVAL_WEIGHT: {
Expand All @@ -452,6 +480,7 @@ static int CeedOperatorApplyAdd_Cuda(CeedOperator op, CeedVector in_vec, CeedVec
if (eval_mode == CEED_EVAL_NONE) {
CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields]));
}
if (impl->skip_rstr_out[i]) continue;
// Get output vector
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
// Restrict
Expand Down Expand Up @@ -499,6 +528,8 @@ static int CeedOperatorSetupAtPoints_Cuda(CeedOperator op) {
// Allocate
CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_out));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->apply_add_basis_out));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out));
Expand All @@ -507,11 +538,11 @@ static int CeedOperatorSetupAtPoints_Cuda(CeedOperator op) {

// Set up infield and outfield e_vecs and q_vecs
// Infields
CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, true, true, impl->skip_rstr_in, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields,
CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, true, true, impl->skip_rstr_in, NULL, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields,
max_num_points, num_elem));
// Outfields
CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, false, true, NULL, impl->e_vecs, impl->q_vecs_out, num_input_fields, num_output_fields,
max_num_points, num_elem));
CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, false, true, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs, impl->q_vecs_out,
num_input_fields, num_output_fields, max_num_points, num_elem));

CeedCallBackend(CeedOperatorSetSetupDone(op));
return CEED_ERROR_SUCCESS;
Expand Down Expand Up @@ -635,8 +666,13 @@ static int CeedOperatorApplyAddAtPoints_Cuda(CeedOperator op, CeedVector in_vec,
case CEED_EVAL_DIV:
case CEED_EVAL_CURL:
CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, impl->q_vecs_out[i],
impl->e_vecs[i + impl->num_inputs]));
if (impl->apply_add_basis_out[i]) {
CeedCallBackend(CeedBasisApplyAddAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem,
impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs]));
} else {
CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, impl->q_vecs_out[i],
impl->e_vecs[i + impl->num_inputs]));
}
break;
// LCOV_EXCL_START
case CEED_EVAL_WEIGHT: {
Expand All @@ -657,6 +693,7 @@ static int CeedOperatorApplyAddAtPoints_Cuda(CeedOperator op, CeedVector in_vec,
if (eval_mode == CEED_EVAL_NONE) {
CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields]));
}
if (impl->skip_rstr_out[i]) continue;
// Get output vector
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
// Restrict
Expand Down
2 changes: 1 addition & 1 deletion backends/cuda-ref/ceed-cuda-ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ typedef struct {
} CeedOperatorAssemble_Cuda;

typedef struct {
bool *skip_rstr_in;
bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out;
uint64_t *input_states; // State tracking for passive inputs
CeedVector *e_vecs; // E-vectors, inputs followed by outputs
CeedVector *q_vecs_in; // Input Q-vectors needed to apply operator
Expand Down
Loading

0 comments on commit f8a0df5

Please sign in to comment.