Skip to content

Commit

Permalink
hip - skip duplicate output rstr
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremylt committed Aug 13, 2024
1 parent d8d280d commit 79df7ad
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 14 deletions.
63 changes: 50 additions & 13 deletions backends/hip-ref/ceed-hip-ref-operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ static int CeedOperatorDestroy_Hip(CeedOperator op) {

// Apply data
CeedCallBackend(CeedFree(&impl->skip_rstr_in));
CeedCallBackend(CeedFree(&impl->skip_rstr_out));
CeedCallBackend(CeedFree(&impl->apply_add_basis_out));
for (CeedInt i = 0; i < impl->num_inputs + impl->num_outputs; i++) {
CeedCallBackend(CeedVectorDestroy(&impl->e_vecs[i]));
}
Expand Down Expand Up @@ -96,8 +98,8 @@ static int CeedOperatorDestroy_Hip(CeedOperator op) {
//------------------------------------------------------------------------------
// Setup infields or outfields
//------------------------------------------------------------------------------
static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool is_input, bool is_at_points, bool *skip_rstr, CeedVector *e_vecs,
CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q, CeedInt num_elem) {
static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool is_input, bool is_at_points, bool *skip_rstr, bool *apply_add_basis,
CeedVector *e_vecs, CeedVector *q_vecs, CeedInt start_e, CeedInt num_fields, CeedInt Q, CeedInt num_elem) {
Ceed ceed;
CeedQFunctionField *qf_fields;
CeedOperatorField *op_fields;
Expand Down Expand Up @@ -183,7 +185,7 @@ static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool i
break;
}
}
// Drop duplicate input restrictions
// Drop duplicate restrictions
if (is_input) {
for (CeedInt i = 0; i < num_fields; i++) {
CeedVector vec_i;
Expand All @@ -198,11 +200,31 @@ static int CeedOperatorSetupFields_Hip(CeedQFunction qf, CeedOperator op, bool i
CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j));
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j));
if (vec_i == vec_j && rstr_i == rstr_j) {
CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i], &e_vecs[j]));
CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i + start_e], &e_vecs[j + start_e]));
skip_rstr[j] = true;
}
}
}
} else {
for (CeedInt i = num_fields - 1; i >= 0; i--) {
CeedVector vec_i;
CeedElemRestriction rstr_i;

CeedCallBackend(CeedOperatorFieldGetVector(op_fields[i], &vec_i));
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr_i));
for (CeedInt j = i - 1; j >= 0; j--) {
CeedVector vec_j;
CeedElemRestriction rstr_j;

CeedCallBackend(CeedOperatorFieldGetVector(op_fields[j], &vec_j));
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[j], &rstr_j));
if (vec_i == vec_j && rstr_i == rstr_j) {
CeedCallBackend(CeedVectorReferenceCopy(e_vecs[i + start_e], &e_vecs[j + start_e]));
skip_rstr[j] = true;
apply_add_basis[i] = true;
}
}
}
}
return CEED_ERROR_SUCCESS;
}
Expand Down Expand Up @@ -233,6 +255,8 @@ static int CeedOperatorSetup_Hip(CeedOperator op) {
// Allocate
CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_out));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->apply_add_basis_out));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out));
Expand All @@ -242,10 +266,10 @@ static int CeedOperatorSetup_Hip(CeedOperator op) {
// Set up infield and outfield e_vecs and q_vecs
// Infields
CeedCallBackend(
CeedOperatorSetupFields_Hip(qf, op, true, false, impl->skip_rstr_in, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, Q, num_elem));
CeedOperatorSetupFields_Hip(qf, op, true, false, impl->skip_rstr_in, NULL, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields, Q, num_elem));
// Outfields
CeedCallBackend(
CeedOperatorSetupFields_Hip(qf, op, false, false, NULL, impl->e_vecs, impl->q_vecs_out, num_input_fields, num_output_fields, Q, num_elem));
CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, false, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs, impl->q_vecs_out,
num_input_fields, num_output_fields, Q, num_elem));

