From d94da85e13e86f044488a6bde0fd16e316b50eaa Mon Sep 17 00:00:00 2001 From: Jeremy L Thompson Date: Tue, 17 Sep 2024 13:54:57 -0600 Subject: [PATCH] sycl - more consistency --- .../ceed-sycl-gen-operator-build.sycl.cpp | 8 ++++---- .../sycl-ref/ceed-sycl-ref-operator.sycl.cpp | 18 ++++++++++++------ 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/backends/sycl-gen/ceed-sycl-gen-operator-build.sycl.cpp b/backends/sycl-gen/ceed-sycl-gen-operator-build.sycl.cpp index b14e155124..ee7aab812c 100644 --- a/backends/sycl-gen/ceed-sycl-gen-operator-build.sycl.cpp +++ b/backends/sycl-gen/ceed-sycl-gen-operator-build.sycl.cpp @@ -274,9 +274,9 @@ extern "C" int CeedOperatorBuildKernel_Sycl_gen(CeedOperator op) { // Get elem_size, eval_mode, num_comp CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); - CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp)); CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); + CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); // Set field constants if (eval_mode != CEED_EVAL_WEIGHT) { @@ -334,9 +334,9 @@ extern "C" int CeedOperatorBuildKernel_Sycl_gen(CeedOperator op) { // Get elem_size, eval_mode, num_comp CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); - CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp)); CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr)); + CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); // Set field constants CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis)); @@ -401,8 +401,8 @@ extern "C" int CeedOperatorBuildKernel_Sycl_gen(CeedOperator op) { // Get elem_size, eval_mode, num_comp CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr)); CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); - CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp)); + CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); // Restriction if (eval_mode != CEED_EVAL_WEIGHT && !((eval_mode == CEED_EVAL_NONE) && use_collograd_parallelization)) { @@ -677,8 +677,8 @@ extern "C" int CeedOperatorBuildKernel_Sycl_gen(CeedOperator op) { // Get elem_size, eval_mode, num_comp CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr)); CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size)); - CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp)); + CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); // Basis action code << " // EvalMode: " << CeedEvalModes[eval_mode] << "\n"; switch (eval_mode) { diff --git a/backends/sycl-ref/ceed-sycl-ref-operator.sycl.cpp b/backends/sycl-ref/ceed-sycl-ref-operator.sycl.cpp index a784e1e53f..6074612f27 100644 --- a/backends/sycl-ref/ceed-sycl-ref-operator.sycl.cpp +++ b/backends/sycl-ref/ceed-sycl-ref-operator.sycl.cpp @@ -89,11 +89,13 @@ static int CeedOperatorDestroy_Sycl(CeedOperator op) { CeedCallSycl(ceed, sycl::free(impl->diag->d_interp_out, sycl_data->sycl_context)); CeedCallSycl(ceed, sycl::free(impl->diag->d_grad_in, sycl_data->sycl_context)); CeedCallSycl(ceed, sycl::free(impl->diag->d_grad_out, sycl_data->sycl_context)); - CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->diag_rstr)); - CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->point_block_diag_rstr)); CeedCallBackend(CeedVectorDestroy(&impl->diag->elem_diag)); CeedCallBackend(CeedVectorDestroy(&impl->diag->point_block_elem_diag)); + CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->diag_rstr)); + CeedCallBackend(CeedElemRestrictionDestroy(&impl->diag->point_block_diag_rstr)); + CeedCallBackend(CeedBasisDestroy(&impl->diag->basis_in)); + CeedCallBackend(CeedBasisDestroy(&impl->diag->basis_out)); } CeedCallBackend(CeedFree(&impl->diag)); @@ -745,8 +747,8 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) { CeedCallBackend(CeedCalloc(1, &impl->diag)); CeedOperatorDiag_Sycl *diag = impl->diag; - diag->basis_in = basis_in; - diag->basis_out = basis_out; + CeedCallBackend(CeedBasisReferenceCopy(basis_in, &diag->basis_in)); + CeedCallBackend(CeedBasisReferenceCopy(basis_out, &diag->basis_out)); diag->h_eval_mode_in = eval_mode_in; diag->h_eval_mode_out = eval_mode_out; diag->num_eval_mode_in = num_eval_mode_in; @@ -825,6 +827,10 @@ static inline int CeedOperatorAssembleDiagonalSetup_Sycl(CeedOperator op) { CeedCallBackend(CeedElemRestrictionDestroy(&rstr_out)); } + // Cleanup + CeedCallBackend(CeedBasisDestroy(&basis_in)); + CeedCallBackend(CeedBasisDestroy(&basis_out)); + // Wait for all copies to complete and handle exceptions CeedCallSycl(ceed, sycl::event::wait_and_throw(copy_events)); return CEED_ERROR_SUCCESS; @@ -1043,9 +1049,9 @@ static int CeedSingleOperatorAssembleSetup_Sycl(CeedOperator op) { CeedCallBackend(CeedOperatorFieldGetBasis(input_fields[i], &basis)); CeedCheck(!basis_in || basis_in == basis, ceed, CEED_ERROR_BACKEND, "Backend does not implement operator assembly with multiple active bases"); if (!basis_in) CeedCallBackend(CeedBasisReferenceCopy(basis, &basis_in)); - CeedCallBackend(CeedBasisGetDimension(basis, &dim)); - CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts)); CeedCallBackend(CeedBasisDestroy(&basis)); + CeedCallBackend(CeedBasisGetDimension(basis_in, &dim)); + CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts)); CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode)); if (eval_mode != CEED_EVAL_NONE) { CeedCallBackend(CeedRealloc(num_B_in_mats_to_load + 1, &eval_mode_in));