From 80db90e526c73c71398a0d018adb94c06edd8858 Mon Sep 17 00:00:00 2001 From: Jeremy L Thompson Date: Mon, 9 Sep 2024 12:26:07 -0600 Subject: [PATCH] rust - fix CeedOperatorFieldGet* --- rust/libceed/src/basis.rs | 7 ++ rust/libceed/src/elem_restriction.rs | 7 ++ rust/libceed/src/operator.rs | 129 ++++++++++++++++++--------- 3 files changed, 100 insertions(+), 43 deletions(-) diff --git a/rust/libceed/src/basis.rs b/rust/libceed/src/basis.rs index 4c11fb79b4..b308076ab6 100644 --- a/rust/libceed/src/basis.rs +++ b/rust/libceed/src/basis.rs @@ -173,6 +173,13 @@ impl<'a> Basis<'a> { }) } + pub(crate) fn from_raw(ptr: bind_ceed::CeedBasis) -> crate::Result { + Ok(Self { + ptr, + _lifeline: PhantomData, + }) + } + pub fn create_tensor_H1_Lagrange( ceed: &crate::Ceed, dim: usize, diff --git a/rust/libceed/src/elem_restriction.rs b/rust/libceed/src/elem_restriction.rs index 950a840403..e800f56c1e 100644 --- a/rust/libceed/src/elem_restriction.rs +++ b/rust/libceed/src/elem_restriction.rs @@ -193,6 +193,13 @@ impl<'a> ElemRestriction<'a> { }) } + pub(crate) fn from_raw(ptr: bind_ceed::CeedElemRestriction) -> crate::Result { + Ok(Self { + ptr, + _lifeline: PhantomData, + }) + } + pub fn create_oriented( ceed: &crate::Ceed, nelem: usize, diff --git a/rust/libceed/src/operator.rs b/rust/libceed/src/operator.rs index bd5d7f37f4..a1e5db55cb 100644 --- a/rust/libceed/src/operator.rs +++ b/rust/libceed/src/operator.rs @@ -17,6 +17,9 @@ use crate::prelude::*; #[derive(Debug)] pub struct OperatorField<'a> { pub(crate) ptr: bind_ceed::CeedOperatorField, + pub(crate) vector: crate::Vector<'a>, + pub(crate) elem_restriction: crate::ElemRestriction<'a>, + pub(crate) basis: crate::Basis<'a>, _lifeline: PhantomData<&'a ()>, } @@ -24,6 +27,39 @@ pub struct OperatorField<'a> { // Implementations // ----------------------------------------------------------------------------- impl<'a> OperatorField<'a> { + pub(crate) fn from_raw( + ptr: bind_ceed::CeedOperatorField, + ceed: crate::Ceed, + ) -> crate::Result { + let vector = { + let mut vector_ptr = std::ptr::null_mut(); + let ierr = unsafe { bind_ceed::CeedOperatorFieldGetVector(ptr, &mut vector_ptr) }; + ceed.check_error(ierr)?; + crate::Vector::from_raw(vector_ptr)? + }; + let elem_restriction = { + let mut elem_restriction_ptr = std::ptr::null_mut(); + let ierr = unsafe { + bind_ceed::CeedOperatorFieldGetElemRestriction(ptr, &mut elem_restriction_ptr) + }; + ceed.check_error(ierr)?; + crate::ElemRestriction::from_raw(elem_restriction_ptr)? + }; + let basis = { + let mut basis_ptr = std::ptr::null_mut(); + let ierr = unsafe { bind_ceed::CeedOperatorFieldGetBasis(ptr, &mut basis_ptr) }; + ceed.check_error(ierr)?; + crate::Basis::from_raw(basis_ptr)? + }; + Ok(Self { + ptr, + vector, + elem_restriction, + basis, + _lifeline: PhantomData, + }) + } + /// Get the name of an OperatorField /// /// ``` @@ -110,24 +146,21 @@ impl<'a> OperatorField<'a> { /// inputs[1].elem_restriction().is_none(), /// "Incorrect field ElemRestriction" /// ); + /// + /// let outputs = op.outputs()?; + /// + /// assert!( + /// outputs[0].elem_restriction().is_some(), + /// "Incorrect field ElemRestriction" + /// ); /// # Ok(()) /// # } /// ``` pub fn elem_restriction(&self) -> ElemRestrictionOpt { - let mut ptr = std::ptr::null_mut(); - unsafe { - bind_ceed::CeedOperatorFieldGetElemRestriction(self.ptr, &mut ptr); - } - if ptr == unsafe { bind_ceed::CEED_ELEMRESTRICTION_NONE } { + if self.elem_restriction.ptr == unsafe { bind_ceed::CEED_ELEMRESTRICTION_NONE } { ElemRestrictionOpt::None } else { - let slice = unsafe { - std::slice::from_raw_parts( - &ptr as *const bind_ceed::CeedElemRestriction as *const crate::ElemRestriction, - 1 as usize, - ) - }; - ElemRestrictionOpt::Some(&slice[0]) + ElemRestrictionOpt::Some(&self.elem_restriction) } } @@ -172,20 +205,10 @@ impl<'a> OperatorField<'a> { /// # } /// ``` pub fn basis(&self) -> BasisOpt { - let mut ptr = std::ptr::null_mut(); - unsafe { - bind_ceed::CeedOperatorFieldGetBasis(self.ptr, &mut ptr); - } - if ptr == unsafe { bind_ceed::CEED_BASIS_NONE } { + if self.basis.ptr == unsafe { bind_ceed::CEED_BASIS_NONE } { BasisOpt::None } else { - let slice = unsafe { - std::slice::from_raw_parts( - &ptr as *const bind_ceed::CeedBasis as *const crate::Basis, - 1 as usize, - ) - }; - BasisOpt::Some(&slice[0]) + BasisOpt::Some(&self.basis) } } @@ -222,26 +245,20 @@ impl<'a> OperatorField<'a> { /// /// assert!(inputs[0].vector().is_active(), "Incorrect field Vector"); /// assert!(inputs[1].vector().is_none(), "Incorrect field Vector"); + /// + /// let outputs = op.outputs()?; + /// + /// assert!(outputs[0].vector().is_active(), "Incorrect field Vector"); /// # Ok(()) /// # } /// ``` pub fn vector(&self) -> VectorOpt { - let mut ptr = std::ptr::null_mut(); - unsafe { - bind_ceed::CeedOperatorFieldGetVector(self.ptr, &mut ptr); - } - if ptr == unsafe { bind_ceed::CEED_VECTOR_ACTIVE } { + if self.vector.ptr == unsafe { bind_ceed::CEED_VECTOR_ACTIVE } { VectorOpt::Active - } else if ptr == unsafe { bind_ceed::CEED_VECTOR_NONE } { + } else if self.vector.ptr == unsafe { bind_ceed::CEED_VECTOR_NONE } { VectorOpt::None } else { - let slice = unsafe { - std::slice::from_raw_parts( - &ptr as *const bind_ceed::CeedVector as *const crate::Vector, - 1 as usize, - ) - }; - VectorOpt::Some(&slice[0]) + VectorOpt::Some(&self.vector) } } } @@ -814,7 +831,7 @@ impl<'a> Operator<'a> { /// # Ok(()) /// # } /// ``` - pub fn inputs(&self) -> crate::Result<&[crate::OperatorField]> { + pub fn inputs(&self) -> crate::Result> { // Get array of raw C pointers for inputs let mut num_inputs = 0; let mut inputs_ptr = std::ptr::null_mut(); @@ -831,11 +848,24 @@ impl<'a> Operator<'a> { // Convert raw C pointers to fixed length slice let inputs_slice = unsafe { std::slice::from_raw_parts( - inputs_ptr as *const crate::OperatorField, + inputs_ptr as *mut bind_ceed::CeedOperatorField, num_inputs as usize, ) }; - Ok(inputs_slice) + // And finally build vec + let ceed = { + let mut ptr = std::ptr::null_mut(); + let mut ptr_copy = std::ptr::null_mut(); + unsafe { + bind_ceed::CeedOperatorGetCeed(self.op_core.ptr, &mut ptr); + bind_ceed::CeedReferenceCopy(ptr, &mut ptr_copy); // refcount + } + crate::Ceed { ptr } + }; + let inputs = (0..num_inputs as usize) + .map(|i| crate::OperatorField::from_raw(inputs_slice[i], ceed.clone())) + .collect::>>()?; + Ok(inputs) } /// Get a slice of Operator outputs @@ -873,7 +903,7 @@ impl<'a> Operator<'a> { /// # Ok(()) /// # } /// ``` - pub fn outputs(&self) -> crate::Result<&[crate::OperatorField]> { + pub fn outputs(&self) -> crate::Result> { // Get array of raw C pointers for outputs let mut num_outputs = 0; let mut outputs_ptr = std::ptr::null_mut(); @@ -890,11 +920,24 @@ impl<'a> Operator<'a> { // Convert raw C pointers to fixed length slice let outputs_slice = unsafe { std::slice::from_raw_parts( - outputs_ptr as *const crate::OperatorField, + outputs_ptr as *mut bind_ceed::CeedOperatorField, num_outputs as usize, ) }; - Ok(outputs_slice) + // And finally build vec + let ceed = { + let mut ptr = std::ptr::null_mut(); + let mut ptr_copy = std::ptr::null_mut(); + unsafe { + bind_ceed::CeedOperatorGetCeed(self.op_core.ptr, &mut ptr); + bind_ceed::CeedReferenceCopy(ptr, &mut ptr_copy); // refcount + } + crate::Ceed { ptr } + }; + let outputs = (0..num_outputs as usize) + .map(|i| crate::OperatorField::from_raw(outputs_slice[i], ceed.clone())) + .collect::>>()?; + Ok(outputs) } /// Check if Operator is setup correctly