Skip to content

Commit

Permalink
sycl - more consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremylt committed Sep 17, 2024
1 parent 9a34422 commit d94da85
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
8 changes: 4 additions & 4 deletions backends/sycl-gen/ceed-sycl-gen-operator-build.sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,9 @@ extern "C" int CeedOperatorBuildKernel_Sycl_gen(CeedOperator op) {
// Get elem_size, eval_mode, num_comp
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));

// Set field constants
if (eval_mode != CEED_EVAL_WEIGHT) {
Expand Down Expand Up @@ -334,9 +334,9 @@ extern "C" int CeedOperatorBuildKernel_Sycl_gen(CeedOperator op) {
// Get elem_size, eval_mode, num_comp
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));

// Set field constants
CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
Expand Down Expand Up @@ -401,8 +401,8 @@ extern "C" int CeedOperatorBuildKernel_Sycl_gen(CeedOperator op) {
// Get elem_size, eval_mode, num_comp
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));

// Restriction
if (eval_mode != CEED_EVAL_WEIGHT && !((eval_mode == CEED_EVAL_NONE) && use_collograd_parallelization)) {
Expand Down Expand Up @@ -677,8 +677,8 @@ extern "C" int CeedOperatorBuildKernel_Sycl_gen(CeedOperator op) {
// Get elem_size, eval_mode, num_comp
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
// Basis action
code << " // EvalMode: " << CeedEvalModes[eval_mode] << "\n";
switch (eval_mode) {
Expand Down
18 changes: 12 additions & 6 deletions backends/sycl-ref/ceed-sycl-ref-operator.sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,13 @@ static int CeedOperatorDestroy_Sycl(CeedOperator op) {
CeedCallSycl(ceed, sycl::free(impl->diag->d_interp_out, sycl_data->sycl_context));
CeedCallSycl(ceed, sycl::free(impl->diag->d_grad_in, sycl_data->sycl_context));
CeedCallSycl(ceed, sycl::free(impl->diag->d_grad_out, sycl_data->sycl_context));
CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->diag_rstr));
CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->point_block_diag_rstr));

CeedCallBackend(CeedVectorDestroy(&impl->diag->elem_diag));
CeedCallBackend(CeedVectorDestroy(&impl->diag->point_block_elem_diag));
CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->diag_rstr));
CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->point_block_diag_rstr));
CeedCallBackend(CeedBasisDestroy(&impl->diag->basis_in));
CeedCallBackend(CeedBasisDestroy(&impl->diag->basis_out));
}
CeedCallBackend(CeedFree(&impl->diag));

Expand Down Expand Up @@ -745,8 +747,8 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) {
CeedCallBackend(CeedCalloc(1, &impl->diag));
CeedOperatorDiag_Sycl *diag = impl->diag;

diag->basis_in = basis_in;
diag->basis_out = basis_out;
CeedCallBackend(CeedBasisReferenceCopy(basis_in, &diag->basis_in));
CeedCallBackend(CeedBasisReferenceCopy(basis_out, &diag->basis_out));
diag->h_eval_mode_in = eval_mode_in;
diag->h_eval_mode_out = eval_mode_out;
diag->num_eval_mode_in = num_eval_mode_in;
Expand Down Expand Up @@ -825,6 +827,10 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) {
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_out));
}

// Cleanup
CeedCallBackend(CeedBasisDestroy(&basis_in));
CeedCallBackend(CeedBasisDestroy(&basis_out));

// Wait for all copies to complete and handle exceptions
CeedCallSycl(ceed, sycl::event::wait_and_throw(copy_events));
return CEED_ERROR_SUCCESS;
Expand Down Expand Up @@ -1043,9 +1049,9 @@ static int CeedSingleOperatorAssembleSetup_Sycl(CeedOperator op) {
CeedCallBackend(CeedOperatorFieldGetBasis(input_fields[i], &basis));
CeedCheck(!basis_in || basis_in == basis, ceed, CEED_ERROR_BACKEND, "Backend does not implement operator assembly with multiple active bases");
if (!basis_in) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_in));
CeedCallBackend(CeedBasisGetDimension(basis, &dim));
CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts));
CeedCallBackend(CeedBasisDestroy(&basis));
CeedCallBackend(CeedBasisGetDimension(basis_in, &dim));
CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));
if (eval_mode != CEED_EVAL_NONE) {
CeedCallBackend(CeedRealloc(num_B_in_mats_to_load + 1, &eval_mode_in));
Expand Down

0 comments on commit d94da85

Please sign in to comment.