CeedCallBackend(CeedOperatorSetSetupDone(op));
return CEED_ERROR_SUCCESS;
Expand Down Expand Up @@ -430,7 +454,11 @@ static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector in_vec, CeedVect
case CEED_EVAL_DIV:
case CEED_EVAL_CURL:
CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs]));
if (impl->apply_add_basis_out[i]) {
CeedCallBackend(CeedBasisApplyAdd(basis, num_elem, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs]));
} else {
CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs]));
}
break;
// LCOV_EXCL_START
case CEED_EVAL_WEIGHT: {
Expand All @@ -451,6 +479,7 @@ static int CeedOperatorApplyAdd_Hip(CeedOperator op, CeedVector in_vec, CeedVect
if (eval_mode == CEED_EVAL_NONE) {
CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields]));
}
if (impl->skip_rstr_out[i]) continue;
// Get output vector
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
// Restrict
Expand Down Expand Up @@ -498,6 +527,8 @@ static int CeedOperatorSetupAtPoints_Hip(CeedOperator op) {
// Allocate
CeedCallBackend(CeedCalloc(num_input_fields + num_output_fields, &impl->e_vecs));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_in));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->skip_rstr_out));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->apply_add_basis_out));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->input_states));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_in));
CeedCallBackend(CeedCalloc(CEED_FIELD_MAX, &impl->q_vecs_out));
Expand All @@ -506,11 +537,11 @@ static int CeedOperatorSetupAtPoints_Hip(CeedOperator op) {

// Set up infield and outfield e_vecs and q_vecs
// Infields
CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, true, true, impl->skip_rstr_in, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields,
CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, true, true, impl->skip_rstr_in, NULL, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields,
max_num_points, num_elem));
// Outfields
CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, true, NULL, impl->e_vecs, impl->q_vecs_out, num_input_fields, num_output_fields,
max_num_points, num_elem));
CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, true, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs, impl->q_vecs_out,
num_input_fields, num_output_fields, max_num_points, num_elem));

CeedCallBackend(CeedOperatorSetSetupDone(op));
return CEED_ERROR_SUCCESS;
Expand Down Expand Up @@ -634,8 +665,13 @@ static int CeedOperatorApplyAddAtPoints_Hip(CeedOperator op, CeedVector in_vec,
case CEED_EVAL_DIV:
case CEED_EVAL_CURL:
CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, impl->q_vecs_out[i],
impl->e_vecs[i + impl->num_inputs]));
if (impl->apply_add_basis_out[i]) {
CeedCallBackend(CeedBasisApplyAddAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem,
impl->q_vecs_out[i], impl->e_vecs[i + impl->num_inputs]));
} else {
CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, impl->q_vecs_out[i],
impl->e_vecs[i + impl->num_inputs]));
}
break;
// LCOV_EXCL_START
case CEED_EVAL_WEIGHT: {
Expand All @@ -656,6 +692,7 @@ static int CeedOperatorApplyAddAtPoints_Hip(CeedOperator op, CeedVector in_vec,
if (eval_mode == CEED_EVAL_NONE) {
CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs[i + impl->num_inputs], &e_data[i + num_input_fields]));
}
if (impl->skip_rstr_out[i]) continue;
// Get output vector
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
// Restrict
Expand Down
2 changes: 1 addition & 1 deletion backends/hip-ref/ceed-hip-ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ typedef struct {
} CeedOperatorAssemble_Hip;

typedef struct {
bool *skip_rstr_in;
bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out;
uint64_t *input_states; // State tracking for passive inputs
CeedVector *e_vecs; // E-vectors, inputs followed by outputs
CeedVector *q_vecs_in; // Input Q-vectors needed to apply operator
Expand Down

0 comments on commit 79df7ad

Please sign in to comment.