From b93f07bb3f133a3b6b2af2e08357382858bf8ee3 Mon Sep 17 00:00:00 2001 From: Mohammed Ghannam Date: Sun, 7 Apr 2024 17:45:21 +0200 Subject: [PATCH] Reference counted ScipPtr --- src/model.rs | 7 +++++++ src/scip.rs | 16 ++++++++++------ src/solution.rs | 19 ++++++++++--------- 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/src/model.rs b/src/model.rs index d0180f2..81c2685 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1787,4 +1787,11 @@ mod tests { assert_eq!(second_solved.status(), Status::Optimal); assert!((second_solved.obj_val() - expected_obj).abs() <= 1e-6); } + + #[test] + fn solution_after_model_drop() { + let mut model = create_model(); + let sol = model.solve().best_sol().unwrap(); // Temporary value returned from `model.solve()` is dropped. + dbg!(sol); + } } diff --git a/src/scip.rs b/src/scip.rs index a066c78..83c38be 100644 --- a/src/scip.rs +++ b/src/scip.rs @@ -6,6 +6,7 @@ use crate::{ }; use crate::{scip_call, HeurTiming, Heuristic}; use core::panic; +use std::cell::RefCell; use std::collections::BTreeMap; use std::ffi::{c_int, CStr, CString}; use std::mem::MaybeUninit; @@ -15,7 +16,7 @@ use std::rc::Rc; #[derive(Debug)] pub(crate) struct ScipPtr { pub(crate) raw: *mut ffi::SCIP, - consumed: bool, + uses: Rc>, vars_added_in_solving: Vec<*mut ffi::SCIP_VAR>, } @@ -26,15 +27,17 @@ impl ScipPtr { let scip_ptr = unsafe { scip_ptr.assume_init() }; ScipPtr { raw: scip_ptr, - consumed: false, + uses: Rc::new(RefCell::new(1)), vars_added_in_solving: Vec::new(), } } pub(crate) fn clone(&self) -> Self { + let uses = self.uses.clone(); + *uses.borrow_mut() += 1; ScipPtr { raw: self.raw, - consumed: true, + uses, vars_added_in_solving: Vec::new(), } } @@ -184,7 +187,7 @@ impl ScipPtr { let sol = unsafe { ffi::SCIPgetBestSol(self.raw) }; Solution { - scip_ptr: self.raw, + scip_ptr: self.clone(), raw: sol, } } @@ -489,7 +492,7 @@ impl ScipPtr { scip_call! { ffi::SCIPcreateSol(self.raw, sol.as_mut_ptr(), std::ptr::null_mut()) } let sol = unsafe { sol.assume_init() }; Ok(Solution { - scip_ptr: self.raw, + scip_ptr: self.clone(), raw: sol, }) } @@ -992,7 +995,8 @@ impl ScipPtr { impl Drop for ScipPtr { fn drop(&mut self) { - if self.consumed { + *self.uses.borrow_mut() -= 1; + if *self.uses.borrow() > 0 { return; } // Rust Model struct keeps at most one copy of each variable and constraint pointers diff --git a/src/solution.rs b/src/solution.rs index 9d1f714..03b7079 100644 --- a/src/solution.rs +++ b/src/solution.rs @@ -3,28 +3,29 @@ use std::rc::Rc; use crate::variable::Variable; use crate::{ffi, scip_call_panic}; +use crate::scip::ScipPtr; /// A wrapper for a SCIP solution. -#[derive(PartialEq, Eq)] + pub struct Solution { - pub(crate) scip_ptr: *mut ffi::SCIP, + pub(crate) scip_ptr: ScipPtr, pub(crate) raw: *mut ffi::SCIP_SOL, } impl Solution { /// Returns the objective value of the solution. pub fn obj_val(&self) -> f64 { - unsafe { ffi::SCIPgetSolOrigObj(self.scip_ptr, self.raw) } + unsafe { ffi::SCIPgetSolOrigObj(self.scip_ptr.raw, self.raw) } } /// Returns the value of a variable in the solution. pub fn val(&self, var: Rc) -> f64 { - unsafe { ffi::SCIPgetSolVal(self.scip_ptr, self.raw, var.raw) } + unsafe { ffi::SCIPgetSolVal(self.scip_ptr.raw, self.raw, var.raw) } } /// Sets the value of a variable in the solution. pub fn set_val(&self, var: Rc, val: f64) { - scip_call_panic!(ffi::SCIPsetSolVal(self.scip_ptr, self.raw, var.raw, val)); + scip_call_panic!(ffi::SCIPsetSolVal(self.scip_ptr.raw, self.raw, var.raw, val)); } } @@ -33,12 +34,12 @@ impl fmt::Debug for Solution { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let obj_val = self.obj_val(); writeln!(f, "Solution with obj val: {obj_val}")?; - let vars = unsafe { ffi::SCIPgetVars(self.scip_ptr) }; - let n_vars = unsafe { ffi::SCIPgetNVars(self.scip_ptr) }; + let vars = unsafe { ffi::SCIPgetVars(self.scip_ptr.raw) }; + let n_vars = unsafe { ffi::SCIPgetNVars(self.scip_ptr.raw) }; for i in 0..n_vars { let var = unsafe { *vars.offset(i as isize) }; - let val = unsafe { ffi::SCIPgetSolVal(self.scip_ptr, self.raw, var) }; - let eps = unsafe { ffi::SCIPepsilon(self.scip_ptr) }; + let val = unsafe { ffi::SCIPgetSolVal(self.scip_ptr.raw, self.raw, var) }; + let eps = unsafe { ffi::SCIPepsilon(self.scip_ptr.raw) }; if val > eps || val < -eps { let name_ptr = unsafe { ffi::SCIPvarGetName(var) }; // from CString